diff --git a/.drone.yml b/.drone.yml index 05132ef..7dd8d29 100644 --- a/.drone.yml +++ b/.drone.yml @@ -3,11 +3,15 @@ kind: pipeline type: docker name: build-go1.22 +environment: + CGO_ENABLED: "0" + GOOS: linux + GOARCH: amd64 + steps: - name: build-go1.22-alpine3.20 image: golang:1.22-alpine3.20 environment: - CGO_ENABLED: "0" GOOS: linux GOARCH: amd64 commands: @@ -19,13 +23,14 @@ kind: pipeline type: docker name: build-go1.24 +environment: + CGO_ENABLED: "0" + GOOS: linux + GOARCH: amd64 + steps: - name: build-go1.24-alpine3.21 image: golang:1.24-alpine3.21 - environment: - CGO_ENABLED: "0" - GOOS: linux - GOARCH: amd64 commands: - apk update - apk add --no-cache git diff --git a/cmd/go-away/main.go b/cmd/go-away/main.go index 0dc4a75..d7cc5e1 100644 --- a/cmd/go-away/main.go +++ b/cmd/go-away/main.go @@ -1,18 +1,24 @@ package main import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" "errors" "flag" "fmt" "git.gammaspectra.live/git/go-away/lib" "git.gammaspectra.live/git/go-away/lib/policy" + "git.gammaspectra.live/git/go-away/utils" "gopkg.in/yaml.v3" "log" "log/slog" + "maps" "net" "net/http" "os" "strconv" + "strings" ) func setupListener(network, address, socketMode string) (net.Listener, string) { @@ -49,6 +55,17 @@ func setupListener(network, address, socketMode string) (net.Listener, string) { return listener, formattedAddress } +type MultiVar []string + +func (v *MultiVar) String() string { + return fmt.Sprintf("%v", *v) +} + +func (v *MultiVar) Set(value string) error { + *v = append(*v, value) + return nil +} + func main() { bind := flag.String("bind", ":8080", "network address to bind HTTP to") bindNetwork := flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp") @@ -63,11 +80,18 @@ func main() { packageName := flag.String("package-path", "git.gammaspectra.live/git/go-away/cmd/go-away", "package name to expose in .well-known url path") + jwtPrivateKeySeed := flag.String("jwt-private-key-seed", "", "Seed for the jwt private key, or on JWT_PRIVATE_KEY_SEED env. One be generated by passing \"generate\" as a value, follows RFC 8032 private key definition. Defaults to random") + + var backends MultiVar + flag.Var(&backends, "backend", "backend definition in the form of an.example.com=http://backend:1234 (can be specified multiple times)") + flag.Parse() + var err error + { var programLevel slog.Level - if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil { + if err = (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil { _, _ = fmt.Fprintf(os.Stderr, "invalid log level %s: %v, using info\n", *slogLevel, err) programLevel = slog.LevelInfo } @@ -82,6 +106,36 @@ func main() { slog.SetDefault(slog.New(h)) } + var seed []byte + + var kValue string + if kValue = os.Getenv("JWT_PRIVATE_KEY_SEED"); kValue != "" { + + } else if *jwtPrivateKeySeed != "" { + kValue = *jwtPrivateKeySeed + } + + if kValue != "" { + if strings.ToLower(kValue) == "generate" { + _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + log.Fatal(fmt.Errorf("failed to generate private key: %w", err)) + } + fmt.Printf("%x\n", priv.Seed()) + os.Exit(0) + } + + seed, err = hex.DecodeString(kValue) + if err != nil { + log.Fatal(fmt.Errorf("failed to decode seed: %w", err)) + } + + if len(seed) != ed25519.SeedSize { + log.Fatal(fmt.Errorf("invalid seed length: %d, expected %d", len(seed), ed25519.SeedSize)) + } + + } + policyData, err := os.ReadFile(*policyFile) if err != nil { log.Fatal(fmt.Errorf("failed to read policy file: %w", err)) @@ -93,11 +147,36 @@ func main() { log.Fatal(fmt.Errorf("failed to parse policy file: %w", err)) } + createdBackends := make(map[string]http.Handler) + + parsedBackends := make(map[string]string) + //TODO: deprecate + maps.Copy(parsedBackends, p.Backends) + for _, backend := range backends { + parts := strings.Split(backend, "=") + if len(parts) != 2 { + log.Fatal(fmt.Errorf("invalid backend definition: %s", backend)) + } + parsedBackends[parts[0]] = parts[1] + } + + for k, v := range parsedBackends { + backend, err := utils.MakeReverseProxy(v) + if err != nil { + log.Fatal(fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err)) + } + + backend.ErrorLog = slog.NewLogLogger(slog.With("backend", k).Handler(), slog.LevelError) + createdBackends[k] = backend + } + state, err := lib.NewState(p, lib.StateSettings{ + Backends: createdBackends, Debug: *debug, PackageName: *packageName, ChallengeTemplate: *challengeTemplate, ChallengeTemplateTheme: *challengeTemplateTheme, + PrivateKeySeed: seed, }) if err != nil { diff --git a/lib/challenge.go b/lib/challenge.go index ab1b5c6..d81292e 100644 --- a/lib/challenge.go +++ b/lib/challenge.go @@ -28,7 +28,7 @@ type ChallengeInformation struct { IssuedAt *jwt.NumericDate `json:"iat,omitempty"` } -func (state *State) GetRequestAddress(r *http.Request) net.IP { +func getRequestAddress(r *http.Request) net.IP { //TODO: verified upstream ipStr := r.Header.Get("X-Real-Ip") if ipStr == "" { @@ -47,7 +47,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h hasher.Write([]byte("challenge\x00")) hasher.Write([]byte(name)) hasher.Write([]byte{0}) - hasher.Write(state.GetRequestAddress(r).To16()) + hasher.Write(getRequestAddress(r).To16()) hasher.Write([]byte{0}) // specific headers @@ -64,7 +64,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h hasher.Write([]byte{0}) _ = binary.Write(hasher, binary.LittleEndian, until.UTC().Unix()) hasher.Write([]byte{0}) - hasher.Write(state.PublicKey) + hasher.Write(state.publicKey) hasher.Write([]byte{0}) return hasher.Sum(nil) @@ -73,7 +73,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h func (state *State) IssueChallengeToken(name string, key, result []byte, until time.Time) (token string, err error) { signer, err := jose.NewSigner(jose.SigningKey{ Algorithm: jose.EdDSA, - Key: state.PrivateKey, + Key: state.privateKey, }, nil) if err != nil { return "", err @@ -135,7 +135,7 @@ func (state *State) VerifyChallengeToken(name string, expectedKey []byte, w http } var i ChallengeInformation - err = token.Claims(state.PublicKey, &i) + err = token.Claims(state.publicKey, &i) if err != nil { return false, err } diff --git a/lib/http.go b/lib/http.go index 3ece22a..edc08e2 100644 --- a/lib/http.go +++ b/lib/http.go @@ -3,7 +3,6 @@ package lib import ( "bytes" "codeberg.org/meta/gzipped/v2" - "context" "crypto/rand" "encoding/base64" "encoding/hex" @@ -11,16 +10,12 @@ import ( "fmt" "git.gammaspectra.live/git/go-away/embed" "git.gammaspectra.live/git/go-away/lib/policy" - "git.gammaspectra.live/git/go-away/utils" "github.com/google/cel-go/common/types" "html/template" "io" "log/slog" "maps" - "net" "net/http" - "net/http/httputil" - "net/url" "path" "path/filepath" "strconv" @@ -72,34 +67,6 @@ func initTemplate(name, data string) error { return nil } -func makeReverseProxy(target string) (*httputil.ReverseProxy, error) { - u, err := url.Parse(target) - if err != nil { - return nil, fmt.Errorf("failed to parse target URL: %w", err) - } - - transport := http.DefaultTransport.(*http.Transport).Clone() - - // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124 - if u.Scheme == "unix" { - // clean path up so we don't use the socket path in proxied requests - addr := u.Path - u.Path = "" - // tell transport how to dial unix sockets - transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - dialer := net.Dialer{} - return dialer.DialContext(ctx, "unix", addr) - } - // tell transport how to handle the unix url scheme - transport.RegisterProtocol("unix", utils.UnixRoundTripper{Transport: transport}) - } - - rp := httputil.NewSingleHostReverseProxy(u) - rp.Transport = transport - - return rp, nil -} - func (state *State) challengePage(w http.ResponseWriter, id string, status int, challenge string, params map[string]any) error { input := make(map[string]any) input["Id"] = id @@ -158,10 +125,10 @@ func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration } } -func (state *State) getLogger(r *http.Request) *slog.Logger { +func GetLoggerForRequest(r *http.Request) *slog.Logger { return slog.With( "request_id", r.Header.Get("X-Away-Id"), - "remote_address", state.GetRequestAddress(r), + "remote_address", getRequestAddress(r), "user_agent", r.UserAgent(), "host", r.Host, "path", r.URL.Path, @@ -172,13 +139,13 @@ func (state *State) getLogger(r *http.Request) *slog.Logger { func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { host := r.Host - backend, ok := state.Backends[host] + backend, ok := state.Settings.Backends[host] if !ok { http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) return } - lg := state.getLogger(r) + lg := GetLoggerForRequest(r) start := time.Now() @@ -186,7 +153,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { env := map[string]any{ "host": host, "method": r.Method, - "remoteAddress": state.GetRequestAddress(r), + "remoteAddress": getRequestAddress(r), "userAgent": r.UserAgent(), "path": r.URL.Path, "query": func() map[string]string { @@ -292,7 +259,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) { if rule.Action == policy.RuleActionCHECK { goto nextRule } - state.getLogger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName) + GetLoggerForRequest(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName) // we pass the challenge early! r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS") @@ -407,15 +374,15 @@ func (state *State) setupRoutes() error { state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start)) if err != nil { - state.getLogger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect")) + GetLoggerForRequest(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect")) return err } else if !ok { - state.getLogger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect")) + GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect")) ClearCookie(CookiePrefix+challengeName, w) _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName)) return nil } - state.getLogger(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect")) + GetLoggerForRequest(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect")) token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry) if err != nil { diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 9a824b2..33eca0b 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -44,5 +44,7 @@ type Policy struct { Rules []Rule `yaml:"rules"` + // Backends + // Deprecated Backends map[string]string `json:"backends"` } diff --git a/lib/state.go b/lib/state.go index d66f500..b634a30 100644 --- a/lib/state.go +++ b/lib/state.go @@ -29,6 +29,7 @@ import ( "log/slog" "net" "net/http" + "net/http/httputil" "net/url" "os" "path" @@ -42,7 +43,6 @@ type State struct { Settings StateSettings UrlPath string Mux *http.ServeMux - Backends map[string]http.Handler Networks map[string]cidranger.Ranger @@ -55,8 +55,8 @@ type State struct { Rules []RuleState - PublicKey ed25519.PublicKey - PrivateKey ed25519.PrivateKey + publicKey ed25519.PublicKey + privateKey ed25519.PrivateKey Poison map[string][]byte } @@ -100,6 +100,8 @@ type ChallengeState struct { } type StateSettings struct { + Backends map[string]http.Handler + PrivateKeySeed []byte Debug bool PackageName string ChallengeTemplate string @@ -116,25 +118,36 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) } state.UrlPath = "/.well-known/." + state.Settings.PackageName - state.Backends = make(map[string]http.Handler) + // set a reasonable configuration for default http proxy if there is none + for _, backend := range state.Settings.Backends { + if proxy, ok := backend.(*httputil.ReverseProxy); ok { + if proxy.ErrorHandler == nil { + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + GetLoggerForRequest(r).Error(err.Error()) + _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadGateway, err) + } + } + } + } - for k, v := range p.Backends { - backend, err := makeReverseProxy(v) + if len(state.Settings.PrivateKeySeed) > 0 { + if len(state.Settings.PrivateKeySeed) != ed25519.SeedSize { + return nil, fmt.Errorf("invalid private key seed length: %d", len(state.Settings.PrivateKeySeed)) + } + + state.privateKey = ed25519.NewKeyFromSeed(state.Settings.PrivateKeySeed) + state.publicKey = state.privateKey.Public().(ed25519.PublicKey) + + clear(state.Settings.PrivateKeySeed) + + } else { + state.publicKey, state.privateKey, err = ed25519.GenerateKey(rand.Reader) if err != nil { - return nil, fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err) + return nil, err } - backend.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - state.getLogger(r).Error(fmt.Errorf("backend %s error: %w", k, err).Error()) - _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadGateway, err) - } - state.Backends[k] = backend } - state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader) - if err != nil { - return nil, err - } - privateKeyFingerprint := sha256.Sum256(state.PrivateKey) + privateKeyFingerprint := sha256.Sum256(state.privateKey) if state.Settings.ChallengeTemplate == "" { state.Settings.ChallengeTemplate = "anubis" @@ -423,12 +436,12 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error) if ok, err := c.Verify(key, result); err != nil { return err } else if !ok { - state.getLogger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect")) + GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect")) ClearCookie(CookiePrefix+challengeName, w) _ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName)) return nil } - state.getLogger(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect")) + GetLoggerForRequest(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect")) token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry) if err != nil { diff --git a/utils/http.go b/utils/http.go new file mode 100644 index 0000000..2a2e5b4 --- /dev/null +++ b/utils/http.go @@ -0,0 +1,38 @@ +package utils + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httputil" + "net/url" +) + +func MakeReverseProxy(target string) (*httputil.ReverseProxy, error) { + u, err := url.Parse(target) + if err != nil { + return nil, fmt.Errorf("failed to parse target URL: %w", err) + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + + // https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124 + if u.Scheme == "unix" { + // clean path up so we don't use the socket path in proxied requests + addr := u.Path + u.Path = "" + // tell transport how to dial unix sockets + transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + dialer := net.Dialer{} + return dialer.DialContext(ctx, "unix", addr) + } + // tell transport how to handle the unix url scheme + transport.RegisterProtocol("unix", UnixRoundTripper{Transport: transport}) + } + + rp := httputil.NewSingleHostReverseProxy(u) + rp.Transport = transport + + return rp, nil +}