diff --git a/internal/client/signer/local.go b/internal/client/signer/local.go new file mode 100644 index 0000000..39f124d --- /dev/null +++ b/internal/client/signer/local.go @@ -0,0 +1,34 @@ +package signer + +import ( + "context" + "io" + "strings" +) + +type Signer interface { + Sign(data io.Reader) ([]byte, error) + GetPublicKey(format string) ([]byte, error) +} + +type LocalSigner struct { + Signer +} + +func (s *LocalSigner) Sign(ctx context.Context, data string) (string, error) { + signed, err := s.Signer.Sign(strings.NewReader(data)) + if err != nil { + return "", err + } + + return string(signed), nil +} + +func (s *LocalSigner) GetPublicKey(ctx context.Context, format string) (string, error) { + publicKey, err := s.Signer.GetPublicKey(format) + if err != nil { + return "", err + } + + return string(publicKey), nil +} diff --git a/internal/client/signer/local_test.go b/internal/client/signer/local_test.go new file mode 100644 index 0000000..52a80be --- /dev/null +++ b/internal/client/signer/local_test.go @@ -0,0 +1,104 @@ +package signer + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +type SignerMock struct { + mock.Mock +} + +func (m *SignerMock) Sign(data io.Reader) ([]byte, error) { + args := m.Called(data) + var result []byte + if casted, ok := args.Get(0).([]byte); ok { + result = casted + } + + return result, args.Error(1) +} + +func (m *SignerMock) GetPublicKey(format string) ([]byte, error) { + args := m.Called(format) + var result []byte + if casted, ok := args.Get(0).([]byte); ok { + result = casted + } + + return result, args.Error(1) +} + +type LocalSignerServiceTestSuite struct { + suite.Suite + + Service *LocalSigner + + Signer *SignerMock +} + +func (t *LocalSignerServiceTestSuite) SetupSubTest() { + t.Signer = &SignerMock{} + + t.Service = &LocalSigner{ + Signer: t.Signer, + } +} + +func (t *LocalSignerServiceTestSuite) TearDownSubTest() { + t.Signer.AssertExpectations(t.T()) +} + +func (t *LocalSignerServiceTestSuite) TestSign() { + t.Run("successfully sign", func() { + signature := []byte("mock signature") + t.Signer.On("Sign", mock.Anything).Return(signature, nil).Run(func(args mock.Arguments) { + r, _ := io.ReadAll(args.Get(0).(io.Reader)) + t.Equal([]byte("mock body to sign"), r) + }) + + result, err := t.Service.Sign(context.Background(), "mock body to sign") + t.NoError(err) + t.Equal(string(signature), result) + }) + + t.Run("handle error during sign", func() { + expectedErr := errors.New("mock error") + t.Signer.On("Sign", mock.Anything).Return(nil, expectedErr) + + result, err := t.Service.Sign(context.Background(), "mock body to sign") + t.Error(err) + t.Same(expectedErr, err) + t.Empty(result) + }) +} + +func (t *LocalSignerServiceTestSuite) TestGetPublicKey() { + t.Run("successfully get", func() { + publicKey := []byte("mock public key") + t.Signer.On("GetPublicKey", "pem").Return(publicKey, nil) + + result, err := t.Service.GetPublicKey(context.Background(), "pem") + t.NoError(err) + t.Equal(string(publicKey), result) + }) + + t.Run("handle error", func() { + expectedErr := errors.New("mock error") + t.Signer.On("GetPublicKey", "pem").Return(nil, expectedErr) + + result, err := t.Service.GetPublicKey(context.Background(), "pem") + t.Error(err) + t.Same(expectedErr, err) + t.Empty(result) + }) +} + +func TestLocalSignerService(t *testing.T) { + suite.Run(t, new(LocalSignerServiceTestSuite)) +} diff --git a/internal/cmd/serve.go b/internal/cmd/serve.go index 6a7fac3..8eda6d1 100644 --- a/internal/cmd/serve.go +++ b/internal/cmd/serve.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + "ely.by/chrly/internal/di" "ely.by/chrly/internal/http" "ely.by/chrly/internal/otel" ) @@ -15,7 +16,7 @@ var serveCmd = &cobra.Command{ Use: "serve", Short: "Starts HTTP handler for the skins system", RunE: func(cmd *cobra.Command, args []string) error { - return startServer("skinsystem", "api") + return startServer(di.ModuleSkinsystem, di.ModuleProfiles, di.ModuleSigner) }, } diff --git a/internal/cmd/token.go b/internal/cmd/token.go index df20ef3..5d8d81c 100644 --- a/internal/cmd/token.go +++ b/internal/cmd/token.go @@ -9,8 +9,10 @@ import ( ) var tokenCmd = &cobra.Command{ - Use: "token", - Short: "Creates a new token, which allows to interact with Chrly API", + Use: "token scope1 ...", + Example: "token profiles sign", + Short: "Creates a new token, which allows to interact with Chrly API", + ValidArgs: []string{string(security.ProfilesScope), string(security.SignScope)}, RunE: func(cmd *cobra.Command, args []string) error { container := shouldGetContainer() var auth *security.Jwt @@ -19,7 +21,12 @@ var tokenCmd = &cobra.Command{ return err } - token, err := auth.NewToken(security.ProfileScope) + scopes := make([]security.Scope, len(args)) + for i := range args { + scopes[i] = security.Scope(args[i]) + } + + token, err := auth.NewToken(scopes...) if err != nil { return fmt.Errorf("Unable to create a new token. The error is %v\n", err) } diff --git a/internal/di/config.go b/internal/di/config.go index 93664f3..50d89e3 100644 --- a/internal/di/config.go +++ b/internal/di/config.go @@ -6,9 +6,5 @@ import ( ) var configDiOptions = di.Options( - di.Provide(newConfig), + di.Provide(viper.GetViper), ) - -func newConfig() *viper.Viper { - return viper.GetViper() -} diff --git a/internal/di/handlers.go b/internal/di/handlers.go index 5a643e4..39dc9ff 100644 --- a/internal/di/handlers.go +++ b/internal/di/handlers.go @@ -12,12 +12,18 @@ import ( "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux" . "ely.by/chrly/internal/http" + "ely.by/chrly/internal/security" ) +const ModuleSkinsystem = "skinsystem" +const ModuleProfiles = "profiles" +const ModuleSigner = "signer" + var handlersDiOptions = di.Options( di.Provide(newHandlerFactory, di.As(new(http.Handler))), - di.Provide(newSkinsystemHandler, di.WithName("skinsystem")), - di.Provide(newApiHandler, di.WithName("api")), + di.Provide(newSkinsystemHandler, di.WithName(ModuleSkinsystem)), + di.Provide(newProfilesApiHandler, di.WithName(ModuleProfiles)), + di.Provide(newSignerApiHandler, di.WithName(ModuleSigner)), ) func newHandlerFactory( @@ -31,8 +37,8 @@ func newHandlerFactory( // if you set an empty prefix. Since the main application should be mounted at the root prefix, // we use it as the base router var router *mux.Router - if slices.Contains(enabledModules, "skinsystem") { - if err := container.Resolve(&router, di.Name("skinsystem")); err != nil { + if slices.Contains(enabledModules, ModuleSkinsystem) { + if err := container.Resolve(&router, di.Name(ModuleSkinsystem)); err != nil { return nil, err } } else { @@ -43,9 +49,9 @@ func newHandlerFactory( router.Use(otelmux.Middleware("chrly")) router.NotFoundHandler = http.HandlerFunc(NotFoundHandler) - if slices.Contains(enabledModules, "api") { - var apiRouter *mux.Router - if err := container.Resolve(&apiRouter, di.Name("api")); err != nil { + if slices.Contains(enabledModules, ModuleProfiles) { + var profilesApiRouter *mux.Router + if err := container.Resolve(&profilesApiRouter, di.Name(ModuleProfiles)); err != nil { return nil, err } @@ -54,9 +60,29 @@ func newHandlerFactory( return nil, err } - apiRouter.Use(CreateAuthenticationMiddleware(authenticator)) + profilesApiRouter.Use(NewAuthenticationMiddleware(authenticator, security.ProfilesScope)) - mount(router, "/api", apiRouter) + mount(router, "/api/profiles", profilesApiRouter) + } + + if slices.Contains(enabledModules, ModuleSigner) { + var signerApiRouter *mux.Router + if err := container.Resolve(&signerApiRouter, di.Name(ModuleSigner)); err != nil { + return nil, err + } + + var authenticator Authenticator + if err := container.Resolve(&authenticator); err != nil { + return nil, err + } + + authMiddleware := NewAuthenticationMiddleware(authenticator, security.SignScope) + conditionalAuth := NewConditionalMiddleware(func(req *http.Request) bool { + return req.Method != "GET" + }, authMiddleware) + signerApiRouter.Use(conditionalAuth) + + mount(router, "/api/signer", signerApiRouter) } // Resolve health checkers last, because all the services required by the application @@ -81,25 +107,31 @@ func newHandlerFactory( func newSkinsystemHandler( config *viper.Viper, profilesProvider ProfilesProvider, - texturesSigner TexturesSigner, + texturesSigner SignerService, ) *mux.Router { config.SetDefault("textures.extra_param_name", "chrly") config.SetDefault("textures.extra_param_value", "how do you tame a horse in Minecraft?") return (&Skinsystem{ ProfilesProvider: profilesProvider, - TexturesSigner: texturesSigner, + SignerService: texturesSigner, TexturesExtraParamName: config.GetString("textures.extra_param_name"), TexturesExtraParamValue: config.GetString("textures.extra_param_value"), }).Handler() } -func newApiHandler(profilesManager ProfilesManager) *mux.Router { - return (&Api{ +func newProfilesApiHandler(profilesManager ProfilesManager) *mux.Router { + return (&ProfilesApi{ ProfilesManager: profilesManager, }).Handler() } +func newSignerApiHandler(signer Signer) *mux.Router { + return (&SignerApi{ + Signer: signer, + }).Handler() +} + func mount(router *mux.Router, path string, handler http.Handler) { router.PathPrefix(path).Handler( http.StripPrefix( diff --git a/internal/di/security.go b/internal/di/security.go index bec4856..8407295 100644 --- a/internal/di/security.go +++ b/internal/di/security.go @@ -4,10 +4,11 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "encoding/base64" "encoding/pem" - "strings" + "errors" + "log/slog" + "ely.by/chrly/internal/client/signer" "ely.by/chrly/internal/http" "ely.by/chrly/internal/security" @@ -16,41 +17,43 @@ import ( ) var securityDiOptions = di.Options( - di.Provide(newTexturesSigner, - di.As(new(http.TexturesSigner)), + di.Provide(newSigner, + di.As(new(http.Signer)), + di.As(new(signer.Signer)), ), + di.Provide(newSignerService), ) -func newTexturesSigner(config *viper.Viper) (*security.Signer, error) { +func newSigner(config *viper.Viper) (*security.Signer, error) { + var privateKey *rsa.PrivateKey + var err error + keyStr := config.GetString("chrly.signing.key") if keyStr == "" { - // TODO: log a message about the generated signing key and the way to specify it permanently - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, err } - return security.NewSigner(privateKey), nil - } - - var keyBytes []byte - if strings.HasPrefix(keyStr, "base64:") { - base64Value := keyStr[7:] - decodedKey, err := base64.URLEncoding.DecodeString(base64Value) - if err != nil { - return nil, err - } - - keyBytes = decodedKey + slog.Warn("A private signing key has been generated. To make it permanent, specify the valid RSA private key in the config parameter chrly.signing.key") } else { - keyBytes = []byte(keyStr) - } + keyBytes := []byte(keyStr) + rawPem, _ := pem.Decode(keyBytes) + if rawPem == nil { + return nil, errors.New("unable to decode pem key") + } - rawPem, _ := pem.Decode(keyBytes) - privateKey, err := x509.ParsePKCS1PrivateKey(rawPem.Bytes) - if err != nil { - return nil, err + privateKey, err = x509.ParsePKCS1PrivateKey(rawPem.Bytes) + if err != nil { + return nil, err + } } return security.NewSigner(privateKey), nil } + +func newSignerService(s signer.Signer) http.SignerService { + return &signer.LocalSigner{ + Signer: s, + } +} diff --git a/internal/http/http.go b/internal/http/http.go index 3d9477f..53d546f 100644 --- a/internal/http/http.go +++ b/internal/http/http.go @@ -3,7 +3,6 @@ package http import ( "context" "encoding/json" - "log/slog" "net/http" "time" @@ -11,12 +10,10 @@ import ( "github.com/mono83/slf" "github.com/mono83/slf/wd" - "ely.by/chrly/internal/version" + "ely.by/chrly/internal/security" ) func StartServer(ctx context.Context, server *http.Server, logger slf.Logger) { - slog.Debug("Chrly :v (:c)", slog.String("v", version.Version()), slog.String("c", version.Commit())) - srvErr := make(chan error, 1) go func() { logger.Info("Starting the server, HTTP on: :addr", wd.StringParam("addr", server.Addr)) @@ -40,15 +37,13 @@ func StartServer(ctx context.Context, server *http.Server, logger slf.Logger) { } type Authenticator interface { - Authenticate(req *http.Request) error + Authenticate(req *http.Request, scope security.Scope) error } -// The current middleware implementation doesn't check the scope assigned to the token. -// For now there is only one scope and at this moment I don't want to spend time on it -func CreateAuthenticationMiddleware(checker Authenticator) mux.MiddlewareFunc { +func NewAuthenticationMiddleware(authenticator Authenticator, scope security.Scope) mux.MiddlewareFunc { return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - err := checker.Authenticate(req) + err := authenticator.Authenticate(req, scope) if err != nil { apiForbidden(resp, err.Error()) return @@ -59,6 +54,18 @@ func CreateAuthenticationMiddleware(checker Authenticator) mux.MiddlewareFunc { } } +func NewConditionalMiddleware(cond func(req *http.Request) bool, m mux.MiddlewareFunc) mux.MiddlewareFunc { + return func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + if cond(req) { + handler = m.Middleware(handler) + } + + handler.ServeHTTP(resp, req) + }) + } +} + func NotFoundHandler(response http.ResponseWriter, _ *http.Request) { data, _ := json.Marshal(map[string]string{ "status": "404", @@ -73,7 +80,7 @@ func NotFoundHandler(response http.ResponseWriter, _ *http.Request) { func apiBadRequest(resp http.ResponseWriter, errorsPerField map[string][]string) { resp.WriteHeader(http.StatusBadRequest) resp.Header().Set("Content-Type", "application/json") - result, _ := json.Marshal(map[string]interface{}{ + result, _ := json.Marshal(map[string]any{ "errors": errorsPerField, }) _, _ = resp.Write(result) @@ -90,7 +97,7 @@ func apiServerError(resp http.ResponseWriter, err error) { func apiForbidden(resp http.ResponseWriter, reason string) { resp.WriteHeader(http.StatusForbidden) resp.Header().Set("Content-Type", "application/json") - result, _ := json.Marshal(map[string]interface{}{ + result, _ := json.Marshal(map[string]any{ "error": reason, }) _, _ = resp.Write(result) diff --git a/internal/http/http_test.go b/internal/http/http_test.go index 623ab2b..1711cdb 100644 --- a/internal/http/http_test.go +++ b/internal/http/http_test.go @@ -2,34 +2,35 @@ package http import ( "errors" - "io/ioutil" + "io" "net/http" "net/http/httptest" "testing" - testify "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + testify "github.com/stretchr/testify/require" + + "ely.by/chrly/internal/security" ) type authCheckerMock struct { mock.Mock } -func (m *authCheckerMock) Authenticate(req *http.Request) error { - args := m.Called(req) - return args.Error(0) +func (m *authCheckerMock) Authenticate(req *http.Request, scope security.Scope) error { + return m.Called(req, scope).Error(0) } -func TestCreateAuthenticationMiddleware(t *testing.T) { +func TestAuthenticationMiddleware(t *testing.T) { t.Run("pass", func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", "https://example.com", nil) resp := httptest.NewRecorder() auth := &authCheckerMock{} - auth.On("Authenticate", req).Once().Return(nil) + auth.On("Authenticate", req, security.Scope("mock")).Once().Return(nil) isHandlerCalled := false - middlewareFunc := CreateAuthenticationMiddleware(auth) + middlewareFunc := NewAuthenticationMiddleware(auth, "mock") middlewareFunc.Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { isHandlerCalled = true })).ServeHTTP(resp, req) @@ -40,21 +41,21 @@ func TestCreateAuthenticationMiddleware(t *testing.T) { }) t.Run("fail", func(t *testing.T) { - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", "https://example.com", nil) resp := httptest.NewRecorder() auth := &authCheckerMock{} - auth.On("Authenticate", req).Once().Return(errors.New("error reason")) + auth.On("Authenticate", req, security.Scope("mock")).Once().Return(errors.New("error reason")) isHandlerCalled := false - middlewareFunc := CreateAuthenticationMiddleware(auth) + middlewareFunc := NewAuthenticationMiddleware(auth, "mock") middlewareFunc.Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { isHandlerCalled = true })).ServeHTTP(resp, req) testify.False(t, isHandlerCalled, "Handler shouldn't be called") testify.Equal(t, 403, resp.Code) - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) testify.JSONEq(t, `{ "error": "error reason" }`, string(body)) @@ -63,10 +64,56 @@ func TestCreateAuthenticationMiddleware(t *testing.T) { }) } +func TestConditionalMiddleware(t *testing.T) { + t.Run("true", func(t *testing.T) { + req := httptest.NewRequest("GET", "https://example.com", nil) + resp := httptest.NewRecorder() + + isNestedMiddlewareCalled := false + isHandlerCalled := false + NewConditionalMiddleware( + func(req *http.Request) bool { + return true + }, + func(handler http.Handler) http.Handler { + isNestedMiddlewareCalled = true + return handler + }, + ).Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + isHandlerCalled = true + })).ServeHTTP(resp, req) + + testify.True(t, isNestedMiddlewareCalled, "Nested middleware wasn't called") + testify.True(t, isHandlerCalled, "Handler wasn't called from the middleware") + }) + + t.Run("false", func(t *testing.T) { + req := httptest.NewRequest("GET", "https://example.com", nil) + resp := httptest.NewRecorder() + + isNestedMiddlewareCalled := false + isHandlerCalled := false + NewConditionalMiddleware( + func(req *http.Request) bool { + return false + }, + func(handler http.Handler) http.Handler { + isNestedMiddlewareCalled = true + return handler + }, + ).Middleware(http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + isHandlerCalled = true + })).ServeHTTP(resp, req) + + testify.False(t, isNestedMiddlewareCalled, "Nested middleware shouldn't be called") + testify.True(t, isHandlerCalled, "Handler wasn't called from the middleware") + }) +} + func TestNotFoundHandler(t *testing.T) { assert := testify.New(t) - req := httptest.NewRequest("GET", "http://example.com", nil) + req := httptest.NewRequest("GET", "https://example.com", nil) w := httptest.NewRecorder() NotFoundHandler(w, req) @@ -74,7 +121,7 @@ func TestNotFoundHandler(t *testing.T) { resp := w.Result() assert.Equal(404, resp.StatusCode) assert.Equal("application/json", resp.Header.Get("Content-Type")) - response, _ := ioutil.ReadAll(resp.Body) + response, _ := io.ReadAll(resp.Body) assert.JSONEq(`{ "status": "404", "message": "Not Found" diff --git a/internal/http/api.go b/internal/http/profiles.go similarity index 79% rename from internal/http/api.go rename to internal/http/profiles.go index 8fc8716..bd4251b 100644 --- a/internal/http/api.go +++ b/internal/http/profiles.go @@ -17,19 +17,19 @@ type ProfilesManager interface { RemoveProfileByUuid(ctx context.Context, uuid string) error } -type Api struct { +type ProfilesApi struct { ProfilesManager } -func (ctx *Api) Handler() *mux.Router { +func (ctx *ProfilesApi) Handler() *mux.Router { router := mux.NewRouter().StrictSlash(true) - router.HandleFunc("/profiles", ctx.postProfileHandler).Methods(http.MethodPost) - router.HandleFunc("/profiles/{uuid}", ctx.deleteProfileByUuidHandler).Methods(http.MethodDelete) + router.HandleFunc("/", ctx.postProfileHandler).Methods(http.MethodPost) + router.HandleFunc("/{uuid}", ctx.deleteProfileByUuidHandler).Methods(http.MethodDelete) return router } -func (ctx *Api) postProfileHandler(resp http.ResponseWriter, req *http.Request) { +func (ctx *ProfilesApi) postProfileHandler(resp http.ResponseWriter, req *http.Request) { err := req.ParseForm() if err != nil { apiBadRequest(resp, map[string][]string{ @@ -63,7 +63,7 @@ func (ctx *Api) postProfileHandler(resp http.ResponseWriter, req *http.Request) resp.WriteHeader(http.StatusCreated) } -func (ctx *Api) deleteProfileByUuidHandler(resp http.ResponseWriter, req *http.Request) { +func (ctx *ProfilesApi) deleteProfileByUuidHandler(resp http.ResponseWriter, req *http.Request) { uuid := mux.Vars(req)["uuid"] err := ctx.ProfilesManager.RemoveProfileByUuid(req.Context(), uuid) if err != nil { diff --git a/internal/http/api_test.go b/internal/http/profiles_test.go similarity index 81% rename from internal/http/api_test.go rename to internal/http/profiles_test.go index 697770d..34a4cea 100644 --- a/internal/http/api_test.go +++ b/internal/http/profiles_test.go @@ -30,26 +30,26 @@ func (m *ProfilesManagerMock) RemoveProfileByUuid(ctx context.Context, uuid stri return m.Called(ctx, uuid).Error(0) } -type ApiTestSuite struct { +type ProfilesTestSuite struct { suite.Suite - App *Api + App *ProfilesApi ProfilesManager *ProfilesManagerMock } -func (t *ApiTestSuite) SetupSubTest() { +func (t *ProfilesTestSuite) SetupSubTest() { t.ProfilesManager = &ProfilesManagerMock{} - t.App = &Api{ + t.App = &ProfilesApi{ ProfilesManager: t.ProfilesManager, } } -func (t *ApiTestSuite) TearDownSubTest() { +func (t *ProfilesTestSuite) TearDownSubTest() { t.ProfilesManager.AssertExpectations(t.T()) } -func (t *ApiTestSuite) TestPostProfile() { +func (t *ProfilesTestSuite) TestPostProfile() { t.Run("successfully post profile", func() { t.ProfilesManager.On("PersistProfile", mock.Anything, &db.Profile{ Uuid: "0f657aa8-bfbe-415d-b700-5750090d3af3", @@ -61,7 +61,7 @@ func (t *ApiTestSuite) TestPostProfile() { MojangSignature: "bW9jawo=", }).Once().Return(nil) - req := httptest.NewRequest("POST", "http://chrly/profiles", bytes.NewBufferString(url.Values{ + req := httptest.NewRequest("POST", "http://chrly/", bytes.NewBufferString(url.Values{ "uuid": {"0f657aa8-bfbe-415d-b700-5750090d3af3"}, "username": {"mock_username"}, "skinUrl": {"https://example.com/skin.png"}, @@ -82,7 +82,7 @@ func (t *ApiTestSuite) TestPostProfile() { }) t.Run("handle malformed body", func() { - req := httptest.NewRequest("POST", "http://chrly/profiles", strings.NewReader("invalid;=url?encoded_string")) + req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("invalid;=url?encoded_string")) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() @@ -107,7 +107,7 @@ func (t *ApiTestSuite) TestPostProfile() { }, }) - req := httptest.NewRequest("POST", "http://chrly/profiles", strings.NewReader("")) + req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("")) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() @@ -129,7 +129,7 @@ func (t *ApiTestSuite) TestPostProfile() { t.Run("receive other error", func() { t.ProfilesManager.On("PersistProfile", mock.Anything, mock.Anything).Once().Return(errors.New("mock error")) - req := httptest.NewRequest("POST", "http://chrly/profiles", strings.NewReader("")) + req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("")) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() @@ -140,11 +140,11 @@ func (t *ApiTestSuite) TestPostProfile() { }) } -func (t *ApiTestSuite) TestDeleteProfileByUuid() { +func (t *ProfilesTestSuite) TestDeleteProfileByUuid() { t.Run("successfully delete", func() { t.ProfilesManager.On("RemoveProfileByUuid", mock.Anything, "0f657aa8-bfbe-415d-b700-5750090d3af3").Once().Return(nil) - req := httptest.NewRequest("DELETE", "http://chrly/profiles/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) + req := httptest.NewRequest("DELETE", "http://chrly/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) w := httptest.NewRecorder() t.App.Handler().ServeHTTP(w, req) @@ -158,7 +158,7 @@ func (t *ApiTestSuite) TestDeleteProfileByUuid() { t.Run("error from manager", func() { t.ProfilesManager.On("RemoveProfileByUuid", mock.Anything, mock.Anything).Return(errors.New("mock error")) - req := httptest.NewRequest("DELETE", "http://chrly/profiles/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) + req := httptest.NewRequest("DELETE", "http://chrly/0f657aa8-bfbe-415d-b700-5750090d3af3", nil) w := httptest.NewRecorder() t.App.Handler().ServeHTTP(w, req) @@ -168,6 +168,6 @@ func (t *ApiTestSuite) TestDeleteProfileByUuid() { }) } -func TestApi(t *testing.T) { - suite.Run(t, new(ApiTestSuite)) +func TestProfilesApi(t *testing.T) { + suite.Run(t, new(ProfilesTestSuite)) } diff --git a/internal/http/signer.go b/internal/http/signer.go new file mode 100644 index 0000000..7ac7a88 --- /dev/null +++ b/internal/http/signer.go @@ -0,0 +1,60 @@ +package http + +import ( + "encoding/base64" + "fmt" + "io" + "net/http" + + "github.com/gorilla/mux" +) + +type Signer interface { + Sign(data io.Reader) ([]byte, error) + GetPublicKey(format string) ([]byte, error) +} + +type SignerApi struct { + Signer +} + +func (s *SignerApi) Handler() *mux.Router { + router := mux.NewRouter().StrictSlash(true) + router.HandleFunc("/", s.signHandler).Methods(http.MethodPost) + router.HandleFunc("/public-key.{format:(?:pem|der)}", s.getPublicKeyHandler).Methods(http.MethodGet) + + return router +} + +func (s *SignerApi) signHandler(resp http.ResponseWriter, req *http.Request) { + signature, err := s.Signer.Sign(req.Body) + if err != nil { + apiServerError(resp, fmt.Errorf("unable to sign the value: %w", err)) + return + } + + resp.Header().Set("Content-Type", "application/octet-stream+base64") + + buf := make([]byte, base64.StdEncoding.EncodedLen(len(signature))) + base64.StdEncoding.Encode(buf, signature) + _, _ = resp.Write(buf) +} + +func (s *SignerApi) getPublicKeyHandler(resp http.ResponseWriter, req *http.Request) { + format := mux.Vars(req)["format"] + publicKey, err := s.Signer.GetPublicKey(format) + if err != nil { + apiServerError(resp, fmt.Errorf("unable to retrieve public key: %w", err)) + return + } + + if format == "pem" { + resp.Header().Set("Content-Type", "application/x-pem-file") + resp.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.pem"`) + } else { + resp.Header().Set("Content-Type", "application/octet-stream") + resp.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.der"`) + } + + _, _ = resp.Write(publicKey) +} diff --git a/internal/http/signer_test.go b/internal/http/signer_test.go new file mode 100644 index 0000000..4729e8c --- /dev/null +++ b/internal/http/signer_test.go @@ -0,0 +1,146 @@ +package http + +import ( + "bytes" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" +) + +type SignerMock struct { + mock.Mock +} + +func (m *SignerMock) Sign(data io.Reader) ([]byte, error) { + args := m.Called(data) + var result []byte + if casted, ok := args.Get(0).([]byte); ok { + result = casted + } + + return result, args.Error(1) +} + +func (m *SignerMock) GetPublicKey(format string) ([]byte, error) { + args := m.Called(format) + var result []byte + if casted, ok := args.Get(0).([]byte); ok { + result = casted + } + + return result, args.Error(1) +} + +type SignerApiTestSuite struct { + suite.Suite + + App *SignerApi + + Signer *SignerMock +} + +func (t *SignerApiTestSuite) SetupSubTest() { + t.Signer = &SignerMock{} + + t.App = &SignerApi{ + Signer: t.Signer, + } +} + +func (t *SignerApiTestSuite) TearDownSubTest() { + t.Signer.AssertExpectations(t.T()) +} + +func (t *SignerApiTestSuite) TestSign() { + t.Run("successfully sign", func() { + signature := []byte("mock signature") + t.Signer.On("Sign", mock.Anything).Return(signature, nil).Run(func(args mock.Arguments) { + buf := &bytes.Buffer{} + _, _ = io.Copy(buf, args.Get(0).(io.Reader)) + r, _ := io.ReadAll(buf) + + t.Equal([]byte("mock body to sign"), r) + }) + + req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("mock body to sign")) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusOK, result.StatusCode) + t.Equal("application/octet-stream+base64", result.Header.Get("Content-Type")) + body, _ := io.ReadAll(result.Body) + t.Equal([]byte{0x62, 0x57, 0x39, 0x6a, 0x61, 0x79, 0x42, 0x7a, 0x61, 0x57, 0x64, 0x75, 0x59, 0x58, 0x52, 0x31, 0x63, 0x6d, 0x55, 0x3d}, body) + }) + + t.Run("handle error during sign", func() { + t.Signer.On("Sign", mock.Anything).Return(nil, errors.New("mock error")) + + req := httptest.NewRequest("POST", "http://chrly/", strings.NewReader("mock body to sign")) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusInternalServerError, result.StatusCode) + }) +} + +func (t *SignerApiTestSuite) TestGetPublicKey() { + t.Run("in pem format", func() { + publicKey := []byte("mock public key in pem format") + t.Signer.On("GetPublicKey", "pem").Return(publicKey, nil) + + req := httptest.NewRequest("GET", "http://chrly/public-key.pem", nil) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusOK, result.StatusCode) + t.Equal("application/x-pem-file", result.Header.Get("Content-Type")) + t.Equal(`attachment; filename="yggdrasil_session_pubkey.pem"`, result.Header.Get("Content-Disposition")) + body, _ := io.ReadAll(result.Body) + t.Equal(publicKey, body) + }) + + t.Run("in der format", func() { + publicKey := []byte("mock public key in der format") + t.Signer.On("GetPublicKey", "der").Return(publicKey, nil) + + req := httptest.NewRequest("GET", "http://chrly/public-key.der", nil) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusOK, result.StatusCode) + t.Equal("application/octet-stream", result.Header.Get("Content-Type")) + t.Equal(`attachment; filename="yggdrasil_session_pubkey.der"`, result.Header.Get("Content-Disposition")) + body, _ := io.ReadAll(result.Body) + t.Equal(publicKey, body) + }) + + t.Run("handle error", func() { + t.Signer.On("GetPublicKey", "pem").Return(nil, errors.New("mock error")) + + req := httptest.NewRequest("GET", "http://chrly/public-key.pem", nil) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusInternalServerError, result.StatusCode) + }) +} + +func TestSignerApi(t *testing.T) { + suite.Run(t, new(SignerApiTestSuite)) +} diff --git a/internal/http/skinsystem.go b/internal/http/skinsystem.go index 4a94c7a..dd88914 100644 --- a/internal/http/skinsystem.go +++ b/internal/http/skinsystem.go @@ -2,12 +2,10 @@ package http import ( "context" - "crypto/rsa" - "crypto/x509" "encoding/base64" "encoding/json" - "encoding/pem" "fmt" + "io" "net/http" "strings" "time" @@ -25,40 +23,39 @@ type ProfilesProvider interface { FindProfileByUsername(ctx context.Context, username string, allowProxy bool) (*db.Profile, error) } -// TexturesSigner uses context because in the future we may separate this logic into a separate microservice -type TexturesSigner interface { - SignTextures(ctx context.Context, textures string) (string, error) - GetPublicKey(ctx context.Context) (*rsa.PublicKey, error) +// SignerService uses context because in the future we may separate this logic as an external microservice +type SignerService interface { + Sign(ctx context.Context, data string) (string, error) + GetPublicKey(ctx context.Context, format string) (string, error) } type Skinsystem struct { ProfilesProvider - TexturesSigner + SignerService TexturesExtraParamName string TexturesExtraParamValue string } -func (ctx *Skinsystem) Handler() *mux.Router { +func (s *Skinsystem) Handler() *mux.Router { router := mux.NewRouter().StrictSlash(true) - router.HandleFunc("/skins/{username}", ctx.skinHandler).Methods(http.MethodGet) - router.HandleFunc("/cloaks/{username}", ctx.capeHandler).Methods(http.MethodGet) + router.HandleFunc("/skins/{username}", s.skinHandler).Methods(http.MethodGet) + router.HandleFunc("/cloaks/{username}", s.capeHandler).Methods(http.MethodGet) // TODO: alias /capes/{username}? - router.HandleFunc("/textures/{username}", ctx.texturesHandler).Methods(http.MethodGet) - router.HandleFunc("/textures/signed/{username}", ctx.signedTexturesHandler).Methods(http.MethodGet) - router.HandleFunc("/profile/{username}", ctx.profileHandler).Methods(http.MethodGet) + router.HandleFunc("/textures/{username}", s.texturesHandler).Methods(http.MethodGet) + router.HandleFunc("/textures/signed/{username}", s.signedTexturesHandler).Methods(http.MethodGet) + router.HandleFunc("/profile/{username}", s.profileHandler).Methods(http.MethodGet) // Legacy - router.HandleFunc("/skins", ctx.skinGetHandler).Methods(http.MethodGet) - router.HandleFunc("/cloaks", ctx.capeGetHandler).Methods(http.MethodGet) + router.HandleFunc("/skins", s.skinGetHandler).Methods(http.MethodGet) + router.HandleFunc("/cloaks", s.capeGetHandler).Methods(http.MethodGet) // Utils - router.HandleFunc("/signature-verification-key.der", ctx.signatureVerificationKeyHandler).Methods(http.MethodGet) - router.HandleFunc("/signature-verification-key.pem", ctx.signatureVerificationKeyHandler).Methods(http.MethodGet) + router.HandleFunc("/signature-verification-key.{format:(?:pem|der)}", s.signatureVerificationKeyHandler).Methods(http.MethodGet) return router } -func (ctx *Skinsystem) skinHandler(response http.ResponseWriter, request *http.Request) { - profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) +func (s *Skinsystem) skinHandler(response http.ResponseWriter, request *http.Request) { + profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) if err != nil { apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) return @@ -71,7 +68,7 @@ func (ctx *Skinsystem) skinHandler(response http.ResponseWriter, request *http.R http.Redirect(response, request, profile.SkinUrl, http.StatusMovedPermanently) } -func (ctx *Skinsystem) skinGetHandler(response http.ResponseWriter, request *http.Request) { +func (s *Skinsystem) skinGetHandler(response http.ResponseWriter, request *http.Request) { username := request.URL.Query().Get("name") if username == "" { response.WriteHeader(http.StatusBadRequest) @@ -80,11 +77,11 @@ func (ctx *Skinsystem) skinGetHandler(response http.ResponseWriter, request *htt mux.Vars(request)["username"] = username - ctx.skinHandler(response, request) + s.skinHandler(response, request) } -func (ctx *Skinsystem) capeHandler(response http.ResponseWriter, request *http.Request) { - profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) +func (s *Skinsystem) capeHandler(response http.ResponseWriter, request *http.Request) { + profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), parseUsername(mux.Vars(request)["username"]), true) if err != nil { apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) return @@ -97,7 +94,7 @@ func (ctx *Skinsystem) capeHandler(response http.ResponseWriter, request *http.R http.Redirect(response, request, profile.CapeUrl, http.StatusMovedPermanently) } -func (ctx *Skinsystem) capeGetHandler(response http.ResponseWriter, request *http.Request) { +func (s *Skinsystem) capeGetHandler(response http.ResponseWriter, request *http.Request) { username := request.URL.Query().Get("name") if username == "" { response.WriteHeader(http.StatusBadRequest) @@ -106,11 +103,11 @@ func (ctx *Skinsystem) capeGetHandler(response http.ResponseWriter, request *htt mux.Vars(request)["username"] = username - ctx.capeHandler(response, request) + s.capeHandler(response, request) } -func (ctx *Skinsystem) texturesHandler(response http.ResponseWriter, request *http.Request) { - profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) +func (s *Skinsystem) texturesHandler(response http.ResponseWriter, request *http.Request) { + profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) if err != nil { apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) return @@ -133,8 +130,8 @@ func (ctx *Skinsystem) texturesHandler(response http.ResponseWriter, request *ht _, _ = response.Write(responseData) } -func (ctx *Skinsystem) signedTexturesHandler(response http.ResponseWriter, request *http.Request) { - profile, err := ctx.ProfilesProvider.FindProfileByUsername( +func (s *Skinsystem) signedTexturesHandler(response http.ResponseWriter, request *http.Request) { + profile, err := s.ProfilesProvider.FindProfileByUsername( request.Context(), mux.Vars(request)["username"], getToBool(request.URL.Query().Get("proxy")), @@ -164,8 +161,8 @@ func (ctx *Skinsystem) signedTexturesHandler(response http.ResponseWriter, reque Value: profile.MojangTextures, }, { - Name: ctx.TexturesExtraParamName, - Value: ctx.TexturesExtraParamValue, + Name: s.TexturesExtraParamName, + Value: s.TexturesExtraParamValue, }, }, } @@ -175,8 +172,8 @@ func (ctx *Skinsystem) signedTexturesHandler(response http.ResponseWriter, reque _, _ = response.Write(responseJson) } -func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *http.Request) { - profile, err := ctx.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) +func (s *Skinsystem) profileHandler(response http.ResponseWriter, request *http.Request) { + profile, err := s.ProfilesProvider.FindProfileByUsername(request.Context(), mux.Vars(request)["username"], true) if err != nil { apiServerError(response, fmt.Errorf("unable to retrieve a profile: %w", err)) return @@ -203,7 +200,7 @@ func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *htt } if request.URL.Query().Has("unsigned") && !getToBool(request.URL.Query().Get("unsigned")) { - signature, err := ctx.TexturesSigner.SignTextures(request.Context(), texturesProp.Value) + signature, err := s.SignerService.Sign(request.Context(), texturesProp.Value) if err != nil { apiServerError(response, fmt.Errorf("unable to sign textures: %w", err)) return @@ -218,8 +215,8 @@ func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *htt Props: []*mojang.Property{ texturesProp, { - Name: ctx.TexturesExtraParamName, - Value: ctx.TexturesExtraParamValue, + Name: s.TexturesExtraParamName, + Value: s.TexturesExtraParamValue, }, }, } @@ -229,32 +226,23 @@ func (ctx *Skinsystem) profileHandler(response http.ResponseWriter, request *htt _, _ = response.Write(responseJson) } -func (ctx *Skinsystem) signatureVerificationKeyHandler(response http.ResponseWriter, request *http.Request) { - publicKey, err := ctx.TexturesSigner.GetPublicKey(request.Context()) +func (s *Skinsystem) signatureVerificationKeyHandler(response http.ResponseWriter, request *http.Request) { + format := mux.Vars(request)["format"] + publicKey, err := s.SignerService.GetPublicKey(request.Context(), format) if err != nil { - panic(err) + apiServerError(response, fmt.Errorf("unable to retrieve public key: %w", err)) + return } - asn1Bytes, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - panic(err) - } - - if strings.HasSuffix(request.URL.Path, ".pem") { - publicKeyBlock := pem.Block{ - Type: "PUBLIC KEY", - Bytes: asn1Bytes, - } - - publicKeyPemBytes := pem.EncodeToMemory(&publicKeyBlock) - - response.Header().Set("Content-Disposition", "attachment; filename=\"yggdrasil_session_pubkey.pem\"") - _, _ = response.Write(publicKeyPemBytes) + if format == "pem" { + response.Header().Set("Content-Type", "application/x-pem-file") + response.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.pem"`) } else { response.Header().Set("Content-Type", "application/octet-stream") - response.Header().Set("Content-Disposition", "attachment; filename=\"yggdrasil_session_pubkey.der\"") - _, _ = response.Write(asn1Bytes) + response.Header().Set("Content-Disposition", `attachment; filename="yggdrasil_session_pubkey.der"`) } + + _, _ = io.WriteString(response, publicKey) } func parseUsername(username string) string { diff --git a/internal/http/skinsystem_test.go b/internal/http/skinsystem_test.go index 1c83e80..347a730 100644 --- a/internal/http/skinsystem_test.go +++ b/internal/http/skinsystem_test.go @@ -2,14 +2,10 @@ package http import ( "context" - "crypto/rsa" - "crypto/x509" - "encoding/pem" "errors" "io" "net/http" "net/http/httptest" - "strings" "testing" "time" @@ -34,23 +30,18 @@ func (m *ProfilesProviderMock) FindProfileByUsername(ctx context.Context, userna return result, args.Error(1) } -type TexturesSignerMock struct { +type SignerServiceMock struct { mock.Mock } -func (m *TexturesSignerMock) SignTextures(ctx context.Context, textures string) (string, error) { - args := m.Called(ctx, textures) +func (m *SignerServiceMock) Sign(ctx context.Context, data string) (string, error) { + args := m.Called(ctx, data) return args.String(0), args.Error(1) } -func (m *TexturesSignerMock) GetPublicKey(ctx context.Context) (*rsa.PublicKey, error) { - args := m.Called(ctx) - var publicKey *rsa.PublicKey - if casted, ok := args.Get(0).(*rsa.PublicKey); ok { - publicKey = casted - } - - return publicKey, args.Error(1) +func (m *SignerServiceMock) GetPublicKey(ctx context.Context, format string) (string, error) { + args := m.Called(ctx, format) + return args.String(0), args.Error(1) } type SkinsystemTestSuite struct { @@ -59,7 +50,7 @@ type SkinsystemTestSuite struct { App *Skinsystem ProfilesProvider *ProfilesProviderMock - TexturesSigner *TexturesSignerMock + SignerService *SignerServiceMock } /******************** @@ -73,11 +64,11 @@ func (t *SkinsystemTestSuite) SetupSubTest() { } t.ProfilesProvider = &ProfilesProviderMock{} - t.TexturesSigner = &TexturesSignerMock{} + t.SignerService = &SignerServiceMock{} t.App = &Skinsystem{ ProfilesProvider: t.ProfilesProvider, - TexturesSigner: t.TexturesSigner, + SignerService: t.SignerService, TexturesExtraParamName: "texturesParamName", TexturesExtraParamValue: "texturesParamValue", } @@ -85,7 +76,7 @@ func (t *SkinsystemTestSuite) SetupSubTest() { func (t *SkinsystemTestSuite) TearDownSubTest() { t.ProfilesProvider.AssertExpectations(t.T()) - t.TexturesSigner.AssertExpectations(t.T()) + t.SignerService.AssertExpectations(t.T()) } func (t *SkinsystemTestSuite) TestSkinHandler() { @@ -470,7 +461,7 @@ func (t *SkinsystemTestSuite) TestProfile() { SkinUrl: "https://example.com/skin.png", SkinModel: "slim", }, nil) - t.TexturesSigner.On("SignTextures", mock.Anything, "eyJ0aW1lc3RhbXAiOjE2MTQyMTQyMjMwMDAsInByb2ZpbGVJZCI6Im1vY2stdXVpZCIsInByb2ZpbGVOYW1lIjoibW9ja191c2VybmFtZSIsInRleHR1cmVzIjp7IlNLSU4iOnsidXJsIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9za2luLnBuZyIsIm1ldGFkYXRhIjp7Im1vZGVsIjoic2xpbSJ9fX19").Return("mock signature", nil) + t.SignerService.On("Sign", mock.Anything, "eyJ0aW1lc3RhbXAiOjE2MTQyMTQyMjMwMDAsInByb2ZpbGVJZCI6Im1vY2stdXVpZCIsInByb2ZpbGVOYW1lIjoibW9ja191c2VybmFtZSIsInRleHR1cmVzIjp7IlNLSU4iOnsidXJsIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS9za2luLnBuZyIsIm1ldGFkYXRhIjp7Im1vZGVsIjoic2xpbSJ9fX19").Return("mock signature", nil) t.App.Handler().ServeHTTP(w, req) @@ -526,7 +517,7 @@ func (t *SkinsystemTestSuite) TestProfile() { w := httptest.NewRecorder() t.ProfilesProvider.On("FindProfileByUsername", mock.Anything, "mock_username", true).Return(&db.Profile{}, nil) - t.TexturesSigner.On("SignTextures", mock.Anything, mock.Anything).Return("", errors.New("mock error")) + t.SignerService.On("Sign", mock.Anything, mock.Anything).Return("", errors.New("mock error")) t.App.Handler().ServeHTTP(w, req) @@ -535,77 +526,52 @@ func (t *SkinsystemTestSuite) TestProfile() { }) } -type signingKeyTestCase struct { - Name string - KeyFormat string - BeforeTest func(suite *SkinsystemTestSuite) - PanicErr string - AfterTest func(suite *SkinsystemTestSuite, response *http.Response) -} - -var signingKeyTestsCases = []*signingKeyTestCase{ - { - Name: "Get public key in DER format", - KeyFormat: "DER", - BeforeTest: func(suite *SkinsystemTestSuite) { - pubPem, _ := pem.Decode([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnm\nUDlzHBQH3DpYef5WCO32TDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQ==\n-----END PUBLIC KEY-----")) - publicKey, _ := x509.ParsePKIXPublicKey(pubPem.Bytes) - - suite.TexturesSigner.On("GetPublicKey", mock.Anything).Return(publicKey, nil) - }, - AfterTest: func(suite *SkinsystemTestSuite, response *http.Response) { - suite.Equal(200, response.StatusCode) - suite.Equal("application/octet-stream", response.Header.Get("Content-Type")) - suite.Equal("attachment; filename=\"yggdrasil_session_pubkey.der\"", response.Header.Get("Content-Disposition")) - body, _ := io.ReadAll(response.Body) - suite.Equal([]byte{48, 92, 48, 13, 6, 9, 42, 134, 72, 134, 247, 13, 1, 1, 1, 5, 0, 3, 75, 0, 48, 72, 2, 65, 0, 214, 212, 165, 80, 153, 144, 194, 169, 126, 246, 25, 211, 197, 183, 150, 233, 157, 1, 166, 49, 44, 25, 230, 80, 57, 115, 28, 20, 7, 220, 58, 88, 121, 254, 86, 8, 237, 246, 76, 53, 58, 125, 226, 9, 231, 192, 52, 148, 12, 176, 130, 214, 120, 195, 8, 182, 116, 97, 206, 207, 253, 97, 2, 247, 2, 3, 1, 0, 1}, body) - }, - }, - { - Name: "Get public key in PEM format", - KeyFormat: "PEM", - BeforeTest: func(suite *SkinsystemTestSuite) { - pubPem, _ := pem.Decode([]byte("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnm\nUDlzHBQH3DpYef5WCO32TDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQ==\n-----END PUBLIC KEY-----")) - publicKey, _ := x509.ParsePKIXPublicKey(pubPem.Bytes) - - suite.TexturesSigner.On("GetPublicKey", mock.Anything).Return(publicKey, nil) - }, - AfterTest: func(suite *SkinsystemTestSuite, response *http.Response) { - suite.Equal(200, response.StatusCode) - suite.Equal("text/plain; charset=utf-8", response.Header.Get("Content-Type")) - suite.Equal("attachment; filename=\"yggdrasil_session_pubkey.pem\"", response.Header.Get("Content-Disposition")) - body, _ := io.ReadAll(response.Body) - suite.Equal("-----BEGIN PUBLIC KEY-----\nMFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnm\nUDlzHBQH3DpYef5WCO32TDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQ==\n-----END PUBLIC KEY-----\n", string(body)) - }, - }, - { - Name: "Error while obtaining public key", - KeyFormat: "DER", - BeforeTest: func(suite *SkinsystemTestSuite) { - suite.TexturesSigner.On("GetPublicKey", mock.Anything).Return(nil, errors.New("textures signer error")) - }, - PanicErr: "textures signer error", - }, -} - func (t *SkinsystemTestSuite) TestSignatureVerificationKey() { - for _, testCase := range signingKeyTestsCases { - t.Run(testCase.Name, func() { - testCase.BeforeTest(t) + t.Run("in pem format", func() { + publicKey := "mock public key in pem format" + t.SignerService.On("GetPublicKey", mock.Anything, "pem").Return(publicKey, nil) - req := httptest.NewRequest("GET", "http://chrly/signature-verification-key."+strings.ToLower(testCase.KeyFormat), nil) - w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "http://chrly/signature-verification-key.pem", nil) + w := httptest.NewRecorder() - if testCase.PanicErr != "" { - t.PanicsWithError(testCase.PanicErr, func() { - t.App.Handler().ServeHTTP(w, req) - }) - } else { - t.App.Handler().ServeHTTP(w, req) - testCase.AfterTest(t, w.Result()) - } - }) - } + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusOK, result.StatusCode) + t.Equal("application/x-pem-file", result.Header.Get("Content-Type")) + t.Equal(`attachment; filename="yggdrasil_session_pubkey.pem"`, result.Header.Get("Content-Disposition")) + body, _ := io.ReadAll(result.Body) + t.Equal(publicKey, string(body)) + }) + + t.Run("in der format", func() { + publicKey := "mock public key in der format" + t.SignerService.On("GetPublicKey", mock.Anything, "der").Return(publicKey, nil) + + req := httptest.NewRequest("GET", "http://chrly/signature-verification-key.der", nil) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusOK, result.StatusCode) + t.Equal("application/octet-stream", result.Header.Get("Content-Type")) + t.Equal(`attachment; filename="yggdrasil_session_pubkey.der"`, result.Header.Get("Content-Disposition")) + body, _ := io.ReadAll(result.Body) + t.Equal(publicKey, string(body)) + }) + + t.Run("handle error", func() { + t.SignerService.On("GetPublicKey", mock.Anything, "pem").Return("", errors.New("mock error")) + + req := httptest.NewRequest("GET", "http://chrly/signature-verification-key.pem", nil) + w := httptest.NewRecorder() + + t.App.Handler().ServeHTTP(w, req) + + result := w.Result() + t.Equal(http.StatusInternalServerError, result.StatusCode) + }) } func TestSkinsystem(t *testing.T) { diff --git a/internal/security/jwt.go b/internal/security/jwt.go index 878958c..e04f3d9 100644 --- a/internal/security/jwt.go +++ b/internal/security/jwt.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "net/http" + "slices" "strings" "time" @@ -15,14 +16,23 @@ import ( var now = time.Now var signingMethod = jwt.SigningMethodHS256 -const scopesClaim = "scopes" - type Scope string const ( - ProfileScope Scope = "profiles" + ProfilesScope Scope = "profiles" + SignScope Scope = "sign" ) +var validScopes = []Scope{ + ProfilesScope, + SignScope, +} + +type claims struct { + jwt.RegisteredClaims + Scopes []Scope `json:"scopes"` +} + func NewJwt(key []byte) *Jwt { return &Jwt{ Key: key, @@ -38,11 +48,20 @@ func (t *Jwt) NewToken(scopes ...Scope) (string, error) { return "", errors.New("you must specify at least one scope") } - token := jwt.NewWithClaims(signingMethod, jwt.MapClaims{ - "iss": "chrly", - "iat": now().Unix(), - scopesClaim: scopes, - }) + for _, scope := range scopes { + if !slices.Contains(validScopes, scope) { + return "", fmt.Errorf("unknown scope %s", scope) + } + } + + token := jwt.New(signingMethod) + token.Claims = &claims{ + jwt.RegisteredClaims{ + Issuer: "chrly", + IssuedAt: jwt.NewNumericDate(now()), + }, + scopes, + } token.Header["v"] = version.MajorVersion return token.SignedString(t.Key) @@ -52,7 +71,7 @@ func (t *Jwt) NewToken(scopes ...Scope) (string, error) { var MissingAuthenticationError = errors.New("authentication value not provided") var InvalidTokenError = errors.New("passed authentication value is invalid") -func (t *Jwt) Authenticate(req *http.Request) error { +func (t *Jwt) Authenticate(req *http.Request, scope Scope) error { bearerToken := req.Header.Get("Authorization") if bearerToken == "" { return MissingAuthenticationError @@ -62,8 +81,8 @@ func (t *Jwt) Authenticate(req *http.Request) error { return InvalidTokenError } - tokenStr := bearerToken[7:] - token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { + tokenStr := bearerToken[7:] // trim "bearer " part + token, err := jwt.ParseWithClaims(tokenStr, &claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } @@ -78,5 +97,10 @@ func (t *Jwt) Authenticate(req *http.Request) error { return errors.Join(InvalidTokenError, errors.New("missing v header")) } + claims := token.Claims.(*claims) + if !slices.Contains(claims.Scopes, scope) { + return errors.New("the token doesn't have the scope to perform the action") + } + return nil } diff --git a/internal/security/jwt_test.go b/internal/security/jwt_test.go index 40b53ea..232d26a 100644 --- a/internal/security/jwt_test.go +++ b/internal/security/jwt_test.go @@ -16,10 +16,16 @@ func TestJwtAuth_NewToken(t *testing.T) { return time.Date(2024, 2, 1, 11, 26, 15, 0, time.UTC) } - t.Run("with scope", func(t *testing.T) { - token, err := jwt.NewToken(ProfileScope, "custom-scope") + t.Run("with known scope", func(t *testing.T) { + token, err := jwt.NewToken(ProfilesScope, SignScope) require.NoError(t, err) - require.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInYiOjV9.eyJpYXQiOjE3MDY3ODY3NzUsImlzcyI6ImNocmx5Iiwic2NvcGVzIjpbInByb2ZpbGVzIiwiY3VzdG9tLXNjb3BlIl19.Iq673YyWWkJZjIkBmKYRN8Lx9qoD39S_e-MegG0aORM", token) + require.Equal(t, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIsInYiOjV9.eyJpc3MiOiJjaHJseSIsImlhdCI6MTcwNjc4Njc3NSwic2NvcGVzIjpbInByb2ZpbGVzIiwic2lnbiJdfQ.HkNGiDba3I_bLGN6sF0eTE5n6rMLgYfAZZEqI4xb2X4", token) + }) + + t.Run("with unknown scope", func(t *testing.T) { + token, err := jwt.NewToken("scope-123") + require.ErrorContains(t, err, "unknown") + require.Empty(t, token) }) t.Run("no scopes", func(t *testing.T) { @@ -34,41 +40,48 @@ func TestJwtAuth_Authenticate(t *testing.T) { t.Run("success", func(t *testing.T) { req := httptest.NewRequest("POST", "http://localhost", nil) req.Header.Add("Authorization", "Bearer "+jwtString) - err := jwt.Authenticate(req) + err := jwt.Authenticate(req, ProfilesScope) require.NoError(t, err) }) + t.Run("has no required scope", func(t *testing.T) { + req := httptest.NewRequest("POST", "http://localhost", nil) + req.Header.Add("Authorization", "Bearer "+jwtString) + err := jwt.Authenticate(req, SignScope) + require.ErrorContains(t, err, "scope") + }) + t.Run("request without auth header", func(t *testing.T) { req := httptest.NewRequest("POST", "http://localhost", nil) - err := jwt.Authenticate(req) + err := jwt.Authenticate(req, ProfilesScope) require.ErrorIs(t, err, MissingAuthenticationError) }) t.Run("no bearer token prefix", func(t *testing.T) { req := httptest.NewRequest("POST", "http://localhost", nil) req.Header.Add("Authorization", "trash") - err := jwt.Authenticate(req) + err := jwt.Authenticate(req, ProfilesScope) require.ErrorIs(t, err, InvalidTokenError) }) t.Run("bearer token but not jwt", func(t *testing.T) { req := httptest.NewRequest("POST", "http://localhost", nil) req.Header.Add("Authorization", "Bearer seems.like.jwt") - err := jwt.Authenticate(req) + err := jwt.Authenticate(req, ProfilesScope) require.ErrorIs(t, err, InvalidTokenError) }) t.Run("invalid signature", func(t *testing.T) { req := httptest.NewRequest("POST", "http://localhost", nil) req.Header.Add("Authorization", "Bearer "+jwtString+"123") - err := jwt.Authenticate(req) + err := jwt.Authenticate(req, ProfilesScope) require.ErrorIs(t, err, InvalidTokenError) }) t.Run("missing v header", func(t *testing.T) { req := httptest.NewRequest("POST", "http://localhost", nil) req.Header.Add("Authorization", "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjE3MDY3ODY3NzUsImlzcyI6ImNocmx5Iiwic2NvcGVzIjpbInByb2ZpbGVzIl19.zOX2ZKyU37kjwt1p9uCHxALxWQD2UC0wWcAcNvBXGq0") - err := jwt.Authenticate(req) + err := jwt.Authenticate(req, ProfilesScope) require.ErrorIs(t, err, InvalidTokenError) require.ErrorContains(t, err, "missing v header") }) diff --git a/internal/security/signer.go b/internal/security/signer.go index 2ed2a42..b42049e 100644 --- a/internal/security/signer.go +++ b/internal/security/signer.go @@ -1,15 +1,18 @@ package security import ( - "context" "crypto" "crypto/rand" "crypto/rsa" "crypto/sha1" - "encoding/base64" + "crypto/x509" + "encoding/pem" + "errors" + "io" ) var randomReader = rand.Reader +var invalidKeyFormat = errors.New(`invalid key format: it should be"der" or "pem"`) func NewSigner(key *rsa.PrivateKey) *Signer { return &Signer{Key: key} @@ -19,23 +22,38 @@ type Signer struct { Key *rsa.PrivateKey } -func (s *Signer) SignTextures(ctx context.Context, textures string) (string, error) { - message := []byte(textures) +func (s *Signer) Sign(data io.Reader) ([]byte, error) { messageHash := sha1.New() - _, err := messageHash.Write(message) + _, err := io.Copy(messageHash, data) if err != nil { - return "", err + return nil, err } messageHashSum := messageHash.Sum(nil) signature, err := rsa.SignPKCS1v15(randomReader, s.Key, crypto.SHA1, messageHashSum) if err != nil { - return "", err + return nil, err } - return base64.StdEncoding.EncodeToString(signature), nil + return signature, nil } -func (s *Signer) GetPublicKey(ctx context.Context) (*rsa.PublicKey, error) { - return &s.Key.PublicKey, nil +func (s *Signer) GetPublicKey(format string) ([]byte, error) { + if format != "der" && format != "pem" { + return nil, invalidKeyFormat + } + + asn1Bytes, err := x509.MarshalPKIXPublicKey(s.Key.Public()) + if err != nil { + return nil, err + } + + if format == "pem" { + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: asn1Bytes, + }), nil + } + + return asn1Bytes, nil } diff --git a/internal/security/signer_test.go b/internal/security/signer_test.go index ea7f551..39384b3 100644 --- a/internal/security/signer_test.go +++ b/internal/security/signer_test.go @@ -1,10 +1,9 @@ package security import ( - "context" - "crypto/rsa" "crypto/x509" "encoding/pem" + "strings" "testing" "github.com/stretchr/testify/require" @@ -17,7 +16,7 @@ func (c *ConstantReader) Read(p []byte) (int, error) { return 1, nil } -func TestSigner_SignTextures(t *testing.T) { +func TestSigner_Sign(t *testing.T) { randomReader = &ConstantReader{} rawKey, _ := pem.Decode([]byte("-----BEGIN RSA PRIVATE KEY-----\nMIIBOwIBAAJBANbUpVCZkMKpfvYZ08W3lumdAaYxLBnmUDlzHBQH3DpYef5WCO32\nTDU6feIJ58A0lAywgtZ4wwi2dGHOz/1hAvcCAwEAAQJAItaxSHTe6PKbyEU/9pxj\nONdhYRYwVLLo56gnMYhkyoEqaaMsfov8hhoepkYZBMvZFB2bDOsQ2SaJ+E2eiBO4\nAQIhAPssS0+BR9w0bOdmjGqmdE9NrN5UJQcOW13s29+6QzUBAiEA2vWOepA5Apiu\npEA3pwoGdkVCrNSnnKjDQzDXBnpd3/cCIEFNd9sY4qUG4FWdXN6RnmXL7Sj0uZfH\nDMwzu8rEM5sBAiEAhvdoDNqLmbMdq3c+FsPSOeL1d21Zp/JK8kbPtFmHNf8CIQDV\n6FSZDwvWfuxaM7BsycQONkjDBTPNu+lqctJBGnBv3A==\n-----END RSA PRIVATE KEY-----\n")) @@ -25,9 +24,14 @@ func TestSigner_SignTextures(t *testing.T) { signer := NewSigner(key) - signature, err := signer.SignTextures(context.Background(), "eyJ0aW1lc3RhbXAiOjE2MTQzMDcxMzQsInByb2ZpbGVJZCI6ImZmYzhmZGM5NTgyNDUwOWU4YTU3Yzk5Yjk0MGZiOTk2IiwicHJvZmlsZU5hbWUiOiJFcmlja1NrcmF1Y2giLCJ0ZXh0dXJlcyI6eyJTS0lOIjp7InVybCI6Imh0dHA6Ly9lbHkuYnkvc3RvcmFnZS9za2lucy82OWM2NzQwZDI5OTNlNWQ2ZjZhN2ZjOTI0MjBlZmMyOS5wbmcifX0sImVseSI6dHJ1ZX0") + signature, err := signer.Sign(strings.NewReader("mock string to sign")) require.NoError(t, err) - require.Equal(t, "IyHCxTP5ITquEXTHcwCtLd08jWWy16JwlQeWg8naxhoAVQecHGRdzHRscuxtdq/446kmeox7h4EfRN2A2ZLL+A==", signature) + require.Equal(t, []byte{ + 0xd0, 0x88, 0xc6, 0x65, 0x27, 0x5d, 0xe4, 0x86, 0x6b, 0x7a, 0x5a, 0xd, 0x94, 0x6f, 0x80, 0x88, 0x12, 0x8e, 0x65, + 0x75, 0xfb, 0xba, 0xcb, 0x7f, 0x90, 0xf5, 0xae, 0x5d, 0x2c, 0x5d, 0x60, 0xf6, 0x83, 0x54, 0xd3, 0x40, 0xd, 0x1f, + 0xc0, 0xbc, 0x6d, 0xa8, 0x6f, 0x6, 0xd8, 0x38, 0x74, 0x5b, 0x4f, 0x15, 0x82, 0x6d, 0x67, 0x95, 0x7b, 0xf, 0xcc, + 0xf3, 0x51, 0xfe, 0xcd, 0xb9, 0x1e, 0xdf, + }, signature) } func TestSigner_GetPublicKey(t *testing.T) { @@ -38,7 +42,40 @@ func TestSigner_GetPublicKey(t *testing.T) { signer := NewSigner(key) - publicKey, err := signer.GetPublicKey(context.Background()) - require.NoError(t, err) - require.IsType(t, &rsa.PublicKey{}, publicKey) + t.Run("pem format", func(t *testing.T) { + publicKey, err := signer.GetPublicKey("pem") + require.NoError(t, err) + require.Equal(t, []byte{ + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x50, 0x55, 0x42, 0x4c, 0x49, 0x43, 0x20, + 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0xa, 0x4d, 0x46, 0x77, 0x77, 0x44, 0x51, 0x59, 0x4a, 0x4b, + 0x6f, 0x5a, 0x49, 0x68, 0x76, 0x63, 0x4e, 0x41, 0x51, 0x45, 0x42, 0x42, 0x51, 0x41, 0x44, 0x53, 0x77, 0x41, + 0x77, 0x53, 0x41, 0x4a, 0x42, 0x41, 0x4e, 0x62, 0x55, 0x70, 0x56, 0x43, 0x5a, 0x6b, 0x4d, 0x4b, 0x70, 0x66, + 0x76, 0x59, 0x5a, 0x30, 0x38, 0x57, 0x33, 0x6c, 0x75, 0x6d, 0x64, 0x41, 0x61, 0x59, 0x78, 0x4c, 0x42, 0x6e, + 0x6d, 0xa, 0x55, 0x44, 0x6c, 0x7a, 0x48, 0x42, 0x51, 0x48, 0x33, 0x44, 0x70, 0x59, 0x65, 0x66, 0x35, 0x57, + 0x43, 0x4f, 0x33, 0x32, 0x54, 0x44, 0x55, 0x36, 0x66, 0x65, 0x49, 0x4a, 0x35, 0x38, 0x41, 0x30, 0x6c, 0x41, + 0x79, 0x77, 0x67, 0x74, 0x5a, 0x34, 0x77, 0x77, 0x69, 0x32, 0x64, 0x47, 0x48, 0x4f, 0x7a, 0x2f, 0x31, 0x68, + 0x41, 0x76, 0x63, 0x43, 0x41, 0x77, 0x45, 0x41, 0x41, 0x51, 0x3d, 0x3d, 0xa, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, + 0x45, 0x4e, 0x44, 0x20, 0x50, 0x55, 0x42, 0x4c, 0x49, 0x43, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, + 0x2d, 0xa, + }, publicKey) + }) + + t.Run("der format", func(t *testing.T) { + publicKey, err := signer.GetPublicKey("der") + require.NoError(t, err) + require.Equal(t, []byte{ + 0x30, 0x5c, 0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0x1, 0x5, 0x0, 0x3, 0x4b, 0x0, + 0x30, 0x48, 0x2, 0x41, 0x0, 0xd6, 0xd4, 0xa5, 0x50, 0x99, 0x90, 0xc2, 0xa9, 0x7e, 0xf6, 0x19, 0xd3, 0xc5, + 0xb7, 0x96, 0xe9, 0x9d, 0x1, 0xa6, 0x31, 0x2c, 0x19, 0xe6, 0x50, 0x39, 0x73, 0x1c, 0x14, 0x7, 0xdc, 0x3a, + 0x58, 0x79, 0xfe, 0x56, 0x8, 0xed, 0xf6, 0x4c, 0x35, 0x3a, 0x7d, 0xe2, 0x9, 0xe7, 0xc0, 0x34, 0x94, 0xc, + 0xb0, 0x82, 0xd6, 0x78, 0xc3, 0x8, 0xb6, 0x74, 0x61, 0xce, 0xcf, 0xfd, 0x61, 0x2, 0xf7, 0x2, 0x3, 0x1, 0x0, + 0x1, + }, publicKey) + }) + + t.Run("unknown format", func(t *testing.T) { + publicKey, err := signer.GetPublicKey("unknown") + require.Nil(t, publicKey) + require.ErrorContains(t, err, "invalid") + }) }