Move backends to cmd args, allow setting private key seed via parameter or ENV var
This commit is contained in:
15
.drone.yml
15
.drone.yml
@@ -3,11 +3,15 @@ kind: pipeline
|
|||||||
type: docker
|
type: docker
|
||||||
name: build-go1.22
|
name: build-go1.22
|
||||||
|
|
||||||
|
environment:
|
||||||
|
CGO_ENABLED: "0"
|
||||||
|
GOOS: linux
|
||||||
|
GOARCH: amd64
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: build-go1.22-alpine3.20
|
- name: build-go1.22-alpine3.20
|
||||||
image: golang:1.22-alpine3.20
|
image: golang:1.22-alpine3.20
|
||||||
environment:
|
environment:
|
||||||
CGO_ENABLED: "0"
|
|
||||||
GOOS: linux
|
GOOS: linux
|
||||||
GOARCH: amd64
|
GOARCH: amd64
|
||||||
commands:
|
commands:
|
||||||
@@ -19,13 +23,14 @@ kind: pipeline
|
|||||||
type: docker
|
type: docker
|
||||||
name: build-go1.24
|
name: build-go1.24
|
||||||
|
|
||||||
steps:
|
environment:
|
||||||
- name: build-go1.24-alpine3.21
|
|
||||||
image: golang:1.24-alpine3.21
|
|
||||||
environment:
|
|
||||||
CGO_ENABLED: "0"
|
CGO_ENABLED: "0"
|
||||||
GOOS: linux
|
GOOS: linux
|
||||||
GOARCH: amd64
|
GOARCH: amd64
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: build-go1.24-alpine3.21
|
||||||
|
image: golang:1.24-alpine3.21
|
||||||
commands:
|
commands:
|
||||||
- apk update
|
- apk update
|
||||||
- apk add --no-cache git
|
- apk add --no-cache git
|
||||||
|
|||||||
@@ -1,18 +1,24 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"git.gammaspectra.live/git/go-away/lib"
|
"git.gammaspectra.live/git/go-away/lib"
|
||||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||||
|
"git.gammaspectra.live/git/go-away/utils"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupListener(network, address, socketMode string) (net.Listener, string) {
|
func setupListener(network, address, socketMode string) (net.Listener, string) {
|
||||||
@@ -49,6 +55,17 @@ func setupListener(network, address, socketMode string) (net.Listener, string) {
|
|||||||
return listener, formattedAddress
|
return listener, formattedAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MultiVar []string
|
||||||
|
|
||||||
|
func (v *MultiVar) String() string {
|
||||||
|
return fmt.Sprintf("%v", *v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *MultiVar) Set(value string) error {
|
||||||
|
*v = append(*v, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
bind := flag.String("bind", ":8080", "network address to bind HTTP to")
|
bind := flag.String("bind", ":8080", "network address to bind HTTP to")
|
||||||
bindNetwork := flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
|
bindNetwork := flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
|
||||||
@@ -63,11 +80,18 @@ func main() {
|
|||||||
|
|
||||||
packageName := flag.String("package-path", "git.gammaspectra.live/git/go-away/cmd/go-away", "package name to expose in .well-known url path")
|
packageName := flag.String("package-path", "git.gammaspectra.live/git/go-away/cmd/go-away", "package name to expose in .well-known url path")
|
||||||
|
|
||||||
|
jwtPrivateKeySeed := flag.String("jwt-private-key-seed", "", "Seed for the jwt private key, or on JWT_PRIVATE_KEY_SEED env. One be generated by passing \"generate\" as a value, follows RFC 8032 private key definition. Defaults to random")
|
||||||
|
|
||||||
|
var backends MultiVar
|
||||||
|
flag.Var(&backends, "backend", "backend definition in the form of an.example.com=http://backend:1234 (can be specified multiple times)")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
{
|
{
|
||||||
var programLevel slog.Level
|
var programLevel slog.Level
|
||||||
if err := (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil {
|
if err = (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil {
|
||||||
_, _ = fmt.Fprintf(os.Stderr, "invalid log level %s: %v, using info\n", *slogLevel, err)
|
_, _ = fmt.Fprintf(os.Stderr, "invalid log level %s: %v, using info\n", *slogLevel, err)
|
||||||
programLevel = slog.LevelInfo
|
programLevel = slog.LevelInfo
|
||||||
}
|
}
|
||||||
@@ -82,6 +106,36 @@ func main() {
|
|||||||
slog.SetDefault(slog.New(h))
|
slog.SetDefault(slog.New(h))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var seed []byte
|
||||||
|
|
||||||
|
var kValue string
|
||||||
|
if kValue = os.Getenv("JWT_PRIVATE_KEY_SEED"); kValue != "" {
|
||||||
|
|
||||||
|
} else if *jwtPrivateKeySeed != "" {
|
||||||
|
kValue = *jwtPrivateKeySeed
|
||||||
|
}
|
||||||
|
|
||||||
|
if kValue != "" {
|
||||||
|
if strings.ToLower(kValue) == "generate" {
|
||||||
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(fmt.Errorf("failed to generate private key: %w", err))
|
||||||
|
}
|
||||||
|
fmt.Printf("%x\n", priv.Seed())
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
seed, err = hex.DecodeString(kValue)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(fmt.Errorf("failed to decode seed: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(seed) != ed25519.SeedSize {
|
||||||
|
log.Fatal(fmt.Errorf("invalid seed length: %d, expected %d", len(seed), ed25519.SeedSize))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
policyData, err := os.ReadFile(*policyFile)
|
policyData, err := os.ReadFile(*policyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
||||||
@@ -93,11 +147,36 @@ func main() {
|
|||||||
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
createdBackends := make(map[string]http.Handler)
|
||||||
|
|
||||||
|
parsedBackends := make(map[string]string)
|
||||||
|
//TODO: deprecate
|
||||||
|
maps.Copy(parsedBackends, p.Backends)
|
||||||
|
for _, backend := range backends {
|
||||||
|
parts := strings.Split(backend, "=")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
log.Fatal(fmt.Errorf("invalid backend definition: %s", backend))
|
||||||
|
}
|
||||||
|
parsedBackends[parts[0]] = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range parsedBackends {
|
||||||
|
backend, err := utils.MakeReverseProxy(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
backend.ErrorLog = slog.NewLogLogger(slog.With("backend", k).Handler(), slog.LevelError)
|
||||||
|
createdBackends[k] = backend
|
||||||
|
}
|
||||||
|
|
||||||
state, err := lib.NewState(p, lib.StateSettings{
|
state, err := lib.NewState(p, lib.StateSettings{
|
||||||
|
Backends: createdBackends,
|
||||||
Debug: *debug,
|
Debug: *debug,
|
||||||
PackageName: *packageName,
|
PackageName: *packageName,
|
||||||
ChallengeTemplate: *challengeTemplate,
|
ChallengeTemplate: *challengeTemplate,
|
||||||
ChallengeTemplateTheme: *challengeTemplateTheme,
|
ChallengeTemplateTheme: *challengeTemplateTheme,
|
||||||
|
PrivateKeySeed: seed,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ type ChallengeInformation struct {
|
|||||||
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
IssuedAt *jwt.NumericDate `json:"iat,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *State) GetRequestAddress(r *http.Request) net.IP {
|
func getRequestAddress(r *http.Request) net.IP {
|
||||||
//TODO: verified upstream
|
//TODO: verified upstream
|
||||||
ipStr := r.Header.Get("X-Real-Ip")
|
ipStr := r.Header.Get("X-Real-Ip")
|
||||||
if ipStr == "" {
|
if ipStr == "" {
|
||||||
@@ -47,7 +47,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h
|
|||||||
hasher.Write([]byte("challenge\x00"))
|
hasher.Write([]byte("challenge\x00"))
|
||||||
hasher.Write([]byte(name))
|
hasher.Write([]byte(name))
|
||||||
hasher.Write([]byte{0})
|
hasher.Write([]byte{0})
|
||||||
hasher.Write(state.GetRequestAddress(r).To16())
|
hasher.Write(getRequestAddress(r).To16())
|
||||||
hasher.Write([]byte{0})
|
hasher.Write([]byte{0})
|
||||||
|
|
||||||
// specific headers
|
// specific headers
|
||||||
@@ -64,7 +64,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h
|
|||||||
hasher.Write([]byte{0})
|
hasher.Write([]byte{0})
|
||||||
_ = binary.Write(hasher, binary.LittleEndian, until.UTC().Unix())
|
_ = binary.Write(hasher, binary.LittleEndian, until.UTC().Unix())
|
||||||
hasher.Write([]byte{0})
|
hasher.Write([]byte{0})
|
||||||
hasher.Write(state.PublicKey)
|
hasher.Write(state.publicKey)
|
||||||
hasher.Write([]byte{0})
|
hasher.Write([]byte{0})
|
||||||
|
|
||||||
return hasher.Sum(nil)
|
return hasher.Sum(nil)
|
||||||
@@ -73,7 +73,7 @@ func (state *State) GetChallengeKeyForRequest(name string, until time.Time, r *h
|
|||||||
func (state *State) IssueChallengeToken(name string, key, result []byte, until time.Time) (token string, err error) {
|
func (state *State) IssueChallengeToken(name string, key, result []byte, until time.Time) (token string, err error) {
|
||||||
signer, err := jose.NewSigner(jose.SigningKey{
|
signer, err := jose.NewSigner(jose.SigningKey{
|
||||||
Algorithm: jose.EdDSA,
|
Algorithm: jose.EdDSA,
|
||||||
Key: state.PrivateKey,
|
Key: state.privateKey,
|
||||||
}, nil)
|
}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -135,7 +135,7 @@ func (state *State) VerifyChallengeToken(name string, expectedKey []byte, w http
|
|||||||
}
|
}
|
||||||
|
|
||||||
var i ChallengeInformation
|
var i ChallengeInformation
|
||||||
err = token.Claims(state.PublicKey, &i)
|
err = token.Claims(state.publicKey, &i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|||||||
51
lib/http.go
51
lib/http.go
@@ -3,7 +3,6 @@ package lib
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"codeberg.org/meta/gzipped/v2"
|
"codeberg.org/meta/gzipped/v2"
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
@@ -11,16 +10,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"git.gammaspectra.live/git/go-away/embed"
|
"git.gammaspectra.live/git/go-away/embed"
|
||||||
"git.gammaspectra.live/git/go-away/lib/policy"
|
"git.gammaspectra.live/git/go-away/lib/policy"
|
||||||
"git.gammaspectra.live/git/go-away/utils"
|
|
||||||
"github.com/google/cel-go/common/types"
|
"github.com/google/cel-go/common/types"
|
||||||
"html/template"
|
"html/template"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -72,34 +67,6 @@ func initTemplate(name, data string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeReverseProxy(target string) (*httputil.ReverseProxy, error) {
|
|
||||||
u, err := url.Parse(target)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse target URL: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
|
||||||
|
|
||||||
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
|
|
||||||
if u.Scheme == "unix" {
|
|
||||||
// clean path up so we don't use the socket path in proxied requests
|
|
||||||
addr := u.Path
|
|
||||||
u.Path = ""
|
|
||||||
// tell transport how to dial unix sockets
|
|
||||||
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
|
||||||
dialer := net.Dialer{}
|
|
||||||
return dialer.DialContext(ctx, "unix", addr)
|
|
||||||
}
|
|
||||||
// tell transport how to handle the unix url scheme
|
|
||||||
transport.RegisterProtocol("unix", utils.UnixRoundTripper{Transport: transport})
|
|
||||||
}
|
|
||||||
|
|
||||||
rp := httputil.NewSingleHostReverseProxy(u)
|
|
||||||
rp.Transport = transport
|
|
||||||
|
|
||||||
return rp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (state *State) challengePage(w http.ResponseWriter, id string, status int, challenge string, params map[string]any) error {
|
func (state *State) challengePage(w http.ResponseWriter, id string, status int, challenge string, params map[string]any) error {
|
||||||
input := make(map[string]any)
|
input := make(map[string]any)
|
||||||
input["Id"] = id
|
input["Id"] = id
|
||||||
@@ -158,10 +125,10 @@ func (state *State) addTiming(w http.ResponseWriter, name, desc string, duration
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (state *State) getLogger(r *http.Request) *slog.Logger {
|
func GetLoggerForRequest(r *http.Request) *slog.Logger {
|
||||||
return slog.With(
|
return slog.With(
|
||||||
"request_id", r.Header.Get("X-Away-Id"),
|
"request_id", r.Header.Get("X-Away-Id"),
|
||||||
"remote_address", state.GetRequestAddress(r),
|
"remote_address", getRequestAddress(r),
|
||||||
"user_agent", r.UserAgent(),
|
"user_agent", r.UserAgent(),
|
||||||
"host", r.Host,
|
"host", r.Host,
|
||||||
"path", r.URL.Path,
|
"path", r.URL.Path,
|
||||||
@@ -172,13 +139,13 @@ func (state *State) getLogger(r *http.Request) *slog.Logger {
|
|||||||
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
host := r.Host
|
host := r.Host
|
||||||
|
|
||||||
backend, ok := state.Backends[host]
|
backend, ok := state.Settings.Backends[host]
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
|
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
lg := state.getLogger(r)
|
lg := GetLoggerForRequest(r)
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
@@ -186,7 +153,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
env := map[string]any{
|
env := map[string]any{
|
||||||
"host": host,
|
"host": host,
|
||||||
"method": r.Method,
|
"method": r.Method,
|
||||||
"remoteAddress": state.GetRequestAddress(r),
|
"remoteAddress": getRequestAddress(r),
|
||||||
"userAgent": r.UserAgent(),
|
"userAgent": r.UserAgent(),
|
||||||
"path": r.URL.Path,
|
"path": r.URL.Path,
|
||||||
"query": func() map[string]string {
|
"query": func() map[string]string {
|
||||||
@@ -292,7 +259,7 @@ func (state *State) handleRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
if rule.Action == policy.RuleActionCHECK {
|
if rule.Action == policy.RuleActionCHECK {
|
||||||
goto nextRule
|
goto nextRule
|
||||||
}
|
}
|
||||||
state.getLogger(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
|
GetLoggerForRequest(r).Warn("challenge passed", "rule", rule.Name, "rule_hash", rule.Hash, "challenge", challengeName)
|
||||||
|
|
||||||
// we pass the challenge early!
|
// we pass the challenge early!
|
||||||
r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS")
|
r.Header.Set(fmt.Sprintf("X-Away-Challenge-%s-Verify", challengeName), "PASS")
|
||||||
@@ -407,15 +374,15 @@ func (state *State) setupRoutes() error {
|
|||||||
state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start))
|
state.addTiming(w, "challenge-verify", "Verify client challenge", time.Since(start))
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.getLogger(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
GetLoggerForRequest(r).Error(fmt.Errorf("challenge error: %w", err).Error(), "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||||
return err
|
return err
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
state.getLogger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||||
ClearCookie(CookiePrefix+challengeName, w)
|
ClearCookie(CookiePrefix+challengeName, w)
|
||||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
state.getLogger(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
GetLoggerForRequest(r).Info("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||||
|
|
||||||
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -44,5 +44,7 @@ type Policy struct {
|
|||||||
|
|
||||||
Rules []Rule `yaml:"rules"`
|
Rules []Rule `yaml:"rules"`
|
||||||
|
|
||||||
|
// Backends
|
||||||
|
// Deprecated
|
||||||
Backends map[string]string `json:"backends"`
|
Backends map[string]string `json:"backends"`
|
||||||
}
|
}
|
||||||
|
|||||||
47
lib/state.go
47
lib/state.go
@@ -29,6 +29,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
@@ -42,7 +43,6 @@ type State struct {
|
|||||||
Settings StateSettings
|
Settings StateSettings
|
||||||
UrlPath string
|
UrlPath string
|
||||||
Mux *http.ServeMux
|
Mux *http.ServeMux
|
||||||
Backends map[string]http.Handler
|
|
||||||
|
|
||||||
Networks map[string]cidranger.Ranger
|
Networks map[string]cidranger.Ranger
|
||||||
|
|
||||||
@@ -55,8 +55,8 @@ type State struct {
|
|||||||
|
|
||||||
Rules []RuleState
|
Rules []RuleState
|
||||||
|
|
||||||
PublicKey ed25519.PublicKey
|
publicKey ed25519.PublicKey
|
||||||
PrivateKey ed25519.PrivateKey
|
privateKey ed25519.PrivateKey
|
||||||
|
|
||||||
Poison map[string][]byte
|
Poison map[string][]byte
|
||||||
}
|
}
|
||||||
@@ -100,6 +100,8 @@ type ChallengeState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type StateSettings struct {
|
type StateSettings struct {
|
||||||
|
Backends map[string]http.Handler
|
||||||
|
PrivateKeySeed []byte
|
||||||
Debug bool
|
Debug bool
|
||||||
PackageName string
|
PackageName string
|
||||||
ChallengeTemplate string
|
ChallengeTemplate string
|
||||||
@@ -116,25 +118,36 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
|||||||
}
|
}
|
||||||
state.UrlPath = "/.well-known/." + state.Settings.PackageName
|
state.UrlPath = "/.well-known/." + state.Settings.PackageName
|
||||||
|
|
||||||
state.Backends = make(map[string]http.Handler)
|
// set a reasonable configuration for default http proxy if there is none
|
||||||
|
for _, backend := range state.Settings.Backends {
|
||||||
for k, v := range p.Backends {
|
if proxy, ok := backend.(*httputil.ReverseProxy); ok {
|
||||||
backend, err := makeReverseProxy(v)
|
if proxy.ErrorHandler == nil {
|
||||||
if err != nil {
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
return nil, fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err)
|
GetLoggerForRequest(r).Error(err.Error())
|
||||||
}
|
|
||||||
backend.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
||||||
state.getLogger(r).Error(fmt.Errorf("backend %s error: %w", k, err).Error())
|
|
||||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadGateway, err)
|
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusBadGateway, err)
|
||||||
}
|
}
|
||||||
state.Backends[k] = backend
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state.PublicKey, state.PrivateKey, err = ed25519.GenerateKey(rand.Reader)
|
if len(state.Settings.PrivateKeySeed) > 0 {
|
||||||
|
if len(state.Settings.PrivateKeySeed) != ed25519.SeedSize {
|
||||||
|
return nil, fmt.Errorf("invalid private key seed length: %d", len(state.Settings.PrivateKeySeed))
|
||||||
|
}
|
||||||
|
|
||||||
|
state.privateKey = ed25519.NewKeyFromSeed(state.Settings.PrivateKeySeed)
|
||||||
|
state.publicKey = state.privateKey.Public().(ed25519.PublicKey)
|
||||||
|
|
||||||
|
clear(state.Settings.PrivateKeySeed)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
state.publicKey, state.privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
privateKeyFingerprint := sha256.Sum256(state.PrivateKey)
|
}
|
||||||
|
|
||||||
|
privateKeyFingerprint := sha256.Sum256(state.privateKey)
|
||||||
|
|
||||||
if state.Settings.ChallengeTemplate == "" {
|
if state.Settings.ChallengeTemplate == "" {
|
||||||
state.Settings.ChallengeTemplate = "anubis"
|
state.Settings.ChallengeTemplate = "anubis"
|
||||||
@@ -423,12 +436,12 @@ func NewState(p policy.Policy, settings StateSettings) (state *State, err error)
|
|||||||
if ok, err := c.Verify(key, result); err != nil {
|
if ok, err := c.Verify(key, result); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if !ok {
|
} else if !ok {
|
||||||
state.getLogger(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
GetLoggerForRequest(r).Warn("challenge failed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||||
ClearCookie(CookiePrefix+challengeName, w)
|
ClearCookie(CookiePrefix+challengeName, w)
|
||||||
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
_ = state.errorPage(w, r.Header.Get("X-Away-Id"), http.StatusForbidden, fmt.Errorf("access denied: failed challenge %s", challengeName))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
state.getLogger(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
GetLoggerForRequest(r).Warn("challenge passed", "challenge", challengeName, "redirect", r.FormValue("redirect"))
|
||||||
|
|
||||||
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
token, err := state.IssueChallengeToken(challengeName, key, []byte(result), expiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
38
utils/http.go
Normal file
38
utils/http.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MakeReverseProxy(target string) (*httputil.ReverseProxy, error) {
|
||||||
|
u, err := url.Parse(target)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse target URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||||
|
|
||||||
|
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
|
||||||
|
if u.Scheme == "unix" {
|
||||||
|
// clean path up so we don't use the socket path in proxied requests
|
||||||
|
addr := u.Path
|
||||||
|
u.Path = ""
|
||||||
|
// tell transport how to dial unix sockets
|
||||||
|
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||||
|
dialer := net.Dialer{}
|
||||||
|
return dialer.DialContext(ctx, "unix", addr)
|
||||||
|
}
|
||||||
|
// tell transport how to handle the unix url scheme
|
||||||
|
transport.RegisterProtocol("unix", UnixRoundTripper{Transport: transport})
|
||||||
|
}
|
||||||
|
|
||||||
|
rp := httputil.NewSingleHostReverseProxy(u)
|
||||||
|
rp.Transport = transport
|
||||||
|
|
||||||
|
return rp, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user