Support atomically swapping http handler for passhtrough
This commit is contained in:
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
@@ -25,8 +24,7 @@ import (
|
|||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) {
|
func setupListener(network, address, socketMode string, proxy bool) (net.Listener, string) {
|
||||||
@@ -198,16 +196,6 @@ func main() {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
policyData, err := os.ReadFile(*policyFile)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
p, err := policy.NewPolicy(bytes.NewReader(policyData), *policySnippets)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
createdBackends := make(map[string]http.Handler)
|
createdBackends := make(map[string]http.Handler)
|
||||||
|
|
||||||
parsedBackends := make(map[string]string)
|
parsedBackends := make(map[string]string)
|
||||||
@@ -263,58 +251,43 @@ func main() {
|
|||||||
tlsConfig = acmeManager.TLSConfig()
|
tlsConfig = acmeManager.TLSConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
passThroughCtx, cancelFunc := context.WithCancel(context.Background())
|
|
||||||
defer cancelFunc()
|
|
||||||
|
|
||||||
if *passThrough {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
backend, ok := createdBackends[r.Host]
|
|
||||||
if !ok {
|
|
||||||
http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
backend.ServeHTTP(w, r)
|
|
||||||
}), tlsConfig)
|
|
||||||
|
|
||||||
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
|
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
|
||||||
slog.Warn(
|
slog.Warn(
|
||||||
"listening passthrough",
|
"listening",
|
||||||
"url", listenUrl,
|
"url", listenUrl,
|
||||||
)
|
)
|
||||||
defer listener.Close()
|
|
||||||
|
|
||||||
wg.Add(1)
|
var serverHandler atomic.Pointer[http.Handler]
|
||||||
go func() {
|
server := utils.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer wg.Done()
|
if handler := serverHandler.Load(); handler == nil {
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||||
if tlsConfig != nil {
|
|
||||||
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
|
(*handler).ServeHTTP(w, r)
|
||||||
log.Fatal(err)
|
|
||||||
}
|
}
|
||||||
|
}), tlsConfig)
|
||||||
|
|
||||||
|
if *passThrough {
|
||||||
|
// setup a passthrough handler temporarily
|
||||||
|
fn := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backend := utils.SelectHTTPHandler(createdBackends, r.Host)
|
||||||
|
if backend == nil {
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||||
|
} else {
|
||||||
|
backend.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
}()
|
}))
|
||||||
|
serverHandler.Store(&fn)
|
||||||
<-passThroughCtx.Done()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if err := server.Shutdown(ctx); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
}
|
||||||
_ = server.Close()
|
|
||||||
}()
|
go func() {
|
||||||
|
policyData, err := os.ReadFile(*policyFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(fmt.Errorf("failed to read policy file: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := policy.NewPolicy(bytes.NewReader(policyData), *policySnippets)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(fmt.Errorf("failed to parse policy file: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := policy.Settings{
|
settings := policy.Settings{
|
||||||
@@ -335,24 +308,18 @@ func main() {
|
|||||||
log.Fatal(fmt.Errorf("failed to create state: %w", err))
|
log.Fatal(fmt.Errorf("failed to create state: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// cancel the existing server listener
|
serverHandler.Store(&state)
|
||||||
cancelFunc()
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
listener, listenUrl := setupListener(*bindNetwork, *bind, *socketMode, *bindProxy)
|
|
||||||
slog.Warn(
|
slog.Warn(
|
||||||
"listening",
|
"handler started",
|
||||||
"url", listenUrl,
|
|
||||||
)
|
)
|
||||||
|
}()
|
||||||
server := utils.NewServer(state, tlsConfig)
|
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
|
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
|
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Defines challenge.StateInterface
|
// Defines challenge.StateInterface
|
||||||
@@ -142,17 +141,5 @@ func (state *State) Settings() policy.Settings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (state *State) GetBackend(host string) http.Handler {
|
func (state *State) GetBackend(host string) http.Handler {
|
||||||
backend, ok := state.Settings().Backends[host]
|
return utils.SelectHTTPHandler(state.Settings().Backends, host)
|
||||||
if !ok {
|
|
||||||
// do wildcard match
|
|
||||||
wildcard := "*." + strings.Join(strings.Split(host, ".")[1:], ".")
|
|
||||||
backend, ok = state.Settings().Backends[wildcard]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
// return fallback
|
|
||||||
backend = state.Settings().Backends["*"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//TODO: dynamic
|
|
||||||
return backend
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func NewServer(handler http.Handler, tlsConfig *tls.Config) *http.Server {
|
func NewServer(handler http.Handler, tlsConfig *tls.Config) *http.Server {
|
||||||
|
|
||||||
if tlsConfig == nil {
|
if tlsConfig == nil {
|
||||||
proto := new(http.Protocols)
|
proto := new(http.Protocols)
|
||||||
proto.SetHTTP1(true)
|
proto.SetHTTP1(true)
|
||||||
@@ -36,6 +35,21 @@ func NewServer(handler http.Handler, tlsConfig *tls.Config) *http.Server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SelectHTTPHandler(backends map[string]http.Handler, host string) http.Handler {
|
||||||
|
backend, ok := backends[host]
|
||||||
|
if !ok {
|
||||||
|
// do wildcard match
|
||||||
|
wildcard := "*." + strings.Join(strings.Split(host, ".")[1:], ".")
|
||||||
|
backend, ok = backends[wildcard]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
// return fallback
|
||||||
|
backend = backends["*"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return backend
|
||||||
|
}
|
||||||
|
|
||||||
func EnsureNoOpenRedirect(redirect string) (string, error) {
|
func EnsureNoOpenRedirect(redirect string) (string, error) {
|
||||||
uri, err := url.Parse(redirect)
|
uri, err := url.Parse(redirect)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user