make multi tenancy based on namespace and ID of account (#85)

* make multi tenancy based on namespace and ID of account not just namespace

* fix tests
This commit is contained in:
Dominic Wong
2021-04-21 15:09:38 +01:00
committed by GitHub
parent 4361ef6d90
commit 7ae45b522e
13 changed files with 68 additions and 47 deletions

2
go.sum
View File

@@ -392,8 +392,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5
github.com/micro/dev v0.0.0-20201117163752-d3cfc9788dfa h1:1BoFPE4/NTF7WKLZWsEFImOsN143QAU7Dkw9J2/qFXA= github.com/micro/dev v0.0.0-20201117163752-d3cfc9788dfa h1:1BoFPE4/NTF7WKLZWsEFImOsN143QAU7Dkw9J2/qFXA=
github.com/micro/dev v0.0.0-20201117163752-d3cfc9788dfa/go.mod h1:j/8E+ezN/ij7a9BXBHMKmLayFfUW1O4h/Owdv67B0X0= github.com/micro/dev v0.0.0-20201117163752-d3cfc9788dfa/go.mod h1:j/8E+ezN/ij7a9BXBHMKmLayFfUW1O4h/Owdv67B0X0=
github.com/micro/micro/v3 v3.0.0-beta.6.0.20201016094841-ca8ffd563b2b/go.mod h1:RPJTp9meQAppzW/9jgQtfJmPpRJAySVPbz9uur4B3Ko= github.com/micro/micro/v3 v3.0.0-beta.6.0.20201016094841-ca8ffd563b2b/go.mod h1:RPJTp9meQAppzW/9jgQtfJmPpRJAySVPbz9uur4B3Ko=
github.com/micro/micro/v3 v3.1.2-0.20210311170414-40583563ada6 h1:uilKEf27gjxx/ZL0wdRGsKmjZJtm/PNEyG/RT4L0pzw=
github.com/micro/micro/v3 v3.1.2-0.20210311170414-40583563ada6/go.mod h1:+cr/21X4agxmBxMuztg/LOFrNBk9xLVJNq4UofWNWic=
github.com/micro/micro/v3 v3.2.1-0.20210416134206-20d3a6b03014 h1:6yuX6VfXT8XZCK9PrFeh4KN/ZG7iaQeMbhl593C9SCE= github.com/micro/micro/v3 v3.2.1-0.20210416134206-20d3a6b03014 h1:6yuX6VfXT8XZCK9PrFeh4KN/ZG7iaQeMbhl593C9SCE=
github.com/micro/micro/v3 v3.2.1-0.20210416134206-20d3a6b03014/go.mod h1:UqfLMsy88SNqc31m7tNMQb6xLNGtsKkjJJFp3iHFXfs= github.com/micro/micro/v3 v3.2.1-0.20210416134206-20d3a6b03014/go.mod h1:UqfLMsy88SNqc31m7tNMQb6xLNGtsKkjJJFp3iHFXfs=
github.com/miekg/dns v1.1.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=

View File

@@ -26,7 +26,7 @@ func testHandler(t *testing.T) *handler.Groups {
} }
// clean any data from a previous run // clean any data from a previous run
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_groups, micro_memberships CASCADE"); err != nil { if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_groups", "micro_someID_memberships" CASCADE`); err != nil {
t.Fatalf("Error cleaning database: %v", err) t.Fatalf("Error cleaning database: %v", err)
} }
@@ -291,5 +291,6 @@ func TestRemoveMember(t *testing.T) {
func microAccountCtx() context.Context { func microAccountCtx() context.Context {
return auth.ContextWithAccount(context.TODO(), &auth.Account{ return auth.ContextWithAccount(context.TODO(), &auth.Account{
Issuer: "micro", Issuer: "micro",
ID: "someID",
}) })
} }

View File

@@ -27,7 +27,7 @@ func testHandler(t *testing.T) *handler.Invites {
} }
// clean any data from a previous run // clean any data from a previous run
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_invites CASCADE"); err != nil { if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_invites" CASCADE`); err != nil {
t.Fatalf("Error cleaning database: %v", err) t.Fatalf("Error cleaning database: %v", err)
} }
@@ -270,5 +270,6 @@ func assertInvitesMatch(t *testing.T, exp, act *pb.Invite) {
func microAccountCtx() context.Context { func microAccountCtx() context.Context {
return auth.ContextWithAccount(context.TODO(), &auth.Account{ return auth.ContextWithAccount(context.TODO(), &auth.Account{
Issuer: "micro", Issuer: "micro",
ID: "someID",
}) })
} }

View File

@@ -34,26 +34,35 @@ func (h *Helper) DBConn(conn *sql.DB) *Helper {
return h return h
} }
func getTenancyKey(acc *auth.Account) string {
owner := acc.Metadata["apikey_owner"]
if len(owner) == 0 {
owner = acc.ID
}
return fmt.Sprintf("%s_%s", acc.Issuer, owner)
}
func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) { func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
acc, ok := auth.AccountFromContext(ctx) acc, ok := auth.AccountFromContext(ctx)
if !ok { if !ok {
return nil, fmt.Errorf("missing account from context") return nil, fmt.Errorf("missing account from context")
} }
h.RLock() h.RLock()
if conn, ok := h.gormConns[acc.Issuer]; ok { tenancyKey := getTenancyKey(acc)
if conn, ok := h.gormConns[tenancyKey]; ok {
h.RUnlock() h.RUnlock()
return conn, nil return conn, nil
} }
h.RUnlock() h.RUnlock()
h.Lock() h.Lock()
// double check // double check
if conn, ok := h.gormConns[acc.Issuer]; ok { if conn, ok := h.gormConns[tenancyKey]; ok {
h.Unlock() h.Unlock()
return conn, nil return conn, nil
} }
defer h.Unlock() defer h.Unlock()
ns := schema.NamingStrategy{ ns := schema.NamingStrategy{
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")), TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(tenancyKey, "-", "")),
} }
db, err := gorm.Open( db, err := gorm.Open(
newGormDialector(postgres.Config{ newGormDialector(postgres.Config{
@@ -67,7 +76,7 @@ func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
} }
if len(h.migrations) == 0 { if len(h.migrations) == 0 {
// record success // record success
h.gormConns[acc.Issuer] = db h.gormConns[tenancyKey] = db
return db, nil return db, nil
} }
@@ -76,7 +85,7 @@ func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
} }
// record success // record success
h.gormConns[acc.Issuer] = db h.gormConns[tenancyKey] = db
return db, nil return db, nil
} }

View File

@@ -4,11 +4,12 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/events" "github.com/micro/micro/v3/service/events"
gorm2 "github.com/micro/services/pkg/gorm"
"github.com/nats-io/nats-streaming-server/util" "github.com/nats-io/nats-streaming-server/util"
"gorm.io/gorm"
) )
var ( var (
@@ -26,18 +27,21 @@ type Token struct {
Token string `gorm:"primaryKey"` Token string `gorm:"primaryKey"`
Topic string Topic string
ExpiresAt time.Time ExpiresAt time.Time
Namespace string
} }
type Streams struct { type Streams struct {
DB *gorm.DB gorm2.Helper
Events events.Stream Events events.Stream
Time func() time.Time Time func() time.Time
} }
// fmtTopic returns a topic string with namespace prefix // fmtTopic returns a topic string with namespace prefix
func fmtTopic(ns, topic string) string { func fmtTopic(acc *auth.Account, topic string) string {
return fmt.Sprintf("%s.%s", ns, topic) owner := acc.Metadata["apikey_owner"]
if len(owner) == 0 {
owner = acc.ID
}
return fmt.Sprintf("%s.%s.%s", acc.Issuer, owner, topic)
} }
// validateTopicInput validates that topic is alphanumeric // validateTopicInput validates that topic is alphanumeric

View File

@@ -1,14 +1,13 @@
package handler_test package handler_test
import ( import (
"database/sql"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/micro/micro/v3/service/events" "github.com/micro/micro/v3/service/events"
"github.com/micro/services/streams/handler" "github.com/micro/services/streams/handler"
"gorm.io/driver/postgres"
"gorm.io/gorm"
) )
func testHandler(t *testing.T) *handler.Streams { func testHandler(t *testing.T) *handler.Streams {
@@ -17,23 +16,24 @@ func testHandler(t *testing.T) *handler.Streams {
if len(addr) == 0 { if len(addr) == 0 {
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
} }
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
sqlDB, err := sql.Open("pgx", addr)
if err != nil { if err != nil {
t.Fatalf("Error connecting to database: %v", err) t.Fatalf("Failed to open connection to DB %s", err)
}
// clean any data from a previous run
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_users, micro_tokens CASCADE"); err != nil {
t.Fatalf("Error cleaning database: %v", err)
} }
// migrate the database h := &handler.Streams{
if err := db.AutoMigrate(&handler.Token{}); err != nil {
t.Fatalf("Error migrating database: %v", err)
}
return &handler.Streams{
DB: db,
Events: new(eventsMock), Events: new(eventsMock),
Time: func() time.Time { Time: func() time.Time {
return time.Unix(1612787045, 0) return time.Unix(1612787045, 0)
}, },
} }
h.DBConn(sqlDB).Migrations(&handler.Token{})
return h
} }
type eventsMock struct { type eventsMock struct {

View File

@@ -27,5 +27,5 @@ func (s *Streams) Publish(ctx context.Context, req *pb.Message, rsp *pb.PublishR
// publish the message // publish the message
logger.Infof("Publishing message to topic: %v", req.Topic) logger.Infof("Publishing message to topic: %v", req.Topic)
return s.Events.Publish(fmtTopic(acc.Issuer, req.Topic), req.Message) return s.Events.Publish(fmtTopic(acc, req.Topic), req.Message)
} }

View File

@@ -34,7 +34,7 @@ func TestPublish(t *testing.T) {
t.Run("ValidMessage", func(t *testing.T) { t.Run("ValidMessage", func(t *testing.T) {
h := testHandler(t) h := testHandler(t)
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"}) ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo", ID: "foo-id"})
err := h.Publish(ctx, &pb.Message{ err := h.Publish(ctx, &pb.Message{
Topic: topic, Message: msg, Topic: topic, Message: msg,
}, &pb.PublishResponse{}) }, &pb.PublishResponse{})
@@ -42,6 +42,6 @@ func TestPublish(t *testing.T) {
assert.Equal(t, 1, h.Events.(*eventsMock).PublishCount) assert.Equal(t, 1, h.Events.(*eventsMock).PublishCount)
assert.Equal(t, msg, h.Events.(*eventsMock).PublishMessage) assert.Equal(t, msg, h.Events.(*eventsMock).PublishMessage)
// topic is prefixed with acc issuer to implement multitenancy // topic is prefixed with acc issuer to implement multitenancy
assert.Equal(t, "foo."+topic, h.Events.(*eventsMock).PublishTopic) assert.Equal(t, "foo.foo-id."+topic, h.Events.(*eventsMock).PublishTopic)
}) })
} }

View File

@@ -34,7 +34,12 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea
// find the token and check to see if it has expired // find the token and check to see if it has expired
var token Token var token Token
if err := s.DB.Where(&Token{Token: req.Token, Namespace: acc.Issuer}).First(&token).Error; err == gorm.ErrRecordNotFound { dbConn, err := s.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error reading token from store: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error reading token from database")
}
if err := dbConn.Where(&Token{Token: req.Token}).First(&token).Error; err == gorm.ErrRecordNotFound {
return ErrInvalidToken return ErrInvalidToken
} else if err != nil { } else if err != nil {
logger.Errorf("Error reading token from store: %v", err) logger.Errorf("Error reading token from store: %v", err)
@@ -51,7 +56,7 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea
// start the subscription // start the subscription
logger.Infof("Subscribing to %v via queue %v", req.Topic, token.Token) logger.Infof("Subscribing to %v via queue %v", req.Topic, token.Token)
evChan, err := s.Events.Consume(fmtTopic(acc.Issuer, req.Topic), events.WithGroup(token.Token)) evChan, err := s.Events.Consume(fmtTopic(acc, req.Topic), events.WithGroup(token.Token))
if err != nil { if err != nil {
logger.Errorf("Error connecting to events stream: %v", err) logger.Errorf("Error connecting to events stream: %v", err)
return errors.InternalServerError("EVENTS_ERROR", "Error connecting to events stream") return errors.InternalServerError("EVENTS_ERROR", "Error connecting to events stream")

View File

@@ -108,7 +108,7 @@ func TestSubscribe(t *testing.T) {
h.Events.(*eventsMock).ConsumeChan = c h.Events.(*eventsMock).ConsumeChan = c
var tRsp pb.TokenResponse var tRsp pb.TokenResponse
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"}) ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{ID: "foo", Issuer: "my-ns"})
err := h.Token(ctx, &pb.TokenRequest{ err := h.Token(ctx, &pb.TokenRequest{
Topic: "helloworld", Topic: "helloworld",
}, &tRsp) }, &tRsp)
@@ -158,7 +158,7 @@ func TestSubscribe(t *testing.T) {
close(c) close(c)
wg.Wait() wg.Wait()
assert.NoError(t, subsErr) assert.NoError(t, subsErr)
assert.Equal(t, "foo.helloworld", h.Events.(*eventsMock).ConsumeTopic) assert.Equal(t, "my-ns.foo.helloworld", h.Events.(*eventsMock).ConsumeTopic)
// sleep to wait for the subscribe loop to push the message to the stream // sleep to wait for the subscribe loop to push the message to the stream
//time.Sleep(1 * time.Second) //time.Sleep(1 * time.Second)

View File

@@ -11,7 +11,7 @@ import (
) )
func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.TokenResponse) error { func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.TokenResponse) error {
acc, ok := auth.AccountFromContext(ctx) _, ok := auth.AccountFromContext(ctx)
if !ok { if !ok {
return errors.Unauthorized("UNAUTHORIZED", "Unauthorized") return errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
} }
@@ -26,9 +26,13 @@ func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.Token
Token: uuid.New().String(), Token: uuid.New().String(),
ExpiresAt: s.Time().Add(TokenTTL), ExpiresAt: s.Time().Add(TokenTTL),
Topic: req.Topic, Topic: req.Topic,
Namespace: acc.Issuer,
} }
if err := s.DB.Create(&t).Error; err != nil { dbConn, err := s.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error creating token in store: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database")
}
if err := dbConn.Create(&t).Error; err != nil {
logger.Errorf("Error creating token in store: %v", err) logger.Errorf("Error creating token in store: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database") return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database")
} }

View File

@@ -1,12 +1,11 @@
package main package main
import ( import (
"database/sql"
"time" "time"
"github.com/micro/services/streams/handler" "github.com/micro/services/streams/handler"
pb "github.com/micro/services/streams/proto" pb "github.com/micro/services/streams/proto"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service"
"github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/config"
@@ -29,20 +28,19 @@ func main() {
logger.Fatalf("Error loading config: %v", err) logger.Fatalf("Error loading config: %v", err)
} }
addr := cfg.String(dbAddress) addr := cfg.String(dbAddress)
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) sqlDB, err := sql.Open("pgx", addr)
if err != nil { if err != nil {
logger.Fatalf("Error connecting to database: %v", err) logger.Fatalf("Failed to open connection to DB %s", err)
}
if err := db.AutoMigrate(&handler.Token{}); err != nil {
logger.Fatalf("Error migrating database: %v", err)
} }
// Register handler h := &handler.Streams{
pb.RegisterStreamsHandler(srv.Server(), &handler.Streams{
DB: db,
Events: events.DefaultStream, Events: events.DefaultStream,
Time: time.Now, Time: time.Now,
}) }
h.DBConn(sqlDB).Migrations(&handler.Token{})
// Register handler
pb.RegisterStreamsHandler(srv.Server(), h)
// Run service // Run service
if err := srv.Run(); err != nil { if err := srv.Run(); err != nil {

View File

@@ -27,7 +27,7 @@ func testHandler(t *testing.T) *handler.Users {
t.Fatalf("Failed to open connection to DB %s", err) t.Fatalf("Failed to open connection to DB %s", err)
} }
// clean any data from a previous run // clean any data from a previous run
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_users, micro_tokens CASCADE"); err != nil { if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_users", "micro_someID_tokens" CASCADE`); err != nil {
t.Fatalf("Error cleaning database: %v", err) t.Fatalf("Error cleaning database: %v", err)
} }
@@ -50,5 +50,6 @@ func assertUsersMatch(t *testing.T, exp, act *pb.User) {
func microAccountCtx() context.Context { func microAccountCtx() context.Context {
return auth.ContextWithAccount(context.TODO(), &auth.Account{ return auth.ContextWithAccount(context.TODO(), &auth.Account{
Issuer: "micro", Issuer: "micro",
ID: "someID",
}) })
} }