Mark challenge keys with whether client was ipv4 or ipv6, allow retrying IPv4 -> IPv6 happy eyeballs automatically

This commit is contained in:
WeebDataHoarder
2025-04-08 14:07:24 +02:00
parent f80d6ebd15
commit 186904e020
2 changed files with 68 additions and 18 deletions

View File

@@ -3,6 +3,8 @@ package lib
import (
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"github.com/go-jose/go-jose/v4/jwt"
"net"
"net/http"
@@ -52,12 +54,44 @@ func getRequestAddress(r *http.Request, clientHeader string) net.IP {
return net.ParseIP(ipStr)
}
func (state *State) GetChallengeKeyForRequest(challengeName string, until time.Time, r *http.Request) []byte {
type ChallengeKey []byte
const ChallengeKeySize = sha256.Size
func (k *ChallengeKey) Set(flags ChallengeKeyFlags) {
(*k)[0] |= uint8(flags)
}
func (k *ChallengeKey) Get(flags ChallengeKeyFlags) ChallengeKeyFlags {
return ChallengeKeyFlags((*k)[0] & uint8(flags))
}
func (k *ChallengeKey) Unset(flags ChallengeKeyFlags) {
(*k)[0] = (*k)[0] & ^(uint8(flags))
}
type ChallengeKeyFlags uint8
const (
ChallengeKeyFlagIsIPv4 = ChallengeKeyFlags(1 << iota)
)
func ChallengeKeyFromString(s string) (ChallengeKey, error) {
b, err := hex.DecodeString(s)
if err != nil {
return nil, err
}
if len(b) != ChallengeKeySize {
return nil, errors.New("invalid challenge key")
}
return ChallengeKey(b), nil
}
func (state *State) GetChallengeKeyForRequest(challengeName string, until time.Time, r *http.Request) ChallengeKey {
address := getRequestAddress(r, state.Settings.ClientIpHeader)
hasher := sha256.New()
hasher.Write([]byte("challenge\x00"))
hasher.Write([]byte(challengeName))
hasher.Write([]byte{0})
hasher.Write(getRequestAddress(r, state.Settings.ClientIpHeader).To16())
hasher.Write(address.To16())
hasher.Write([]byte{0})
// specific headers
@@ -78,5 +112,13 @@ func (state *State) GetChallengeKeyForRequest(challengeName string, until time.T
hasher.Write(state.publicKey)
hasher.Write([]byte{0})
return hasher.Sum(nil)
sum := ChallengeKey(hasher.Sum(nil))
sum[0] = 0
if address.To4() != nil {
// Is IPv4, mark
sum.Set(ChallengeKeyFlagIsIPv4)
}
return ChallengeKey(sum)
}

View File

@@ -566,7 +566,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
redirect, err := utils.EnsureNoOpenRedirect(r.FormValue("redirect"))
if err != nil {
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusInternalServerError, err, "")
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadRequest, err, "")
return
}
@@ -585,29 +585,37 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
if ok, err := c.Verify(key, result, r); err != nil {
return err
} else if !ok {
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", redirect)
utils.ClearCookie(utils.CookiePrefix+challengeName, w)
data.Challenges[c.Id] = challenge.VerifyResultFAIL
state.SolveChallenge(key, challenge.VerifyResultFAIL)
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", redirect)
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName), redirect)
return nil
}
// catch happy eyeballs IPv4 -> IPv6 migration, re-direct to try again
if resultKey, err := ChallengeKeyFromString(result); err == nil && resultKey.Get(ChallengeKeyFlagIsIPv4) > 0 && key.Get(ChallengeKeyFlagIsIPv4) == 0 {
state.logger(r).Warn("challenge passed", "challenge", challengeName, "redirect", redirect)
token, err := c.IssueChallengeToken(state.privateKey, key, []byte(result), data.Expires)
if err != nil {
utils.ClearCookie(utils.CookiePrefix+challengeName, w)
} else {
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName), redirect)
return nil
}
} else {
utils.SetCookie(utils.CookiePrefix+challengeName, token, data.Expires, w)
}
data.Challenges[c.Id] = challenge.VerifyResultPASS
state.logger(r).Warn("challenge passed", "challenge", challengeName, "redirect", redirect)
state.SolveChallenge(key, challenge.VerifyResultPASS)
token, err := c.IssueChallengeToken(state.privateKey, key, []byte(result), data.Expires)
if err != nil {
utils.ClearCookie(utils.CookiePrefix+challengeName, w)
} else {
utils.SetCookie(utils.CookiePrefix+challengeName, token, data.Expires, w)
}
data.Challenges[c.Id] = challenge.VerifyResultPASS
state.SolveChallenge(key, challenge.VerifyResultPASS)
}
switch httpCode {
case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
if redirect == "" {
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadRequest, errors.New("no redirect found"), "")
return nil
}
http.Redirect(w, r, redirect, httpCode)
default:
w.Header().Set("Content-Type", mimeType)