Allow multiple backends
This commit is contained in:
49
cmd/away.go
49
cmd/away.go
@@ -1,52 +1,20 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"git.gammaspectra.live/git/go-away/lib"
|
"git.gammaspectra.live/git/go-away/lib"
|
||||||
"git.gammaspectra.live/git/go-away/lib/network"
|
|
||||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeReverseProxy(target string) (http.Handler, error) {
|
|
||||||
u, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse target URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
|
|
||||||
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
|
|
||||||
if u.Scheme == "unix" {
|
|
||||||
// clean path up so we don't use the socket path in proxied requests
|
|
||||||
addr := u.Path
|
|
||||||
u.Path = ""
|
|
||||||
// tell transport how to dial unix sockets
|
|
||||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
dialer := net.Dialer{}
|
|
||||||
return dialer.DialContext(ctx, "unix", addr)
|
|
||||||
}
|
|
||||||
// tell transport how to handle the unix url scheme
|
|
||||||
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
|
|
||||||
}
|
|
||||||
|
|
||||||
rp := httputil.NewSingleHostReverseProxy(u)
|
|
||||||
rp.Transport = transport
|
|
||||||
|
|
||||||
return rp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupListener(network, address, socketMode string) (net.Listener, string) {
|
func setupListener(network, address, socketMode string) (net.Listener, string) {
|
||||||
formattedAddress := ""
|
formattedAddress := ""
|
||||||
switch network {
|
switch network {
|
||||||
@@ -88,15 +56,11 @@ func main() {
|
|||||||
|
|
||||||
slogLevel := flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
|
slogLevel := flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
|
||||||
|
|
||||||
target := flag.String("target", "http://localhost:80", "target to reverse proxy to")
|
|
||||||
|
|
||||||
policyFile := flag.String("policy", "", "path to policy YAML file")
|
policyFile := flag.String("policy", "", "path to policy YAML file")
|
||||||
challengeTemplate := flag.String("challenge-template", "anubis", "name of the challenge template to use")
|
challengeTemplate := flag.String("challenge-template", "anubis", "name of the challenge template to use")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
_, _, _, _ = bind, bindNetwork, socketMode, target
|
|
||||||
|
|
||||||
{
|
{
|
||||||
var programLevel slog.Level
|
var programLevel slog.Level
|
||||||
if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil {
|
if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil {
|
||||||
@@ -119,19 +83,13 @@ func main() {
|
|||||||
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
var policy policy.Policy
|
var p policy.Policy
|
||||||
|
|
||||||
if err = yaml.Unmarshal(policyData, &policy); err != nil {
|
if err = yaml.Unmarshal(policyData, &p); err != nil {
|
||||||
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
backend, err := makeReverseProxy(*target)
|
state, err := lib.NewState(p, lib.StateSettings{
|
||||||
if err != nil {
|
|
||||||
log.Fatal(fmt.Errorf("failed to create reverse proxy for %s: %w", *target, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
state, err := lib.NewState(policy, lib.StateSettings{
|
|
||||||
Backend: backend,
|
|
||||||
PackagePath: "git.gammaspectra.live/git/go-away/cmd",
|
PackagePath: "git.gammaspectra.live/git/go-away/cmd",
|
||||||
ChallengeTemplate: *challengeTemplate,
|
ChallengeTemplate: *challengeTemplate,
|
||||||
})
|
})
|
||||||
@@ -144,7 +102,6 @@ func main() {
|
|||||||
slog.Info(
|
slog.Info(
|
||||||
"listening",
|
"listening",
|
||||||
"url", listenUrl,
|
"url", listenUrl,
|
||||||
"target", *target,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
server := http.Server{
|
server := http.Server{
|
||||||
|
|||||||
54
lib/http.go
54
lib/http.go
@@ -3,16 +3,21 @@ package lib
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"codeberg.org/meta/gzipped/v2"
|
"codeberg.org/meta/gzipped/v2"
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
go_away "git.gammaspectra.live/git/go-away"
|
go_away "git.gammaspectra.live/git/go-away"
|
||||||
|
"git.gammaspectra.live/git/go-away/lib/network"
|
||||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
"html/template"
|
"html/template"
|
||||||
"maps"
|
"maps"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -54,6 +59,34 @@ func init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeReverseProxy(target string) (http.Handler, error) {
|
||||||
|
u, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse target URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
|
||||||
|
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
|
||||||
|
if u.Scheme == "unix" {
|
||||||
|
// clean path up so we don't use the socket path in proxied requests
|
||||||
|
addr := u.Path
|
||||||
|
u.Path = ""
|
||||||
|
// tell transport how to dial unix sockets
|
||||||
|
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "unix", addr)
|
||||||
|
}
|
||||||
|
// tell transport how to handle the unix url scheme
|
||||||
|
transport.RegisterProtocol("unix", network.UnixRoundTripper{Transport: transport})
|
||||||
|
}
|
||||||
|
|
||||||
|
rp := httputil.NewSingleHostReverseProxy(u)
|
||||||
|
rp.Transport = transport
|
||||||
|
|
||||||
|
return rp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (state *State) challengePage(w http.ResponseWriter, status int, challenge string, params map[string]any) error {
|
func (state *State) challengePage(w http.ResponseWriter, status int, challenge string, params map[string]any) error {
|
||||||
input := make(map[string]any)
|
input := make(map[string]any)
|
||||||
input["Random"] = cacheBust
|
input["Random"] = cacheBust
|
||||||
@@ -104,8 +137,17 @@ func (state *State) errorPage(w http.ResponseWriter, status int, err error) erro
|
|||||||
|
|
||||||
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
host := r.Host
|
||||||
|
|
||||||
|
backend, ok := state.Backends[host]
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
//TODO better matcher! combo ast?
|
//TODO better matcher! combo ast?
|
||||||
env := map[string]any{
|
env := map[string]any{
|
||||||
|
"host": host,
|
||||||
"method": r.Method,
|
"method": r.Method,
|
||||||
"remoteAddress": state.GetRequestAddress(r),
|
"remoteAddress": state.GetRequestAddress(r),
|
||||||
"userAgent": r.UserAgent(),
|
"userAgent": r.UserAgent(),
|
||||||
@@ -127,6 +169,10 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, rule := range state.Rules {
|
for _, rule := range state.Rules {
|
||||||
|
// skip rules that have host match
|
||||||
|
if rule.Host != nil && *rule.Host != host {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if out, _, err := rule.Program.Eval(env); err != nil {
|
if out, _, err := rule.Program.Eval(env); err != nil {
|
||||||
//TODO error
|
//TODO error
|
||||||
panic(err)
|
panic(err)
|
||||||
@@ -136,7 +182,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
default:
|
default:
|
||||||
panic(fmt.Errorf("unknown action %s", rule.Action))
|
panic(fmt.Errorf("unknown action %s", rule.Action))
|
||||||
case policy.RuleActionPASS:
|
case policy.RuleActionPASS:
|
||||||
state.Backend.ServeHTTP(w, r)
|
backend.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
|
case policy.RuleActionCHALLENGE, policy.RuleActionCHECK:
|
||||||
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
|
expiry := time.Now().UTC().Add(DefaultValidity).Round(DefaultValidity)
|
||||||
@@ -154,7 +200,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
// we passed the challenge!
|
// we passed the challenge!
|
||||||
//TODO log?
|
//TODO log?
|
||||||
state.Backend.ServeHTTP(w, r)
|
backend.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -174,7 +220,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
goto nextRule
|
goto nextRule
|
||||||
}
|
}
|
||||||
// we pass the challenge early!
|
// we pass the challenge early!
|
||||||
state.Backend.ServeHTTP(w, r)
|
backend.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -197,7 +243,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
nextRule:
|
nextRule:
|
||||||
}
|
}
|
||||||
|
|
||||||
state.Backend.ServeHTTP(w, r)
|
backend.ServeHTTP(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -43,4 +43,6 @@ type Policy struct {
|
|||||||
Challenges map[string]Challenge `yaml:"challenges"`
|
Challenges map[string]Challenge `yaml:"challenges"`
|
||||||
|
|
||||||
Rules []Rule `yaml:"rules"`
|
Rules []Rule `yaml:"rules"`
|
||||||
|
|
||||||
|
Backends map[string]string `json:"backends"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ const (
|
|||||||
|
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
Name string `yaml:"name"`
|
Name string `yaml:"name"`
|
||||||
|
Host *string `yaml:"host"`
|
||||||
Conditions []string `yaml:"conditions"`
|
Conditions []string `yaml:"conditions"`
|
||||||
|
|
||||||
Action string `yaml:"action"`
|
Action string `yaml:"action"`
|
||||||
|
|||||||
23
lib/state.go
23
lib/state.go
@@ -40,7 +40,7 @@ type State struct {
|
|||||||
Settings StateSettings
|
Settings StateSettings
|
||||||
UrlPath string
|
UrlPath string
|
||||||
Mux *http.ServeMux
|
Mux *http.ServeMux
|
||||||
Backend http.Handler
|
Backends map[string]http.Handler
|
||||||
|
|
||||||
Networks map[string]cidranger.Ranger
|
Networks map[string]cidranger.Ranger
|
||||||
|
|
||||||
@@ -61,6 +61,8 @@ type RuleState struct {
|
|||||||
Name string
|
Name string
|
||||||
Hash string
|
Hash string
|
||||||
|
|
||||||
|
Host *string
|
||||||
|
|
||||||
Program cel.Program
|
Program cel.Program
|
||||||
Action policy.RuleAction
|
Action policy.RuleAction
|
||||||
Challenges []string
|
Challenges []string
|
||||||
@@ -94,7 +96,6 @@ type ChallengeState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type StateSettings struct {
|
type StateSettings struct {
|
||||||
Backend http.Handler
|
|
||||||
PackagePath string
|
PackagePath string
|
||||||
ChallengeTemplate string
|
ChallengeTemplate string
|
||||||
}
|
}
|
||||||
@@ -108,7 +109,16 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
state.UrlPath = "/.well-known/." + state.Settings.PackagePath
|
state.UrlPath = "/.well-known/." + state.Settings.PackagePath
|
||||||
state.Backend = settings.Backend
|
|
||||||
|
state.Backends = make(map[string]http.Handler)
|
||||||
|
|
||||||
|
for k, v := range p.Backends {
|
||||||
|
backend, err := makeReverseProxy(v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err)
|
||||||
|
}
|
||||||
|
state.Backends[k] = backend
|
||||||
|
}
|
||||||
|
|
||||||
state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader)
|
state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -492,6 +502,7 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
|||||||
state.RulesEnv, err = cel.NewEnv(
|
state.RulesEnv, err = cel.NewEnv(
|
||||||
cel.DefaultUTCTimeZone(true),
|
cel.DefaultUTCTimeZone(true),
|
||||||
cel.Variable("remoteAddress", cel.BytesType),
|
cel.Variable("remoteAddress", cel.BytesType),
|
||||||
|
cel.Variable("host", cel.StringType),
|
||||||
cel.Variable("method", cel.StringType),
|
cel.Variable("method", cel.StringType),
|
||||||
cel.Variable("userAgent", cel.StringType),
|
cel.Variable("userAgent", cel.StringType),
|
||||||
cel.Variable("path", cel.StringType),
|
cel.Variable("path", cel.StringType),
|
||||||
@@ -565,12 +576,18 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
|||||||
for _, rule := range p.Rules {
|
for _, rule := range p.Rules {
|
||||||
hasher := sha256.New()
|
hasher := sha256.New()
|
||||||
hasher.Write([]byte(rule.Name))
|
hasher.Write([]byte(rule.Name))
|
||||||
|
hasher.Write([]byte{0})
|
||||||
|
if rule.Host != nil {
|
||||||
|
hasher.Write([]byte(*rule.Host))
|
||||||
|
}
|
||||||
|
hasher.Write([]byte{0})
|
||||||
hasher.Write(privateKeyFingerprint[:])
|
hasher.Write(privateKeyFingerprint[:])
|
||||||
sum := hasher.Sum(nil)
|
sum := hasher.Sum(nil)
|
||||||
|
|
||||||
r := RuleState{
|
r := RuleState{
|
||||||
Name: rule.Name,
|
Name: rule.Name,
|
||||||
Hash: hex.EncodeToString(sum[:8]),
|
Hash: hex.EncodeToString(sum[:8]),
|
||||||
|
Host: rule.Host,
|
||||||
Action: policy.RuleAction(strings.ToUpper(rule.Action)),
|
Action: policy.RuleAction(strings.ToUpper(rule.Action)),
|
||||||
Challenges: rule.Challenges,
|
Challenges: rule.Challenges,
|
||||||
}
|
}
|
||||||
|
|||||||
16
policy.yml
16
policy.yml
@@ -1,4 +1,8 @@
|
|||||||
|
|
||||||
|
# Define backends to use. Rules can be done generally, or only applying to specific hosts
|
||||||
|
backends:
|
||||||
|
git.gammaspectra.live: http://gitea:3000
|
||||||
|
|
||||||
# Define networks to be used later below
|
# Define networks to be used later below
|
||||||
networks:
|
networks:
|
||||||
# todo: support direct ASN lookups
|
# todo: support direct ASN lookups
|
||||||
@@ -218,6 +222,18 @@ conditions:
|
|||||||
# user activity tab
|
# user activity tab
|
||||||
- 'path.matches("^/[^/]") && "tab" in query && query.tab == "activity"'
|
- 'path.matches("^/[^/]") && "tab" in query && query.tab == "activity"'
|
||||||
|
|
||||||
|
# Rules and conditions are served this environment
|
||||||
|
# remoteAddress (net.IP) - Connecting client remote address from headers or properties
|
||||||
|
# host (string) - HTTP Host
|
||||||
|
# method (string) - HTTP Method/Verb
|
||||||
|
# userAgent (string) - HTTP User-Agent header
|
||||||
|
# path (string) - HTTP request Path
|
||||||
|
# query (map[string]string) - HTTP request Query arguments
|
||||||
|
# headers (map[string]string) - HTTP request headers
|
||||||
|
#
|
||||||
|
# Additionally these functions are available
|
||||||
|
# inNetwork(networkName string, address net.IP) bool
|
||||||
|
# inNetwork(networkCIDR string, address net.IP) bool
|
||||||
rules:
|
rules:
|
||||||
- name: undesired-networks
|
- name: undesired-networks
|
||||||
conditions:
|
conditions:
|
||||||
|
|||||||
Reference in New Issue
Block a user