From 88085c7044bc69e776f27191134fd2b1aabb8951 Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Thu, 25 Mar 2021 21:34:57 +0000 Subject: [PATCH] multi tenant chats (#79) --- chats/handler/chats.go | 4 ++-- chats/handler/chats_test.go | 30 ++++++++++++++++------------ chats/handler/create_chat.go | 19 ++++++++++++++---- chats/handler/create_chat_test.go | 3 +-- chats/handler/create_message.go | 16 ++++++++++++--- chats/handler/create_message_test.go | 5 ++--- chats/handler/list_messages.go | 12 ++++++++++- chats/handler/list_messages_test.go | 13 ++++++------ chats/main.go | 16 +++++++-------- 9 files changed, 75 insertions(+), 43 deletions(-) diff --git a/chats/handler/chats.go b/chats/handler/chats.go index 0bc0dc2..d2745b3 100644 --- a/chats/handler/chats.go +++ b/chats/handler/chats.go @@ -5,10 +5,10 @@ import ( "time" pb "github.com/micro/services/chats/proto" + "github.com/micro/services/pkg/gorm" "github.com/micro/micro/v3/service/errors" "google.golang.org/protobuf/types/known/timestamppb" - "gorm.io/gorm" ) var ( @@ -21,7 +21,7 @@ var ( ) type Chats struct { - DB *gorm.DB + gorm.Helper Time func() time.Time } diff --git a/chats/handler/chats_test.go b/chats/handler/chats_test.go index 8e8da8e..63ad3ae 100644 --- a/chats/handler/chats_test.go +++ b/chats/handler/chats_test.go @@ -1,17 +1,18 @@ package handler_test import ( + "context" + "database/sql" "os" "testing" "time" + "github.com/micro/micro/v3/service/auth" "github.com/micro/services/chats/handler" pb "github.com/micro/services/chats/proto" "github.com/stretchr/testify/assert" "github.com/golang/protobuf/ptypes/timestamp" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Chats { @@ -20,22 +21,19 @@ func testHandler(t *testing.T) *handler.Chats { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if err := db.Exec("DROP TABLE IF EXISTS chats, messages CASCADE").Error; err != nil { + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_chats, micro_messages CASCADE"); err != nil { t.Fatalf("Error cleaning database: %v", err) } - // migrate the database - if err := db.AutoMigrate(&handler.Chat{}, &handler.Message{}); err != nil { - t.Fatalf("Error migrating database: %v", err) - } - - return &handler.Chats{DB: db, Time: func() time.Time { return time.Unix(1611327673, 0) }} + h := &handler.Chats{Time: func() time.Time { return time.Unix(1611327673, 0) }} + h.DBConn(sqlDB).Migrations(&handler.Chat{}, &handler.Message{}) + return h } func assertChatsMatch(t *testing.T, exp, act *pb.Chat) { @@ -64,8 +62,8 @@ func assertChatsMatch(t *testing.T, exp, act *pb.Chat) { // postgres has a resolution of 100microseconds so just test that it's accurate to the second func microSecondTime(t *timestamp.Timestamp) time.Time { - tt:=t.AsTime() - return time.Unix(tt.Unix(), int64( tt.Nanosecond() - tt.Nanosecond() % 1000)) + tt := t.AsTime() + return time.Unix(tt.Unix(), int64(tt.Nanosecond()-tt.Nanosecond()%1000)) } func assertMessagesMatch(t *testing.T, exp, act *pb.Message) { @@ -91,3 +89,9 @@ func assertMessagesMatch(t *testing.T, exp, act *pb.Message) { } assert.True(t, microSecondTime(exp.SentAt).Equal(microSecondTime(act.SentAt))) } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/chats/handler/create_chat.go b/chats/handler/create_chat.go index ee421a3..0c333d1 100644 --- a/chats/handler/create_chat.go +++ b/chats/handler/create_chat.go @@ -3,11 +3,12 @@ package handler import ( "context" "encoding/json" + "regexp" "sort" - "strings" "time" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/chats/proto" @@ -16,6 +17,10 @@ import ( // Create a chat between two or more users, if a chat already exists for these users, the existing // chat will be returned func (c *Chats) CreateChat(ctx context.Context, req *pb.CreateChatRequest, rsp *pb.CreateChatResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.UserIds) < 2 { return ErrMissingUserIDs @@ -36,19 +41,25 @@ func (c *Chats) CreateChat(ctx context.Context, req *pb.CreateChatRequest, rsp * UserIDs: string(bytes), } + db, err := c.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // write to the database, if we get a unique key error, the chat already exists - err = c.DB.Create(&chat).Error + err = db.Create(&chat).Error if err == nil { rsp.Chat = chat.Serialize() return nil } - if !strings.Contains(err.Error(), "idx_chats_user_ids") { + + if match, _ := regexp.MatchString(`idx_[\S]+_chats_user_ids`, err.Error()); !match { logger.Errorf("Error creating chat: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } var existing Chat - if err := c.DB.Where(&Chat{UserIDs: chat.UserIDs}).First(&existing).Error; err != nil { + if err := db.Where(&Chat{UserIDs: chat.UserIDs}).First(&existing).Error; err != nil { logger.Errorf("Error reading chat: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } diff --git a/chats/handler/create_chat_test.go b/chats/handler/create_chat_test.go index e98689f..f3968d7 100644 --- a/chats/handler/create_chat_test.go +++ b/chats/handler/create_chat_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/google/uuid" @@ -43,7 +42,7 @@ func TestCreateChat(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.CreateChatResponse - err := h.CreateChat(context.TODO(), &pb.CreateChatRequest{ + err := h.CreateChat(microAccountCtx(), &pb.CreateChatRequest{ UserIds: tc.UserIDs, }, &rsp) diff --git a/chats/handler/create_message.go b/chats/handler/create_message.go index 003597a..ed3fb97 100644 --- a/chats/handler/create_message.go +++ b/chats/handler/create_message.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/chats/proto" @@ -13,6 +14,10 @@ import ( // Create a message within a chat func (c *Chats) CreateMessage(ctx context.Context, req *pb.CreateMessageRequest, rsp *pb.CreateMessageResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.AuthorId) == 0 { return ErrMissingAuthorID @@ -24,9 +29,14 @@ func (c *Chats) CreateMessage(ctx context.Context, req *pb.CreateMessageRequest, return ErrMissingText } + db, err := c.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // lookup the chat var conv Chat - if err := c.DB.Where(&Chat{ID: req.ChatId}).First(&conv).Error; err == gorm.ErrRecordNotFound { + if err := db.Where(&Chat{ID: req.ChatId}).First(&conv).Error; err == gorm.ErrRecordNotFound { return ErrNotFound } else if err != nil { logger.Errorf("Error reading chat: %v", err) @@ -44,7 +54,7 @@ func (c *Chats) CreateMessage(ctx context.Context, req *pb.CreateMessageRequest, if len(msg.ID) == 0 { msg.ID = uuid.New().String() } - if err := c.DB.Create(msg).Error; err == nil { + if err := db.Create(msg).Error; err == nil { rsp.Message = msg.Serialize() return nil } else if !strings.Contains(err.Error(), "messages_pkey") { @@ -54,7 +64,7 @@ func (c *Chats) CreateMessage(ctx context.Context, req *pb.CreateMessageRequest, // a message already exists with this id var existing Message - if err := c.DB.Where(&Message{ID: msg.ID}).First(&existing).Error; err != nil { + if err := db.Where(&Message{ID: msg.ID}).First(&existing).Error; err != nil { logger.Errorf("Error creating message: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } diff --git a/chats/handler/create_message_test.go b/chats/handler/create_message_test.go index 623eec1..c2caa43 100644 --- a/chats/handler/create_message_test.go +++ b/chats/handler/create_message_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/chats/handler" @@ -17,7 +16,7 @@ func TestCreateMessage(t *testing.T) { // seed some data var cRsp pb.CreateChatResponse - err := h.CreateChat(context.TODO(), &pb.CreateChatRequest{ + err := h.CreateChat(microAccountCtx(), &pb.CreateChatRequest{ UserIds: []string{uuid.New().String(), uuid.New().String()}, }, &cRsp) if err != nil { @@ -84,7 +83,7 @@ func TestCreateMessage(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.CreateMessageResponse - err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{ + err := h.CreateMessage(microAccountCtx(), &pb.CreateMessageRequest{ AuthorId: tc.AuthorID, ChatId: tc.ChatID, Text: tc.Text, diff --git a/chats/handler/list_messages.go b/chats/handler/list_messages.go index 7ac8acd..8831d31 100644 --- a/chats/handler/list_messages.go +++ b/chats/handler/list_messages.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/chats/proto" @@ -13,13 +14,22 @@ const DefaultLimit = 25 // List the messages within a chat in reverse chronological order, using sent_before to // offset as older messages need to be loaded func (c *Chats) ListMessages(ctx context.Context, req *pb.ListMessagesRequest, rsp *pb.ListMessagesResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.ChatId) == 0 { return ErrMissingChatID } + db, err := c.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // construct the query - q := c.DB.Where(&Message{ChatID: req.ChatId}).Order("sent_at DESC") + q := db.Where(&Message{ChatID: req.ChatId}).Order("sent_at DESC") if req.SentBefore != nil { q = q.Where("sent_at < ?", req.SentBefore.AsTime()) } diff --git a/chats/handler/list_messages_test.go b/chats/handler/list_messages_test.go index f47e71a..2570e6e 100644 --- a/chats/handler/list_messages_test.go +++ b/chats/handler/list_messages_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "sort" "strconv" "testing" @@ -20,7 +19,7 @@ func TestListMessages(t *testing.T) { // seed some data var chatRsp pb.CreateChatResponse - err := h.CreateChat(context.TODO(), &pb.CreateChatRequest{ + err := h.CreateChat(microAccountCtx(), &pb.CreateChatRequest{ UserIds: []string{uuid.New().String(), uuid.New().String()}, }, &chatRsp) assert.NoError(t, err) @@ -31,7 +30,7 @@ func TestListMessages(t *testing.T) { msgs := make([]*pb.Message, 50) for i := 0; i < len(msgs); i++ { var rsp pb.CreateMessageResponse - err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{ + err := h.CreateMessage(microAccountCtx(), &pb.CreateMessageRequest{ ChatId: chatRsp.Chat.Id, AuthorId: uuid.New().String(), Text: strconv.Itoa(i), @@ -42,14 +41,14 @@ func TestListMessages(t *testing.T) { t.Run("MissingChatID", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{}, &rsp) + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{}, &rsp) assert.Equal(t, handler.ErrMissingChatID, err) assert.Nil(t, rsp.Messages) }) t.Run("NoOffset", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{ + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{ ChatId: chatRsp.Chat.Id, }, &rsp) assert.NoError(t, err) @@ -67,7 +66,7 @@ func TestListMessages(t *testing.T) { t.Run("LimitSet", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{ + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{ ChatId: chatRsp.Chat.Id, Limit: &wrapperspb.Int32Value{Value: 10}, }, &rsp) @@ -86,7 +85,7 @@ func TestListMessages(t *testing.T) { t.Run("OffsetAndLimit", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{ + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{ ChatId: chatRsp.Chat.Id, Limit: &wrapperspb.Int32Value{Value: 5}, SentBefore: msgs[20].SentAt, diff --git a/chats/main.go b/chats/main.go index a9b60ec..93732fc 100644 --- a/chats/main.go +++ b/chats/main.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "time" "github.com/micro/services/chats/handler" @@ -9,8 +10,8 @@ import ( "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" - "gorm.io/driver/postgres" - "gorm.io/gorm" + + _ "github.com/jackc/pgx/v4/stdlib" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/chats?sslmode=disable" @@ -28,16 +29,15 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.Chat{}, &handler.Message{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } + h := &handler.Chats{Time: time.Now} + h.DBConn(sqlDB).Migrations(&handler.Chat{}, &handler.Message{}) // Register handler - pb.RegisterChatsHandler(srv.Server(), &handler.Chats{DB: db, Time: time.Now}) + pb.RegisterChatsHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil {