225 lines
5.2 KiB
Go
225 lines
5.2 KiB
Go
package lib
|
|
|
|
import (
|
|
http_cel "codeberg.org/gone/http-cel"
|
|
"fmt"
|
|
"github.com/google/cel-go/cel"
|
|
"github.com/google/cel-go/common/ast"
|
|
"github.com/google/cel-go/common/types"
|
|
"github.com/google/cel-go/common/types/ref"
|
|
"log/slog"
|
|
"net"
|
|
)
|
|
|
|
func (state *State) initConditions() (err error) {
|
|
state.programEnv, err = http_cel.NewEnvironment(
|
|
|
|
cel.Variable("fp", cel.MapType(cel.StringType, cel.StringType)),
|
|
cel.Function("inDNSBL",
|
|
cel.Overload("inDNSBL_ip",
|
|
[]*cel.Type{cel.AnyType},
|
|
cel.BoolType,
|
|
cel.UnaryBinding(func(val ref.Val) ref.Val {
|
|
slog.Error("inDNSBL function has been deprecated, replace with dnsbl challenge")
|
|
return types.Bool(false)
|
|
}),
|
|
),
|
|
),
|
|
|
|
cel.Function("network",
|
|
cel.MemberOverload("netIP_network_string",
|
|
[]*cel.Type{cel.BytesType, cel.StringType},
|
|
cel.BoolType,
|
|
cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val {
|
|
var ip net.IP
|
|
switch v := lhs.Value().(type) {
|
|
case []byte:
|
|
ip = v
|
|
case net.IP:
|
|
ip = v
|
|
}
|
|
|
|
if ip == nil {
|
|
panic(fmt.Errorf("invalid ip %v", lhs.Value()))
|
|
}
|
|
|
|
val, ok := rhs.Value().(string)
|
|
if !ok {
|
|
panic(fmt.Errorf("invalid network value %v", rhs.Value()))
|
|
}
|
|
|
|
network, ok := state.networks[val]
|
|
if !ok {
|
|
_, ipNet, err := net.ParseCIDR(val)
|
|
if err != nil {
|
|
panic("network not found")
|
|
}
|
|
return types.Bool(ipNet.Contains(ip))
|
|
} else {
|
|
ok, err := network().Contains(ip)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return types.Bool(ok)
|
|
}
|
|
}),
|
|
),
|
|
),
|
|
|
|
cel.Function("inNetwork",
|
|
cel.Overload("inNetwork_string_ip",
|
|
[]*cel.Type{cel.StringType, cel.BytesType},
|
|
cel.BoolType,
|
|
cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val {
|
|
var ip net.IP
|
|
switch v := rhs.Value().(type) {
|
|
case []byte:
|
|
ip = v
|
|
case net.IP:
|
|
ip = v
|
|
}
|
|
|
|
if ip == nil {
|
|
panic(fmt.Errorf("invalid ip %v", rhs.Value()))
|
|
}
|
|
|
|
val, ok := lhs.Value().(string)
|
|
if !ok {
|
|
panic(fmt.Errorf("invalid value %v", lhs.Value()))
|
|
}
|
|
slog.Debug(fmt.Sprintf("inNetwork function has been deprecated and will be removed in a future release, use remoteAddress.network(\"%s\") instead", val))
|
|
|
|
network, ok := state.networks[val]
|
|
if !ok {
|
|
_, ipNet, err := net.ParseCIDR(val)
|
|
if err != nil {
|
|
panic("network not found")
|
|
}
|
|
return types.Bool(ipNet.Contains(ip))
|
|
} else {
|
|
ok, err := network().Contains(ip)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return types.Bool(ok)
|
|
}
|
|
}),
|
|
),
|
|
),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (state *State) RegisterCondition(operator string, conditions ...string) (cel.Program, error) {
|
|
compiledAst, err := http_cel.NewAst(state.ProgramEnv(), operator, conditions...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if out := compiledAst.OutputType(); out == nil {
|
|
return nil, fmt.Errorf("no output")
|
|
} else if out != types.BoolType {
|
|
return nil, fmt.Errorf("output type is not bool")
|
|
}
|
|
|
|
walkExpr(compiledAst.NativeRep().Expr(), func(e ast.Expr) {
|
|
if e.Kind() == ast.CallKind {
|
|
call := e.AsCall()
|
|
switch call.FunctionName() {
|
|
// deprecated
|
|
case "inNetwork":
|
|
args := call.Args()
|
|
if !call.IsMemberFunction() && len(args) == 2 {
|
|
// we have a network select function
|
|
switch args[1].Kind() {
|
|
case ast.LiteralKind:
|
|
lit := args[1].AsLiteral()
|
|
if lit.Type() == types.StringType {
|
|
if fn, ok := state.networks[lit.Value().(string)]; ok {
|
|
// preload
|
|
fn()
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
case "network":
|
|
args := call.Args()
|
|
if call.IsMemberFunction() && len(args) == 1 {
|
|
// we have a network select function
|
|
switch args[0].Kind() {
|
|
case ast.LiteralKind:
|
|
lit := args[0].AsLiteral()
|
|
if lit.Type() == types.StringType {
|
|
if fn, ok := state.networks[lit.Value().(string)]; ok {
|
|
// preload
|
|
fn()
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
return http_cel.ProgramAst(state.ProgramEnv(), compiledAst)
|
|
}
|
|
|
|
func walkExpr(e ast.Expr, fn func(ast.Expr)) {
|
|
fn(e)
|
|
|
|
switch e.Kind() {
|
|
case ast.CallKind:
|
|
ee := e.AsCall()
|
|
walkExpr(ee.Target(), fn)
|
|
for _, arg := range ee.Args() {
|
|
walkExpr(arg, fn)
|
|
}
|
|
case ast.ComprehensionKind:
|
|
ee := e.AsComprehension()
|
|
walkExpr(ee.Result(), fn)
|
|
walkExpr(ee.IterRange(), fn)
|
|
walkExpr(ee.AccuInit(), fn)
|
|
walkExpr(ee.LoopCondition(), fn)
|
|
walkExpr(ee.LoopStep(), fn)
|
|
case ast.ListKind:
|
|
ee := e.AsList()
|
|
for _, element := range ee.Elements() {
|
|
walkExpr(element, fn)
|
|
}
|
|
case ast.MapKind:
|
|
ee := e.AsMap()
|
|
for _, entry := range ee.Entries() {
|
|
switch entry.Kind() {
|
|
case ast.MapEntryKind:
|
|
eee := entry.AsMapEntry()
|
|
walkExpr(eee.Key(), fn)
|
|
walkExpr(eee.Value(), fn)
|
|
case ast.StructFieldKind:
|
|
eee := entry.AsStructField()
|
|
walkExpr(eee.Value(), fn)
|
|
}
|
|
}
|
|
case ast.SelectKind:
|
|
ee := e.AsSelect()
|
|
walkExpr(ee.Operand(), fn)
|
|
case ast.StructKind:
|
|
ee := e.AsStruct()
|
|
for _, field := range ee.Fields() {
|
|
switch field.Kind() {
|
|
case ast.MapEntryKind:
|
|
eee := field.AsMapEntry()
|
|
walkExpr(eee.Key(), fn)
|
|
walkExpr(eee.Value(), fn)
|
|
case ast.StructFieldKind:
|
|
eee := field.AsStructField()
|
|
walkExpr(eee.Value(), fn)
|
|
}
|
|
}
|
|
}
|
|
}
|