From 2cb59723714c38737e89f9a3a0b56082735200bc Mon Sep 17 00:00:00 2001 From: WeebDataHoarder Date: Sat, 3 May 2025 15:55:13 +0200 Subject: [PATCH] challenges/context: allow setting request headers towards the backend --- lib/action/context.go | 20 +++++++++++++++++--- lib/challenge/data.go | 28 +++++++++++++++++----------- lib/http.go | 4 +--- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/lib/action/context.go b/lib/action/context.go index be3c4d9..fb96b96 100644 --- a/lib/action/context.go +++ b/lib/action/context.go @@ -7,6 +7,7 @@ import ( "github.com/goccy/go-yaml/ast" "log/slog" "net/http" + "net/textproto" ) func init() { @@ -33,8 +34,9 @@ func init() { var ContextDefaultSettings = ContextSettings{} type ContextSettings struct { - ContextSet map[string]string `yaml:"context-set"` - ResponseHeaders map[string]string `yaml:"response-headers"` + ContextSet map[string]string `yaml:"context-set"` + ResponseHeaders map[string][]string `yaml:"response-headers"` + RequestHeaders map[string][]string `yaml:"request-headers"` } type Context struct { @@ -48,7 +50,19 @@ func (a Context) Handle(logger *slog.Logger, w http.ResponseWriter, r *http.Requ } for k, v := range a.opts.ResponseHeaders { - w.Header().Set(k, v) + // do this to allow unsetting values that are sent automatically + w.Header()[textproto.CanonicalMIMEHeaderKey(k)] = nil + for _, val := range v { + w.Header().Add(k, val) + } + } + + for k, v := range a.opts.RequestHeaders { + // do this to allow unsetting values that are sent automatically + r.Header[textproto.CanonicalMIMEHeaderKey(k)] = nil + for _, val := range v { + r.Header.Add(k, val) + } } return true, nil diff --git a/lib/challenge/data.go b/lib/challenge/data.go index cc52627..b327655 100644 --- a/lib/challenge/data.go +++ b/lib/challenge/data.go @@ -230,17 +230,6 @@ func (d *RequestData) EvaluateChallenges(w http.ResponseWriter, r *http.Request) d.ChallengeVerify[reg.Id()] = verifyResult d.ChallengeState[reg.Id()] = verifyState } - - if d.State.Settings().BackendIpHeader != "" { - if d.State.Settings().ClientIpHeader != "" { - r.Header.Del(d.State.Settings().ClientIpHeader) - } - r.Header.Set(d.State.Settings().BackendIpHeader, d.RemoteAddress.String()) - } - - // send these to client so we consistently get the headers - //w.Header().Set("Accept-CH", "Sec-CH-UA, Sec-CH-UA-Platform") - //w.Header().Set("Critical-CH", "Sec-CH-UA, Sec-CH-UA-Platform") } func (d *RequestData) Expiration(duration time.Duration) time.Time { @@ -251,9 +240,26 @@ func (d *RequestData) HasValidChallenge(id Id) bool { return d.ChallengeVerify[id].Ok() } +func (d *RequestData) ResponseHeaders(headers http.Header) { + // send these to client so we consistently get the headers + //w.Header().Set("Accept-CH", "Sec-CH-UA, Sec-CH-UA-Platform") + //w.Header().Set("Critical-CH", "Sec-CH-UA, Sec-CH-UA-Platform") + + if d.State.Settings().MainName != "" { + headers.Add("Via", fmt.Sprintf("%s %s@%s", d.r.Proto, d.State.Settings().MainName, d.State.Settings().MainVersion)) + } +} + func (d *RequestData) RequestHeaders(headers http.Header) { headers.Set("X-Away-Id", d.Id.String()) + if d.State.Settings().BackendIpHeader != "" { + if d.State.Settings().ClientIpHeader != "" { + headers.Del(d.State.Settings().ClientIpHeader) + } + headers.Set(d.State.Settings().BackendIpHeader, d.RemoteAddress.String()) + } + for id, result := range d.ChallengeVerify { if result.Ok() { c, ok := d.State.GetChallenge(id) diff --git a/lib/http.go b/lib/http.go index 3fdde30..13fdc0d 100644 --- a/lib/http.go +++ b/lib/http.go @@ -323,9 +323,7 @@ func (state *State) ServeHTTP(w http.ResponseWriter, r *http.Request) { data.EvaluateChallenges(w, r) - if state.Settings().MainName != "" { - w.Header().Add("Via", fmt.Sprintf("%s %s@%s", r.Proto, state.Settings().MainName, state.Settings().MainVersion)) - } + data.ResponseHeaders(w.Header()) state.Mux.ServeHTTP(w, r) }