Minor cleanup of challenges code, bind session http cookies to issued tokens

This commit is contained in:
WeebDataHoarder
2025-04-07 19:00:53 +02:00
parent 0968e6feae
commit e08a5697f6
6 changed files with 323 additions and 183 deletions

View File

@@ -3,6 +3,7 @@ package lib
import (
"bytes"
"codeberg.org/meta/gzipped/v2"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
@@ -11,6 +12,7 @@ import (
"git.gammaspectra.live/git/go-away/embed"
"git.gammaspectra.live/git/go-away/lib/challenge"
"git.gammaspectra.live/git/go-away/lib/policy"
"git.gammaspectra.live/git/go-away/utils"
"github.com/google/cel-go/common/types"
"html/template"
"io"
@@ -88,7 +90,7 @@ func (state *State) challengePage(w http.ResponseWriter, id string, status int,
err := templates["challenge-"+state.Settings.ChallengeTemplate+".gohtml"].Execute(buf, input)
if err != nil {
_ = state.errorPage(w, id, http.StatusInternalServerError, err)
_ = state.errorPage(w, id, http.StatusInternalServerError, err, "")
} else {
w.WriteHeader(status)
_, _ = w.Write(buf.Bytes())
@@ -96,7 +98,7 @@ func (state *State) challengePage(w http.ResponseWriter, id string, status int,
return nil
}
func (state *State) errorPage(w http.ResponseWriter, id string, status int, err error) error {
func (state *State) errorPage(w http.ResponseWriter, id string, status int, err error, redirect string) error {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
buf := bytes.NewBuffer(make([]byte, 0, 8192))
@@ -110,6 +112,7 @@ func (state *State) errorPage(w http.ResponseWriter, id string, status int, err
"Title": "Oh no! " + http.StatusText(status),
"HideSpinner": true,
"Challenge": "",
"Redirect": redirect,
})
if err2 != nil {
panic(err2)
@@ -144,6 +147,8 @@ func (state *State) logger(r *http.Request) *slog.Logger {
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
host := r.Host
data := RequestDataFromContext(r.Context())
backend, ok := state.Settings.Backends[host]
if !ok {
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
@@ -190,13 +195,14 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
fail := func(code int, err error) {
state.addTiming(w, "rule-eval", "Evaluate access rules", ruleEvalDuration)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), code, err)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), code, err, "")
}
setAwayState := func(rule RuleState) {
r.Header.Set("X-Away-Rule", rule.Name)
r.Header.Set("X-Away-Hash", rule.Hash)
r.Header.Set("X-Away-Action", string(rule.Action))
data.Headers(state, r.Header)
}
for _, rule := range state.Rules {
@@ -224,39 +230,31 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
serve()
return
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
start = time.Now()
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
for _, challengeName := range rule.Challenges {
key := state.GetChallengeKeyForRequest(challengeName, expiry, r)
ok, err := state.VerifyChallengeToken(challengeName, key, w, r)
if !ok || err != nil {
if !errors.Is(err, http.ErrNoCookie) {
ClearCookie(CookiePrefix+challengeName, w)
}
for _, challengeId := range rule.Challenges {
if result := data.Challenges[challengeId]; !result.Ok() {
continue
} else {
if rule.Action == policy.RuleActionCHECK {
goto nextRule
}
// we passed the challenge!
lg.Debug("request passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
// we passed the challenge!
lg.Debug("request passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", state.Challenges[challengeId].Name)
setAwayState(rule)
serve()
return
}
}
state.addTiming(w, "challenge-token-check", "Verify client challenge tokens", time.Since(start))
// none matched, issue first challenge in priority
for _, challengeName := range rule.Challenges {
c := state.Challenges[challengeName]
for _, challengeId := range rule.Challenges {
c := state.Challenges[challengeId]
if c.ServeChallenge != nil {
result := c.ServeChallenge(w, r, state.GetChallengeKeyForRequest(challengeName, expiry, r), expiry)
result := c.ServeChallenge(w, r, state.GetChallengeKeyForRequest(c.Name, data.Expires, r), data.Expires)
switch result {
case challenge.ResultStop:
lg.Info("request challenged", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
lg.Info("request challenged", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name)
return
case challenge.ResultContinue:
continue
@@ -264,12 +262,12 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
if rule.Action == policy.RuleActionCHECK {
goto nextRule
}
state.logger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
state.logger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name)
data.Challenges[c.Id] = challenge.VerifyResultOK
// we pass the challenge early!
r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS")
lg.Debug("request passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
lg.Debug("request passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", c.Name)
setAwayState(rule)
serve()
return
@@ -347,7 +345,7 @@ func (state *State) setupRoutes() error {
state.Mux.Handle("GET "+state.UrlPath+"/assets/", http.StripPrefix(state.UrlPath, gzipped.FileServer(gzipped.FS(embed.AssetsFs))))
for challengeName, c := range state.Challenges {
for _, c := range state.Challenges {
if c.ServeStatic != nil {
state.Mux.Handle("GET "+c.Path+"/static/", c.ServeStatic)
}
@@ -365,43 +363,47 @@ func (state *State) setupRoutes() error {
} else if c.Verify != nil {
state.Mux.HandleFunc(fmt.Sprintf("GET %s/verify-challenge", c.Path), func(w http.ResponseWriter, r *http.Request) {
err := func() (err error) {
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
key := state.GetChallengeKeyForRequest(challengeName, expiry, r)
data := RequestDataFromContext(r.Context())
key := state.GetChallengeKeyForRequest(c.Name, data.Expires, r)
result := r.FormValue("result")
requestId, err := hex.DecodeString(r.FormValue("requestId"))
if err == nil {
// override
r.Header.Set("X-Away-Id", hex.EncodeToString(requestId))
}
start := time.Now()
ok, err := c.Verify(key, result)
ok, err := c.Verify(key, result, r)
state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start))
if err != nil {
state.logger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", c.Name, "redirect", r.FormValue("redirect"))
return err
} else if !ok {
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
ClearCookie(CookiePrefix+challengeName, w)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
state.logger(r).Warn("challenge failed", "challenge", c.Name, "redirect", r.FormValue("redirect"))
utils.ClearCookie(utils.CookiePrefix+c.Name, w)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", c.Name), r.FormValue("redirect"))
return nil
}
state.logger(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
state.logger(r).Info("challenge passed", "challenge", c.Name, "redirect", r.FormValue("redirect"))
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
token, err := c.IssueChallengeToken(state.privateKey, key, []byte(result), data.Expires)
if err != nil {
ClearCookie(CookiePrefix+challengeName, w)
utils.ClearCookie(utils.CookiePrefix+c.Name, w)
} else {
SetCookie(CookiePrefix+challengeName, token, expiry, w)
utils.SetCookie(utils.CookiePrefix+c.Name, token, data.Expires, w)
}
data.Challenges[c.Id] = challenge.VerifyResultPASS
http.Redirect(w, r, r.FormValue("redirect"), http.StatusTemporaryRedirect)
return nil
}()
if err != nil {
ClearCookie(CookiePrefix+challengeName, w)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusInternalServerError, err)
utils.ClearCookie(utils.CookiePrefix+c.Name, w)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusInternalServerError, err, r.FormValue("redirect"))
return
}
})
@@ -410,3 +412,54 @@ func (state *State) setupRoutes() error {
return nil
}
func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var data RequestData
// generate random id, todo: is this fast?
_, _ = rand.Read(data.Id[:])
data.Challenges = make(map[challenge.Id]challenge.VerifyResult, len(state.Challenges))
data.Expires = time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
for _, c := range state.Challenges {
key := state.GetChallengeKeyForRequest(c.Name, data.Expires, r)
result, err := c.VerifyChallengeToken(state.publicKey, key, r)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
// clear invalid cookie
utils.ClearCookie(utils.CookiePrefix+c.Name, w)
}
data.Challenges[c.Id] = result
}
r.Header.Set("X-Away-Id", hex.EncodeToString(data.Id[:]))
r = r.WithContext(context.WithValue(r.Context(), "_goaway_data", &data))
state.Mux.ServeHTTP(w, r)
}
func RequestDataFromContext(ctx context.Context) *RequestData {
return ctx.Value("_goaway_data").(*RequestData)
}
type RequestData struct {
Id [16]byte
Expires time.Time
Challenges map[challenge.Id]challenge.VerifyResult
}
func (d *RequestData) HasValidChallenge(id challenge.Id) bool {
return d.Challenges[id].Ok()
}
func (d *RequestData) Headers(state *State, headers http.Header) {
for id, result := range d.Challenges {
if result.Ok() {
c, ok := state.Challenges[id]
if !ok {
panic("challenge not found")
}
headers.Set(fmt.Sprintf("X-Away-Challenge-%s-Result", c.Name), result.String())
}
}
}