From 6352339ebf58807a2040f27d9815a6e9c8c73732 Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Thu, 25 Mar 2021 23:27:56 +0000 Subject: [PATCH] lockdown seen service (#81) * seen * readme --- seen/README.md | 6 ++++++ seen/handler/handler.go | 39 +++++++++++++++++++++++++++++++----- seen/handler/handler_test.go | 39 +++++++++++++++++++----------------- seen/main.go | 17 ++++++++-------- 4 files changed, 70 insertions(+), 31 deletions(-) create mode 100644 seen/README.md diff --git a/seen/README.md b/seen/README.md new file mode 100644 index 0000000..096d7cb --- /dev/null +++ b/seen/README.md @@ -0,0 +1,6 @@ +Seen is a service to keep track of which resources a user has seen (read). For example, it can be used to keep track of what notifications have been seen by a user, or what messages they've read in a chat. + + +# Seen Service + +The seen service is a service to keep track of which resources a user has seen (read). diff --git a/seen/handler/handler.go b/seen/handler/handler.go index 34cb5eb..4789c87 100644 --- a/seen/handler/handler.go +++ b/seen/handler/handler.go @@ -4,6 +4,8 @@ import ( "context" "time" + "github.com/micro/micro/v3/service/auth" + gorm2 "github.com/micro/services/pkg/gorm" "gorm.io/gorm" "github.com/google/uuid" @@ -22,7 +24,7 @@ var ( ) type Seen struct { - DB *gorm.DB + gorm2.Helper } type SeenInstance struct { @@ -35,6 +37,10 @@ type SeenInstance struct { // Set a resource as seen by a user. If no timestamp is provided, the current time is used. func (s *Seen) Set(ctx context.Context, req *pb.SetRequest, rsp *pb.SetResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.UserId) == 0 { return ErrMissingUserID @@ -57,7 +63,12 @@ func (s *Seen) Set(ctx context.Context, req *pb.SetRequest, rsp *pb.SetResponse) ResourceID: req.ResourceId, ResourceType: req.ResourceType, } - if err := s.DB.Where(&instance).First(&instance).Error; err == gorm.ErrRecordNotFound { + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + if err := db.Where(&instance).First(&instance).Error; err == gorm.ErrRecordNotFound { instance.ID = uuid.New().String() } else if err != nil { logger.Errorf("Error with store: %v", err) @@ -66,7 +77,7 @@ func (s *Seen) Set(ctx context.Context, req *pb.SetRequest, rsp *pb.SetResponse) // update the resource instance.Timestamp = req.Timestamp.AsTime() - if err := s.DB.Save(&instance).Error; err != nil { + if err := db.Save(&instance).Error; err != nil { logger.Errorf("Error with store: %v", err) return ErrStore } @@ -77,6 +88,10 @@ func (s *Seen) Set(ctx context.Context, req *pb.SetRequest, rsp *pb.SetResponse) // Unset a resource as seen, used in cases where a user viewed a resource but wants to override // this so they remember to action it in the future, e.g. "Mark this as unread". func (s *Seen) Unset(ctx context.Context, req *pb.UnsetRequest, rsp *pb.UnsetResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.UserId) == 0 { return ErrMissingUserID @@ -88,8 +103,13 @@ func (s *Seen) Unset(ctx context.Context, req *pb.UnsetRequest, rsp *pb.UnsetRes return ErrMissingResourceType } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // delete the object from the store - err := s.DB.Delete(SeenInstance{}, SeenInstance{ + err = db.Delete(SeenInstance{}, SeenInstance{ UserID: req.UserId, ResourceID: req.ResourceId, ResourceType: req.ResourceType, @@ -106,6 +126,10 @@ func (s *Seen) Unset(ctx context.Context, req *pb.UnsetRequest, rsp *pb.UnsetRes // is returned for a given resource_id, it indicates that resource has not yet been seen by the // user. func (s *Seen) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.UserId) == 0 { return ErrMissingUserID @@ -117,8 +141,13 @@ func (s *Seen) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRespon return ErrMissingResourceType } + db, err := s.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // query the store - q := s.DB.Where(SeenInstance{UserID: req.UserId, ResourceType: req.ResourceType}) + q := db.Where(SeenInstance{UserID: req.UserId, ResourceType: req.ResourceType}) q = q.Where("resource_id IN (?)", req.ResourceIds) var data []SeenInstance if err := q.Find(&data).Error; err != nil { diff --git a/seen/handler/handler_test.go b/seen/handler/handler_test.go index b352120..93d0448 100644 --- a/seen/handler/handler_test.go +++ b/seen/handler/handler_test.go @@ -2,17 +2,17 @@ package handler_test import ( "context" + "database/sql" "os" "testing" "time" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/services/seen/handler" pb "github.com/micro/services/seen/proto" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/timestamppb" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Seen { @@ -21,22 +21,19 @@ func testHandler(t *testing.T) *handler.Seen { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) - } - - // migrate the database - if err := db.AutoMigrate(&handler.SeenInstance{}); err != nil { - t.Fatalf("Error migrating database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if err := db.Exec("TRUNCATE TABLE seen_instances CASCADE").Error; err != nil { + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_seen_instances CASCADE"); err != nil { t.Fatalf("Error cleaning database: %v", err) } - return &handler.Seen{DB: db} + h := &handler.Seen{} + h.DBConn(sqlDB).Migrations(&handler.SeenInstance{}) + return h } func TestSet(t *testing.T) { @@ -91,7 +88,7 @@ func TestSet(t *testing.T) { h := testHandler(t) for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { - err := h.Set(context.TODO(), &pb.SetRequest{ + err := h.Set(microAccountCtx(), &pb.SetRequest{ UserId: tc.UserID, ResourceId: tc.ResourceID, ResourceType: tc.ResourceType, @@ -110,7 +107,7 @@ func TestUnset(t *testing.T) { ResourceId: uuid.New().String(), ResourceType: "message", } - err := h.Set(context.TODO(), seed, &pb.SetResponse{}) + err := h.Set(microAccountCtx(), seed, &pb.SetResponse{}) assert.NoError(t, err) tt := []struct { @@ -154,7 +151,7 @@ func TestUnset(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { - err := h.Unset(context.TODO(), &pb.UnsetRequest{ + err := h.Unset(microAccountCtx(), &pb.UnsetRequest{ UserId: tc.UserID, ResourceId: tc.ResourceID, ResourceType: tc.ResourceType, @@ -208,7 +205,7 @@ func TestRead(t *testing.T) { }, } for _, d := range td { - assert.NoError(t, h.Set(context.TODO(), &pb.SetRequest{ + assert.NoError(t, h.Set(microAccountCtx(), &pb.SetRequest{ UserId: d.UserID, ResourceId: d.ResourceID, ResourceType: d.ResourceType, @@ -218,7 +215,7 @@ func TestRead(t *testing.T) { // check only the requested values are returned var rsp pb.ReadResponse - err := h.Read(context.TODO(), &pb.ReadRequest{ + err := h.Read(microAccountCtx(), &pb.ReadRequest{ UserId: "user-1", ResourceType: "message", ResourceIds: []string{"message-1", "message-2", "message-3"}, @@ -239,7 +236,7 @@ func TestRead(t *testing.T) { } // unsetting a resource should remove it from the list - err = h.Unset(context.TODO(), &pb.UnsetRequest{ + err = h.Unset(microAccountCtx(), &pb.UnsetRequest{ UserId: "user-1", ResourceId: "message-2", ResourceType: "message", @@ -247,7 +244,7 @@ func TestRead(t *testing.T) { assert.NoError(t, err) rsp = pb.ReadResponse{} - err = h.Read(context.TODO(), &pb.ReadRequest{ + err = h.Read(microAccountCtx(), &pb.ReadRequest{ UserId: "user-1", ResourceType: "message", ResourceIds: []string{"message-1", "message-2", "message-3"}, @@ -261,3 +258,9 @@ func TestRead(t *testing.T) { func microSecondTime(tt time.Time) time.Time { return time.Unix(tt.Unix(), int64(tt.Nanosecond()-tt.Nanosecond()%1000)).UTC() } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/seen/main.go b/seen/main.go index ec7468f..6e52d62 100644 --- a/seen/main.go +++ b/seen/main.go @@ -1,14 +1,16 @@ package main import ( + "database/sql" + "github.com/micro/services/seen/handler" pb "github.com/micro/services/seen/proto" - "gorm.io/driver/postgres" - "gorm.io/gorm" "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" + + _ "github.com/jackc/pgx/v4/stdlib" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/seen?sslmode=disable" @@ -26,16 +28,15 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.SeenInstance{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } + h := &handler.Seen{} + h.DBConn(sqlDB).Migrations(&handler.SeenInstance{}) // Register handler - pb.RegisterSeenHandler(srv.Server(), &handler.Seen{DB: db.Debug()}) + pb.RegisterSeenHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil {