Files
go-away/lib/challenge/data.go

500 lines
13 KiB
Go

package challenge
import (
"bytes"
http_cel "codeberg.org/gone/http-cel"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"git.gammaspectra.live/git/go-away/utils"
"github.com/go-jose/go-jose/v4"
"github.com/go-jose/go-jose/v4/jwt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/traits"
"maps"
unsaferand "math/rand/v2"
"net/http"
"net/netip"
"net/textproto"
"strings"
"time"
)
type requestDataContextKey struct {
}
func RequestDataFromContext(ctx context.Context) *RequestData {
val := ctx.Value(requestDataContextKey{})
if val == nil {
return nil
}
return val.(*RequestData)
}
type RequestId [16]byte
func (id RequestId) String() string {
return hex.EncodeToString(id[:])
}
type RequestData struct {
Id RequestId
Time time.Time
ChallengeVerify map[Id]VerifyResult
ChallengeState map[Id]VerifyState
ChallengeMap TokenChallengeMap
challengeMapModified bool
RemoteAddress netip.AddrPort
State StateInterface
cookieName string
issuedChallenge string
ExtraHeaders http.Header
r *http.Request
fp map[string]string
header traits.Mapper
query traits.Mapper
opts map[string]string
}
func CreateRequestData(r *http.Request, state StateInterface) (*http.Request, *RequestData) {
var data RequestData
// generate random id, todo: is this fast?
_, _ = rand.Read(data.Id[:])
data.RemoteAddress = utils.GetRequestAddress(r, state.Settings().ClientIpHeader)
data.ChallengeVerify = make(map[Id]VerifyResult, len(state.GetChallenges()))
data.ChallengeState = make(map[Id]VerifyState, len(state.GetChallenges()))
data.Time = time.Now().UTC()
data.State = state
data.ExtraHeaders = make(http.Header)
data.fp = make(map[string]string, 2)
if fp := utils.GetTLSFingerprint(r); fp != nil {
if ja3nPtr := fp.JA3N(); ja3nPtr != nil {
ja3n := ja3nPtr.String()
data.fp["ja3n"] = ja3n
}
if ja4Ptr := fp.JA4(); ja4Ptr != nil {
ja4 := ja4Ptr.String()
data.fp["ja4"] = ja4
}
}
q := r.URL.Query()
if q.Has(QueryArgChallenge) {
data.issuedChallenge = q.Get(QueryArgChallenge)
}
// delete query parameters that were set by go-away
for k := range q {
if strings.HasPrefix(k, QueryArgPrefix) {
q.Del(k)
}
}
data.query = http_cel.NewValuesMap(q)
data.header = http_cel.NewMIMEMap(textproto.MIMEHeader(r.Header))
data.opts = make(map[string]string)
r = r.WithContext(context.WithValue(r.Context(), requestDataContextKey{}, &data))
r = utils.SetRemoteAddress(r, data.RemoteAddress)
data.r = r
data.cookieName = utils.DefaultCookiePrefix + hex.EncodeToString(data.cookieHostKey()) + "-state"
return r, &data
}
func (d *RequestData) ResolveName(name string) (any, bool) {
switch name {
case "host":
return d.r.Host, true
case "method":
return d.r.Method, true
case "remoteAddress":
return d.RemoteAddress.Addr().AsSlice(), true
case "userAgent":
return d.r.UserAgent(), true
case "path":
return d.r.URL.Path, true
case "query":
return d.query, true
case "headers":
return d.header, true
case "fp":
return d.fp, true
default:
return nil, false
}
}
func (d *RequestData) Parent() cel.Activation {
return nil
}
func (d *RequestData) NetworkPrefix() netip.Addr {
address := d.RemoteAddress.Addr().Unmap()
if address.Is4() {
// Take a /24 for IPv4
prefix, _ := address.Prefix(24)
return prefix.Addr()
} else {
// Take a /64 for IPv6
prefix, _ := address.Prefix(64)
return prefix.Addr()
}
}
const (
RequestOptBackendHost = "backend-host"
RequestOptProxyMetaTags = "proxy-meta-tags"
RequestOptProxySafeLinkTags = "proxy-safe-link-tags"
)
func (d *RequestData) SetOpt(n, v string) {
d.opts[n] = v
}
func (d *RequestData) GetOpt(n, def string) string {
v, ok := d.opts[n]
if !ok {
return def
}
return v
}
func (d *RequestData) GetOptBool(n string, def bool) bool {
v, ok := d.opts[n]
if !ok {
return def
}
switch v {
case "true", "t", "1", "yes", "yep", "y", "ok":
return true
case "false", "f", "0", "no", "nope", "n", "err":
return false
default:
return def
}
}
func (d *RequestData) BackendHost() (http.Handler, string) {
host := d.r.Host
if opt := d.GetOpt(RequestOptBackendHost, ""); opt != "" && opt != host {
host = d.r.Host
}
return d.State.GetBackend(host), host
}
func (d *RequestData) ClearChallengeToken(reg *Registration) {
delete(d.ChallengeMap, reg.Name)
d.challengeMapModified = true
}
func (d *RequestData) IssueChallengeToken(reg *Registration, key Key, result []byte, until time.Time, ok bool) {
d.ChallengeMap[reg.Name] = TokenChallenge{
Key: key[:],
Result: result,
Ok: ok,
Expiry: jwt.NumericDate(until.Unix()),
IssuedAt: jwt.NumericDate(time.Now().UTC().Unix()),
}
d.challengeMapModified = true
}
var ErrVerifyKeyMismatch = errors.New("verify: key mismatch")
var ErrVerifyVerifyMismatch = errors.New("verify: verification mismatch")
var ErrTokenExpired = errors.New("token: expired")
func (d *RequestData) VerifyChallengeToken(reg *Registration, token TokenChallenge, expectedKey Key) (VerifyResult, VerifyState, error) {
if token.Expiry.Time().Compare(time.Now()) < 0 {
return VerifyResultFail, VerifyStateNone, ErrTokenExpired
}
if token.NotBefore.Time().Compare(time.Now()) > 0 {
return VerifyResultFail, VerifyStateNone, errors.New("token not valid yet")
}
if bytes.Compare(expectedKey[:], token.Key) != 0 {
return VerifyResultFail, VerifyStateNone, ErrVerifyKeyMismatch
}
if reg.Verify != nil {
if unsaferand.Float64() < reg.VerifyProbability {
// random spot check
if ok, err := reg.Verify(expectedKey, token.Result, d.r); err != nil {
return VerifyResultFail, VerifyStateFull, err
} else if ok == VerifyResultNotOK {
return VerifyResultNotOK, VerifyStateFull, nil
} else if !ok.Ok() {
return ok, VerifyStateFull, ErrVerifyVerifyMismatch
} else {
return ok, VerifyStateFull, nil
}
}
}
if !token.Ok {
return VerifyResultNotOK, VerifyStateBrief, nil
}
return VerifyResultOK, VerifyStateBrief, nil
}
func (d *RequestData) verifyChallenge(reg *Registration, key Key) (verifyResult VerifyResult, verifyState VerifyState, err error) {
token, ok := d.ChallengeMap[reg.Name]
if !ok {
verifyResult = VerifyResultFail
verifyState = VerifyStateNone
} else {
verifyResult, verifyState, err = d.VerifyChallengeToken(reg, token, key)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
// clear invalid state
d.ClearChallengeToken(reg)
}
// prevent evaluating the challenge if not solved
if !verifyResult.Ok() && reg.Condition != nil {
out, _, err := reg.Condition.Eval(d)
// verify eligibility
if err != nil {
d.State.Logger(d.r).Error(err.Error(), "challenge", reg.Name)
} else if out != nil && out.Type() == types.BoolType {
if out.Equal(types.True) != types.True {
// skip challenge match due to precondition!
verifyResult = VerifyResultSkip
return verifyResult, verifyState, err
}
}
}
}
if !verifyResult.Ok() && d.issuedChallenge == reg.Name {
// we issued the challenge, must skip to prevent loops
verifyResult = VerifyResultSkip
}
return verifyResult, verifyState, err
}
func (d *RequestData) EvaluateChallenges(w http.ResponseWriter, r *http.Request) {
challengeMap, err := d.verifyChallengeState()
if err != nil {
if !errors.Is(err, http.ErrNoCookie) {
//clear invalid cookie and continue
utils.ClearCookie(d.cookieName, w, r)
}
challengeMap = make(TokenChallengeMap)
}
d.ChallengeMap = challengeMap
for _, reg := range d.State.GetChallenges() {
key := GetChallengeKeyForRequest(d.State, reg, d.Expiration(reg.Duration), r)
verifyResult, verifyState, err := d.verifyChallenge(reg, key)
if err != nil {
// clear invalid state
d.ClearChallengeToken(reg)
}
d.ChallengeVerify[reg.Id()] = verifyResult
d.ChallengeState[reg.Id()] = verifyState
}
}
func (d *RequestData) Expiration(duration time.Duration) time.Time {
return d.Time.Add(duration).Round(duration)
}
func (d *RequestData) HasValidChallenge(id Id) bool {
return d.ChallengeVerify[id].Ok()
}
func (d *RequestData) ResponseHeaders(w http.ResponseWriter) {
// send these to client so we consistently get the headers
//w.Header().Set("Accept-CH", "Sec-CH-UA, Sec-CH-UA-Platform")
//w.Header().Set("Critical-CH", "Sec-CH-UA, Sec-CH-UA-Platform")
if d.State.Settings().MainName != "" {
w.Header().Add("Via", fmt.Sprintf("%s %s@%s", d.r.Proto, d.State.Settings().MainName, d.State.Settings().MainVersion))
}
if d.challengeMapModified {
expiration := d.Expiration(DefaultDuration)
if token, err := d.issueChallengeState(expiration); err == nil {
utils.SetCookie(d.cookieName, token, expiration, w, d.r)
} else {
d.State.Logger(d.r).Error("error while issuing cookie", "error", err)
}
}
}
func (d *RequestData) RequestHeaders(headers http.Header) {
headers.Set("X-Away-Id", d.Id.String())
if d.State.Settings().BackendIpHeader != "" {
if d.State.Settings().ClientIpHeader != "" {
headers.Del(d.State.Settings().ClientIpHeader)
}
headers.Set(d.State.Settings().BackendIpHeader, d.RemoteAddress.Addr().Unmap().String())
}
for id, result := range d.ChallengeVerify {
if result.Ok() {
c, ok := d.State.GetChallenge(id)
if !ok {
panic("challenge not found")
}
headers.Set(fmt.Sprintf("X-Away-Challenge-%s-Result", c.Name), result.String())
headers.Set(fmt.Sprintf("X-Away-Challenge-%s-State", c.Name), d.ChallengeState[id].String())
}
}
if ja4, ok := d.fp["fp4"]; ok {
headers.Set("X-TLS-Fingerprint-JA4", ja4)
}
if ja3n, ok := d.fp["ja3n"]; ok {
headers.Set("X-TLS-Fingerprint-JA3N", ja3n)
}
maps.Copy(headers, d.ExtraHeaders)
}
type Token struct {
State TokenChallengeMap `json:"state"`
Expiry jwt.NumericDate `json:"exp,omitempty"`
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt jwt.NumericDate `json:"iat,omitempty"`
}
type TokenChallengeMap map[string]TokenChallenge
type TokenChallenge struct {
Key []byte `json:"key"`
Result []byte `json:"result,omitempty"`
Ok bool `json:"ok"`
Expiry jwt.NumericDate `json:"exp,omitempty"`
NotBefore jwt.NumericDate `json:"nbf,omitempty"`
IssuedAt jwt.NumericDate `json:"iat,omitempty"`
}
func (d *RequestData) verifyChallengeStateCookie(cookie *http.Cookie) (TokenChallengeMap, error) {
cookie, err := d.r.Cookie(d.cookieName)
if err != nil {
return nil, err
}
if cookie == nil {
return nil, http.ErrNoCookie
}
encryptedToken, err := jwt.ParseSignedAndEncrypted(cookie.Value,
[]jose.KeyAlgorithm{jose.DIRECT},
[]jose.ContentEncryption{jose.A256GCM},
[]jose.SignatureAlgorithm{jose.EdDSA},
)
if err != nil {
return nil, err
}
signedToken, err := encryptedToken.Decrypt(d.cookieKey())
if err != nil {
return nil, err
}
var i Token
err = signedToken.Claims(d.State.PublicKey(), &i)
if err != nil {
return nil, err
}
if i.Expiry.Time().Compare(time.Now()) < 0 {
return nil, ErrTokenExpired
}
if i.NotBefore.Time().Compare(time.Now()) > 0 {
return nil, errors.New("token not valid yet")
}
return i.State, nil
}
func (d *RequestData) verifyChallengeState() (state TokenChallengeMap, err error) {
cookies := d.r.CookiesNamed(d.cookieName)
if len(cookies) == 0 {
return nil, http.ErrNoCookie
}
for _, cookie := range cookies {
state, err = d.verifyChallengeStateCookie(cookie)
if err == nil {
return state, nil
}
}
return state, err
}
func (d *RequestData) issueChallengeState(until time.Time) (string, error) {
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.EdDSA,
Key: d.State.PrivateKey(),
}, nil)
if err != nil {
return "", err
}
encrypter, err := jose.NewEncrypter(jose.A256GCM, jose.Recipient{
Algorithm: jose.DIRECT,
Key: d.cookieKey(),
}, (&jose.EncrypterOptions{
Compression: jose.DEFLATE,
}).WithContentType("JWT"))
if err != nil {
return "", err
}
return jwt.SignedAndEncrypted(signer, encrypter).Claims(Token{
State: d.ChallengeMap,
Expiry: jwt.NumericDate(until.Unix()),
NotBefore: jwt.NumericDate(time.Now().UTC().AddDate(0, 0, -1).Unix()),
IssuedAt: jwt.NumericDate(time.Now().UTC().Unix()),
}).Serialize()
}
func (d *RequestData) cookieKey() []byte {
sum := sha256.New()
sum.Write([]byte(d.r.Host))
sum.Write([]byte{0})
sum.Write(d.NetworkPrefix().AsSlice())
sum.Write([]byte{0})
sum.Write(d.State.PrivateKey())
sum.Write([]byte{0})
// version/compressor
sum.Write([]byte("1.0/DEFLATE"))
sum.Write([]byte{0})
return sum.Sum(nil)
}
func (d *RequestData) cookieHostKey() []byte {
sum := sha256.New()
sum.Write([]byte(d.r.Host))
sum.Write([]byte{0})
sum.Write(d.NetworkPrefix().AsSlice())
sum.Write([]byte{0})
return sum.Sum(nil)[:6]
}