384 lines
11 KiB
Go
384 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"git.gammaspectra.live/git/go-away/lib"
|
|
"git.gammaspectra.live/git/go-away/lib/policy"
|
|
"git.gammaspectra.live/git/go-away/lib/settings"
|
|
"git.gammaspectra.live/git/go-away/utils"
|
|
"github.com/goccy/go-yaml"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/pprof"
|
|
"os"
|
|
"os/signal"
|
|
"path"
|
|
"runtime"
|
|
"runtime/debug"
|
|
"strings"
|
|
"syscall"
|
|
)
|
|
|
|
var internalCmdName = "go-away"
|
|
var internalMainName = "go-away"
|
|
var internalMainVersion = "dev"
|
|
|
|
func init() {
|
|
buildInfo, ok := debug.ReadBuildInfo()
|
|
if !ok {
|
|
return
|
|
}
|
|
internalCmdName = buildInfo.Path
|
|
internalMainName = buildInfo.Main.Path
|
|
internalMainVersion = buildInfo.Main.Version
|
|
}
|
|
|
|
type MultiVar []string
|
|
|
|
func (v *MultiVar) String() string {
|
|
return fmt.Sprintf("%v", *v)
|
|
}
|
|
|
|
func (v *MultiVar) Set(value string) error {
|
|
*v = append(*v, value)
|
|
return nil
|
|
}
|
|
|
|
func fatal(err error) {
|
|
slog.Error(err.Error())
|
|
_, _ = fmt.Fprintln(os.Stderr, "================================================")
|
|
_, _ = fmt.Fprintln(os.Stderr, "Fatal error:")
|
|
_, _ = fmt.Fprintln(os.Stderr, err.Error())
|
|
os.Exit(1)
|
|
}
|
|
|
|
func main() {
|
|
|
|
opt := settings.DefaultSettings
|
|
|
|
flag.StringVar(&opt.Bind.Address, "bind", opt.Bind.Address, "network address to bind HTTP/HTTP(s) to")
|
|
flag.StringVar(&opt.Bind.Network, "bind-network", opt.Bind.Network, "network family to bind HTTP to, e.g. unix, tcp")
|
|
flag.BoolVar(&opt.Bind.Proxy, "bind-proxy", opt.Bind.Proxy, "use PROXY protocol in front of the listener")
|
|
flag.StringVar(&opt.Bind.SocketMode, "socket-mode", opt.Bind.SocketMode, "socket mode (permissions) for unix domain sockets.")
|
|
flag.StringVar(&opt.BindMetrics, "metrics-bind", opt.BindMetrics, "network address to bind metrics on")
|
|
flag.StringVar(&opt.BindDebug, "debug-bind", opt.BindDebug, "network address to bind debug on")
|
|
|
|
slogLevel := flag.String("slog-level", "WARN", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
|
|
flag.BoolVar(&opt.Bind.Passthrough, "passthrough", opt.Bind.Passthrough, "passthrough mode sends all requests to matching backends until state is loaded")
|
|
check := flag.Bool("check", false, "check configuration and policies, then exit")
|
|
flag.StringVar(&opt.Bind.TLSAcmeAutoCert, "acme-autocert", opt.Bind.TLSAcmeAutoCert, "enables HTTP(s) mode and uses the provided ACME server URL or available service (available: letsencrypt)")
|
|
|
|
clientIpHeader := flag.String("client-ip-header", "", "Client HTTP header to fetch their IP address from (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)")
|
|
backendIpHeader := flag.String("backend-ip-header", "", "Backend HTTP header to set the client IP address from, if empty defaults to leaving Client header alone (X-Real-Ip, X-Client-Ip, X-Forwarded-For, Cf-Connecting-Ip, etc.)")
|
|
|
|
cachePath := flag.String("cache", path.Join(os.TempDir(), "go_away_cache"), "path to temporary cache directory")
|
|
|
|
policyFile := flag.String("policy", "", "path to policy YAML file")
|
|
var policySnippets MultiVar
|
|
flag.Var(&policySnippets, "policy-snippets", "path to YAML snippets folder (can be specified multiple times)")
|
|
|
|
flag.StringVar(&opt.ChallengeTemplate, "challenge-template", opt.ChallengeTemplate, "name or path of the challenge template to use (anubis, forgejo)")
|
|
|
|
templateTheme := flag.String("challenge-template-theme", opt.ChallengeTemplateOverrides["Theme"], "override template theme to use (forgejo => [forgejo-auto, forgejo-dark, forgejo-light, gitea...])")
|
|
templateLogo := flag.String("challenge-template-logo", opt.ChallengeTemplateOverrides["Logo"], "override template logo to use")
|
|
|
|
basePath := flag.String("path", "/.well-known/."+internalCmdName, "base path where to expose go-away package onto, challenges will be served from here")
|
|
|
|
jwtPrivateKeySeed := flag.String("jwt-private-key-seed", "", "Seed for the jwt private key, or on JWT_PRIVATE_KEY_SEED env. One be generated by passing \"generate\" as a value, follows RFC 8032 private key definition. Defaults to random")
|
|
|
|
var backends MultiVar
|
|
flag.Var(&backends, "backend", "backend definition in the form of an.example.com=http://backend:1234 (can be specified multiple times)")
|
|
|
|
settingsFile := flag.String("config", "", "path to config override YAML file")
|
|
|
|
flag.Parse()
|
|
|
|
if *backendIpHeader == "" {
|
|
*backendIpHeader = *clientIpHeader
|
|
}
|
|
|
|
var err error
|
|
|
|
{
|
|
var programLevel slog.Level
|
|
if err = (&programLevel).UnmarshalText([]byte(*slogLevel)); err != nil {
|
|
_, _ = fmt.Fprintf(os.Stderr, "invalid log level %s: %v, using info\n", *slogLevel, err)
|
|
programLevel = slog.LevelInfo
|
|
}
|
|
|
|
leveler := &slog.LevelVar{}
|
|
leveler.Set(programLevel)
|
|
|
|
h := slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
|
AddSource: programLevel <= slog.LevelDebug,
|
|
Level: leveler,
|
|
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
|
if a.Key == "source" {
|
|
if src, ok := a.Value.Any().(*slog.Source); ok {
|
|
return slog.String(a.Key, fmt.Sprintf("%s:%d", src.File, src.Line))
|
|
}
|
|
}
|
|
return a
|
|
},
|
|
})
|
|
slog.SetDefault(slog.New(h))
|
|
// set default log logger to slog logger level
|
|
slog.SetLogLoggerLevel(programLevel)
|
|
}
|
|
|
|
slog.Info("go-away", "package", internalMainName, "version", internalMainVersion, "cmd", internalCmdName, "go", runtime.Version(), "os", runtime.GOOS, "arch", runtime.GOARCH)
|
|
|
|
// preload missing settings
|
|
opt.ChallengeTemplateOverrides["Theme"] = *templateTheme
|
|
opt.ChallengeTemplateOverrides["Logo"] = *templateLogo
|
|
|
|
// load overrides
|
|
if *settingsFile != "" {
|
|
settingsData, err := os.ReadFile(*settingsFile)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("could not read settings file: %w", err))
|
|
}
|
|
err = yaml.Unmarshal(settingsData, &opt)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("could not parse settings file: %w", err))
|
|
}
|
|
}
|
|
|
|
var seed []byte
|
|
|
|
var kValue string
|
|
if kValue = os.Getenv("JWT_PRIVATE_KEY_SEED"); kValue != "" {
|
|
|
|
} else if *jwtPrivateKeySeed != "" {
|
|
kValue = *jwtPrivateKeySeed
|
|
}
|
|
|
|
if kValue != "" {
|
|
if strings.ToLower(kValue) == "generate" {
|
|
_, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to generate private key: %w", err))
|
|
}
|
|
fmt.Printf("%x\n", priv.Seed())
|
|
os.Exit(0)
|
|
}
|
|
|
|
seed, err = hex.DecodeString(kValue)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to decode seed: %w", err))
|
|
}
|
|
|
|
if len(seed) != ed25519.SeedSize {
|
|
fatal(fmt.Errorf("invalid seed length: %d, expected %d", len(seed), ed25519.SeedSize))
|
|
}
|
|
|
|
}
|
|
|
|
createdBackends := make(map[string]http.Handler)
|
|
for _, backend := range backends {
|
|
if backend == "" {
|
|
// skip empty to allow no values
|
|
continue
|
|
}
|
|
parts := strings.Split(backend, "=")
|
|
if len(parts) != 2 {
|
|
fatal(fmt.Errorf("invalid backend definition: %s, expected 2 parts, got %v", backend, parts))
|
|
}
|
|
|
|
// make no-settings, default backend
|
|
opt.Backends[parts[0]] = settings.Backend{
|
|
URL: parts[1],
|
|
IpHeader: *backendIpHeader,
|
|
}
|
|
}
|
|
|
|
for k, v := range opt.Backends {
|
|
if v.IpHeader == "" {
|
|
//set default value
|
|
v.IpHeader = *backendIpHeader
|
|
}
|
|
|
|
backend, err := v.Create()
|
|
if err != nil {
|
|
fatal(fmt.Errorf("backend %s: failed to make reverse proxy: %w", k, err))
|
|
}
|
|
|
|
backend.ErrorLog = slog.NewLogLogger(slog.With("backend", k).Handler(), slog.LevelDebug)
|
|
createdBackends[k] = backend
|
|
}
|
|
|
|
if len(createdBackends) == 0 {
|
|
fatal(fmt.Errorf("no backends defined in cmdline or settings file"))
|
|
}
|
|
|
|
var cache utils.Cache
|
|
var acmeCache string
|
|
if *cachePath != "" {
|
|
err = os.MkdirAll(*cachePath, 0755)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to create cache directory: %w", err))
|
|
}
|
|
for _, n := range []string{"networks", "acme"} {
|
|
err = os.MkdirAll(path.Join(*cachePath, n), 0755)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to create cache sub directory %s: %w", n, err))
|
|
}
|
|
}
|
|
|
|
cache, err = utils.CacheDirectory(*cachePath)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to open cache directory: %w", err))
|
|
}
|
|
|
|
acmeCache = path.Join(*cachePath, "acme")
|
|
}
|
|
|
|
loadPolicyState := func() (*lib.State, error) {
|
|
policyData, err := os.ReadFile(*policyFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read policy file: %w", err)
|
|
}
|
|
|
|
p, err := policy.NewPolicy(bytes.NewReader(policyData), policySnippets...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse policy file: %w", err)
|
|
}
|
|
|
|
stateSettings := policy.StateSettings{
|
|
Cache: cache,
|
|
Backends: createdBackends,
|
|
MainName: internalMainName,
|
|
MainVersion: internalMainVersion,
|
|
BasePath: *basePath,
|
|
PrivateKeySeed: seed,
|
|
ClientIpHeader: *clientIpHeader,
|
|
BackendIpHeader: *backendIpHeader,
|
|
ChallengeResponseCode: http.StatusTeapot,
|
|
}
|
|
|
|
state, err := lib.NewState(*p, opt, stateSettings)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create state: %w", err)
|
|
}
|
|
return state, nil
|
|
}
|
|
|
|
if *check {
|
|
_, err := loadPolicyState()
|
|
if err != nil {
|
|
fatal(err)
|
|
}
|
|
slog.Info("load ok")
|
|
os.Exit(0)
|
|
}
|
|
|
|
listener, listenUrl := opt.Bind.Listener()
|
|
slog.Warn(
|
|
"listening",
|
|
"url", listenUrl,
|
|
)
|
|
|
|
server, swap, err := opt.Bind.Server(createdBackends, acmeCache)
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to create server: %w", err))
|
|
}
|
|
|
|
server.ErrorLog = slog.NewLogLogger(slog.With("server", "http").Handler(), slog.LevelDebug)
|
|
|
|
go func() {
|
|
handler, err := loadPolicyState()
|
|
if err != nil {
|
|
fatal(fmt.Errorf("failed to load policy state: %w", err))
|
|
}
|
|
|
|
swap(handler)
|
|
slog.Warn(
|
|
"handler configuration loaded",
|
|
"key_fingerprint", hex.EncodeToString(handler.PrivateKeyFingerprint()),
|
|
)
|
|
|
|
// allow reloading from now on
|
|
c := make(chan os.Signal, 1)
|
|
signal.Notify(c, syscall.SIGHUP)
|
|
for sig := range c {
|
|
if sig != syscall.SIGHUP {
|
|
continue
|
|
}
|
|
oldHandler := handler
|
|
handler, err = loadPolicyState()
|
|
if err != nil {
|
|
slog.Error("handler configuration reload error", "err", err)
|
|
continue
|
|
}
|
|
|
|
swap(handler)
|
|
slog.Warn("handler configuration reloaded")
|
|
if oldHandler != nil {
|
|
_ = oldHandler.Close()
|
|
}
|
|
}
|
|
}()
|
|
|
|
if opt.BindDebug != "" {
|
|
go func() {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
|
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
|
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
|
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
|
debugServer := http.Server{
|
|
Addr: opt.BindDebug,
|
|
Handler: mux,
|
|
ErrorLog: slog.NewLogLogger(slog.With("server", "debug").Handler(), slog.LevelDebug),
|
|
}
|
|
|
|
slog.Warn(
|
|
"listening debug",
|
|
"bind", opt.BindDebug,
|
|
)
|
|
if err = debugServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
|
fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
if opt.BindMetrics != "" {
|
|
go func() {
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/metrics", promhttp.Handler())
|
|
metricsServer := http.Server{
|
|
Addr: opt.BindMetrics,
|
|
Handler: mux,
|
|
ErrorLog: slog.NewLogLogger(slog.With("server", "metrics").Handler(), slog.LevelDebug),
|
|
}
|
|
|
|
slog.Warn(
|
|
"listening metrics",
|
|
"bind", opt.BindMetrics,
|
|
)
|
|
if err = metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
|
fatal(err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
if server.TLSConfig != nil {
|
|
if err := server.ServeTLS(listener, "", ""); !errors.Is(err, http.ErrServerClosed) {
|
|
fatal(err)
|
|
}
|
|
} else {
|
|
if err := server.Serve(listener); !errors.Is(err, http.ErrServerClosed) {
|
|
fatal(err)
|
|
}
|
|
}
|
|
|
|
}
|