Support atomically swapping http handler for passhtrough

This commit is contained in:
WeebDataHoarder
2025-04-23 17:28:44 +02:00
parent 3b11792594
commit 9719c0ff39
3 changed files with 74 additions and 106 deletions

View File

@@ -2,7 +2,6 @@ package main
import ( import (
"bytes" "bytes"
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
@@ -25,8 +24,7 @@ import (
"runtime/debug" "runtime/debug"
"strconv" "strconv"
"strings" "strings"
"sync" "sync/atomic"
"time"
) )
func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) { func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) {
@@ -198,16 +196,6 @@ func main() {
} }
policyData, err := os.ReadFile(*policyFile)
if err != nil {
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
}
p, err := policy.NewPolicy(bytes.NewReader(policyData), *policySnippets)
if err != nil {
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
}
createdBackends := make(map[string]http.Handler) createdBackends := make(map[string]http.Handler)
parsedBackends := make(map[string]string) parsedBackends := make(map[string]string)
@@ -263,58 +251,43 @@ func main() {
tlsConfig = acmeManager.TLSConfig() tlsConfig = acmeManager.TLSConfig()
} }
var wg sync.WaitGroup
passThroughCtx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if *passThrough {
wg.Add(1)
go func() {
defer wg.Done()
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backend, ok := createdBackends[r.Host]
if !ok {
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
return
}
backend.ServeHTTP(w, r)
}), tlsConfig)
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy) listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
slog.Warn( slog.Warn(
"listening passthrough", "listening",
"url", listenUrl, "url", listenUrl,
) )
defer listener.Close()
wg.Add(1) var serverHandler atomic.Pointer[http.Handler]
go func() { server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done() if handler := serverHandler.Load(); handler == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
if tlsConfig != nil {
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
} else { } else {
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) { (*handler).ServeHTTP(w, r)
log.Fatal(err)
} }
}), tlsConfig)
if *passThrough {
// setup a passthrough handler temporarily
fn := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backend := utils.SelectHTTPHandler(createdBackends, r.Host)
if backend == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
backend.ServeHTTP(w, r)
} }
}() }))
serverHandler.Store(&fn)
<-passThroughCtx.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatal(err)
} }
_ = server.Close()
}() go func() {
policyData, err := os.ReadFile(*policyFile)
if err != nil {
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
}
p, err := policy.NewPolicy(bytes.NewReader(policyData), *policySnippets)
if err != nil {
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
} }
settings := policy.Settings{ settings := policy.Settings{
@@ -335,24 +308,18 @@ func main() {
log.Fatal(fmt.Errorf("failed to create state: %w", err)) log.Fatal(fmt.Errorf("failed to create state: %w", err))
} }
// cancel the existing server listener serverHandler.Store(&state)
cancelFunc()
wg.Wait()
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
slog.Warn( slog.Warn(
"listening", "handler started",
"url", listenUrl,
) )
}()
server := utils.NewServer(state, tlsConfig)
if tlsConfig != nil { if tlsConfig != nil {
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) { if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err) log.Fatal(err)
} }
} else { } else {
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) { if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err) log.Fatal(err)
} }

View File

@@ -10,7 +10,6 @@ import (
"log/slog" "log/slog"
"maps" "maps"
"net/http" "net/http"
"strings"
) )
// Defines challenge.StateInterface // Defines challenge.StateInterface
@@ -142,17 +141,5 @@ func (state *State) Settings() policy.Settings {
} }
func (state *State) GetBackend(host string) http.Handler { func (state *State) GetBackend(host string) http.Handler {
backend, ok := state.Settings().Backends[host] return utils.SelectHTTPHandler(state.Settings().Backends, host)
if !ok {
// do wildcard match
wildcard := "*." + strings.Join(strings.Split(host, ".")[1:], ".")
backend, ok = state.Settings().Backends[wildcard]
if !ok {
// return fallback
backend = state.Settings().Backends["*"]
}
}
//TODO: dynamic
return backend
} }

View File

@@ -15,7 +15,6 @@ import (
) )
func NewServer(handler http.Handler, tlsConfig *tls.Config) *http.Server { func NewServer(handler http.Handler, tlsConfig *tls.Config) *http.Server {
if tlsConfig == nil { if tlsConfig == nil {
proto := new(http.Protocols) proto := new(http.Protocols)
proto.SetHTTP1(true) proto.SetHTTP1(true)
@@ -36,6 +35,21 @@ func NewServer(handler http.Handler, tlsConfig *tls.Config) *http.Server {
} }
} }
func SelectHTTPHandler(backends map[string]http.Handler, host string) http.Handler {
backend, ok := backends[host]
if !ok {
// do wildcard match
wildcard := "*." + strings.Join(strings.Split(host, ".")[1:], ".")
backend, ok = backends[wildcard]
if !ok {
// return fallback
backend = backends["*"]
}
}
return backend
}
func EnsureNoOpenRedirect(redirect string) (string, error) { func EnsureNoOpenRedirect(redirect string) (string, error) {
uri, err := url.Parse(redirect) uri, err := url.Parse(redirect)
if err != nil { if err != nil {