From baf9df9f0a4bb9343ad37473be8ec99f162d1348 Mon Sep 17 00:00:00 2001 From: WeebDataHoarder <57538841+WeebDataHoarder@users.noreply.github.com> Date: Tue, 8 Apr 2025 11:40:16 +0200 Subject: [PATCH] Allow conditions on challenges, and early hint deadline --- examples/forgejo.yml | 4 ++ lib/challenge.go | 12 ++++ lib/challenge/state.go | 13 ++-- lib/conditions.go | 69 +++++++++++++++++++++ lib/http.go | 73 ++++++++++++++-------- lib/policy/challenge.go | 7 ++- lib/state.go | 134 ++++++++++++++++------------------------ 7 files changed, 197 insertions(+), 115 deletions(-) create mode 100644 lib/conditions.go diff --git a/examples/forgejo.yml b/examples/forgejo.yml index f12e58a..45997cc 100644 --- a/examples/forgejo.yml +++ b/examples/forgejo.yml @@ -116,12 +116,16 @@ challenges: # Challenges with a redirect via Link header with rel=preload and early hints (non-JS, requires HTTP parsing, fetching and logic) # Works on HTTP/2 and above! self-preload-link: + # doesn't seem to work reliably on other stuff that firefox + # userAgent.contains("Firefox/") && + condition: '"Sec-Fetch-Mode" in headers && headers["Sec-Fetch-Mode"] == "navigate"' mode: "preload-link" runtime: # verifies that result = key mode: "key" probability: 0.1 parameters: + preload-early-hint-deadline: 3s key-code: 200 key-mime: text/css key-content: "" diff --git a/lib/challenge.go b/lib/challenge.go index 9114c60..1b22eb8 100644 --- a/lib/challenge.go +++ b/lib/challenge.go @@ -20,6 +20,18 @@ type ChallengeInformation struct { IssuedAt *jwt.NumericDate `json:"iat,omitempty"` } +func getRequestScheme(r *http.Request) string { + if proto := r.Header.Get("X-Forwarded-Proto"); proto == "http" || proto == "https" { + return proto + } + + if r.TLS != nil { + return "https" + } + + return "http" +} + func getRequestAddress(r *http.Request, clientHeader string) net.IP { var ipStr string if clientHeader != "" { diff --git a/lib/challenge/state.go b/lib/challenge/state.go index 878be17..e319c81 100644 --- a/lib/challenge/state.go +++ b/lib/challenge/state.go @@ -7,6 +7,7 @@ import ( "git.gammaspectra.live/git/go-away/utils" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + "github.com/google/cel-go/cel" "math/rand/v2" "net/http" "time" @@ -26,9 +27,10 @@ const ( type Id int type Challenge struct { - Id Id - Name string - Path string + Id Id + Program cel.Program + Name string + Path string Verify func(key []byte, result string, r *http.Request) (bool, error) VerifyProbability float64 @@ -86,6 +88,7 @@ type VerifyResult int const ( VerifyResultNONE = VerifyResult(iota) VerifyResultFAIL + VerifyResultSKIP // VerifyResultPASS Client just passed this challenge VerifyResultPASS @@ -95,7 +98,7 @@ const ( ) func (r VerifyResult) Ok() bool { - return r > VerifyResultFAIL + return r >= VerifyResultPASS } func (r VerifyResult) String() string { @@ -104,6 +107,8 @@ func (r VerifyResult) String() string { return "NONE" case VerifyResultFAIL: return "FAIL" + case VerifyResultSKIP: + return "SKIP" case VerifyResultPASS: return "PASS" case VerifyResultOK: diff --git a/lib/conditions.go b/lib/conditions.go new file mode 100644 index 0000000..322ad5b --- /dev/null +++ b/lib/conditions.go @@ -0,0 +1,69 @@ +package lib + +import ( + "fmt" + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "net" +) + +func (state *State) initConditions() (err error) { + state.RulesEnv, err = cel.NewEnv( + cel.DefaultUTCTimeZone(true), + cel.Variable("remoteAddress", cel.BytesType), + cel.Variable("host", cel.StringType), + cel.Variable("method", cel.StringType), + cel.Variable("userAgent", cel.StringType), + cel.Variable("path", cel.StringType), + cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)), + // http.Header + cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)), + //TODO: dynamic type? + cel.Function("inNetwork", + cel.Overload("inNetwork_string_ip", + []*cel.Type{cel.StringType, cel.AnyType}, + cel.BoolType, + cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val { + var ip net.IP + switch v := rhs.Value().(type) { + case []byte: + ip = v + case net.IP: + ip = v + case string: + ip = net.ParseIP(v) + } + + if ip == nil { + panic(fmt.Errorf("invalid ip %v", rhs.Value())) + } + + val, ok := lhs.Value().(string) + if !ok { + panic(fmt.Errorf("invalid value %v", lhs.Value())) + } + + network, ok := state.Networks[val] + if !ok { + _, ipNet, err := net.ParseCIDR(val) + if err != nil { + panic("network not found") + } + return types.Bool(ipNet.Contains(ip)) + } else { + ok, err := network.Contains(ip) + if err != nil { + panic(err) + } + return types.Bool(ok) + } + }), + ), + ), + ) + if err != nil { + return err + } + return nil +} diff --git a/lib/http.go b/lib/http.go index 1340c80..44c123e 100644 --- a/lib/http.go +++ b/lib/http.go @@ -159,29 +159,6 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { start := time.Now() - //TODO better matcher! combo ast? - env := map[string]any{ - "host": host, - "method": r.Method, - "remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader), - "userAgent": r.UserAgent(), - "path": r.URL.Path, - "query": func() map[string]string { - result := make(map[string]string) - for k, v := range r.URL.Query() { - result[k] = strings.Join(v, ",") - } - return result - }(), - "headers": func() map[string]string { - result := make(map[string]string) - for k, v := range r.Header { - result[k] = strings.Join(v, ",") - } - return result - }(), - } - state.addTiming(w, "rule-env", "Setup the rule environment", time.Since(start)) var ( @@ -211,7 +188,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { continue } start = time.Now() - out, _, err := rule.Program.Eval(env) + out, _, err := rule.Program.Eval(data.ProgramEnv) ruleEvalDuration += time.Since(start) if err != nil { @@ -230,7 +207,6 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { serve() return case policy.RuleActionCHALLENGE, policy.RuleActionCHECK: - for _, challengeId := range rule.Challenges { if result := data.Challenges[challengeId]; !result.Ok() { continue @@ -249,6 +225,11 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { // none matched, issue first challenge in priority for _, challengeId := range rule.Challenges { + result := data.Challenges[challengeId] + if result.Ok() || result == challenge.VerifyResultSKIP { + // skip already ok'd challenges for some reason, and also skip skipped challenges + continue + } c := state.Challenges[challengeId] if c.ServeChallenge != nil { result := c.ServeChallenge(w, r, state.GetChallengeKeyForRequest(c.Name, data.Expires, r), data.Expires) @@ -264,7 +245,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { } state.logger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name) - data.Challenges[c.Id] = challenge.VerifyResultOK + // set pass if caller didn't set one + if !data.Challenges[c.Id].Ok() { + data.Challenges[c.Id] = challenge.VerifyResultPASS + } // we pass the challenge early! lg.Debug("request passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name) @@ -425,6 +409,27 @@ func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, _ = rand.Read(data.Id[:]) data.Challenges = make(map[challenge.Id]challenge.VerifyResult, len(state.Challenges)) data.Expires = time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity) + data.ProgramEnv = map[string]any{ + "host": r.Host, + "method": r.Method, + "remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader), + "userAgent": r.UserAgent(), + "path": r.URL.Path, + "query": func() map[string]string { + result := make(map[string]string) + for k, v := range r.URL.Query() { + result[k] = strings.Join(v, ",") + } + return result + }(), + "headers": func() map[string]string { + result := make(map[string]string) + for k, v := range r.Header { + result[k] = strings.Join(v, ",") + } + return result + }(), + } for _, c := range state.Challenges { key := state.GetChallengeKeyForRequest(c.Name, data.Expires, r) @@ -433,6 +438,21 @@ func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) { // clear invalid cookie utils.ClearCookie(utils.CookiePrefix+c.Name, w) } + + // prevent the challenge if not solved + if !result.Ok() && c.Program != nil { + out, _, err := c.Program.Eval(data.ProgramEnv) + // verify eligibility + if err != nil { + state.logger(r).Error(err.Error(), "challenge", c.Name) + } else if out != nil && out.Type() == types.BoolType { + if out.Equal(types.True) != types.True { + // skip challenge match! + result = challenge.VerifyResultSKIP + continue + } + } + } data.Challenges[c.Id] = result } @@ -449,6 +469,7 @@ func RequestDataFromContext(ctx context.Context) *RequestData { type RequestData struct { Id [16]byte + ProgramEnv map[string]any Expires time.Time Challenges map[challenge.Id]challenge.VerifyResult } diff --git a/lib/policy/challenge.go b/lib/policy/challenge.go index f1f973d..cac0892 100644 --- a/lib/policy/challenge.go +++ b/lib/policy/challenge.go @@ -1,9 +1,10 @@ package policy type Challenge struct { - Mode string `yaml:"mode"` - Asset *string `yaml:"asset,omitempty"` - Url *string `yaml:"url,omitempty"` + Conditions []string `yaml:"conditions"` + Mode string `yaml:"mode"` + Asset *string `yaml:"asset,omitempty"` + Url *string `yaml:"url,omitempty"` Parameters map[string]string `json:"parameters,omitempty"` Runtime struct { diff --git a/lib/state.go b/lib/state.go index 820163a..68fbb35 100644 --- a/lib/state.go +++ b/lib/state.go @@ -20,15 +20,12 @@ import ( "git.gammaspectra.live/git/go-away/utils" "git.gammaspectra.live/git/go-away/utils/inline" "github.com/google/cel-go/cel" - "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" "github.com/tetratelabs/wazero/api" "github.com/yl2chen/cidranger" "html/template" "io" "io/fs" "log/slog" - "net" "net/http" "net/http/httputil" "net/url" @@ -197,13 +194,56 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) state.Wasm = wasm.NewRunner(true) + err = state.initConditions() + if err != nil { + return nil, err + } + + var replacements []string + for k, entries := range p.Conditions { + ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...) + if err != nil { + return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err) + } + + cond, err := cel.AstToString(ast) + if err != nil { + return nil, fmt.Errorf("conditions %s: error printing condition: %v", k, err) + } + + replacements = append(replacements, fmt.Sprintf("($%s)", k)) + replacements = append(replacements, "("+cond+")") + } + conditionReplacer := strings.NewReplacer(replacements...) + state.Challenges = make(map[challenge.Id]challenge.Challenge) idCounter := challenge.Id(1) for challengeName, p := range p.Challenges { + + // allow nesting + var conditions []string + for _, cond := range p.Conditions { + cond = conditionReplacer.Replace(cond) + conditions = append(conditions, cond) + } + + var program cel.Program + if len(conditions) > 0 { + ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, conditions...) + if err != nil { + return nil, fmt.Errorf("challenge %s: error compiling conditions: %v", challengeName, err) + } + program, err = state.RulesEnv.Program(ast) + if err != nil { + return nil, fmt.Errorf("challenge %s: error compiling program: %v", challengeName, err) + } + } + c := challenge.Challenge{ Id: idCounter, + Program: program, Name: challengeName, Path: fmt.Sprintf("%s/challenge/%s", state.UrlPath, challengeName), VerifyProbability: p.Runtime.Probability, @@ -383,13 +423,16 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) return challenge.ResultStop } case "preload-link": - deadline := time.Second * 5 + deadline, _ := time.ParseDuration(p.Parameters["preload-early-hint-deadline"]) + if deadline == 0 { + deadline = time.Second * 3 + } c.ServeChallenge = func(w http.ResponseWriter, r *http.Request, key []byte, expiry time.Time) challenge.Result { // this only works on HTTP/2 and HTTP/3 if r.ProtoMajor < 2 { - // this can happen if we are an upgraded request + // this can happen if we are an upgraded request from HTTP/1.1 to HTTP/2 in H2C if _, ok := w.(http.Pusher); !ok { return challenge.ResultContinue } @@ -397,18 +440,19 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) data := RequestDataFromContext(r.Context()) redirectUri := new(url.URL) + redirectUri.Scheme = getRequestScheme(r) + redirectUri.Host = r.Host redirectUri.Path = c.Path + "/verify-challenge" values := make(url.Values) values.Set("result", hex.EncodeToString(key)) - values.Set("redirect", r.URL.String()) values.Set("requestId", r.Header.Get("X-Away-Id")) redirectUri.RawQuery = values.Encode() - w.Header().Set("Link", fmt.Sprintf("<%s>; rel=preload; as=style; fetchpriority=high", redirectUri.String())) + w.Header().Set("Link", fmt.Sprintf("<%s>; rel=\"preload\"; as=\"style\"; fetchpriority=high", redirectUri.String())) defer func() { - // remove old header header! + // remove old header so it won't show on response! w.Header().Del("Link") }() w.WriteHeader(http.StatusEarlyHints) @@ -656,80 +700,6 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) state.Challenges[c.Id] = c } - state.RulesEnv, err = cel.NewEnv( - cel.DefaultUTCTimeZone(true), - cel.Variable("remoteAddress", cel.BytesType), - cel.Variable("host", cel.StringType), - cel.Variable("method", cel.StringType), - cel.Variable("userAgent", cel.StringType), - cel.Variable("path", cel.StringType), - cel.Variable("query", cel.MapType(cel.StringType, cel.StringType)), - // http.Header - cel.Variable("headers", cel.MapType(cel.StringType, cel.StringType)), - //TODO: dynamic type? - cel.Function("inNetwork", - cel.Overload("inNetwork_string_ip", - []*cel.Type{cel.StringType, cel.AnyType}, - cel.BoolType, - cel.BinaryBinding(func(lhs ref.Val, rhs ref.Val) ref.Val { - var ip net.IP - switch v := rhs.Value().(type) { - case []byte: - ip = v - case net.IP: - ip = v - case string: - ip = net.ParseIP(v) - } - - if ip == nil { - panic(fmt.Errorf("invalid ip %v", rhs.Value())) - } - - val, ok := lhs.Value().(string) - if !ok { - panic(fmt.Errorf("invalid value %v", lhs.Value())) - } - - network, ok := state.Networks[val] - if !ok { - _, ipNet, err := net.ParseCIDR(val) - if err != nil { - panic("network not found") - } - return types.Bool(ipNet.Contains(ip)) - } else { - ok, err := network.Contains(ip) - if err != nil { - panic(err) - } - return types.Bool(ok) - } - }), - ), - ), - ) - if err != nil { - return nil, err - } - - var replacements []string - for k, entries := range p.Conditions { - ast, err := condition.FromStrings(state.RulesEnv, condition.OperatorOr, entries...) - if err != nil { - return nil, fmt.Errorf("conditions %s: error compiling conditions: %v", k, err) - } - - cond, err := cel.AstToString(ast) - if err != nil { - return nil, fmt.Errorf("conditions %s: error printing condition: %v", k, err) - } - - replacements = append(replacements, fmt.Sprintf("($%s)", k)) - replacements = append(replacements, "("+cond+")") - } - conditionReplacer := strings.NewReplacer(replacements...) - for _, rule := range p.Rules { hasher := sha256.New() hasher.Write([]byte(rule.Name))