From 8e38c8b83465be5a50d0c353a5dd9f3110493784 Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Thu, 25 Mar 2021 17:33:20 +0000 Subject: [PATCH] multi tenant threads (#78) * multitenant threads * auth for v1 --- groups/main.go | 2 ++ threads/handler/create_conversation.go | 12 +++++++- threads/handler/create_conversation_test.go | 3 +- threads/handler/create_message.go | 16 +++++++++-- threads/handler/create_message_test.go | 5 ++-- threads/handler/delete_conversation.go | 12 +++++++- threads/handler/delete_conversation_test.go | 11 ++++---- threads/handler/list_conversations.go | 12 +++++++- threads/handler/list_conversations_test.go | 9 +++--- threads/handler/list_messages.go | 12 +++++++- threads/handler/list_messages_test.go | 13 ++++----- threads/handler/read_conversation.go | 12 +++++++- threads/handler/read_conversation_test.go | 5 ++-- threads/handler/recent_messages.go | 12 +++++++- threads/handler/recent_messages_test.go | 11 ++++---- threads/handler/streams.go | 4 +-- threads/handler/streams_test.go | 31 ++++++++++----------- threads/handler/update_conversation.go | 14 ++++++++-- threads/handler/update_conversation_test.go | 13 ++++----- threads/main.go | 16 +++++------ 20 files changed, 149 insertions(+), 76 deletions(-) diff --git a/groups/main.go b/groups/main.go index f4243da..b910377 100644 --- a/groups/main.go +++ b/groups/main.go @@ -8,6 +8,8 @@ import ( "github.com/micro/micro/v3/service/logger" "github.com/micro/services/groups/handler" pb "github.com/micro/services/groups/proto" + + _ "github.com/jackc/pgx/v4/stdlib" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/groups?sslmode=disable" diff --git a/threads/handler/create_conversation.go b/threads/handler/create_conversation.go index 4ea9b2b..75a56f4 100644 --- a/threads/handler/create_conversation.go +++ b/threads/handler/create_conversation.go @@ -4,6 +4,7 @@ import ( "context" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -11,6 +12,10 @@ import ( // Create a conversation func (s *Threads) CreateConversation(ctx context.Context, req *pb.CreateConversationRequest, rsp *pb.CreateConversationResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.GroupId) == 0 { return ErrMissingGroupID @@ -26,7 +31,12 @@ func (s *Threads) CreateConversation(ctx context.Context, req *pb.CreateConversa GroupID: req.GroupId, CreatedAt: s.Time(), } - if err := s.DB.Create(conv).Error; err != nil { + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + if err := db.Create(conv).Error; err != nil { logger.Errorf("Error creating conversation: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } diff --git a/threads/handler/create_conversation_test.go b/threads/handler/create_conversation_test.go index 3c2db53..7079dc4 100644 --- a/threads/handler/create_conversation_test.go +++ b/threads/handler/create_conversation_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/threads/handler" @@ -40,7 +39,7 @@ func TestCreateConversation(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: tc.Topic, GroupId: tc.GroupID, }, &rsp) diff --git a/threads/handler/create_message.go b/threads/handler/create_message.go index a03a492..8e43168 100644 --- a/threads/handler/create_message.go +++ b/threads/handler/create_message.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -13,6 +14,10 @@ import ( // Create a message within a conversation func (s *Threads) CreateMessage(ctx context.Context, req *pb.CreateMessageRequest, rsp *pb.CreateMessageResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.AuthorId) == 0 { return ErrMissingAuthorID @@ -24,9 +29,14 @@ func (s *Threads) CreateMessage(ctx context.Context, req *pb.CreateMessageReques return ErrMissingText } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // lookup the conversation var conv Conversation - if err := s.DB.Where(&Conversation{ID: req.ConversationId}).First(&conv).Error; err == gorm.ErrRecordNotFound { + if err := db.Where(&Conversation{ID: req.ConversationId}).First(&conv).Error; err == gorm.ErrRecordNotFound { return ErrNotFound } else if err != nil { logger.Errorf("Error reading conversation: %v", err) @@ -44,7 +54,7 @@ func (s *Threads) CreateMessage(ctx context.Context, req *pb.CreateMessageReques if len(msg.ID) == 0 { msg.ID = uuid.New().String() } - if err := s.DB.Create(msg).Error; err == nil { + if err := db.Create(msg).Error; err == nil { rsp.Message = msg.Serialize() return nil } else if !strings.Contains(err.Error(), "messages_pkey") { @@ -54,7 +64,7 @@ func (s *Threads) CreateMessage(ctx context.Context, req *pb.CreateMessageReques // a message already exists with this id var existing Message - if err := s.DB.Where(&Message{ID: msg.ID}).First(&existing).Error; err != nil { + if err := db.Where(&Message{ID: msg.ID}).First(&existing).Error; err != nil { logger.Errorf("Error creating message: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } diff --git a/threads/handler/create_message_test.go b/threads/handler/create_message_test.go index 5bfea50..d02c146 100644 --- a/threads/handler/create_message_test.go +++ b/threads/handler/create_message_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/threads/handler" @@ -17,7 +16,7 @@ func TestCreateMessage(t *testing.T) { // seed some data var cRsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "HelloWorld", GroupId: uuid.New().String(), }, &cRsp) if err != nil { @@ -84,7 +83,7 @@ func TestCreateMessage(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.CreateMessageResponse - err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{ + err := h.CreateMessage(microAccountCtx(), &pb.CreateMessageRequest{ AuthorId: tc.AuthorID, ConversationId: tc.ConversationID, Text: tc.Text, diff --git a/threads/handler/delete_conversation.go b/threads/handler/delete_conversation.go index b8d79b0..3762592 100644 --- a/threads/handler/delete_conversation.go +++ b/threads/handler/delete_conversation.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -11,12 +12,21 @@ import ( // Delete a conversation and all the messages within func (s *Threads) DeleteConversation(ctx context.Context, req *pb.DeleteConversationRequest, rsp *pb.DeleteConversationResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } - return s.DB.Transaction(func(tx *gorm.DB) error { + return db.Transaction(func(tx *gorm.DB) error { // delete all the messages if err := tx.Where(&Message{ConversationID: req.Id}).Delete(&Message{}).Error; err != nil { logger.Errorf("Error deleting messages: %v", err) diff --git a/threads/handler/delete_conversation_test.go b/threads/handler/delete_conversation_test.go index ae16885..ef89010 100644 --- a/threads/handler/delete_conversation_test.go +++ b/threads/handler/delete_conversation_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/threads/handler" @@ -16,7 +15,7 @@ func TestDeleteConversation(t *testing.T) { // seed some data var cRsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "HelloWorld", GroupId: uuid.New().String(), }, &cRsp) if err != nil { @@ -25,24 +24,24 @@ func TestDeleteConversation(t *testing.T) { } t.Run("MissingID", func(t *testing.T) { - err := h.DeleteConversation(context.TODO(), &pb.DeleteConversationRequest{}, &pb.DeleteConversationResponse{}) + err := h.DeleteConversation(microAccountCtx(), &pb.DeleteConversationRequest{}, &pb.DeleteConversationResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) t.Run("Valid", func(t *testing.T) { - err := h.DeleteConversation(context.TODO(), &pb.DeleteConversationRequest{ + err := h.DeleteConversation(microAccountCtx(), &pb.DeleteConversationRequest{ Id: cRsp.Conversation.Id, }, &pb.DeleteConversationResponse{}) assert.NoError(t, err) - err = h.ReadConversation(context.TODO(), &pb.ReadConversationRequest{ + err = h.ReadConversation(microAccountCtx(), &pb.ReadConversationRequest{ Id: cRsp.Conversation.Id, }, &pb.ReadConversationResponse{}) assert.Equal(t, handler.ErrNotFound, err) }) t.Run("Retry", func(t *testing.T) { - err := h.DeleteConversation(context.TODO(), &pb.DeleteConversationRequest{ + err := h.DeleteConversation(microAccountCtx(), &pb.DeleteConversationRequest{ Id: cRsp.Conversation.Id, }, &pb.DeleteConversationResponse{}) assert.NoError(t, err) diff --git a/threads/handler/list_conversations.go b/threads/handler/list_conversations.go index 8a62f90..40811d9 100644 --- a/threads/handler/list_conversations.go +++ b/threads/handler/list_conversations.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -10,14 +11,23 @@ import ( // List all the conversations for a group func (s *Threads) ListConversations(ctx context.Context, req *pb.ListConversationsRequest, rsp *pb.ListConversationsResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.GroupId) == 0 { return ErrMissingGroupID } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // query the database var convs []Conversation - if err := s.DB.Where(&Conversation{GroupID: req.GroupId}).Find(&convs).Error; err != nil { + if err := db.Where(&Conversation{GroupID: req.GroupId}).Find(&convs).Error; err != nil { logger.Errorf("Error reading conversation: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } diff --git a/threads/handler/list_conversations_test.go b/threads/handler/list_conversations_test.go index 87393aa..1efc442 100644 --- a/threads/handler/list_conversations_test.go +++ b/threads/handler/list_conversations_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/google/uuid" @@ -15,7 +14,7 @@ func TestListConversations(t *testing.T) { // seed some data var cRsp1 pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "HelloWorld", GroupId: uuid.New().String(), }, &cRsp1) if err != nil { @@ -23,7 +22,7 @@ func TestListConversations(t *testing.T) { return } var cRsp2 pb.CreateConversationResponse - err = h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err = h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "FooBar", GroupId: uuid.New().String(), }, &cRsp2) if err != nil { @@ -33,14 +32,14 @@ func TestListConversations(t *testing.T) { t.Run("MissingGroupID", func(t *testing.T) { var rsp pb.ListConversationsResponse - err := h.ListConversations(context.TODO(), &pb.ListConversationsRequest{}, &rsp) + err := h.ListConversations(microAccountCtx(), &pb.ListConversationsRequest{}, &rsp) assert.Equal(t, handler.ErrMissingGroupID, err) assert.Nil(t, rsp.Conversations) }) t.Run("Valid", func(t *testing.T) { var rsp pb.ListConversationsResponse - err := h.ListConversations(context.TODO(), &pb.ListConversationsRequest{ + err := h.ListConversations(microAccountCtx(), &pb.ListConversationsRequest{ GroupId: cRsp1.Conversation.GroupId, }, &rsp) diff --git a/threads/handler/list_messages.go b/threads/handler/list_messages.go index e637493..794f24c 100644 --- a/threads/handler/list_messages.go +++ b/threads/handler/list_messages.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -13,13 +14,22 @@ const DefaultLimit = 25 // List the messages within a conversation in reverse chronological order, using sent_before to // offset as older messages need to be loaded func (s *Threads) ListMessages(ctx context.Context, req *pb.ListMessagesRequest, rsp *pb.ListMessagesResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.ConversationId) == 0 { return ErrMissingConversationID } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // construct the query - q := s.DB.Where(&Message{ConversationID: req.ConversationId}).Order("sent_at DESC") + q := db.Where(&Message{ConversationID: req.ConversationId}).Order("sent_at DESC") if req.SentBefore != nil { q = q.Where("sent_at < ?", req.SentBefore.AsTime()) } diff --git a/threads/handler/list_messages_test.go b/threads/handler/list_messages_test.go index 73046c9..e3a27c9 100644 --- a/threads/handler/list_messages_test.go +++ b/threads/handler/list_messages_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "sort" "strconv" "testing" @@ -20,7 +19,7 @@ func TestListMessages(t *testing.T) { // seed some data var convRsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "TestListMessages", GroupId: uuid.New().String(), }, &convRsp) assert.NoError(t, err) @@ -31,7 +30,7 @@ func TestListMessages(t *testing.T) { msgs := make([]*pb.Message, 50) for i := 0; i < len(msgs); i++ { var rsp pb.CreateMessageResponse - err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{ + err := h.CreateMessage(microAccountCtx(), &pb.CreateMessageRequest{ ConversationId: convRsp.Conversation.Id, AuthorId: uuid.New().String(), Text: strconv.Itoa(i), @@ -42,14 +41,14 @@ func TestListMessages(t *testing.T) { t.Run("MissingConversationID", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{}, &rsp) + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{}, &rsp) assert.Equal(t, handler.ErrMissingConversationID, err) assert.Nil(t, rsp.Messages) }) t.Run("NoOffset", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{ + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{ ConversationId: convRsp.Conversation.Id, }, &rsp) assert.NoError(t, err) @@ -67,7 +66,7 @@ func TestListMessages(t *testing.T) { t.Run("LimitSet", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{ + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{ ConversationId: convRsp.Conversation.Id, Limit: &wrapperspb.Int32Value{Value: 10}, }, &rsp) @@ -86,7 +85,7 @@ func TestListMessages(t *testing.T) { t.Run("OffsetAndLimit", func(t *testing.T) { var rsp pb.ListMessagesResponse - err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{ + err := h.ListMessages(microAccountCtx(), &pb.ListMessagesRequest{ ConversationId: convRsp.Conversation.Id, Limit: &wrapperspb.Int32Value{Value: 5}, SentBefore: msgs[20].SentAt, diff --git a/threads/handler/read_conversation.go b/threads/handler/read_conversation.go index 2ae08c1..b4402cb 100644 --- a/threads/handler/read_conversation.go +++ b/threads/handler/read_conversation.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "gorm.io/gorm" "github.com/micro/micro/v3/service/errors" @@ -12,6 +13,10 @@ import ( // Read a conversation using its ID, can filter using group ID if provided func (s *Threads) ReadConversation(ctx context.Context, req *pb.ReadConversationRequest, rsp *pb.ReadConversationResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID @@ -23,9 +28,14 @@ func (s *Threads) ReadConversation(ctx context.Context, req *pb.ReadConversation q.GroupID = req.GroupId.Value } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // execute the query var conv Conversation - if err := s.DB.Where(&q).First(&conv).Error; err == gorm.ErrRecordNotFound { + if err := db.Where(&q).First(&conv).Error; err == gorm.ErrRecordNotFound { return ErrNotFound } else if err != nil { logger.Errorf("Error reading conversation: %v", err) diff --git a/threads/handler/read_conversation_test.go b/threads/handler/read_conversation_test.go index 35cfc25..bd194c7 100644 --- a/threads/handler/read_conversation_test.go +++ b/threads/handler/read_conversation_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/google/uuid" @@ -16,7 +15,7 @@ func TestReadConversation(t *testing.T) { // seed some data var cRsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "HelloWorld", GroupId: uuid.New().String(), }, &cRsp) if err != nil { @@ -55,7 +54,7 @@ func TestReadConversation(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.ReadConversationResponse - err := h.ReadConversation(context.TODO(), &pb.ReadConversationRequest{ + err := h.ReadConversation(microAccountCtx(), &pb.ReadConversationRequest{ Id: tc.ID, GroupId: tc.GroupID, }, &rsp) assert.Equal(t, tc.Error, err) diff --git a/threads/handler/recent_messages.go b/threads/handler/recent_messages.go index addda0d..26dadd3 100644 --- a/threads/handler/recent_messages.go +++ b/threads/handler/recent_messages.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -13,6 +14,10 @@ import ( // most messages retrieved per conversation is 25, however this can be overriden using the // limit_per_conversation option func (s *Threads) RecentMessages(ctx context.Context, req *pb.RecentMessagesRequest, rsp *pb.RecentMessagesResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.ConversationIds) == 0 { return ErrMissingConversationIDs @@ -23,9 +28,14 @@ func (s *Threads) RecentMessages(ctx context.Context, req *pb.RecentMessagesRequ limit = int(req.LimitPerConversation.Value) } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // query the database var msgs []Message - err := s.DB.Transaction(func(tx *gorm.DB) error { + err = db.Transaction(func(tx *gorm.DB) error { for _, id := range req.ConversationIds { var cms []Message if err := tx.Where(&Message{ConversationID: id}).Order("sent_at DESC").Limit(limit).Find(&cms).Error; err != nil { diff --git a/threads/handler/recent_messages_test.go b/threads/handler/recent_messages_test.go index 3f8ab87..5a9a26d 100644 --- a/threads/handler/recent_messages_test.go +++ b/threads/handler/recent_messages_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "fmt" "testing" "time" @@ -22,7 +21,7 @@ func TestRecentMessages(t *testing.T) { convos := make(map[string][]*pb.Message, 3) for i := 0; i < 3; i++ { var convRsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "TestRecentMessages", GroupId: uuid.New().String(), }, &convRsp) assert.NoError(t, err) @@ -35,7 +34,7 @@ func TestRecentMessages(t *testing.T) { for j := 0; j < 50; j++ { var rsp pb.CreateMessageResponse - err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{ + err := h.CreateMessage(microAccountCtx(), &pb.CreateMessageRequest{ ConversationId: convRsp.Conversation.Id, AuthorId: uuid.New().String(), Text: fmt.Sprintf("Conversation %v, Message %v", i, j), @@ -47,14 +46,14 @@ func TestRecentMessages(t *testing.T) { t.Run("MissingConversationIDs", func(t *testing.T) { var rsp pb.RecentMessagesResponse - err := h.RecentMessages(context.TODO(), &pb.RecentMessagesRequest{}, &rsp) + err := h.RecentMessages(microAccountCtx(), &pb.RecentMessagesRequest{}, &rsp) assert.Equal(t, handler.ErrMissingConversationIDs, err) assert.Nil(t, rsp.Messages) }) t.Run("LimitSet", func(t *testing.T) { var rsp pb.RecentMessagesResponse - err := h.RecentMessages(context.TODO(), &pb.RecentMessagesRequest{ + err := h.RecentMessages(microAccountCtx(), &pb.RecentMessagesRequest{ ConversationIds: ids, LimitPerConversation: &wrapperspb.Int32Value{Value: 10}, }, &rsp) @@ -79,7 +78,7 @@ func TestRecentMessages(t *testing.T) { reducedIDs := ids[:2] var rsp pb.RecentMessagesResponse - err := h.RecentMessages(context.TODO(), &pb.RecentMessagesRequest{ + err := h.RecentMessages(microAccountCtx(), &pb.RecentMessagesRequest{ ConversationIds: reducedIDs, }, &rsp) assert.NoError(t, err) diff --git a/threads/handler/streams.go b/threads/handler/streams.go index 97cf327..3a64cc2 100644 --- a/threads/handler/streams.go +++ b/threads/handler/streams.go @@ -4,9 +4,9 @@ import ( "time" "github.com/micro/micro/v3/service/errors" + gorm2 "github.com/micro/services/pkg/gorm" pb "github.com/micro/services/threads/proto" "google.golang.org/protobuf/types/known/timestamppb" - "gorm.io/gorm" ) var ( @@ -21,7 +21,7 @@ var ( ) type Threads struct { - DB *gorm.DB + gorm2.Helper Time func() time.Time } diff --git a/threads/handler/streams_test.go b/threads/handler/streams_test.go index 24de420..c94cfca 100644 --- a/threads/handler/streams_test.go +++ b/threads/handler/streams_test.go @@ -1,17 +1,17 @@ package handler_test import ( + "context" + "database/sql" "os" "testing" "time" "github.com/golang/protobuf/ptypes/timestamp" + "github.com/micro/micro/v3/service/auth" "github.com/micro/services/threads/handler" pb "github.com/micro/services/threads/proto" "github.com/stretchr/testify/assert" - - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Threads { @@ -20,27 +20,20 @@ func testHandler(t *testing.T) *handler.Threads { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if err := db.Exec("DROP TABLE IF EXISTS conversations, messages CASCADE").Error; err != nil { + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_conversations, micro_messages CASCADE"); err != nil { t.Fatalf("Error cleaning database: %v", err) } - // migrate the database - if err := db.AutoMigrate(&handler.Conversation{}, &handler.Message{}); err != nil { - t.Fatalf("Error migrating database: %v", err) - } + h := &handler.Threads{Time: func() time.Time { return time.Unix(1611327673, 0) }} + h.DBConn(sqlDB).Migrations(&handler.Conversation{}, &handler.Message{}) - // clean any data from a previous run - if err := db.Exec("TRUNCATE TABLE conversations, messages CASCADE").Error; err != nil { - t.Fatalf("Error cleaning database: %v", err) - } - - return &handler.Threads{DB: db, Time: func() time.Time { return time.Unix(1611327673, 0) }} + return h } func assertConversationsMatch(t *testing.T, exp, act *pb.Conversation) { @@ -99,3 +92,9 @@ func microSecondTime(t *timestamp.Timestamp) time.Time { tt := t.AsTime() return time.Unix(tt.Unix(), int64(tt.Nanosecond()-tt.Nanosecond()%1000)) } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/threads/handler/update_conversation.go b/threads/handler/update_conversation.go index 5076c32..5110d3e 100644 --- a/threads/handler/update_conversation.go +++ b/threads/handler/update_conversation.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/threads/proto" @@ -11,6 +12,10 @@ import ( // Update a conversations topic func (s *Threads) UpdateConversation(ctx context.Context, req *pb.UpdateConversationRequest, rsp *pb.UpdateConversationResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID @@ -19,9 +24,14 @@ func (s *Threads) UpdateConversation(ctx context.Context, req *pb.UpdateConversa return ErrMissingTopic } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // lookup the conversation var conv Conversation - if err := s.DB.Where(&Conversation{ID: req.Id}).First(&conv).Error; err == gorm.ErrRecordNotFound { + if err := db.Where(&Conversation{ID: req.Id}).First(&conv).Error; err == gorm.ErrRecordNotFound { return ErrNotFound } else if err != nil { logger.Errorf("Error reading conversation: %v", err) @@ -30,7 +40,7 @@ func (s *Threads) UpdateConversation(ctx context.Context, req *pb.UpdateConversa // update the conversation conv.Topic = req.Topic - if err := s.DB.Save(&conv).Error; err != nil { + if err := db.Save(&conv).Error; err != nil { logger.Errorf("Error updating conversation: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database") } diff --git a/threads/handler/update_conversation_test.go b/threads/handler/update_conversation_test.go index ebbbf5f..e9293ca 100644 --- a/threads/handler/update_conversation_test.go +++ b/threads/handler/update_conversation_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/google/uuid" @@ -15,7 +14,7 @@ func TestUpdateConversation(t *testing.T) { // seed some data var cRsp pb.CreateConversationResponse - err := h.CreateConversation(context.TODO(), &pb.CreateConversationRequest{ + err := h.CreateConversation(microAccountCtx(), &pb.CreateConversationRequest{ Topic: "HelloWorld", GroupId: uuid.New().String(), }, &cRsp) if err != nil { @@ -24,21 +23,21 @@ func TestUpdateConversation(t *testing.T) { } t.Run("MissingID", func(t *testing.T) { - err := h.UpdateConversation(context.TODO(), &pb.UpdateConversationRequest{ + err := h.UpdateConversation(microAccountCtx(), &pb.UpdateConversationRequest{ Topic: "NewTopic", }, &pb.UpdateConversationResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) t.Run("MissingTopic", func(t *testing.T) { - err := h.UpdateConversation(context.TODO(), &pb.UpdateConversationRequest{ + err := h.UpdateConversation(microAccountCtx(), &pb.UpdateConversationRequest{ Id: uuid.New().String(), }, &pb.UpdateConversationResponse{}) assert.Equal(t, handler.ErrMissingTopic, err) }) t.Run("InvalidID", func(t *testing.T) { - err := h.UpdateConversation(context.TODO(), &pb.UpdateConversationRequest{ + err := h.UpdateConversation(microAccountCtx(), &pb.UpdateConversationRequest{ Id: uuid.New().String(), Topic: "NewTopic", }, &pb.UpdateConversationResponse{}) @@ -46,14 +45,14 @@ func TestUpdateConversation(t *testing.T) { }) t.Run("Valid", func(t *testing.T) { - err := h.UpdateConversation(context.TODO(), &pb.UpdateConversationRequest{ + err := h.UpdateConversation(microAccountCtx(), &pb.UpdateConversationRequest{ Id: cRsp.Conversation.Id, Topic: "NewTopic", }, &pb.UpdateConversationResponse{}) assert.NoError(t, err) var rsp pb.ReadConversationResponse - err = h.ReadConversation(context.TODO(), &pb.ReadConversationRequest{ + err = h.ReadConversation(microAccountCtx(), &pb.ReadConversationRequest{ Id: cRsp.Conversation.Id, }, &rsp) assert.NoError(t, err) diff --git a/threads/main.go b/threads/main.go index d8b98be..c0e56b0 100644 --- a/threads/main.go +++ b/threads/main.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "time" "github.com/micro/services/threads/handler" @@ -9,8 +10,8 @@ import ( "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" - "gorm.io/driver/postgres" - "gorm.io/gorm" + + _ "github.com/jackc/pgx/v4/stdlib" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/threads?sslmode=disable" @@ -28,16 +29,15 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.Conversation{}, &handler.Message{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } + h := &handler.Threads{Time: time.Now} + h.DBConn(sqlDB).Migrations(&handler.Conversation{}, &handler.Message{}) // Register handler - pb.RegisterThreadsHandler(srv.Server(), &handler.Threads{DB: db, Time: time.Now}) + pb.RegisterThreadsHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil {