Allow multiple backends

This commit is contained in:
WeebDataHoarder
2025-04-02 19:23:09 +02:00
parent 8d9d5a8ab3
commit 150927e7ba
6 changed files with 92 additions and 53 deletions

View File

@@ -3,16 +3,21 @@ package lib
import (
"bytes"
"codeberg.org/meta/gzipped/v2"
"context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
go_away "git.gammaspectra.live/git/go-away"
"git.gammaspectra.live/git/go-away/lib/network"
"git.gammaspectra.live/git/go-away/lib/policy"
"github.com/google/cel-go/common/types"
"html/template"
"maps"
"net"
"net/http"
"net/http/httputil"
"net/url"
"path/filepath"
"strings"
"time"
@@ -54,6 +59,34 @@ func init() {
}
}
func makeReverseProxy(target string) (http.Handler, error) {
u, err := url.Parse(target)
if err != nil {
return nil, fmt.Errorf("failed to parse target URL: %w", err)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
if u.Scheme == "unix" {
// clean path up so we don't use the socket path in proxied requests
addr := u.Path
u.Path = ""
// tell transport how to dial unix sockets
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", addr)
}
// tell transport how to handle the unix url scheme
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
}
rp := httputil.NewSingleHostReverseProxy(u)
rp.Transport = transport
return rp, nil
}
func (state *State) challengePage(w http.ResponseWriter, status int, challenge string, params map[string]any) error {
input := make(map[string]any)
input["Random"] = cacheBust
@@ -104,8 +137,17 @@ func (state *State) errorPage(w http.ResponseWriter, status int, err error) erro
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
host := r.Host
backend, ok := state.Backends[host]
if !ok {
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
return
}
//TODO better matcher! combo ast?
env := map[string]any{
"host": host,
"method": r.Method,
"remoteAddress": state.GetRequestAddress(r),
"userAgent": r.UserAgent(),
@@ -127,6 +169,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
}
for _, rule := range state.Rules {
// skip rules that have host match
if rule.Host != nil && *rule.Host != host {
continue
}
if out, _, err := rule.Program.Eval(env); err != nil {
//TODO error
panic(err)
@@ -136,7 +182,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
default:
panic(fmt.Errorf("unknown action %s", rule.Action))
case policy.RuleActionPASS:
state.Backend.ServeHTTP(w, r)
backend.ServeHTTP(w, r)
return
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
@@ -154,7 +200,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
}
// we passed the challenge!
//TODO log?
state.Backend.ServeHTTP(w, r)
backend.ServeHTTP(w, r)
return
}
}
@@ -174,7 +220,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
goto nextRule
}
// we pass the challenge early!
state.Backend.ServeHTTP(w, r)
backend.ServeHTTP(w, r)
return
}
} else {
@@ -197,7 +243,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
nextRule:
}
state.Backend.ServeHTTP(w, r)
backend.ServeHTTP(w, r)
return
}