From 81252c1611a95bc314b8aa0c398a67d7c3b3757b Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Thu, 25 Mar 2021 22:22:46 +0000 Subject: [PATCH] lockdown invites (#80) --- invites/handler/invites.go | 49 +++++++++++++++++++++++++++++---- invites/handler/invites_test.go | 45 ++++++++++++++++-------------- invites/main.go | 17 ++++++------ 3 files changed, 76 insertions(+), 35 deletions(-) diff --git a/invites/handler/invites.go b/invites/handler/invites.go index 7c12764..70b49f2 100644 --- a/invites/handler/invites.go +++ b/invites/handler/invites.go @@ -8,9 +8,11 @@ import ( "strings" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/invites/proto" + gorm2 "github.com/micro/services/pkg/gorm" "gorm.io/gorm" ) @@ -43,11 +45,15 @@ func (i *Invite) Serialize() *pb.Invite { } type Invites struct { - DB *gorm.DB + gorm2.Helper } // Create an invite func (i *Invites) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.GroupId) == 0 { return ErrMissingGroupID @@ -66,8 +72,12 @@ func (i *Invites) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Cre GroupID: req.GroupId, Email: strings.ToLower(req.Email), } - if err := i.DB.Create(invite).Error; err != nil && strings.Contains(err.Error(), "group_email") { - } else if err != nil { + db, err := i.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + if err := db.Create(invite).Error; err != nil && !strings.Contains(err.Error(), "group_email") { logger.Errorf("Error writing to the store: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } @@ -79,6 +89,10 @@ func (i *Invites) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Cre // Read an invite using ID or code func (i *Invites) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request var query Invite if req.Id != nil { @@ -89,9 +103,14 @@ func (i *Invites) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRes return ErrMissingIDAndCode } + db, err := i.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // query the database var invite Invite - if err := i.DB.Where(&query).First(&invite).Error; err == gorm.ErrRecordNotFound { + if err := db.Where(&query).First(&invite).Error; err == gorm.ErrRecordNotFound { return ErrInviteNotFound } else if err != nil { logger.Errorf("Error reading from the store: %v", err) @@ -105,6 +124,10 @@ func (i *Invites) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRes // List invited for a group or specific email func (i *Invites) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if req.Email == nil && req.GroupId == nil { return ErrMissingGroupIDAndEmail @@ -119,9 +142,14 @@ func (i *Invites) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListRes query.Email = strings.ToLower(req.Email.Value) } + db, err := i.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // query the database var invites []Invite - if err := i.DB.Where(&query).Find(&invites).Error; err != nil { + if err := db.Where(&query).Find(&invites).Error; err != nil { logger.Errorf("Error reading from the store: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } @@ -136,13 +164,22 @@ func (i *Invites) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListRes // Delete an invite func (i *Invites) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID } + db, err := i.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // delete from the database - if err := i.DB.Where(&Invite{ID: req.Id}).Delete(&Invite{}).Error; err != nil { + if err := db.Where(&Invite{ID: req.Id}).Delete(&Invite{}).Error; err != nil { logger.Errorf("Error deleting from the store: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } diff --git a/invites/handler/invites_test.go b/invites/handler/invites_test.go index 917874e..375fcd7 100644 --- a/invites/handler/invites_test.go +++ b/invites/handler/invites_test.go @@ -2,17 +2,17 @@ package handler_test import ( "context" + "database/sql" "os" "testing" + "github.com/micro/micro/v3/service/auth" "github.com/micro/services/invites/handler" pb "github.com/micro/services/invites/proto" "google.golang.org/protobuf/types/known/wrapperspb" "github.com/google/uuid" "github.com/stretchr/testify/assert" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Invites { @@ -21,22 +21,19 @@ func testHandler(t *testing.T) *handler.Invites { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if err := db.Exec("DROP TABLE IF EXISTS invites CASCADE").Error; err != nil { + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_invites CASCADE"); err != nil { t.Fatalf("Error cleaning database: %v", err) } - // migrate the database - if err := db.AutoMigrate(&handler.Invite{}); err != nil { - t.Fatalf("Error migrating database: %v", err) - } - - return &handler.Invites{DB: db} + h := &handler.Invites{} + h.DBConn(sqlDB).Migrations(&handler.Invite{}) + return h } func TestCreate(t *testing.T) { @@ -78,7 +75,7 @@ func TestCreate(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ GroupId: tc.GroupID, Email: tc.Email, }, &rsp) assert.Equal(t, tc.Error, err) @@ -106,7 +103,7 @@ func TestRead(t *testing.T) { // seed some data var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{Email: "john@doe.com", GroupId: uuid.New().String()}, &cRsp) + err := h.Create(microAccountCtx(), &pb.CreateRequest{Email: "john@doe.com", GroupId: uuid.New().String()}, &cRsp) assert.NoError(t, err) if cRsp.Invite == nil { t.Fatal("No invite returned on create") @@ -149,7 +146,7 @@ func TestRead(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.ReadResponse - err := h.Read(context.TODO(), &pb.ReadRequest{Id: tc.ID, Code: tc.Code}, &rsp) + err := h.Read(microAccountCtx(), &pb.ReadRequest{Id: tc.ID, Code: tc.Code}, &rsp) assert.Equal(t, tc.Error, err) if tc.Invite == nil { @@ -166,7 +163,7 @@ func TestList(t *testing.T) { // seed some data var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{Email: "john@doe.com", GroupId: uuid.New().String()}, &cRsp) + err := h.Create(microAccountCtx(), &pb.CreateRequest{Email: "john@doe.com", GroupId: uuid.New().String()}, &cRsp) assert.NoError(t, err) if cRsp.Invite == nil { t.Fatal("No invite returned on create") @@ -212,7 +209,7 @@ func TestList(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.ListResponse - err := h.List(context.TODO(), &pb.ListRequest{Email: tc.Email, GroupId: tc.GroupID}, &rsp) + err := h.List(microAccountCtx(), &pb.ListRequest{Email: tc.Email, GroupId: tc.GroupID}, &rsp) assert.Equal(t, tc.Error, err) if tc.Invite == nil { @@ -232,13 +229,13 @@ func TestDelete(t *testing.T) { h := testHandler(t) t.Run("MissingID", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) // seed some data var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{Email: "john@doe.com", GroupId: uuid.New().String()}, &cRsp) + err := h.Create(microAccountCtx(), &pb.CreateRequest{Email: "john@doe.com", GroupId: uuid.New().String()}, &cRsp) assert.NoError(t, err) if cRsp.Invite == nil { t.Fatal("No invite returned on create") @@ -246,15 +243,15 @@ func TestDelete(t *testing.T) { } t.Run("Valid", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{Id: cRsp.Invite.Id}, &pb.DeleteResponse{}) + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{Id: cRsp.Invite.Id}, &pb.DeleteResponse{}) assert.NoError(t, err) - err = h.Read(context.TODO(), &pb.ReadRequest{Id: &wrapperspb.StringValue{Value: cRsp.Invite.Id}}, &pb.ReadResponse{}) + err = h.Read(microAccountCtx(), &pb.ReadRequest{Id: &wrapperspb.StringValue{Value: cRsp.Invite.Id}}, &pb.ReadResponse{}) assert.Equal(t, handler.ErrInviteNotFound, err) }) t.Run("Repeat", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{Id: cRsp.Invite.Id}, &pb.DeleteResponse{}) + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{Id: cRsp.Invite.Id}, &pb.DeleteResponse{}) assert.NoError(t, err) }) } @@ -269,3 +266,9 @@ func assertInvitesMatch(t *testing.T, exp, act *pb.Invite) { assert.Equal(t, exp.Email, act.Email) assert.Equal(t, exp.GroupId, act.GroupId) } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/invites/main.go b/invites/main.go index 893c4d0..53bb46e 100644 --- a/invites/main.go +++ b/invites/main.go @@ -1,14 +1,16 @@ package main import ( + "database/sql" + "github.com/micro/services/invites/handler" pb "github.com/micro/services/invites/proto" "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" - "gorm.io/driver/postgres" - "gorm.io/gorm" + + _ "github.com/jackc/pgx/v4/stdlib" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/invites?sslmode=disable" @@ -26,16 +28,15 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.Invite{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } + h := &handler.Invites{} + h.DBConn(sqlDB).Migrations(&handler.Invite{}) // Register handler - pb.RegisterInvitesHandler(srv.Server(), &handler.Invites{DB: db}) + pb.RegisterInvitesHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil {