跳到主要内容

拦截器(Interceptors)

拦截器是用于各类 Ent 查询的执行中间件。与钩子(hooks)不同,拦截器作用于读取路径,并以接口形式实现,允许它们在查询的不同阶段进行拦截和修改,从而对查询行为提供更精细的控制。例如,可参见下文中的遍历器接口

定义拦截器

要定义 Interceptor,用户可以声明一个实现了 Intercept 方法的结构体,或使用预定义的 ent.InterceptFunc 适配器。

ent.InterceptFunc(func(next ent.Querier) ent.Querier {
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
// 在查询执行前执行某些操作。
value, err := next.Query(ctx, query)
// 在查询执行后执行某些操作。
return value, err
})
})

在上面的示例中,ent.Query 表示生成的查询构建器(例如 ent.<T>Query),访问其方法需要进行类型断言。例如:

ent.InterceptFunc(func(next ent.Querier) ent.Querier {
return ent.QuerierFunc(func(ctx context.Context, query ent.Query) (ent.Value, error) {
if q, ok := query.(*ent.UserQuery); ok {
q.Where(user.Name("a8m"))
}
return next.Query(ctx, query)
})
})

然而,由 intercept 功能标志生成的实用程序支持创建可应用于任何查询类型的通用拦截器。intercept 功能标志可通过以下两种方式之一添加到项目中:

配置

如果使用默认的 go generate 配置,请在 ent/generate.go 文件中添加 --feature intercept 选项,如下所示:

ent/generate.go
package ent

//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature intercept ./schema

建议同时添加 schema/snapshot 功能标志和 intercept 标志以提升开发体验,例如:

//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature intercept,schema/snapshot ./schema

拦截器注册

信息

请注意,与 schema hooks 类似,如果在模式中使用了 Interceptors 选项,则必须在主包中添加以下导入,因为模式包和生成的 ent 包之间可能存在循环导入:

import _ "<project>/ent/runtime"

使用生成的 intercept

将功能标志添加到项目后,即可使用 intercept 包创建拦截器:

client.Intercept(
intercept.Func(func(ctx context.Context, q intercept.Query) error {
// 将所有查询限制为 1000 条记录。
q.Limit(1000)
return nil
})
)

定义遍历器

在某些情况下,需要拦截图遍历并在继续处理查询返回的节点之前修改其构建器。例如,在下面的查询中,我们希望确保系统中任何图遍历仅遍历 active 用户:

intercept.TraverseUser(func(ctx context.Context, q *ent.UserQuery) error {
q.Where(user.Active(true))
return nil
})

定义并注册此类遍历器后,它将在系统中的所有图遍历中生效。例如:

func TestTypedTraverser(t *testing.T) {
ctx := context.Background()
client := enttest.Open(t, dialect.SQLite, "file:ent?mode=memory&_fk=1")
defer client.Close()
a8m, nat := client.User.Create().SetName("a8m").SaveX(ctx), client.User.Create().SetName("nati").SetActive(false).SaveX(ctx)
client.Pet.CreateBulk(
client.Pet.Create().SetName("a").SetOwner(a8m),
client.Pet.Create().SetName("b").SetOwner(a8m),
client.Pet.Create().SetName("c").SetOwner(nat),
).ExecX(ctx)

// 获取所有用户的宠物。
if n := client.User.Query().QueryPets().CountX(ctx); n != 3 {
t.Errorf("got %d pets, want 3", n)
}

// 添加过滤非活跃用户的拦截器。
client.User.Intercept(
intercept.TraverseUser(func(ctx context.Context, q *ent.UserQuery) error {
q.Where(user.Active(true))
return nil
}),
)

// 仅返回活跃用户的宠物。
if n := client.User.Query().QueryPets().CountX(ctx); n != 2 {
t.Errorf("got %d pets, want 2", n)
}
}

拦截器 vs. 遍历器

InterceptorsTraversers 都可用于修改查询行为,但它们在执行的不同阶段起作用。拦截器作为中间件,允许在查询执行前修改查询,并在记录从数据库返回后修改记录。因此,它们仅应用于查询的最终阶段——即在数据库上实际执行语句期间。另一方面,遍历器在更早的阶段被调用,在图遍历的每一步,允许它们在中间查询和最终查询连接在一起之前进行修改。

总之,遍历函数更适合为图遍历添加默认过滤器,而拦截函数更适合为应用程序实现日志记录或缓存功能。

client.User.Query().
QueryGroups(). // 应用用户遍历函数。
QueryPosts(). // 应用群组遍历函数。
All(ctx) // 应用帖子遍历和拦截函数。

示例

软删除

软删除模式是拦截器和钩子的常见用例。下面的示例演示如何使用 ent.Mixin 将此类功能添加到项目中的所有模式:

// SoftDeleteMixin 为模式实现软删除模式。
type SoftDeleteMixin struct {
mixin.Schema
}

// SoftDeleteMixin 的字段。
func (SoftDeleteMixin) Fields() []ent.Field {
return []ent.Field{
field.Time("delete_time").
Optional(),
}
}

type softDeleteKey struct{}

// SkipSoftDelete 返回一个跳过软删除拦截器/变更器的新上下文。
func SkipSoftDelete(parent context.Context) context.Context {
return context.WithValue(parent, softDeleteKey{}, true)
}

// SoftDeleteMixin 的拦截器。
func (d SoftDeleteMixin) Interceptors() []ent.Interceptor {
return []ent.Interceptor{
intercept.TraverseFunc(func(ctx context.Context, q intercept.Query) error {
// 跳过软删除,即包含软删除的实体。
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
return nil
}
d.P(q)
return nil
}),
}
}

// SoftDeleteMixin 的钩子。
func (d SoftDeleteMixin) Hooks() []ent.Hook {
return []ent.Hook{
hook.On(
func(next ent.Mutator) ent.Mutator {
return ent.MutateFunc(func(ctx context.Context, m ent.Mutation) (ent.Value, error) {
// 跳过软删除,即永久删除实体。
if skip, _ := ctx.Value(softDeleteKey{}).(bool); skip {
return next.Mutate(ctx, m)
}
mx, ok := m.(interface {
SetOp(ent.Op)
Client() *gen.Client
SetDeleteTime(time.Time)
WhereP(...func(*sql.Selector))
})
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
d.P(mx)
mx.SetOp(ent.OpUpdate)
mx.SetDeleteTime(time.Now())
return mx.Client().Mutate(ctx, m)
})
},
ent.OpDeleteOne|ent.OpDelete,
),
}
}

// P 向查询和变更添加存储层谓词。
func (d SoftDeleteMixin) P(w interface{ WhereP(...func(*sql.Selector)) }) {
w.WhereP(
sql.FieldIsNull(d.Fields()[0].Descriptor().Name),
)
}

限制记录数量

以下示例演示如何使用拦截器函数限制从数据库返回的记录数量:

client.Intercept(
intercept.Func(func(ctx context.Context, q intercept.Query) error {
// LimitInterceptor 将从数据库返回的记录数限制为 1000,
// 前提是未显式设置 Limit。
if ent.QueryFromContext(ctx).Limit == nil {
q.Limit(1000)
}
return nil
}),
)

多项目支持

下面的示例演示如何编写可在多个项目中使用的通用拦截器:

// 项目级示例。使用 "entgo" 包强调该拦截器不依赖任何生成的代码。
func SharedLimiter[Q interface{ Limit(int) }](f func(entgo.Query) (Q, error), limit int) entgo.Interceptor {
return entgo.InterceptFunc(func(next entgo.Querier) entgo.Querier {
return entgo.QuerierFunc(func(ctx context.Context, query entgo.Query) (entgo.Value, error) {
l, err := f(query)
if err != nil {
return nil, err
}
l.Limit(limit)
// LimitInterceptor 将从数据库返回的记录数限制为配置的值,
// 前提是未显式设置 Limit。
if entgo.QueryFromContext(ctx).Limit == nil {
l.Limit(limit)
}
return next.Query(ctx, query)
})
})
}