Rearranged wasm challenge utils
This commit is contained in:
@@ -82,7 +82,9 @@ func main() {
|
||||
socketMode := flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.")
|
||||
|
||||
slogLevel := flag.String("slog-level", "WARN", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
|
||||
debug := flag.Bool("debug", false, "debug mode with logs and server timings")
|
||||
debugMode := flag.Bool("debug", false, "debug mode with logs and server timings")
|
||||
|
||||
clientIpHeader := flag.String("client-ip-header", "", "Client HTTP header to fetch their IP address from (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)")
|
||||
|
||||
policyFile := flag.String("policy", "", "path to policy YAML file")
|
||||
challengeTemplate := flag.String("challenge-template", "anubis", "name or path of the challenge template to use (anubis, forgejo)")
|
||||
@@ -110,7 +112,7 @@ func main() {
|
||||
leveler.Set(programLevel)
|
||||
|
||||
h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: *debug,
|
||||
AddSource: *debugMode,
|
||||
Level: leveler,
|
||||
})
|
||||
slog.SetDefault(slog.New(h))
|
||||
@@ -182,11 +184,12 @@ func main() {
|
||||
|
||||
state, err := lib.NewState(p, lib.StateSettings{
|
||||
Backends: createdBackends,
|
||||
Debug: *debug,
|
||||
Debug: *debugMode,
|
||||
PackageName: *packageName,
|
||||
ChallengeTemplate: *challengeTemplate,
|
||||
ChallengeTemplateTheme: *challengeTemplateTheme,
|
||||
PrivateKeySeed: seed,
|
||||
ClientIpHeader: *clientIpHeader,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge/wasm"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge/wasm/interface"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -18,7 +19,7 @@ func main() {
|
||||
makeChallenge := flag.String("make-challenge", "", "Path to contents for MakeChallenge input")
|
||||
makeChallengeOutput := flag.String("make-challenge-out", "", "Path to contents for expected MakeChallenge output")
|
||||
verifyChallenge := flag.String("verify-challenge", "", "Path to contents for VerifyChallenge input")
|
||||
verifyChallengeOutput := flag.Uint64("verify-challenge-out", uint64(challenge.VerifyChallengeOutputOK), "Path to contents for expected VerifyChallenge output")
|
||||
verifyChallengeOutput := flag.Uint64("verify-challenge-out", uint64(_interface.VerifyChallengeOutputOK), "Path to contents for expected VerifyChallenge output")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
@@ -32,7 +33,7 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
runner := challenge.NewRunner(true)
|
||||
runner := wasm.NewRunner(true)
|
||||
defer runner.Close()
|
||||
|
||||
err = runner.Compile("test", wasmData)
|
||||
@@ -44,7 +45,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var makeIn challenge.MakeChallengeInput
|
||||
var makeIn _interface.MakeChallengeInput
|
||||
err = json.Unmarshal(makeData, &makeIn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -54,7 +55,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var makeOut challenge.MakeChallengeOutput
|
||||
var makeOut _interface.MakeChallengeOutput
|
||||
err = json.Unmarshal(makeOutData, &makeOut)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -64,7 +65,7 @@ func main() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var verifyIn challenge.VerifyChallengeInput
|
||||
var verifyIn _interface.VerifyChallengeInput
|
||||
err = json.Unmarshal(verifyData, &verifyIn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -75,7 +76,7 @@ func main() {
|
||||
}
|
||||
|
||||
err = runner.Instantiate("test", func(ctx context.Context, mod api.Module) error {
|
||||
out, err := challenge.MakeChallengeCall(ctx, mod, makeIn)
|
||||
out, err := wasm.MakeChallengeCall(ctx, mod, makeIn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,13 +91,13 @@ func main() {
|
||||
}
|
||||
|
||||
err = runner.Instantiate("test", func(ctx context.Context, mod api.Module) error {
|
||||
out, err := challenge.VerifyChallengeCall(ctx, mod, verifyIn)
|
||||
out, err := wasm.VerifyChallengeCall(ctx, mod, verifyIn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if out != challenge.VerifyChallengeOutput(*verifyChallengeOutput) {
|
||||
return fmt.Errorf("verify output did not match expected output, got %d expected %d", out, challenge.VerifyChallengeOutput(*verifyChallengeOutput))
|
||||
if out != _interface.VerifyChallengeOutput(*verifyChallengeOutput) {
|
||||
return fmt.Errorf("verify output did not match expected output, got %d expected %d", out, _interface.VerifyChallengeOutput(*verifyChallengeOutput))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge/wasm/interface"
|
||||
"git.gammaspectra.live/git/go-away/utils/inline"
|
||||
"math/bits"
|
||||
"strconv"
|
||||
@@ -31,8 +31,8 @@ func getChallenge(key []byte, params map[string]string) ([]byte, uint64) {
|
||||
}
|
||||
|
||||
//go:wasmexport MakeChallenge
|
||||
func MakeChallenge(in challenge.Allocation) (out challenge.Allocation) {
|
||||
return challenge.MakeChallengeDecode(func(in challenge.MakeChallengeInput, out *challenge.MakeChallengeOutput) {
|
||||
func MakeChallenge(in _interface.Allocation) (out _interface.Allocation) {
|
||||
return _interface.MakeChallengeDecode(func(in _interface.MakeChallengeInput, out *_interface.MakeChallengeOutput) {
|
||||
c, difficulty := getChallenge(in.Key, in.Parameters)
|
||||
|
||||
// create target
|
||||
@@ -63,24 +63,24 @@ func MakeChallenge(in challenge.Allocation) (out challenge.Allocation) {
|
||||
}
|
||||
|
||||
//go:wasmexport VerifyChallenge
|
||||
func VerifyChallenge(in challenge.Allocation) (out challenge.VerifyChallengeOutput) {
|
||||
return challenge.VerifyChallengeDecode(func(in challenge.VerifyChallengeInput) challenge.VerifyChallengeOutput {
|
||||
func VerifyChallenge(in _interface.Allocation) (out _interface.VerifyChallengeOutput) {
|
||||
return _interface.VerifyChallengeDecode(func(in _interface.VerifyChallengeInput) _interface.VerifyChallengeOutput {
|
||||
c, difficulty := getChallenge(in.Key, in.Parameters)
|
||||
|
||||
result := make([]byte, inline.DecodedLen(len(in.Result)))
|
||||
n, err := inline.Decode(result, in.Result)
|
||||
if err != nil {
|
||||
return challenge.VerifyChallengeOutputError
|
||||
return _interface.VerifyChallengeOutputError
|
||||
}
|
||||
result = result[:n]
|
||||
|
||||
if len(result) < 8 {
|
||||
return challenge.VerifyChallengeOutputError
|
||||
return _interface.VerifyChallengeOutputError
|
||||
}
|
||||
|
||||
// verify we used same challenge
|
||||
if subtle.ConstantTimeCompare(result[:len(result)-8], c) != 1 {
|
||||
return challenge.VerifyChallengeOutputFailed
|
||||
return _interface.VerifyChallengeOutputFailed
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(result)
|
||||
@@ -95,9 +95,9 @@ func VerifyChallenge(in challenge.Allocation) (out challenge.VerifyChallengeOutp
|
||||
}
|
||||
|
||||
if leadingZeroesCount < int(difficulty) {
|
||||
return challenge.VerifyChallengeOutputFailed
|
||||
return _interface.VerifyChallengeOutputFailed
|
||||
}
|
||||
|
||||
return challenge.VerifyChallengeOutputOK
|
||||
return _interface.VerifyChallengeOutputOK
|
||||
}, in)
|
||||
}
|
||||
|
||||
Binary file not shown.
@@ -25,12 +25,17 @@ type ChallengeInformation struct {
|
||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||
}
|
||||
|
||||
func getRequestAddress(r *http.Request) net.IP {
|
||||
//TODO: verified upstream
|
||||
ipStr := r.Header.Get("X-Real-Ip")
|
||||
if ipStr == "" {
|
||||
ipStr = strings.Split(r.Header.Get("X-Forwarded-For"), ",")[0]
|
||||
func getRequestAddress(r *http.Request, clientHeader string) net.IP {
|
||||
var ipStr string
|
||||
if clientHeader != "" {
|
||||
ipStr = r.Header.Get(clientHeader)
|
||||
}
|
||||
if ipStr != "" {
|
||||
// handle X-Forwarded-For
|
||||
ipStr = strings.Split(ipStr, ",")[0]
|
||||
}
|
||||
|
||||
// fallback
|
||||
if ipStr == "" {
|
||||
parts := strings.Split(r.RemoteAddr, ":")
|
||||
// drop port
|
||||
@@ -44,7 +49,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h
|
||||
hasher.Write([]byte("challenge\x00"))
|
||||
hasher.Write([]byte(name))
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write(getRequestAddress(r).To16())
|
||||
hasher.Write(getRequestAddress(r, state.Settings.ClientIpHeader).To16())
|
||||
hasher.Write([]byte{0})
|
||||
|
||||
// specific headers
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package challenge
|
||||
package _interface
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
10
lib/challenge/wasm/interface/interface_generic.go
Normal file
10
lib/challenge/wasm/interface/interface_generic.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !tinygo || !wasip1
|
||||
|
||||
package _interface
|
||||
|
||||
func PtrToBytes(ptr uint32, size uint32) []byte { panic("not implemented") }
|
||||
func BytesToPtr(s []byte) (uint32, uint32) { panic("not implemented") }
|
||||
func BytesToLeakedPtr(s []byte) (uint32, uint32) { panic("not implemented") }
|
||||
func PtrToString(ptr uint32, size uint32) string { panic("not implemented") }
|
||||
func StringToPtr(s string) (uint32, uint32) { panic("not implemented") }
|
||||
func StringToLeakedPtr(s string) (uint32, uint32) { panic("not implemented") }
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build tinygo
|
||||
|
||||
package challenge
|
||||
package _interface
|
||||
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
@@ -1,10 +1,7 @@
|
||||
//go:build !tinygo || !wasip1
|
||||
|
||||
package challenge
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/tetratelabs/wazero"
|
||||
@@ -123,73 +120,3 @@ func (r *Runner) Instantiate(key string, f func(ctx context.Context, mod api.Mod
|
||||
|
||||
return f(r.context, mod)
|
||||
}
|
||||
|
||||
func MakeChallengeCall(ctx context.Context, mod api.Module, in MakeChallengeInput) (*MakeChallengeOutput, error) {
|
||||
makeChallengeFunc := mod.ExportedFunction("MakeChallenge")
|
||||
malloc := mod.ExportedFunction("malloc")
|
||||
free := mod.ExportedFunction("free")
|
||||
|
||||
inData, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer free.Call(ctx, mallocResult[0])
|
||||
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
|
||||
return nil, errors.New("could not write memory")
|
||||
}
|
||||
result, err := makeChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultPtr := Allocation(result[0])
|
||||
outData, ok := mod.Memory().Read(resultPtr.Pointer(), resultPtr.Size())
|
||||
if !ok {
|
||||
return nil, errors.New("could not read result")
|
||||
}
|
||||
defer free.Call(ctx, uint64(resultPtr.Pointer()))
|
||||
|
||||
var out MakeChallengeOutput
|
||||
err = json.Unmarshal(outData, &out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func VerifyChallengeCall(ctx context.Context, mod api.Module, in VerifyChallengeInput) (VerifyChallengeOutput, error) {
|
||||
verifyChallengeFunc := mod.ExportedFunction("VerifyChallenge")
|
||||
malloc := mod.ExportedFunction("malloc")
|
||||
free := mod.ExportedFunction("free")
|
||||
|
||||
inData, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return VerifyChallengeOutputError, err
|
||||
}
|
||||
|
||||
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
|
||||
if err != nil {
|
||||
return VerifyChallengeOutputError, err
|
||||
}
|
||||
defer free.Call(ctx, mallocResult[0])
|
||||
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
|
||||
return VerifyChallengeOutputError, errors.New("could not write memory")
|
||||
}
|
||||
result, err := verifyChallengeFunc.Call(ctx, uint64(NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
|
||||
if err != nil {
|
||||
return VerifyChallengeOutputError, err
|
||||
}
|
||||
|
||||
return VerifyChallengeOutput(result[0]), nil
|
||||
}
|
||||
|
||||
func PtrToBytes(ptr uint32, size uint32) []byte { panic("not implemented") }
|
||||
func BytesToPtr(s []byte) (uint32, uint32) { panic("not implemented") }
|
||||
func BytesToLeakedPtr(s []byte) (uint32, uint32) { panic("not implemented") }
|
||||
func PtrToString(ptr uint32, size uint32) string { panic("not implemented") }
|
||||
func StringToPtr(s string) (uint32, uint32) { panic("not implemented") }
|
||||
func StringToLeakedPtr(s string) (uint32, uint32) { panic("not implemented") }
|
||||
72
lib/challenge/wasm/utils.go
Normal file
72
lib/challenge/wasm/utils.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package wasm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge/wasm/interface"
|
||||
"github.com/tetratelabs/wazero/api"
|
||||
)
|
||||
|
||||
func MakeChallengeCall(ctx context.Context, mod api.Module, in _interface.MakeChallengeInput) (*_interface.MakeChallengeOutput, error) {
|
||||
makeChallengeFunc := mod.ExportedFunction("MakeChallenge")
|
||||
malloc := mod.ExportedFunction("malloc")
|
||||
free := mod.ExportedFunction("free")
|
||||
|
||||
inData, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer free.Call(ctx, mallocResult[0])
|
||||
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
|
||||
return nil, errors.New("could not write memory")
|
||||
}
|
||||
result, err := makeChallengeFunc.Call(ctx, uint64(_interface.NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resultPtr := _interface.Allocation(result[0])
|
||||
outData, ok := mod.Memory().Read(resultPtr.Pointer(), resultPtr.Size())
|
||||
if !ok {
|
||||
return nil, errors.New("could not read result")
|
||||
}
|
||||
defer free.Call(ctx, uint64(resultPtr.Pointer()))
|
||||
|
||||
var out _interface.MakeChallengeOutput
|
||||
err = json.Unmarshal(outData, &out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func VerifyChallengeCall(ctx context.Context, mod api.Module, in _interface.VerifyChallengeInput) (_interface.VerifyChallengeOutput, error) {
|
||||
verifyChallengeFunc := mod.ExportedFunction("VerifyChallenge")
|
||||
malloc := mod.ExportedFunction("malloc")
|
||||
free := mod.ExportedFunction("free")
|
||||
|
||||
inData, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return _interface.VerifyChallengeOutputError, err
|
||||
}
|
||||
|
||||
mallocResult, err := malloc.Call(ctx, uint64(len(inData)))
|
||||
if err != nil {
|
||||
return _interface.VerifyChallengeOutputError, err
|
||||
}
|
||||
defer free.Call(ctx, mallocResult[0])
|
||||
if !mod.Memory().Write(uint32(mallocResult[0]), inData) {
|
||||
return _interface.VerifyChallengeOutputError, errors.New("could not write memory")
|
||||
}
|
||||
result, err := verifyChallengeFunc.Call(ctx, uint64(_interface.NewAllocation(uint32(mallocResult[0]), uint32(len(inData)))))
|
||||
if err != nil {
|
||||
return _interface.VerifyChallengeOutputError, err
|
||||
}
|
||||
|
||||
return _interface.VerifyChallengeOutput(result[0]), nil
|
||||
}
|
||||
20
lib/http.go
20
lib/http.go
@@ -125,10 +125,10 @@ func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration
|
||||
}
|
||||
}
|
||||
|
||||
func GetLoggerForRequest(r *http.Request) *slog.Logger {
|
||||
func GetLoggerForRequest(r *http.Request, clientHeader string) *slog.Logger {
|
||||
return slog.With(
|
||||
"request_id", r.Header.Get("X-Away-Id"),
|
||||
"remote_address", getRequestAddress(r),
|
||||
"remote_address", getRequestAddress(r, clientHeader),
|
||||
"user_agent", r.UserAgent(),
|
||||
"host", r.Host,
|
||||
"path", r.URL.Path,
|
||||
@@ -136,6 +136,10 @@ func GetLoggerForRequest(r *http.Request) *slog.Logger {
|
||||
)
|
||||
}
|
||||
|
||||
func (state *State) logger(r *http.Request) *slog.Logger {
|
||||
return GetLoggerForRequest(r, state.Settings.ClientIpHeader)
|
||||
}
|
||||
|
||||
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
|
||||
@@ -145,7 +149,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
lg := GetLoggerForRequest(r)
|
||||
lg := state.logger(r)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
@@ -153,7 +157,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
env := map[string]any{
|
||||
"host": host,
|
||||
"method": r.Method,
|
||||
"remoteAddress": getRequestAddress(r),
|
||||
"remoteAddress": getRequestAddress(r, state.Settings.ClientIpHeader),
|
||||
"userAgent": r.UserAgent(),
|
||||
"path": r.URL.Path,
|
||||
"query": func() map[string]string {
|
||||
@@ -259,7 +263,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if rule.Action == policy.RuleActionCHECK {
|
||||
goto nextRule
|
||||
}
|
||||
GetLoggerForRequest(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
|
||||
state.logger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
|
||||
|
||||
// we pass the challenge early!
|
||||
r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS")
|
||||
@@ -374,15 +378,15 @@ func (state *State) setupRoutes() error {
|
||||
state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start))
|
||||
|
||||
if err != nil {
|
||||
GetLoggerForRequest(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
state.logger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
return err
|
||||
} else if !ok {
|
||||
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
||||
return nil
|
||||
}
|
||||
GetLoggerForRequest(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
state.logger(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
|
||||
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
||||
if err != nil {
|
||||
|
||||
26
lib/state.go
26
lib/state.go
@@ -12,7 +12,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.gammaspectra.live/git/go-away/embed"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge/wasm"
|
||||
"git.gammaspectra.live/git/go-away/lib/challenge/wasm/interface"
|
||||
"git.gammaspectra.live/git/go-away/lib/condition"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"git.gammaspectra.live/git/go-away/utils/inline"
|
||||
@@ -44,7 +45,7 @@ type State struct {
|
||||
|
||||
Networks map[string]cidranger.Ranger
|
||||
|
||||
Wasm *challenge.Runner
|
||||
Wasm *wasm.Runner
|
||||
|
||||
Challenges map[string]ChallengeState
|
||||
|
||||
@@ -101,6 +102,7 @@ type StateSettings struct {
|
||||
PackageName string
|
||||
ChallengeTemplate string
|
||||
ChallengeTemplateTheme string
|
||||
ClientIpHeader string
|
||||
}
|
||||
|
||||
func NewState(p policy.Policy, settings StateSettings) (state *State, err error) {
|
||||
@@ -118,7 +120,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
||||
if proxy, ok := backend.(*httputil.ReverseProxy); ok {
|
||||
if proxy.ErrorHandler == nil {
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
GetLoggerForRequest(r).Error(err.Error())
|
||||
state.logger(r).Error(err.Error())
|
||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadGateway, err)
|
||||
}
|
||||
}
|
||||
@@ -186,7 +188,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
||||
state.Networks[k] = ranger
|
||||
}
|
||||
|
||||
state.Wasm = challenge.NewRunner(true)
|
||||
state.Wasm = wasm.NewRunner(true)
|
||||
|
||||
state.Challenges = make(map[string]ChallengeState)
|
||||
|
||||
@@ -429,12 +431,12 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
||||
if ok, err := c.Verify(key, result); err != nil {
|
||||
return err
|
||||
} else if !ok {
|
||||
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
state.logger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
||||
return nil
|
||||
}
|
||||
GetLoggerForRequest(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
state.logger(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
|
||||
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
||||
if err != nil {
|
||||
@@ -476,7 +478,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
||||
c.MakeChallenge = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
err := state.Wasm.Instantiate(challengeName, func(ctx context.Context, mod api.Module) (err error) {
|
||||
|
||||
in := challenge.MakeChallengeInput{
|
||||
in := _interface.MakeChallengeInput{
|
||||
Key: state.GetChallengeKeyForRequest(challengeName, time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity), r),
|
||||
Parameters: p.Parameters,
|
||||
Headers: inline.MIMEHeader(r.Header),
|
||||
@@ -486,7 +488,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := challenge.MakeChallengeCall(ctx, mod, in)
|
||||
out, err := wasm.MakeChallengeCall(ctx, mod, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -508,21 +510,21 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
||||
|
||||
c.Verify = func(key []byte, result string) (ok bool, err error) {
|
||||
err = state.Wasm.Instantiate(challengeName, func(ctx context.Context, mod api.Module) (err error) {
|
||||
in := challenge.VerifyChallengeInput{
|
||||
in := _interface.VerifyChallengeInput{
|
||||
Key: key,
|
||||
Parameters: p.Parameters,
|
||||
Result: []byte(result),
|
||||
}
|
||||
|
||||
out, err := challenge.VerifyChallengeCall(ctx, mod, in)
|
||||
out, err := wasm.VerifyChallengeCall(ctx, mod, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if out == challenge.VerifyChallengeOutputError {
|
||||
if out == _interface.VerifyChallengeOutputError {
|
||||
return errors.New("error checking challenge")
|
||||
}
|
||||
ok = out == challenge.VerifyChallengeOutputOK
|
||||
ok = out == _interface.VerifyChallengeOutputOK
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user