mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-15 12:34:44 +00:00
Refactor Chats Service (#48)
This commit is contained in:
61
chats/handler/chats.go
Normal file
61
chats/handler/chats.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingID = errors.BadRequest("MISSING_ID", "Missing ID")
|
||||
ErrMissingAuthorID = errors.BadRequest("MISSING_AUTHOR_ID", "Missing Author ID")
|
||||
ErrMissingText = errors.BadRequest("MISSING_TEXT", "Missing text")
|
||||
ErrMissingChatID = errors.BadRequest("MISSING_CHAT_ID", "Missing Chat ID")
|
||||
ErrMissingUserIDs = errors.BadRequest("MISSING_USER_IDs", "Two or more user IDs are required")
|
||||
ErrNotFound = errors.NotFound("NOT_FOUND", "Chat not found")
|
||||
)
|
||||
|
||||
type Chats struct {
|
||||
DB *gorm.DB
|
||||
Time func() time.Time
|
||||
}
|
||||
|
||||
type Chat struct {
|
||||
ID string
|
||||
UserIDs string `gorm:"uniqueIndex"` // sorted json array
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID string
|
||||
AuthorID string
|
||||
ChatID string
|
||||
Text string
|
||||
SentAt time.Time
|
||||
}
|
||||
|
||||
func (m *Message) Serialize() *pb.Message {
|
||||
return &pb.Message{
|
||||
Id: m.ID,
|
||||
AuthorId: m.AuthorID,
|
||||
ChatId: m.ChatID,
|
||||
Text: m.Text,
|
||||
SentAt: timestamppb.New(m.SentAt),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Chat) Serialize() *pb.Chat {
|
||||
var userIDs []string
|
||||
json.Unmarshal([]byte(c.UserIDs), &userIDs)
|
||||
|
||||
return &pb.Chat{
|
||||
Id: c.ID,
|
||||
UserIds: userIDs,
|
||||
CreatedAt: timestamppb.New(c.CreatedAt),
|
||||
}
|
||||
}
|
||||
83
chats/handler/chats_test.go
Normal file
83
chats/handler/chats_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/micro/services/chats/handler"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func testHandler(t *testing.T) *handler.Chats {
|
||||
// connect to the database
|
||||
db, err := gorm.Open(postgres.Open("postgresql://postgres@localhost:5432/chats?sslmode=disable"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error connecting to database: %v", err)
|
||||
}
|
||||
|
||||
// migrate the database
|
||||
if err := db.AutoMigrate(&handler.Chat{}, &handler.Message{}); err != nil {
|
||||
t.Fatalf("Error migrating database: %v", err)
|
||||
}
|
||||
|
||||
// clean any data from a previous run
|
||||
if err := db.Exec("TRUNCATE TABLE chats, messages CASCADE").Error; err != nil {
|
||||
t.Fatalf("Error cleaning database: %v", err)
|
||||
}
|
||||
|
||||
return &handler.Chats{DB: db, Time: func() time.Time { return time.Unix(1611327673, 0) }}
|
||||
}
|
||||
|
||||
func assertChatsMatch(t *testing.T, exp, act *pb.Chat) {
|
||||
if act == nil {
|
||||
t.Errorf("Chat not returned")
|
||||
return
|
||||
}
|
||||
|
||||
// adapt this check so we can reuse the func in testing create, where we don't know the exact id
|
||||
// which will be generated
|
||||
if len(exp.Id) > 0 {
|
||||
assert.Equal(t, exp.Id, act.Id)
|
||||
} else {
|
||||
assert.NotEmpty(t, act.Id)
|
||||
}
|
||||
|
||||
assert.Equal(t, exp.UserIds, act.UserIds)
|
||||
|
||||
if act.CreatedAt == nil {
|
||||
t.Errorf("CreatedAt not set")
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, exp.CreatedAt.AsTime().Equal(act.CreatedAt.AsTime()))
|
||||
}
|
||||
|
||||
func assertMessagesMatch(t *testing.T, exp, act *pb.Message) {
|
||||
if act == nil {
|
||||
t.Errorf("Message not returned")
|
||||
return
|
||||
}
|
||||
|
||||
// adapt this check so we can reuse the func in testing create, where we don't know the exact id
|
||||
// which will be generated
|
||||
if len(exp.Id) > 0 {
|
||||
assert.Equal(t, exp.Id, act.Id)
|
||||
} else {
|
||||
assert.NotEmpty(t, act.Id)
|
||||
}
|
||||
|
||||
assert.Equal(t, exp.Text, act.Text)
|
||||
assert.Equal(t, exp.AuthorId, act.AuthorId)
|
||||
assert.Equal(t, exp.ChatId, act.ChatId)
|
||||
|
||||
if act.SentAt == nil {
|
||||
t.Errorf("SentAt not set")
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, exp.SentAt.AsTime().Equal(act.SentAt.AsTime()))
|
||||
}
|
||||
57
chats/handler/create_chat.go
Normal file
57
chats/handler/create_chat.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
)
|
||||
|
||||
// Create a chat between two or more users, if a chat already exists for these users, the existing
|
||||
// chat will be returned
|
||||
func (c *Chats) CreateChat(ctx context.Context, req *pb.CreateChatRequest, rsp *pb.CreateChatResponse) error {
|
||||
// validate the request
|
||||
if len(req.UserIds) < 2 {
|
||||
return ErrMissingUserIDs
|
||||
}
|
||||
|
||||
// sort the user ids and then marshal to json
|
||||
sort.Strings(req.UserIds)
|
||||
bytes, err := json.Marshal(req.UserIds)
|
||||
if err != nil {
|
||||
logger.Errorf("Error mashaling user ids: %v", err)
|
||||
return errors.InternalServerError("ENCODING_ERROR", "Error encoding user ids")
|
||||
}
|
||||
|
||||
// construct the chat
|
||||
chat := Chat{
|
||||
ID: uuid.New().String(),
|
||||
CreatedAt: time.Now(),
|
||||
UserIDs: string(bytes),
|
||||
}
|
||||
|
||||
// write to the database, if we get a unique key error, the chat already exists
|
||||
err = c.DB.Create(&chat).Error
|
||||
if err == nil {
|
||||
rsp.Chat = chat.Serialize()
|
||||
return nil
|
||||
}
|
||||
if !strings.Contains(err.Error(), "idx_chats_user_ids") {
|
||||
logger.Errorf("Error creating chat: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database")
|
||||
}
|
||||
|
||||
var existing Chat
|
||||
if err := c.DB.Where(&Chat{UserIDs: chat.UserIDs}).First(&existing).Error; err != nil {
|
||||
logger.Errorf("Error reading chat: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database")
|
||||
}
|
||||
rsp.Chat = existing.Serialize()
|
||||
return nil
|
||||
}
|
||||
63
chats/handler/create_chat_test.go
Normal file
63
chats/handler/create_chat_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/services/chats/handler"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateChat(t *testing.T) {
|
||||
userIDs := []string{uuid.New().String(), uuid.New().String()}
|
||||
|
||||
tt := []struct {
|
||||
Name string
|
||||
UserIDs []string
|
||||
Error error
|
||||
}{
|
||||
{
|
||||
Name: "NoUserIDs",
|
||||
Error: handler.ErrMissingUserIDs,
|
||||
},
|
||||
{
|
||||
Name: "OneUserID",
|
||||
UserIDs: userIDs[1:],
|
||||
Error: handler.ErrMissingUserIDs,
|
||||
},
|
||||
{
|
||||
Name: "Valid",
|
||||
UserIDs: userIDs,
|
||||
},
|
||||
{
|
||||
Name: "Repeat",
|
||||
UserIDs: userIDs,
|
||||
},
|
||||
}
|
||||
|
||||
var chat *pb.Chat
|
||||
h := testHandler(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
var rsp pb.CreateChatResponse
|
||||
err := h.CreateChat(context.TODO(), &pb.CreateChatRequest{
|
||||
UserIds: tc.UserIDs,
|
||||
}, &rsp)
|
||||
|
||||
assert.Equal(t, tc.Error, err)
|
||||
if tc.Error != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotNil(t, rsp.Chat)
|
||||
if chat == nil {
|
||||
chat = rsp.Chat
|
||||
} else {
|
||||
assertChatsMatch(t, chat, rsp.Chat)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
53
chats/handler/create_message.go
Normal file
53
chats/handler/create_message.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Create a message within a chat
|
||||
func (c *Chats) CreateMessage(ctx context.Context, req *pb.CreateMessageRequest, rsp *pb.CreateMessageResponse) error {
|
||||
// validate the request
|
||||
if len(req.AuthorId) == 0 {
|
||||
return ErrMissingAuthorID
|
||||
}
|
||||
if len(req.ChatId) == 0 {
|
||||
return ErrMissingChatID
|
||||
}
|
||||
if len(req.Text) == 0 {
|
||||
return ErrMissingText
|
||||
}
|
||||
|
||||
return c.DB.Transaction(func(tx *gorm.DB) error {
|
||||
// lookup the chat
|
||||
var conv Chat
|
||||
if err := tx.Where(&Chat{ID: req.ChatId}).First(&conv).Error; err == gorm.ErrRecordNotFound {
|
||||
return ErrNotFound
|
||||
} else if err != nil {
|
||||
logger.Errorf("Error reading chat: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database")
|
||||
}
|
||||
|
||||
// create the message
|
||||
msg := &Message{
|
||||
ID: uuid.New().String(),
|
||||
SentAt: c.Time(),
|
||||
Text: req.Text,
|
||||
AuthorID: req.AuthorId,
|
||||
ChatID: req.ChatId,
|
||||
}
|
||||
if err := tx.Create(msg).Error; err != nil {
|
||||
logger.Errorf("Error creating message: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database")
|
||||
}
|
||||
|
||||
// serialize the response
|
||||
rsp.Message = msg.Serialize()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
89
chats/handler/create_message_test.go
Normal file
89
chats/handler/create_message_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/services/chats/handler"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateMessage(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
// seed some data
|
||||
var cRsp pb.CreateChatResponse
|
||||
err := h.CreateChat(context.TODO(), &pb.CreateChatRequest{
|
||||
UserIds: []string{uuid.New().String(), uuid.New().String()},
|
||||
}, &cRsp)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating chat: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
Name string
|
||||
AuthorID string
|
||||
ChatID string
|
||||
Text string
|
||||
Error error
|
||||
}{
|
||||
{
|
||||
Name: "MissingChatID",
|
||||
Text: "HelloWorld",
|
||||
AuthorID: uuid.New().String(),
|
||||
Error: handler.ErrMissingChatID,
|
||||
},
|
||||
{
|
||||
Name: "MissingAuthorID",
|
||||
ChatID: uuid.New().String(),
|
||||
Text: "HelloWorld",
|
||||
Error: handler.ErrMissingAuthorID,
|
||||
},
|
||||
{
|
||||
Name: "MissingText",
|
||||
ChatID: uuid.New().String(),
|
||||
AuthorID: uuid.New().String(),
|
||||
Error: handler.ErrMissingText,
|
||||
},
|
||||
{
|
||||
Name: "ChatNotFound",
|
||||
ChatID: uuid.New().String(),
|
||||
AuthorID: uuid.New().String(),
|
||||
Text: "HelloWorld",
|
||||
Error: handler.ErrNotFound,
|
||||
},
|
||||
{
|
||||
Name: "Valid",
|
||||
ChatID: cRsp.Chat.Id,
|
||||
AuthorID: uuid.New().String(),
|
||||
Text: "HelloWorld",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
var rsp pb.CreateMessageResponse
|
||||
err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{
|
||||
Text: tc.Text, ChatId: tc.ChatID, AuthorId: tc.AuthorID,
|
||||
}, &rsp)
|
||||
|
||||
assert.Equal(t, tc.Error, err)
|
||||
if tc.Error != nil {
|
||||
assert.Nil(t, rsp.Message)
|
||||
return
|
||||
}
|
||||
|
||||
assertMessagesMatch(t, &pb.Message{
|
||||
AuthorId: tc.AuthorID,
|
||||
ChatId: tc.ChatID,
|
||||
SentAt: timestamppb.New(h.Time()),
|
||||
Text: tc.Text,
|
||||
}, rsp.Message)
|
||||
})
|
||||
}
|
||||
}
|
||||
45
chats/handler/list_messages.go
Normal file
45
chats/handler/list_messages.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
)
|
||||
|
||||
const DefaultLimit = 25
|
||||
|
||||
// List the messages within a chat in reverse chronological order, using sent_before to
|
||||
// offset as older messages need to be loaded
|
||||
func (c *Chats) ListMessages(ctx context.Context, req *pb.ListMessagesRequest, rsp *pb.ListMessagesResponse) error {
|
||||
// validate the request
|
||||
if len(req.ChatId) == 0 {
|
||||
return ErrMissingChatID
|
||||
}
|
||||
|
||||
// construct the query
|
||||
q := c.DB.Where(&Message{ChatID: req.ChatId}).Order("sent_at DESC")
|
||||
if req.SentBefore != nil {
|
||||
q = q.Where("sent_at < ?", req.SentBefore.AsTime())
|
||||
}
|
||||
if req.Limit != nil {
|
||||
q.Limit(int(req.Limit.Value))
|
||||
} else {
|
||||
q.Limit(DefaultLimit)
|
||||
}
|
||||
|
||||
// execute the query
|
||||
var msgs []Message
|
||||
if err := q.Find(&msgs).Error; err != nil {
|
||||
logger.Errorf("Error reading messages: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to database")
|
||||
}
|
||||
|
||||
// serialize the response
|
||||
rsp.Messages = make([]*pb.Message, len(msgs))
|
||||
for i, m := range msgs {
|
||||
rsp.Messages[i] = m.Serialize()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
chats/handler/list_messages_test.go
Normal file
116
chats/handler/list_messages_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/services/chats/handler"
|
||||
pb "github.com/micro/services/chats/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
)
|
||||
|
||||
func TestListMessages(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
h.Time = time.Now
|
||||
|
||||
// seed some data
|
||||
var chatRsp pb.CreateChatResponse
|
||||
err := h.CreateChat(context.TODO(), &pb.CreateChatRequest{
|
||||
UserIds: []string{uuid.New().String(), uuid.New().String()},
|
||||
}, &chatRsp)
|
||||
assert.NoError(t, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msgs := make([]*pb.Message, 50)
|
||||
for i := 0; i < len(msgs); i++ {
|
||||
var rsp pb.CreateMessageResponse
|
||||
err := h.CreateMessage(context.TODO(), &pb.CreateMessageRequest{
|
||||
ChatId: chatRsp.Chat.Id,
|
||||
AuthorId: uuid.New().String(),
|
||||
Text: strconv.Itoa(i),
|
||||
}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
msgs[i] = rsp.Message
|
||||
}
|
||||
|
||||
t.Run("MissingChatID", func(t *testing.T) {
|
||||
var rsp pb.ListMessagesResponse
|
||||
err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{}, &rsp)
|
||||
assert.Equal(t, handler.ErrMissingChatID, err)
|
||||
assert.Nil(t, rsp.Messages)
|
||||
})
|
||||
|
||||
t.Run("NoOffset", func(t *testing.T) {
|
||||
var rsp pb.ListMessagesResponse
|
||||
err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{
|
||||
ChatId: chatRsp.Chat.Id,
|
||||
}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(rsp.Messages) != handler.DefaultLimit {
|
||||
t.Fatalf("Expected %v messages but got %v", handler.DefaultLimit, len(rsp.Messages))
|
||||
return
|
||||
}
|
||||
expected := msgs[25:]
|
||||
sortMessages(rsp.Messages)
|
||||
for i, msg := range rsp.Messages {
|
||||
assertMessagesMatch(t, expected[i], msg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LimitSet", func(t *testing.T) {
|
||||
var rsp pb.ListMessagesResponse
|
||||
err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{
|
||||
ChatId: chatRsp.Chat.Id,
|
||||
Limit: &wrapperspb.Int32Value{Value: 10},
|
||||
}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(rsp.Messages) != 10 {
|
||||
t.Fatalf("Expected %v messages but got %v", 10, len(rsp.Messages))
|
||||
return
|
||||
}
|
||||
expected := msgs[40:]
|
||||
sortMessages(rsp.Messages)
|
||||
for i, msg := range rsp.Messages {
|
||||
assertMessagesMatch(t, expected[i], msg)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OffsetAndLimit", func(t *testing.T) {
|
||||
var rsp pb.ListMessagesResponse
|
||||
err := h.ListMessages(context.TODO(), &pb.ListMessagesRequest{
|
||||
ChatId: chatRsp.Chat.Id,
|
||||
Limit: &wrapperspb.Int32Value{Value: 5},
|
||||
SentBefore: msgs[20].SentAt,
|
||||
}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(rsp.Messages) != 5 {
|
||||
t.Fatalf("Expected %v messages but got %v", 5, len(rsp.Messages))
|
||||
return
|
||||
}
|
||||
expected := msgs[15:20]
|
||||
sortMessages(rsp.Messages)
|
||||
for i, msg := range rsp.Messages {
|
||||
assertMessagesMatch(t, expected[i], msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// sortMessages by the time they were sent
|
||||
func sortMessages(msgs []*pb.Message) {
|
||||
sort.Slice(msgs, func(i, j int) bool {
|
||||
if msgs[i].SentAt == nil || msgs[j].SentAt == nil {
|
||||
return true
|
||||
}
|
||||
return msgs[i].SentAt.AsTime().Before(msgs[j].SentAt.AsTime())
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user