常见问题解答 (FAQ)
问题
How to create an entity from a struct T?
How to create a struct (or a mutation) level validator?
How to write an audit-log extension?
How to write custom predicates?
How to add custom predicates to the codegen assets?
How to define a network address field in PostgreSQL?
How to customize time fields to type DATETIME in MySQL?
How to use a custom generator of IDs?
How to use a custom XID globally unique ID?
How to define a spatial data type field in MySQL?
How to extend the generated models?
How to extend the generated builders?
How to store Protobuf objects in a BLOB column?
How to add CHECK constraints to table?
How to define a custom precision numeric field?
How to configure two or more DB to separate read and write?
How to configure json.Marshal to inline the edges keys in the top level object?
回答
如何从 T 结构体创建实体?
不同的构建器不支持从给定的结构体 T 设置实体字段(或边)。原因是,在更新数据库时无法区分零值和真实值(例如 &ent.T{Age: 0, Name: ""})。设置这些值可能会在数据库中写入错误值或更新不必要的列。
然而, external template 选项允许你通过添加自定义逻辑来扩展默认的代码生成资源。例如,为了为每个创建构建器生成一个方法,该方法接受一个结构体作为输入并配置构建器,请使用以下模板:
{{ range $n := $.Nodes }}
{{ $builder := $n.CreateName }}
{{ $receiver := $n.CreateReceiver }}
func ({{ $receiver }} *{{ $builder }}) Set{{ $n.Name }}(input *{{ $n.Name }}) *{{ $builder }} {
{{- range $f := $n.Fields }}
{{- $setter := print "Set" $f.StructField }}
{{ $receiver }}.{{ $setter }}(input.{{ $f.StructField }})
{{- end }}
return {{ $receiver }}
}
{{ end }}
如何创建 mutation 级别的验证器?
要实现一个 mutation 级别的验证器,你可以使用 schema hooks 来验证单个实体类型上的更改,或者使用 transaction hooks 来验证在多个实体类型上执行的 mutation(例如 GraphQL mutation)。例如:
// A VersionHook is a dummy example for a hook that validates the "version" field
// is incremented by 1 on each update. Note that this is just a dummy example, and
// it doesn't promise consistency in the database.
func VersionHook() ent.Hook {
type OldSetVersion interface {
SetVersion(int)
Version() (int, bool)
OldVersion(context.Context) (int, error)
}
return func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
ver, ok := m.(OldSetVersion)
if !ok {
return next.Mutate(ctx, m)
}
oldV, err := ver.OldVersion(ctx)
if err != nil {
return nil, err
}
curV, exists := ver.Version()
if !exists {
return nil, fmt.Errorf("version field is required in update mutation")
}
if curV != oldV+1 {
return nil, fmt.Errorf("version field must be incremented by 1")
}
// Add an SQL predicate that validates the "version" column is equal
// to "oldV" (ensure it wasn't changed during the mutation by others).
return next.Mutate(ctx, m)
})
}
}
如何编写审计日志扩展?
编写此类扩展的首选方式是使用 ent.Mixin。使用 Fields 选项来设置所有导入混合模式的模式共享的字段,并使用 Hooks 选项为在这些模式上应用的所有 mutation 附加一个 mutation-hook。以下是基于 仓库 issue-tracker 的讨论得到的示例:
// AuditMixin implements the ent.Mixin for sharing
// audit-log capabilities with package schemas.
type AuditMixin struct{
mixin.Schema
}
// Fields of the AuditMixin.
func (AuditMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("created_at").
Immutable().
Default(time.Now),
field.Int("created_by").
Optional(),
field.Time("updated_at").
Default(time.Now).
UpdateDefault(time.Now),
field.Int("updated_by").
Optional(),
}
}
// Hooks of the AuditMixin.
func (AuditMixin) Hooks() []ent.Hook {
return []ent.Hook{
hooks.AuditHook,
}
}
// A AuditHook is an example for audit-log hook.
func AuditHook(next ent.Mutator) ent.Mutator {
// AuditLogger wraps the methods that are shared between all mutations of
// schemas that embed the AuditLog mixin. The variable "exists" is true, if
// the field already exists in the mutation (e.g. was set by a different hook).
type AuditLogger interface {
SetCreatedAt(time.Time)
CreatedAt() (value time.Time, exists bool)
SetCreatedBy(int)
CreatedBy() (id int, exists bool)
SetUpdatedAt(time.Time)
UpdatedAt() (value time.Time, exists bool)
SetUpdatedBy(int)
UpdatedBy() (id int, exists bool)
}
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
ml, ok := m.(AuditLogger)
if !ok {
return nil, fmt.Errorf("unexpected audit-log call from mutation type %T", m)
}
usr, err := viewer.UserFromContext(ctx)
if err != nil {
return nil, err
}
switch op := m.Op(); {
case op.Is(ent.OpCreate):
ml.SetCreatedAt(time.Now())
if _, exists := ml.CreatedBy(); !exists {
ml.SetCreatedBy(usr.ID)
}
case op.Is(ent.OpUpdateOne | ent.OpUpdate):
ml.SetUpdatedAt(time.Now())
if _, exists := ml.UpdatedBy(); !exists {
ml.SetUpdatedBy(usr.ID)
}
}
return next.Mutate(ctx, m)
})
}
如何编写自定义谓词?
用户可以在查询执行前提供自定义谓词。例如:
pets := client.Pet.
Query().
Where(predicate.Pet(func(s *sql.Selector) {
s.Where(sql.InInts(pet.OwnerColumn, 1, 2, 3))
})).
AllX(ctx)
users := client.User.
Query().
Where(predicate.User(func(s *sql.Selector) {
s.Where(sqljson.ValueContains(user.FieldTags, "tag"))
})).
AllX(ctx)
如需了解更多示例,请访问 predicates 页面,或在仓库 issue-tracker 中搜索更高级的示例,如 issue-842。
如何将自定义谓词添加到代码生成资产?
template 选项允许通过扩展或覆盖默认的代码生成资产来实现此功能。为了为 上例 生成类型安全的谓词,请按如下方式使用模板选项:
{{/* A template that adds the "<F>Glob" predicate for all string fields. */}}
{{ define "where/additional/strings" }}
{{ range $f := $.Fields }}
{{ if $f.IsString }}
{{ $func := print $f.StructField "Glob" }}
// {{ $func }} applies the Glob predicate on the {{ quote $f.Name }} field.
func {{ $func }}(pattern string) predicate.{{ $.Name }} {
return predicate.{{ $.Name }}(func(s *sql.Selector) {
s.Where(sql.P(func(b *sql.Builder) {
b.Ident(s.C({{ $f.Constant }})).WriteString(" glob" ).Arg(pattern)
}))
})
}
{{ end }}
{{ end }}
{{ end }}
如何在 PostgreSQL 中定义网络地址字段?
GoType 与 SchemaType 选项允许用户定义数据库特定字段。例如,为了定义 macaddr 字段,请使用以下配置:
func (T) Fields() []ent.Field {
return []ent.Field{
field.String("mac").
GoType(&MAC{}).
SchemaType(map[string]string{
dialect.Postgres: "macaddr",
}).
Validate(func(s string) error {
_, err := net.ParseMAC(s)
return err
}),
}
}
// MAC represents a physical hardware address.
type MAC struct {
net.HardwareAddr
}
// Scan implements the Scanner interface.
func (m *MAC) Scan(value any) (err error) {
switch v := value.(type) {
case nil:
case []byte:
m.HardwareAddr, err = net.ParseMAC(string(v))
case string:
m.HardwareAddr, err = net.ParseMAC(v)
default:
err = fmt.Errorf("unexpected type %T", v)
}
return
}
// Value implements the driver Valuer interface.
func (m MAC) Value() (driver.Value, error) {
return m.HardwareAddr.String(), nil
}
如果数据库不支持 macaddr 类型(例如测试中的 SQLite),则该字段会回退到其本机类型(即 string)。
inet 示例:
func (T) Fields() []ent.Field {
return []ent.Field{
field.String("ip").
GoType(&Inet{}).
SchemaType(map[string]string{
dialect.Postgres: "inet",
}).
Validate(func(s string) error {
if net.ParseIP(s) == nil {
return fmt.Errorf("invalid value for ip %q", s)
}
return nil
}),
}
}
// Inet represents a single IP address
type Inet struct {
net.IP
}
// Scan implements the Scanner interface
func (i *Inet) Scan(value any) (err error) {
switch v := value.(type) {
case nil:
case []byte:
if i.IP = net.ParseIP(string(v)); i.IP == nil {
err = fmt.Errorf("invalid value for ip %q", v)
}
case string:
if i.IP = net.ParseIP(v); i.IP == nil {
err = fmt.Errorf("invalid value for ip %q", v)
}
default:
err = fmt.Errorf("unexpected type %T", v)
}
return
}
// Value implements the driver Valuer interface
func (i Inet) Value() (driver.Value, error) {
return i.IP.String(), nil
}
如何将时间字段自定义为 MySQL 的 DATETIME 类型?
Time 字段在默认情况下使用 MySQL TIMESTAMP 类型,范围为 '1970-01-01 00:00:01' UTC 到 '2038-01-19 03:14:07' UTC(参见 MySQL docs)。
若想自定义时间字段为更宽范围,可使用 MySQL DATETIME:
field.Time("birth_date").
Optional().
SchemaType(map[string]string{
dialect.MySQL: "datetime",
}),
如何使用自定义 ID 生成器?
如果你使用自定义 ID 生成器而不是在数据库中使用自增 ID(例如 Twitter 的 Snowflake),你需要编写一个自定义 ID 字段,该字段在资源创建时自动调用生成器。
实现方式可以使用 DefaultFunc 或 schema 钩子,取决于使用场景。如果生成器不返回错误,DefaultFunc 更简洁;若想在资源创建时捕获错误,则使用 hook 更合适。如何使用 DefaultFunc 的示例见 ID 字段。
下面是使用 hook 的示例,以 sonyflake 为例:
// BaseMixin to be shared will all different schemas.
type BaseMixin struct {
mixin.Schema
}
// Fields of the Mixin.
func (BaseMixin) Fields() []ent.Field {
return []ent.Field{
field.Uint64("id"),
}
}
// Hooks of the Mixin.
func (BaseMixin) Hooks() []ent.Hook {
return []ent.Hook{
hook.On(IDHook(), ent.OpCreate),
}
}
func IDHook() ent.Hook {
sf := sonyflake.NewSonyflake(sonyflake.Settings{})
type IDSetter interface {
SetID(uint64)
}
return func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
is, ok := m.(IDSetter)
if !ok {
return nil, fmt.Errorf("unexpected mutation %T", m)
}
id, err := sf.NextID()
if err != nil {
return nil, err
}
is.SetID(id)
return next.Mutate(ctx, m)
})
}
}
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Mixin of the User.
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{
// Embed the BaseMixin in the user schema.
BaseMixin{},
}
}
如何使用自定义 XID 全局唯一 ID?
包 xid 是一个全局唯一 ID 生成库,使用 Mongo Object ID 算法生成 12 字节、20 个字符的 ID,且无需配置。xid 包已实现 Ent 所需的 database/sql sql.Scanner 和 driver.Valuer 接口。
若要在任何字符串字段中存储 XID,请使用 GoType 设置:
// Fields of type T.
func (T) Fields() []ent.Field {
return []ent.Field{
field.String("id").
GoType(xid.ID{}).
DefaultFunc(xid.New),
}
}
或者,以可复用的 Mixin 方式跨多个模式使用:
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
"github.com/rs/xid"
)
// BaseMixin to be shared will all different schemas.
type BaseMixin struct {
mixin.Schema
}
// Fields of the User.
func (BaseMixin) Fields() []ent.Field {
return []ent.Field{
field.String("id").
GoType(xid.ID{}).
DefaultFunc(xid.New),
}
}
// User holds the schema definition for the User entity.
type User struct {
ent.Schema
}
// Mixin of the User.
func (User) Mixin() []ent.Mixin {
return []ent.Mixin{
// Embed the BaseMixin in the user schema.
BaseMixin{},
}
}
如需在 gqlgen 中使用扩展的 ID(XIDs),请参阅 issue tracker 中的配置说明。
如何在 MySQL 中定义空间数据类型字段?
GoType 与 SchemaType 选项允许用户定义数据库特定字段。例如,为了定义 POINT 字段,请使用以下配置:
// Fields of the Location.
func (Location) Fields() []ent.Field {
return []ent.Field{
field.String("name"),
field.Other("coords", &Point{}).
SchemaType(Point{}.SchemaType()),
}
}
package schema
import (
"database/sql/driver"
"fmt"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"github.com/paulmach/orb"
"github.com/paulmach/orb/encoding/wkb"
)
// A Point consists of (X,Y) or (Lat, Lon) coordinates
// and it is stored in MySQL the POINT spatial data type.
type Point [2]float64
// Scan implements the Scanner interface.
func (p *Point) Scan(value any) error {
bin, ok := value.([]byte)
if !ok {
return fmt.Errorf("invalid binary value for point")
}
var op orb.Point
if err := wkb.Scanner(&op).Scan(bin[4:]); err != nil {
return err
}
p[0], p[1] = op.X(), op.Y()
return nil
}
// Value implements the driver Valuer interface.
func (p Point) Value() (driver.Value, error) {
op := orb.Point{p[0], p[1]}
return wkb.Value(op).Value()
}
// FormatParam implements the sql.ParamFormatter interface to tell the SQL
// builder that the placeholder for a Point parameter needs to be formatted.
func (p Point) FormatParam(placeholder string, info *sql.StmtInfo) string {
if info.Dialect == dialect.MySQL {
return "ST_GeomFromWKB(" + placeholder + ")"
}
return placeholder
}
// SchemaType defines the schema-type of the Point object.
func (Point) SchemaType() map[string]string {
return map[string]string{
dialect.MySQL: "POINT",
}
}
完整示例请参阅 example repository。
如何扩展生成的模型?
Ent 支持使用自定义模板扩展生成的类型(全局类型和模型)。例如,为了为生成的模型添加附加结构体字段或方法,可以像下面的示例一样覆盖 model/fields/additional 模板:
{{# 例子请参考 https://github.com/ent/ent/blob/dd4792f5b30bdd2db0d9a593a977a54cb3f0c1ce/examples/entcpkg/ent/template/static.tmpl }}
若自定义字段/方法需要额外导入,可以通过自定义模板添加这些导入:
{{- define "import/additional/field_types" -}}
"github.com/path/to/your/custom/type"
{{- end -}}
{{- define "import/additional/client_dependencies" -}}
"github.com/path/to/your/custom/type"
{{- end -}}
如何扩展生成的构建器?
参见 注入外部依赖 部分,或查看 GitHub 示例:entcpkg。
如何在 BLOB 列中存储 Protobuf 对象?
假设我们有一个已定义的 Protobuf 消息:
syntax = "proto3";
package pb;
option go_package = "project/pb";
message Hi {
string Greeting = 1;
}
我们为生成的 Protobuf 结构体添加接收者方法,使其实现 ValueScanner:
func (x *Hi) Value() (driver.Value, error) {
return proto.Marshal(x)
}
func (x *Hi) Scan(src any) error {
if src == nil {
return nil
}
if b, ok := src.([]byte); ok {
if err := proto.Unmarshal(b, x); err != nil {
return err
}
return nil
}
return fmt.Errorf("unexpected type %T", src)
}
在模式中添加一个新的 field.Bytes,并将生成的 Protobuf 结构体设为其底层 GoType:
// Fields of the Message.
func (Message) Fields() []ent.Field {
return []ent.Field{
field.Bytes("hi").
GoType(&pb.Hi{}),
}
}
测试验证正常工作:
package main
import (
"context"
"testing"
"project/ent/enttest"
"project/pb"
_ "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func TestMain(t *testing.T) {
client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1")
defer client.Close()
msg := client.Message.Create().
SetHi(&pb.Hi{
Greeting: "hello",
}).
SaveX(context.TODO())
ret := client.Message.GetX(context.TODO(), msg.ID)
require.Equal(t, "hello", ret.Hi.Greeting)
}
如何给表添加 CHECK 约束?
entsql.Annotation 选项允许在 CREATE TABLE 语句中添加自定义 CHECK 约束。要为架构添加 CHECK 约束,请参照以下示例:
func (User) Annotations() []schema.Annotation {
return []schema.Annotation{
&entsql.Annotation{
// The `Check` option allows adding an
// unnamed CHECK constraint to table DDL.
Check: "website <> 'entgo.io'",
// The `Checks` option allows adding multiple CHECK constraints
// to table creation. The keys are used as the constraint names.
Checks: map[string]string{
"valid_nickname": "nickname <> firstname",
"valid_firstname": "length(first_name) > 1",
},
},
}
}
如何定义自定义精度数值字段?
通过 GoType 与 SchemaType 可以定义自定义精度数值字段。例如,定义使用 big.Int 的字段:
func (T) Fields() []ent.Field {
return []ent.Field{
field.Int("precise").
GoType(new(BigInt)).
SchemaType(map[string]string{
dialect.SQLite: "numeric(78, 0)",
dialect.Postgres: "numeric(78, 0)",
}),
}
}
type BigInt struct {
big.Int
}
func (b *BigInt) Scan(src any) error {
var i sql.NullString
if err := i.Scan(src); err != nil {
return err
}
if !i.Valid {
return nil
}
if _, ok := b.Int.SetString(i.String, 10); ok {
return nil
}
return fmt.Errorf("could not scan type %T with value %v into BigInt", src, src)
}
func (b *BigInt) Value() (driver.Value, error) {
return b.String(), nil
}
如何配置两个或更多 DB 以分离读写?
你可以用自己的驱动包装 dialect.Driver 并实现此逻辑。例如:
func main() {
// ...
wd, err := sql.Open(dialect.MySQL, "root:pass@tcp(<addr>)/<database>?parseTime=True")
if err != nil {
log.Fatal(err)
}
rd, err := sql.Open(dialect.MySQL, "readonly:pass@tcp(<addr>)/<database>?parseTime=True")
if err != nil {
log.Fatal(err)
}
client := ent.NewClient(ent.Driver(&multiDriver{w: wd, r: rd}))
defer client.Close()
// Use the client here.
}
type multiDriver struct {
r, w dialect.Driver
}
var _ dialect.Driver = (*multiDriver)(nil)
func (d *multiDriver) Query(ctx context.Context, query string, args, v any) error {
e := d.r
// Mutation statements that use the RETURNING clause.
if ent.QueryFromContext(ctx) == nil {
e = d.w
}
return e.Query(ctx, query, args, v)
}
func (d *multiDriver) Exec(ctx context.Context, query string, args, v any) error {
return d.w.Exec(ctx, query, args, v)
}
func (d *multiDriver) Tx(ctx context.Context) (dialect.Tx, error) {
return d.w.Tx(ctx)
}
func (d *multiDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) {
return d.w.(interface {
BeginTx(context.Context, *sql.TxOptions) (dialect.Tx, error)
}).BeginTx(ctx, opts)
}
func (d *multiDriver) Close() error {
rerr := d.r.Close()
werr := d.w.Close()
if rerr != nil {
return rerr
}
if werr != nil {
return werr
}
return nil
}
func (d *multiDriver) Dialect() string {
return d.r.Dialect()
}
如何配置 json.Marshal 以将 edges 键内联到顶层对象?
要在不包含 edges 属性的情况下编码实体,用户可以按以下两步操作:
- 丢弃 Ent 生成的默认
edges标记。 - 使用自定义
MarshalJSON方法扩展生成的模型。
这两步可以通过 codegen extensions 自动完成,完整的工作示例位于 examples/jsonencode 目录。
//go:build ignore
// +build ignore
package main
import (
"log"
"entgo.io/ent/entc"
"entgo.io/ent/entc/gen"
"entgo.io/ent/schema/edge"
)
func main() {
opts := []entc.Option{
entc.Extensions{
&EncodeExtension{},
),
}
err := entc.Generate("./schema", &gen.Config{}, opts...)
if err != nil {
log.Fatalf("running ent codegen: %v", err)
}
}
// EncodeExtension is an implementation of entc.Extension that adds a MarshalJSON
// method to each generated type <T> and inlines the Edges field to the top level JSON.
type EncodeExtension struct {
entc.DefaultExtension
}
// Templates of the extension.
func (e *EncodeExtension) Templates() []*gen.Template {
return []*gen.Template{
gen.MustParse(gen.NewTemplate("model/additional/jsonencode").
Parse(`
{{ if $.Edges }}
// MarshalJSON implements the json.Marshaler interface.
func ({{ $.Receiver }} *{{ $.Name }}) MarshalJSON() ([]byte, error) {
type Alias {{ $.Name }}
return json.Marshal(&struct {
*Alias
{{ $.Name }}Edges
}{
Alias: (*Alias)({{ $.Receiver }}),
{{ $.Name }}Edges: {{ $.Receiver }}.Edges,
})
}
{{ end }}
`)),
}
}
// Hooks of the extension.
func (e *EncodeExtension) Hooks() []gen.Hook {
return []gen.Hook{
func(next gen.Generator) gen.Generator {
return gen.GenerateFunc(func(g *gen.Graph) error {
tag := edge.Annotation{StructTag: `json:"-"`}
for _, n := range g.Nodes {
n.Annotations.Set(tag.Name(), tag)
}
return next.Generate(g)
})
},
}
}