Allow multiple backends

This commit is contained in:
WeebDataHoarder
2025-04-02 19:23:09 +02:00
parent 8d9d5a8ab3
commit 150927e7ba
6 changed files with 92 additions and 53 deletions

View File

@@ -1,52 +1,20 @@
package main package main
import ( import (
"context"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"git.gammaspectra.live/git/go-away/lib" "git.gammaspectra.live/git/go-away/lib"
"git.gammaspectra.live/git/go-away/lib/network"
"git.gammaspectra.live/git/go-away/lib/policy" "git.gammaspectra.live/git/go-away/lib/policy"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"log" "log"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"net/url"
"os" "os"
"strconv" "strconv"
) )
func makeReverseProxy(target string) (http.Handler, error) {
u, err := url.Parse(target)
if err != nil {
return nil, fmt.Errorf("failed to parse target URL: %w", err)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
if u.Scheme == "unix" {
// clean path up so we don't use the socket path in proxied requests
addr := u.Path
u.Path = ""
// tell transport how to dial unix sockets
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", addr)
}
// tell transport how to handle the unix url scheme
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
}
rp := httputil.NewSingleHostReverseProxy(u)
rp.Transport = transport
return rp, nil
}
func setupListener(network, address, socketMode string) (net.Listener, string) { func setupListener(network, address, socketMode string) (net.Listener, string) {
formattedAddress := "" formattedAddress := ""
switch network { switch network {
@@ -88,15 +56,11 @@ func main() {
slogLevel := flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)") slogLevel := flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
target := flag.String("target", "http://localhost:80", "target to reverse proxy to")
policyFile := flag.String("policy", "", "path to policy YAML file") policyFile := flag.String("policy", "", "path to policy YAML file")
challengeTemplate := flag.String("challenge-template", "anubis", "name of the challenge template to use") challengeTemplate := flag.String("challenge-template", "anubis", "name of the challenge template to use")
flag.Parse() flag.Parse()
_, _, _, _ = bind, bindNetwork, socketMode, target
{ {
var programLevel slog.Level var programLevel slog.Level
if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil { if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil {
@@ -119,19 +83,13 @@ func main() {
log.Fatal(fmt.Errorf("failed to read policy file: %w", err)) log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
} }
var policy policy.Policy var p policy.Policy
if err = yaml.Unmarshal(policyData, &policy); err != nil { if err = yaml.Unmarshal(policyData, &p); err != nil {
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err)) log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
} }
backend, err := makeReverseProxy(*target) state, err := lib.NewState(p, lib.StateSettings{
if err != nil {
log.Fatal(fmt.Errorf("failed to create reverse proxy for %s: %w", *target, err))
}
state, err := lib.NewState(policy, lib.StateSettings{
Backend: backend,
PackagePath: "git.gammaspectra.live/git/go-away/cmd", PackagePath: "git.gammaspectra.live/git/go-away/cmd",
ChallengeTemplate: *challengeTemplate, ChallengeTemplate: *challengeTemplate,
}) })
@@ -144,7 +102,6 @@ func main() {
slog.Info( slog.Info(
"listening", "listening",
"url", listenUrl, "url", listenUrl,
"target", *target,
) )
server := http.Server{ server := http.Server{

View File

@@ -3,16 +3,21 @@ package lib
import ( import (
"bytes" "bytes"
"codeberg.org/meta/gzipped/v2" "codeberg.org/meta/gzipped/v2"
"context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
go_away "git.gammaspectra.live/git/go-away" go_away "git.gammaspectra.live/git/go-away"
"git.gammaspectra.live/git/go-away/lib/network"
"git.gammaspectra.live/git/go-away/lib/policy" "git.gammaspectra.live/git/go-away/lib/policy"
"github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types"
"html/template" "html/template"
"maps" "maps"
"net"
"net/http" "net/http"
"net/http/httputil"
"net/url"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@@ -54,6 +59,34 @@ func init() {
} }
} }
func makeReverseProxy(target string) (http.Handler, error) {
u, err := url.Parse(target)
if err != nil {
return nil, fmt.Errorf("failed to parse target URL: %w", err)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
if u.Scheme == "unix" {
// clean path up so we don't use the socket path in proxied requests
addr := u.Path
u.Path = ""
// tell transport how to dial unix sockets
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", addr)
}
// tell transport how to handle the unix url scheme
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
}
rp := httputil.NewSingleHostReverseProxy(u)
rp.Transport = transport
return rp, nil
}
func (state *State) challengePage(w http.ResponseWriter, status int, challenge string, params map[string]any) error { func (state *State) challengePage(w http.ResponseWriter, status int, challenge string, params map[string]any) error {
input := make(map[string]any) input := make(map[string]any)
input["Random"] = cacheBust input["Random"] = cacheBust
@@ -104,8 +137,17 @@ func (state *State) errorPage(w http.ResponseWriter, status int, err error) erro
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
host := r.Host
backend, ok := state.Backends[host]
if !ok {
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
return
}
//TODO better matcher! combo ast? //TODO better matcher! combo ast?
env := map[string]any{ env := map[string]any{
"host": host,
"method": r.Method, "method": r.Method,
"remoteAddress": state.GetRequestAddress(r), "remoteAddress": state.GetRequestAddress(r),
"userAgent": r.UserAgent(), "userAgent": r.UserAgent(),
@@ -127,6 +169,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
} }
for _, rule := range state.Rules { for _, rule := range state.Rules {
// skip rules that have host match
if rule.Host != nil && *rule.Host != host {
continue
}
if out, _, err := rule.Program.Eval(env); err != nil { if out, _, err := rule.Program.Eval(env); err != nil {
//TODO error //TODO error
panic(err) panic(err)
@@ -136,7 +182,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
default: default:
panic(fmt.Errorf("unknown action %s", rule.Action)) panic(fmt.Errorf("unknown action %s", rule.Action))
case policy.RuleActionPASS: case policy.RuleActionPASS:
state.Backend.ServeHTTP(w, r) backend.ServeHTTP(w, r)
return return
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK: case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity) expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
@@ -154,7 +200,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
} }
// we passed the challenge! // we passed the challenge!
//TODO log? //TODO log?
state.Backend.ServeHTTP(w, r) backend.ServeHTTP(w, r)
return return
} }
} }
@@ -174,7 +220,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
goto nextRule goto nextRule
} }
// we pass the challenge early! // we pass the challenge early!
state.Backend.ServeHTTP(w, r) backend.ServeHTTP(w, r)
return return
} }
} else { } else {
@@ -197,7 +243,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
nextRule: nextRule:
} }
state.Backend.ServeHTTP(w, r) backend.ServeHTTP(w, r)
return return
} }

View File

@@ -43,4 +43,6 @@ type Policy struct {
Challenges map[string]Challenge `yaml:"challenges"` Challenges map[string]Challenge `yaml:"challenges"`
Rules []Rule `yaml:"rules"` Rules []Rule `yaml:"rules"`
Backends map[string]string `json:"backends"`
} }

View File

@@ -12,6 +12,7 @@ const (
type Rule struct { type Rule struct {
Name string `yaml:"name"` Name string `yaml:"name"`
Host *string `yaml:"host"`
Conditions []string `yaml:"conditions"` Conditions []string `yaml:"conditions"`
Action string `yaml:"action"` Action string `yaml:"action"`

View File

@@ -40,7 +40,7 @@ type State struct {
Settings StateSettings Settings StateSettings
UrlPath string UrlPath string
Mux *http.ServeMux Mux *http.ServeMux
Backend http.Handler Backends map[string]http.Handler
Networks map[string]cidranger.Ranger Networks map[string]cidranger.Ranger
@@ -61,6 +61,8 @@ type RuleState struct {
Name string Name string
Hash string Hash string
Host *string
Program cel.Program Program cel.Program
Action policy.RuleAction Action policy.RuleAction
Challenges []string Challenges []string
@@ -94,7 +96,6 @@ type ChallengeState struct {
} }
type StateSettings struct { type StateSettings struct {
Backend http.Handler
PackagePath string PackagePath string
ChallengeTemplate string ChallengeTemplate string
} }
@@ -108,7 +109,16 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
}, },
} }
state.UrlPath = "/.well-known/." + state.Settings.PackagePath state.UrlPath = "/.well-known/." + state.Settings.PackagePath
state.Backend = settings.Backend
state.Backends = make(map[string]http.Handler)
for k, v := range p.Backends {
backend, err := makeReverseProxy(v)
if err != nil {
return nil, fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err)
}
state.Backends[k] = backend
}
state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader) state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader)
if err != nil { if err != nil {
@@ -492,6 +502,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
state.RulesEnv, err = cel.NewEnv( state.RulesEnv, err = cel.NewEnv(
cel.DefaultUTCTimeZone(true), cel.DefaultUTCTimeZone(true),
cel.Variable("remoteAddress", cel.BytesType), cel.Variable("remoteAddress", cel.BytesType),
cel.Variable("host", cel.StringType),
cel.Variable("method", cel.StringType), cel.Variable("method", cel.StringType),
cel.Variable("userAgent", cel.StringType), cel.Variable("userAgent", cel.StringType),
cel.Variable("path", cel.StringType), cel.Variable("path", cel.StringType),
@@ -565,12 +576,18 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
for _, rule := range p.Rules { for _, rule := range p.Rules {
hasher := sha256.New() hasher := sha256.New()
hasher.Write([]byte(rule.Name)) hasher.Write([]byte(rule.Name))
hasher.Write([]byte{0})
if rule.Host != nil {
hasher.Write([]byte(*rule.Host))
}
hasher.Write([]byte{0})
hasher.Write(privateKeyFingerprint[:]) hasher.Write(privateKeyFingerprint[:])
sum := hasher.Sum(nil) sum := hasher.Sum(nil)
r := RuleState{ r := RuleState{
Name: rule.Name, Name: rule.Name,
Hash: hex.EncodeToString(sum[:8]), Hash: hex.EncodeToString(sum[:8]),
Host: rule.Host,
Action: policy.RuleAction(strings.ToUpper(rule.Action)), Action: policy.RuleAction(strings.ToUpper(rule.Action)),
Challenges: rule.Challenges, Challenges: rule.Challenges,
} }

View File

@@ -1,4 +1,8 @@
# Define backends to use. Rules can be done generally, or only applying to specific hosts
backends:
git.gammaspectra.live: http://gitea:3000
# Define networks to be used later below # Define networks to be used later below
networks: networks:
# todo: support direct ASN lookups # todo: support direct ASN lookups
@@ -218,6 +222,18 @@ conditions:
# user activity tab # user activity tab
- 'path.matches("^/[^/]") && "tab" in query && query.tab == "activity"' - 'path.matches("^/[^/]") && "tab" in query && query.tab == "activity"'
# Rules and conditions are served this environment
# remoteAddress (net.IP) - Connecting client remote address from headers or properties
# host (string) - HTTP Host
# method (string) - HTTP Method/Verb
# userAgent (string) - HTTP User-Agent header
# path (string) - HTTP request Path
# query (map[string]string) - HTTP request Query arguments
# headers (map[string]string) - HTTP request headers
#
# Additionally these functions are available
# inNetwork(networkName string, address net.IP) bool
# inNetwork(networkCIDR string, address net.IP) bool
rules: rules:
- name: undesired-networks - name: undesired-networks
conditions: conditions: