From 7ae45b522e4c0852d7fba794702ba682ee9e0432 Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Wed, 21 Apr 2021 15:09:38 +0100 Subject: [PATCH] make multi tenancy based on namespace and ID of account (#85) * make multi tenancy based on namespace and ID of account not just namespace * fix tests --- go.sum | 2 -- groups/handler/handler_test.go | 3 ++- invites/handler/invites_test.go | 3 ++- pkg/gorm/wrapper.go | 19 ++++++++++++++----- streams/handler/handler.go | 14 +++++++++----- streams/handler/handler_test.go | 22 +++++++++++----------- streams/handler/publish.go | 2 +- streams/handler/publish_test.go | 4 ++-- streams/handler/subscribe.go | 9 +++++++-- streams/handler/subscribe_test.go | 4 ++-- streams/handler/token.go | 10 +++++++--- streams/main.go | 20 +++++++++----------- users/handler/handler_test.go | 3 ++- 13 files changed, 68 insertions(+), 47 deletions(-) diff --git a/go.sum b/go.sum index a96d363..2a19fe1 100644 --- a/go.sum +++ b/go.sum @@ -392,8 +392,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5 github.com/micro/dev v0.0.0-20201117163752-d3cfc9788dfa h1:1BoFPE4/NTF7WKLZWsEFImOsN143QAU7Dkw9J2/qFXA= github.com/micro/dev v0.0.0-20201117163752-d3cfc9788dfa/go.mod h1:j/8E+ezN/ij7a9BXBHMKmLayFfUW1O4h/Owdv67B0X0= github.com/micro/micro/v3 v3.0.0-beta.6.0.20201016094841-ca8ffd563b2b/go.mod h1:RPJTp9meQAppzW/9jgQtfJmPpRJAySVPbz9uur4B3Ko= -github.com/micro/micro/v3 v3.1.2-0.20210311170414-40583563ada6 h1:uilKEf27gjxx/ZL0wdRGsKmjZJtm/PNEyG/RT4L0pzw= -github.com/micro/micro/v3 v3.1.2-0.20210311170414-40583563ada6/go.mod h1:+cr/21X4agxmBxMuztg/LOFrNBk9xLVJNq4UofWNWic= github.com/micro/micro/v3 v3.2.1-0.20210416134206-20d3a6b03014 h1:6yuX6VfXT8XZCK9PrFeh4KN/ZG7iaQeMbhl593C9SCE= github.com/micro/micro/v3 v3.2.1-0.20210416134206-20d3a6b03014/go.mod h1:UqfLMsy88SNqc31m7tNMQb6xLNGtsKkjJJFp3iHFXfs= github.com/miekg/dns v1.1.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= diff --git a/groups/handler/handler_test.go b/groups/handler/handler_test.go index 026409b..aa249c8 100644 --- a/groups/handler/handler_test.go +++ b/groups/handler/handler_test.go @@ -26,7 +26,7 @@ func testHandler(t *testing.T) *handler.Groups { } // clean any data from a previous run - if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_groups, micro_memberships CASCADE"); err != nil { + if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_groups", "micro_someID_memberships" CASCADE`); err != nil { t.Fatalf("Error cleaning database: %v", err) } @@ -291,5 +291,6 @@ func TestRemoveMember(t *testing.T) { func microAccountCtx() context.Context { return auth.ContextWithAccount(context.TODO(), &auth.Account{ Issuer: "micro", + ID: "someID", }) } diff --git a/invites/handler/invites_test.go b/invites/handler/invites_test.go index 375fcd7..6a869f3 100644 --- a/invites/handler/invites_test.go +++ b/invites/handler/invites_test.go @@ -27,7 +27,7 @@ func testHandler(t *testing.T) *handler.Invites { } // clean any data from a previous run - if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_invites CASCADE"); err != nil { + if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_invites" CASCADE`); err != nil { t.Fatalf("Error cleaning database: %v", err) } @@ -270,5 +270,6 @@ func assertInvitesMatch(t *testing.T, exp, act *pb.Invite) { func microAccountCtx() context.Context { return auth.ContextWithAccount(context.TODO(), &auth.Account{ Issuer: "micro", + ID: "someID", }) } diff --git a/pkg/gorm/wrapper.go b/pkg/gorm/wrapper.go index 3c10caa..3f9b54c 100644 --- a/pkg/gorm/wrapper.go +++ b/pkg/gorm/wrapper.go @@ -34,26 +34,35 @@ func (h *Helper) DBConn(conn *sql.DB) *Helper { return h } +func getTenancyKey(acc *auth.Account) string { + owner := acc.Metadata["apikey_owner"] + if len(owner) == 0 { + owner = acc.ID + } + return fmt.Sprintf("%s_%s", acc.Issuer, owner) +} + func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) { acc, ok := auth.AccountFromContext(ctx) if !ok { return nil, fmt.Errorf("missing account from context") } h.RLock() - if conn, ok := h.gormConns[acc.Issuer]; ok { + tenancyKey := getTenancyKey(acc) + if conn, ok := h.gormConns[tenancyKey]; ok { h.RUnlock() return conn, nil } h.RUnlock() h.Lock() // double check - if conn, ok := h.gormConns[acc.Issuer]; ok { + if conn, ok := h.gormConns[tenancyKey]; ok { h.Unlock() return conn, nil } defer h.Unlock() ns := schema.NamingStrategy{ - TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")), + TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(tenancyKey, "-", "")), } db, err := gorm.Open( newGormDialector(postgres.Config{ @@ -67,7 +76,7 @@ func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) { } if len(h.migrations) == 0 { // record success - h.gormConns[acc.Issuer] = db + h.gormConns[tenancyKey] = db return db, nil } @@ -76,7 +85,7 @@ func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) { } // record success - h.gormConns[acc.Issuer] = db + h.gormConns[tenancyKey] = db return db, nil } diff --git a/streams/handler/handler.go b/streams/handler/handler.go index a3f895c..a9d33c2 100644 --- a/streams/handler/handler.go +++ b/streams/handler/handler.go @@ -4,11 +4,12 @@ import ( "fmt" "time" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/events" + gorm2 "github.com/micro/services/pkg/gorm" "github.com/nats-io/nats-streaming-server/util" - "gorm.io/gorm" ) var ( @@ -26,18 +27,21 @@ type Token struct { Token string `gorm:"primaryKey"` Topic string ExpiresAt time.Time - Namespace string } type Streams struct { - DB *gorm.DB + gorm2.Helper Events events.Stream Time func() time.Time } // fmtTopic returns a topic string with namespace prefix -func fmtTopic(ns, topic string) string { - return fmt.Sprintf("%s.%s", ns, topic) +func fmtTopic(acc *auth.Account, topic string) string { + owner := acc.Metadata["apikey_owner"] + if len(owner) == 0 { + owner = acc.ID + } + return fmt.Sprintf("%s.%s.%s", acc.Issuer, owner, topic) } // validateTopicInput validates that topic is alphanumeric diff --git a/streams/handler/handler_test.go b/streams/handler/handler_test.go index 4270e3e..46e81f2 100644 --- a/streams/handler/handler_test.go +++ b/streams/handler/handler_test.go @@ -1,14 +1,13 @@ package handler_test import ( + "database/sql" "os" "testing" "time" "github.com/micro/micro/v3/service/events" "github.com/micro/services/streams/handler" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Streams { @@ -17,23 +16,24 @@ func testHandler(t *testing.T) *handler.Streams { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) + } + // clean any data from a previous run + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_users, micro_tokens CASCADE"); err != nil { + t.Fatalf("Error cleaning database: %v", err) } - // migrate the database - if err := db.AutoMigrate(&handler.Token{}); err != nil { - t.Fatalf("Error migrating database: %v", err) - } - - return &handler.Streams{ - DB: db, + h := &handler.Streams{ Events: new(eventsMock), Time: func() time.Time { return time.Unix(1612787045, 0) }, } + h.DBConn(sqlDB).Migrations(&handler.Token{}) + return h } type eventsMock struct { diff --git a/streams/handler/publish.go b/streams/handler/publish.go index 2b59447..a24f354 100644 --- a/streams/handler/publish.go +++ b/streams/handler/publish.go @@ -27,5 +27,5 @@ func (s *Streams) Publish(ctx context.Context, req *pb.Message, rsp *pb.PublishR // publish the message logger.Infof("Publishing message to topic: %v", req.Topic) - return s.Events.Publish(fmtTopic(acc.Issuer, req.Topic), req.Message) + return s.Events.Publish(fmtTopic(acc, req.Topic), req.Message) } diff --git a/streams/handler/publish_test.go b/streams/handler/publish_test.go index 631b57f..38aa3fa 100644 --- a/streams/handler/publish_test.go +++ b/streams/handler/publish_test.go @@ -34,7 +34,7 @@ func TestPublish(t *testing.T) { t.Run("ValidMessage", func(t *testing.T) { h := testHandler(t) - ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"}) + ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo", ID: "foo-id"}) err := h.Publish(ctx, &pb.Message{ Topic: topic, Message: msg, }, &pb.PublishResponse{}) @@ -42,6 +42,6 @@ func TestPublish(t *testing.T) { assert.Equal(t, 1, h.Events.(*eventsMock).PublishCount) assert.Equal(t, msg, h.Events.(*eventsMock).PublishMessage) // topic is prefixed with acc issuer to implement multitenancy - assert.Equal(t, "foo."+topic, h.Events.(*eventsMock).PublishTopic) + assert.Equal(t, "foo.foo-id."+topic, h.Events.(*eventsMock).PublishTopic) }) } diff --git a/streams/handler/subscribe.go b/streams/handler/subscribe.go index 512874f..d463cd1 100644 --- a/streams/handler/subscribe.go +++ b/streams/handler/subscribe.go @@ -34,7 +34,12 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea // find the token and check to see if it has expired var token Token - if err := s.DB.Where(&Token{Token: req.Token, Namespace: acc.Issuer}).First(&token).Error; err == gorm.ErrRecordNotFound { + dbConn, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error reading token from store: %v", err) + return errors.InternalServerError("DATABASE_ERROR", "Error reading token from database") + } + if err := dbConn.Where(&Token{Token: req.Token}).First(&token).Error; err == gorm.ErrRecordNotFound { return ErrInvalidToken } else if err != nil { logger.Errorf("Error reading token from store: %v", err) @@ -51,7 +56,7 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea // start the subscription logger.Infof("Subscribing to %v via queue %v", req.Topic, token.Token) - evChan, err := s.Events.Consume(fmtTopic(acc.Issuer, req.Topic), events.WithGroup(token.Token)) + evChan, err := s.Events.Consume(fmtTopic(acc, req.Topic), events.WithGroup(token.Token)) if err != nil { logger.Errorf("Error connecting to events stream: %v", err) return errors.InternalServerError("EVENTS_ERROR", "Error connecting to events stream") diff --git a/streams/handler/subscribe_test.go b/streams/handler/subscribe_test.go index d63c5ca..e1cfb55 100644 --- a/streams/handler/subscribe_test.go +++ b/streams/handler/subscribe_test.go @@ -108,7 +108,7 @@ func TestSubscribe(t *testing.T) { h.Events.(*eventsMock).ConsumeChan = c var tRsp pb.TokenResponse - ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"}) + ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{ID: "foo", Issuer: "my-ns"}) err := h.Token(ctx, &pb.TokenRequest{ Topic: "helloworld", }, &tRsp) @@ -158,7 +158,7 @@ func TestSubscribe(t *testing.T) { close(c) wg.Wait() assert.NoError(t, subsErr) - assert.Equal(t, "foo.helloworld", h.Events.(*eventsMock).ConsumeTopic) + assert.Equal(t, "my-ns.foo.helloworld", h.Events.(*eventsMock).ConsumeTopic) // sleep to wait for the subscribe loop to push the message to the stream //time.Sleep(1 * time.Second) diff --git a/streams/handler/token.go b/streams/handler/token.go index 65327ca..26fd929 100644 --- a/streams/handler/token.go +++ b/streams/handler/token.go @@ -11,7 +11,7 @@ import ( ) func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.TokenResponse) error { - acc, ok := auth.AccountFromContext(ctx) + _, ok := auth.AccountFromContext(ctx) if !ok { return errors.Unauthorized("UNAUTHORIZED", "Unauthorized") } @@ -26,9 +26,13 @@ func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.Token Token: uuid.New().String(), ExpiresAt: s.Time().Add(TokenTTL), Topic: req.Topic, - Namespace: acc.Issuer, } - if err := s.DB.Create(&t).Error; err != nil { + dbConn, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error creating token in store: %v", err) + return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database") + } + if err := dbConn.Create(&t).Error; err != nil { logger.Errorf("Error creating token in store: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database") } diff --git a/streams/main.go b/streams/main.go index 1d20c92..dff5fbd 100644 --- a/streams/main.go +++ b/streams/main.go @@ -1,12 +1,11 @@ package main import ( + "database/sql" "time" "github.com/micro/services/streams/handler" pb "github.com/micro/services/streams/proto" - "gorm.io/driver/postgres" - "gorm.io/gorm" "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" @@ -29,20 +28,19 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.Token{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } - // Register handler - pb.RegisterStreamsHandler(srv.Server(), &handler.Streams{ - DB: db, + h := &handler.Streams{ Events: events.DefaultStream, Time: time.Now, - }) + } + h.DBConn(sqlDB).Migrations(&handler.Token{}) + + // Register handler + pb.RegisterStreamsHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil { diff --git a/users/handler/handler_test.go b/users/handler/handler_test.go index 91778c8..3dfbad7 100644 --- a/users/handler/handler_test.go +++ b/users/handler/handler_test.go @@ -27,7 +27,7 @@ func testHandler(t *testing.T) *handler.Users { t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_users, micro_tokens CASCADE"); err != nil { + if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_users", "micro_someID_tokens" CASCADE`); err != nil { t.Fatalf("Error cleaning database: %v", err) } @@ -50,5 +50,6 @@ func assertUsersMatch(t *testing.T, exp, act *pb.User) { func microAccountCtx() context.Context { return auth.ContextWithAccount(context.TODO(), &auth.Account{ Issuer: "micro", + ID: "someID", }) }