Support atomically swapping http handler for passhtrough

This commit is contained in:
WeebDataHoarder
2025-04-23 17:28:44 +02:00
parent 3b11792594
commit 9719c0ff39
3 changed files with 74 additions and 106 deletions

View File

@@ -2,7 +2,6 @@ package main
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
@@ -25,8 +24,7 @@ import (
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
"sync/atomic"
)
func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) {
@@ -198,16 +196,6 @@ func main() {
}
policyData, err := os.ReadFile(*policyFile)
if err != nil {
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
}
p, err := policy.NewPolicy(bytes.NewReader(policyData), *policySnippets)
if err != nil {
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
}
createdBackends := make(map[string]http.Handler)
parsedBackends := make(map[string]string)
@@ -263,96 +251,75 @@ func main() {
tlsConfig = acmeManager.TLSConfig()
}
var wg sync.WaitGroup
passThroughCtx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if *passThrough {
wg.Add(1)
go func() {
defer wg.Done()
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backend, ok := createdBackends[r.Host]
if !ok {
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
return
}
backend.ServeHTTP(w, r)
}), tlsConfig)
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
slog.Warn(
"listening passthrough",
"url", listenUrl,
)
defer listener.Close()
wg.Add(1)
go func() {
defer wg.Done()
if tlsConfig != nil {
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
} else {
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
}
}()
<-passThroughCtx.Done()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatal(err)
}
_ = server.Close()
}()
}
settings := policy.Settings{
Backends: createdBackends,
Debug: *debugMode,
PackageName: *packageName,
ChallengeTemplate: *challengeTemplate,
ChallengeTemplateTheme: *challengeTemplateTheme,
PrivateKeySeed: seed,
ClientIpHeader: *clientIpHeader,
BackendIpHeader: *backendIpHeader,
ChallengeResponseCode: http.StatusTeapot,
}
state, err := lib.NewState(*p, settings)
if err != nil {
log.Fatal(fmt.Errorf("failed to create state: %w", err))
}
// cancel the existing server listener
cancelFunc()
wg.Wait()
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
slog.Warn(
"listening",
"url", listenUrl,
)
server := utils.NewServer(state, tlsConfig)
var serverHandler atomic.Pointer[http.Handler]
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler := serverHandler.Load(); handler == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
(*handler).ServeHTTP(w, r)
}
}), tlsConfig)
if *passThrough {
// setup a passthrough handler temporarily
fn := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
backend := utils.SelectHTTPHandler(createdBackends, r.Host)
if backend == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
} else {
backend.ServeHTTP(w, r)
}
}))
serverHandler.Store(&fn)
}
go func() {
policyData, err := os.ReadFile(*policyFile)
if err != nil {
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
}
p, err := policy.NewPolicy(bytes.NewReader(policyData), *policySnippets)
if err != nil {
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
}
settings := policy.Settings{
Backends: createdBackends,
Debug: *debugMode,
PackageName: *packageName,
ChallengeTemplate: *challengeTemplate,
ChallengeTemplateTheme: *challengeTemplateTheme,
PrivateKeySeed: seed,
ClientIpHeader: *clientIpHeader,
BackendIpHeader: *backendIpHeader,
ChallengeResponseCode: http.StatusTeapot,
}
state, err := lib.NewState(*p, settings)
if err != nil {
log.Fatal(fmt.Errorf("failed to create state: %w", err))
}
serverHandler.Store(&state)
slog.Warn(
"handler started",
)
}()
if tlsConfig != nil {
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
} else {
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}