Move backends to cmd args, allow setting private key seed via parameter or ENV var
This commit is contained in:
51
lib/http.go
51
lib/http.go
@@ -3,7 +3,6 @@ package lib
|
||||
import (
|
||||
"bytes"
|
||||
"codeberg.org/meta/gzipped/v2"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
@@ -11,16 +10,12 @@ import (
|
||||
"fmt"
|
||||
"git.gammaspectra.live/git/go-away/embed"
|
||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||
"git.gammaspectra.live/git/go-away/utils"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"html/template"
|
||||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -72,34 +67,6 @@ func initTemplate(name, data string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeReverseProxy(target string) (*httputil.ReverseProxy, error) {
|
||||
u, err := url.Parse(target)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse target URL: %w", err)
|
||||
}
|
||||
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
|
||||
if u.Scheme == "unix" {
|
||||
// clean path up so we don't use the socket path in proxied requests
|
||||
addr := u.Path
|
||||
u.Path = ""
|
||||
// tell transport how to dial unix sockets
|
||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
dialer := net.Dialer{}
|
||||
return dialer.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
// tell transport how to handle the unix url scheme
|
||||
transport.RegisterProtocol("unix", utils.UnixRoundTripper{Transport: transport})
|
||||
}
|
||||
|
||||
rp := httputil.NewSingleHostReverseProxy(u)
|
||||
rp.Transport = transport
|
||||
|
||||
return rp, nil
|
||||
}
|
||||
|
||||
func (state *State) challengePage(w http.ResponseWriter, id string, status int, challenge string, params map[string]any) error {
|
||||
input := make(map[string]any)
|
||||
input["Id"] = id
|
||||
@@ -158,10 +125,10 @@ func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration
|
||||
}
|
||||
}
|
||||
|
||||
func (state *State) getLogger(r *http.Request) *slog.Logger {
|
||||
func GetLoggerForRequest(r *http.Request) *slog.Logger {
|
||||
return slog.With(
|
||||
"request_id", r.Header.Get("X-Away-Id"),
|
||||
"remote_address", state.GetRequestAddress(r),
|
||||
"remote_address", getRequestAddress(r),
|
||||
"user_agent", r.UserAgent(),
|
||||
"host", r.Host,
|
||||
"path", r.URL.Path,
|
||||
@@ -172,13 +139,13 @@ func (state *State) getLogger(r *http.Request) *slog.Logger {
|
||||
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
|
||||
backend, ok := state.Backends[host]
|
||||
backend, ok := state.Settings.Backends[host]
|
||||
if !ok {
|
||||
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
lg := state.getLogger(r)
|
||||
lg := GetLoggerForRequest(r)
|
||||
|
||||
start := time.Now()
|
||||
|
||||
@@ -186,7 +153,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
env := map[string]any{
|
||||
"host": host,
|
||||
"method": r.Method,
|
||||
"remoteAddress": state.GetRequestAddress(r),
|
||||
"remoteAddress": getRequestAddress(r),
|
||||
"userAgent": r.UserAgent(),
|
||||
"path": r.URL.Path,
|
||||
"query": func() map[string]string {
|
||||
@@ -292,7 +259,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if rule.Action == policy.RuleActionCHECK {
|
||||
goto nextRule
|
||||
}
|
||||
state.getLogger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
|
||||
GetLoggerForRequest(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
|
||||
|
||||
// we pass the challenge early!
|
||||
r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS")
|
||||
@@ -407,15 +374,15 @@ func (state *State) setupRoutes() error {
|
||||
state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start))
|
||||
|
||||
if err != nil {
|
||||
state.getLogger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
GetLoggerForRequest(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
return err
|
||||
} else if !ok {
|
||||
state.getLogger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
ClearCookie(CookiePrefix+challengeName, w)
|
||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
||||
return nil
|
||||
}
|
||||
state.getLogger(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
GetLoggerForRequest(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||
|
||||
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user