condition: generalize AST compilation, hot load network prefix blocks as needed, walk the AST and detect and preload networks

This commit is contained in:
WeebDataHoarder
2025-05-01 02:35:27 +02:00
parent 6e47cec540
commit d6c29846df
6 changed files with 223 additions and 132 deletions

View File

@@ -1,37 +1,37 @@
networks: networks:
# aws-cloud: aws-cloud:
# - url: https://ip-ranges.amazonaws.com/ip-ranges.json - url: https://ip-ranges.amazonaws.com/ip-ranges.json
# jq-path: '(.prefixes[] | select(has("ip_prefix")) | .ip_prefix), (.prefixes[] | select(has("ipv6_prefix")) | .ipv6_prefix)' jq-path: '(.prefixes[] | select(has("ip_prefix")) | .ip_prefix), (.prefixes[] | select(has("ipv6_prefix")) | .ipv6_prefix)'
# google-cloud: google-cloud:
# - url: https://www.gstatic.com/ipranges/cloud.json - url: https://www.gstatic.com/ipranges/cloud.json
# jq-path: '(.prefixes[] | select(has("ipv4Prefix")) | .ipv4Prefix), (.prefixes[] | select(has("ipv6Prefix")) | .ipv6Prefix)' jq-path: '(.prefixes[] | select(has("ipv4Prefix")) | .ipv4Prefix), (.prefixes[] | select(has("ipv6Prefix")) | .ipv6Prefix)'
# oracle-cloud: oracle-cloud:
# - url: https://docs.oracle.com/en-us/iaas/tools/public_ip_ranges.json - url: https://docs.oracle.com/en-us/iaas/tools/public_ip_ranges.json
# jq-path: '.regions[] | .cidrs[] | .cidr' jq-path: '.regions[] | .cidrs[] | .cidr'
# azure-cloud: azure-cloud:
# # todo: https://www.microsoft.com/en-us/download/details.aspx?id=56519 does not provide direct JSON # todo: https://www.microsoft.com/en-us/download/details.aspx?id=56519 does not provide direct JSON
# - url: https://raw.githubusercontent.com/femueller/cloud-ip-ranges/refs/heads/master/microsoft-azure-ip-ranges.json - url: https://raw.githubusercontent.com/femueller/cloud-ip-ranges/refs/heads/master/microsoft-azure-ip-ranges.json
# jq-path: '.values[] | .properties.addressPrefixes[]' jq-path: '.values[] | .properties.addressPrefixes[]'
#
# digitalocean: digitalocean:
# - url: https://www.digitalocean.com/geo/google.csv - url: https://www.digitalocean.com/geo/google.csv
# regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
# linode: linode:
# - url: https://geoip.linode.com/ - url: https://geoip.linode.com/
# regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
# vultr: vultr:
# - url: "https://geofeed.constant.com/?json" - url: "https://geofeed.constant.com/?json"
# jq-path: '.subnets[] | .ip_prefix' jq-path: '.subnets[] | .ip_prefix'
# cloudflare: cloudflare:
# - url: https://www.cloudflare.com/ips-v4 - url: https://www.cloudflare.com/ips-v4
# regex: "(?P<prefix>[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+/[0-9]+)" regex: "(?P<prefix>[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+/[0-9]+)"
# - url: https://www.cloudflare.com/ips-v6 - url: https://www.cloudflare.com/ips-v6
# regex: "(?P<prefix>[0-9a-f:]+::/[0-9]+)" regex: "(?P<prefix>[0-9a-f:]+::/[0-9]+)"
#
# icloud-private-relay: icloud-private-relay:
# - url: https://mask-api.icloud.com/egress-ip-ranges.csv - url: https://mask-api.icloud.com/egress-ip-ranges.csv
# regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+)," regex: "(?P<prefix>(([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+)|([0-9a-f:]+::))/[0-9]+),"
# tunnelbroker-relay: tunnelbroker-relay:
# # HE Tunnelbroker # HE Tunnelbroker
# - url: https://tunnelbroker.net/export/google - url: https://tunnelbroker.net/export/google
# regex: "(?P<prefix>([0-9a-f:]+::)/[0-9]+)," regex: "(?P<prefix>([0-9a-f:]+::)/[0-9]+),"

View File

@@ -11,7 +11,6 @@ import (
"github.com/go-jose/go-jose/v4/jwt" "github.com/go-jose/go-jose/v4/jwt"
"github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/ast"
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"io" "io"
"math/rand/v2" "math/rand/v2"
"net/http" "net/http"
@@ -68,20 +67,10 @@ func (r Register) Create(state StateInterface, name string, pol policy.Challenge
} }
if len(conditions) > 0 { if len(conditions) > 0 {
ast, err := http_cel.NewAst(state.ProgramEnv(), http_cel.OperatorOr, conditions...) var err error
reg.Condition, err = state.RegisterCondition(http_cel.OperatorOr, conditions...)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("error compiling conditions: %v", err) return nil, 0, fmt.Errorf("error compiling condition: %w", err)
}
if out := ast.OutputType(); out == nil {
return nil, 0, fmt.Errorf("error compiling conditions: no output")
} else if out != types.BoolType {
return nil, 0, fmt.Errorf("error compiling conditions: output type is not bool")
}
reg.Condition, err = http_cel.ProgramAst(state.ProgramEnv(), ast)
if err != nil {
return nil, 0, fmt.Errorf("error compiling program: %v", err)
} }
} }

View File

@@ -86,7 +86,7 @@ func (r VerifyResult) String() string {
} }
type StateInterface interface { type StateInterface interface {
ProgramEnv() *cel.Env RegisterCondition(operator string, conditions ...string) (cel.Program, error)
Client() *http.Client Client() *http.Client
PrivateKey() ed25519.PrivateKey PrivateKey() ed25519.PrivateKey

View File

@@ -4,6 +4,7 @@ import (
http_cel "codeberg.org/gone/http-cel" http_cel "codeberg.org/gone/http-cel"
"fmt" "fmt"
"github.com/google/cel-go/cel" "github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/ref"
"log/slog" "log/slog"
@@ -55,7 +56,7 @@ func (state *State) initConditions() (err error) {
} }
return types.Bool(ipNet.Contains(ip)) return types.Bool(ipNet.Contains(ip))
} else { } else {
ok, err := network.Contains(ip) ok, err := network().Contains(ip)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -96,7 +97,7 @@ func (state *State) initConditions() (err error) {
} }
return types.Bool(ipNet.Contains(ip)) return types.Bool(ipNet.Contains(ip))
} else { } else {
ok, err := network.Contains(ip) ok, err := network().Contains(ip)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -111,3 +112,113 @@ func (state *State) initConditions() (err error) {
} }
return nil return nil
} }
func (state *State) RegisterCondition(operator string, conditions ...string) (cel.Program, error) {
compiledAst, err := http_cel.NewAst(state.ProgramEnv(), operator, conditions...)
if err != nil {
return nil, err
}
if out := compiledAst.OutputType(); out == nil {
return nil, fmt.Errorf("no output")
} else if out != types.BoolType {
return nil, fmt.Errorf("output type is not bool")
}
walkExpr(compiledAst.NativeRep().Expr(), func(e ast.Expr) {
if e.Kind() == ast.CallKind {
call := e.AsCall()
switch call.FunctionName() {
// deprecated
case "inNetwork":
args := call.Args()
if call.IsMemberFunction() && len(args) == 2 {
// we have a network select function
switch args[1].Kind() {
case ast.LiteralKind:
lit := args[1].AsLiteral()
if lit.Type() == types.StringType {
if fn, ok := state.networks[lit.Value().(string)]; ok {
// preload
fn()
}
}
}
}
case "network":
args := call.Args()
if call.IsMemberFunction() && len(args) == 1 {
// we have a network select function
switch args[0].Kind() {
case ast.LiteralKind:
lit := args[0].AsLiteral()
if lit.Type() == types.StringType {
if fn, ok := state.networks[lit.Value().(string)]; ok {
// preload
fn()
}
}
}
}
}
}
})
return http_cel.ProgramAst(state.ProgramEnv(), compiledAst)
}
func walkExpr(e ast.Expr, fn func(ast.Expr)) {
fn(e)
switch e.Kind() {
case ast.CallKind:
ee := e.AsCall()
walkExpr(ee.Target(), fn)
for _, arg := range ee.Args() {
walkExpr(arg, fn)
}
case ast.ComprehensionKind:
ee := e.AsComprehension()
walkExpr(ee.Result(), fn)
walkExpr(ee.IterRange(), fn)
walkExpr(ee.AccuInit(), fn)
walkExpr(ee.LoopCondition(), fn)
walkExpr(ee.LoopStep(), fn)
case ast.ListKind:
ee := e.AsList()
for _, element := range ee.Elements() {
walkExpr(element, fn)
}
case ast.MapKind:
ee := e.AsMap()
for _, entry := range ee.Entries() {
switch entry.Kind() {
case ast.MapEntryKind:
eee := entry.AsMapEntry()
walkExpr(eee.Key(), fn)
walkExpr(eee.Value(), fn)
case ast.StructFieldKind:
eee := entry.AsStructField()
walkExpr(eee.Value(), fn)
}
}
case ast.SelectKind:
ee := e.AsSelect()
walkExpr(ee.Operand(), fn)
case ast.StructKind:
ee := e.AsStruct()
for _, field := range ee.Fields() {
switch field.Kind() {
case ast.MapEntryKind:
eee := field.AsMapEntry()
walkExpr(eee.Key(), fn)
walkExpr(eee.Value(), fn)
case ast.StructFieldKind:
eee := field.AsStructField()
walkExpr(eee.Value(), fn)
}
}
}
}

View File

@@ -66,20 +66,9 @@ func NewRuleState(state challenge.StateInterface, r policy.Rule, replacer *strin
conditions = append(conditions, cond) conditions = append(conditions, cond)
} }
ast, err := http_cel.NewAst(state.ProgramEnv(), http_cel.OperatorOr, conditions...) program, err := state.RegisterCondition(http_cel.OperatorOr, conditions...)
if err != nil { if err != nil {
return RuleState{}, fmt.Errorf("error compiling conditions: %w", err) return RuleState{}, fmt.Errorf("error compiling condition: %w", err)
}
if out := ast.OutputType(); out == nil {
return RuleState{}, fmt.Errorf("error compiling conditions: no output")
} else if out != types.BoolType {
return RuleState{}, fmt.Errorf("error compiling conditions: output type is not bool")
}
program, err := http_cel.ProgramAst(state.ProgramEnv(), ast)
if err != nil {
return RuleState{}, fmt.Errorf("error compiling program: %w", err)
} }
rule.Condition = program rule.Condition = program
} }

View File

@@ -24,6 +24,7 @@ import (
"path" "path"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
@@ -40,7 +41,7 @@ type State struct {
opt settings.Settings opt settings.Settings
settings policy.StateSettings settings policy.StateSettings
networks map[string]cidranger.Ranger networks map[string]func() cidranger.Ranger
challenges challenge.Register challenges challenge.Register
@@ -54,6 +55,7 @@ type State struct {
} }
func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) { func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSettings) (handler http.Handler, err error) {
state := new(State) state := new(State)
state.close = make(chan struct{}) state.close = make(chan struct{})
state.settings = settings state.settings = settings
@@ -114,89 +116,89 @@ func NewState(p policy.Policy, opt settings.Settings, settings policy.StateSetti
return nil, fmt.Errorf("no template defined for %s", state.opt.ChallengeTemplate) return nil, fmt.Errorf("no template defined for %s", state.opt.ChallengeTemplate)
} }
state.networks = make(map[string]cidranger.Ranger) state.networks = make(map[string]func() cidranger.Ranger)
networkCache := utils.CachePrefix(state.Settings().Cache, "networks/") networkCache := utils.CachePrefix(state.Settings().Cache, "networks/")
for k, network := range p.Networks { for k, network := range p.Networks {
state.networks[k] = sync.OnceValue[cidranger.Ranger](func() cidranger.Ranger {
ranger := cidranger.NewPCTrieRanger()
for i, e := range network {
prefixes, err := func() ([]net.IPNet, error) {
var useCache bool
ranger := cidranger.NewPCTrieRanger() cacheKey := fmt.Sprintf("%s-%d-", k, i)
for i, e := range network { if e.Url != nil {
prefixes, err := func() ([]net.IPNet, error) { slog.Debug("loading network url list", "network", k, "url", *e.Url)
var useCache bool useCache = true
sum := sha256.Sum256([]byte(*e.Url))
cacheKey += hex.EncodeToString(sum[:4])
} else if e.ASN != nil {
slog.Debug("loading ASN", "network", k, "asn", *e.ASN)
useCache = true
cacheKey += strconv.FormatInt(int64(*e.ASN), 10)
}
cacheKey := fmt.Sprintf("%s-%d-", k, i) var cached []net.IPNet
if e.Url != nil { if useCache && networkCache != nil {
slog.Debug("loading network url list", "network", k, "url", *e.Url) //TODO: add randomness
useCache = true cachedData, err := networkCache.Get(cacheKey, time.Hour*24)
sum := sha256.Sum256([]byte(*e.Url)) var l []string
cacheKey += hex.EncodeToString(sum[:4]) _ = json.Unmarshal(cachedData, &l)
} else if e.ASN != nil { for _, n := range l {
slog.Debug("loading ASN", "network", k, "asn", *e.ASN) _, ipNet, err := net.ParseCIDR(n)
useCache = true if err == nil {
cacheKey += strconv.FormatInt(int64(*e.ASN), 10) cached = append(cached, *ipNet)
} }
}
var cached []net.IPNet
if useCache && networkCache != nil {
//TODO: add randomness
cachedData, err := networkCache.Get(cacheKey, time.Hour*24)
var l []string
_ = json.Unmarshal(cachedData, &l)
for _, n := range l {
_, ipNet, err := net.ParseCIDR(n)
if err == nil { if err == nil {
cached = append(cached, *ipNet) // use
return cached, nil
} }
} }
if err == nil {
// use
return cached, nil
prefixes, err := e.FetchPrefixes(state.client, state.radb)
if err != nil {
if len(cached) > 0 {
// use cached meanwhile
return cached, err
}
return nil, err
} }
} if useCache && networkCache != nil {
var l []string
prefixes, err := e.FetchPrefixes(state.client, state.radb) for _, n := range prefixes {
l = append(l, n.String())
}
cachedData, err := json.Marshal(l)
if err == nil {
_ = networkCache.Set(cacheKey, cachedData)
}
}
return prefixes, nil
}()
if err != nil { if err != nil {
if len(cached) > 0 { if e.Url != nil {
// use cached meanwhile slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err)
return cached, err } else if e.ASN != nil {
slog.Error("error loading ASN", "network", k, "asn", *e.ASN, "error", err)
} else {
slog.Error("error loading list", "network", k, "error", err)
} }
return nil, err continue
} }
if useCache && networkCache != nil { for _, prefix := range prefixes {
var l []string err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
for _, n := range prefixes { if err != nil {
l = append(l, n.String()) slog.Error("error inserting prefix", "network", k, "prefix", prefix.String(), "error", err)
} }
cachedData, err := json.Marshal(l)
if err == nil {
_ = networkCache.Set(cacheKey, cachedData)
}
}
return prefixes, nil
}()
if err != nil {
if e.Url != nil {
slog.Error("error loading network list", "network", k, "url", *e.Url, "error", err)
} else if e.ASN != nil {
slog.Error("error loading ASN", "network", k, "asn", *e.ASN, "error", err)
} else {
slog.Error("error loading list", "network", k, "error", err)
}
continue
}
for _, prefix := range prefixes {
err = ranger.Insert(cidranger.NewBasicRangerEntry(prefix))
if err != nil {
return nil, fmt.Errorf("networks %s: error inserting prefix %s: %v", k, prefix.String(), err)
} }
} }
}
slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len()) slog.Warn("loaded network prefixes", "network", k, "count", ranger.Len())
return ranger
state.networks[k] = ranger })
} }
err = state.initConditions() err = state.initConditions()