From 72aca7ee64734bc1cadb468f80c3611b7f93e39d Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Thu, 25 Mar 2021 22:53:48 +0000 Subject: [PATCH] codes --- codes/handler/create.go | 12 +++++++++++- codes/handler/create_test.go | 13 ++++++++++--- codes/handler/handler.go | 4 ++-- codes/handler/handler_test.go | 18 +++++++----------- codes/handler/verify.go | 12 +++++++++++- codes/handler/verify_test.go | 15 +++++++-------- codes/main.go | 16 ++++++++-------- seen/skip | 1 - 8 files changed, 56 insertions(+), 35 deletions(-) delete mode 100644 seen/skip diff --git a/codes/handler/create.go b/codes/handler/create.go index 13a06f8..b1b87d6 100644 --- a/codes/handler/create.go +++ b/codes/handler/create.go @@ -5,12 +5,17 @@ import ( "math/rand" "strconv" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/codes/proto" ) func (c *Codes) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Identity) == 0 { return ErrMissingIdentity @@ -24,8 +29,13 @@ func (c *Codes) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat code.ExpiresAt = c.Time().Add(DefaultTTL) } + db, err := c.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // write to the database - if err := c.DB.Create(&code).Error; err != nil { + if err := db.Create(&code).Error; err != nil { logger.Errorf("Error creating code in database: %v", err) return errors.InternalServerError("DATABASE_ERORR", "Error connecting to database") } diff --git a/codes/handler/create_test.go b/codes/handler/create_test.go index 736c6d7..57379eb 100644 --- a/codes/handler/create_test.go +++ b/codes/handler/create_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/micro/micro/v3/service/auth" "github.com/micro/services/codes/handler" pb "github.com/micro/services/codes/proto" "github.com/stretchr/testify/assert" @@ -15,21 +16,21 @@ func TestCreate(t *testing.T) { t.Run("MissingIdentity", func(t *testing.T) { var rsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{}, &rsp) + err := h.Create(microAccountCtx(), &pb.CreateRequest{}, &rsp) assert.Equal(t, handler.ErrMissingIdentity, err) assert.Empty(t, rsp.Code) }) t.Run("NoExpiry", func(t *testing.T) { var rsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{Identity: "07503196715"}, &rsp) + err := h.Create(microAccountCtx(), &pb.CreateRequest{Identity: "07503196715"}, &rsp) assert.NoError(t, err) assert.NotEmpty(t, rsp.Code) }) t.Run("WithExpiry", func(t *testing.T) { var rsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Identity: "demo@m3o.com", ExpiresAt: timestamppb.Now(), }, &rsp) @@ -37,3 +38,9 @@ func TestCreate(t *testing.T) { assert.NotEmpty(t, rsp.Code) }) } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/codes/handler/handler.go b/codes/handler/handler.go index 2405778..9cb4153 100644 --- a/codes/handler/handler.go +++ b/codes/handler/handler.go @@ -4,7 +4,7 @@ import ( "time" "github.com/micro/micro/v3/service/errors" - "gorm.io/gorm" + "github.com/micro/services/pkg/gorm" ) var ( @@ -17,7 +17,7 @@ var ( ) type Codes struct { - DB *gorm.DB + gorm.Helper Time func() time.Time } diff --git a/codes/handler/handler_test.go b/codes/handler/handler_test.go index ab180a7..207e30b 100644 --- a/codes/handler/handler_test.go +++ b/codes/handler/handler_test.go @@ -1,13 +1,12 @@ package handler_test import ( + "database/sql" "os" "testing" "time" "github.com/micro/services/codes/handler" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Codes { @@ -17,20 +16,17 @@ func testHandler(t *testing.T) *handler.Codes { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) - } - - // migrate the database - if err := db.AutoMigrate(&handler.Code{}); err != nil { - t.Fatalf("Error migrating database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if err := db.Exec("TRUNCATE TABLE codes CASCADE").Error; err != nil { + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_codes CASCADE"); err != nil { t.Fatalf("Error cleaning database: %v", err) } - return &handler.Codes{DB: db, Time: time.Now} + h := &handler.Codes{Time: time.Now} + h.DBConn(sqlDB).Migrations(&handler.Code{}) + return h } diff --git a/codes/handler/verify.go b/codes/handler/verify.go index f0257b2..1abb537 100644 --- a/codes/handler/verify.go +++ b/codes/handler/verify.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/codes/proto" @@ -10,6 +11,10 @@ import ( ) func (c *Codes) Verify(ctx context.Context, req *pb.VerifyRequest, rsp *pb.VerifyResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Code) == 0 { return ErrMissingCode @@ -18,9 +23,14 @@ func (c *Codes) Verify(ctx context.Context, req *pb.VerifyRequest, rsp *pb.Verif return ErrMissingIdentity } + db, err := c.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // lookup the code var code Code - if err := c.DB.Where(&Code{Code: req.Code, Identity: req.Identity}).First(&code).Error; err == gorm.ErrRecordNotFound { + if err := db.Where(&Code{Code: req.Code, Identity: req.Identity}).First(&code).Error; err == gorm.ErrRecordNotFound { return ErrInvalidCode } else if err != nil { logger.Errorf("Error reading code from database: %v", err) diff --git a/codes/handler/verify_test.go b/codes/handler/verify_test.go index fd26463..2222d93 100644 --- a/codes/handler/verify_test.go +++ b/codes/handler/verify_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "time" @@ -15,30 +14,30 @@ func TestVerify(t *testing.T) { t.Run("MissingIdentity", func(t *testing.T) { var rsp pb.VerifyResponse - err := h.Verify(context.TODO(), &pb.VerifyRequest{Code: "123456"}, &rsp) + err := h.Verify(microAccountCtx(), &pb.VerifyRequest{Code: "123456"}, &rsp) assert.Equal(t, handler.ErrMissingIdentity, err) }) t.Run("MissingCode", func(t *testing.T) { var rsp pb.VerifyResponse - err := h.Verify(context.TODO(), &pb.VerifyRequest{Identity: "demo@m3o.com"}, &rsp) + err := h.Verify(microAccountCtx(), &pb.VerifyRequest{Identity: "demo@m3o.com"}, &rsp) assert.Equal(t, handler.ErrMissingCode, err) }) // generate a code to test var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{Identity: "demo@m3o.com"}, &cRsp) + err := h.Create(microAccountCtx(), &pb.CreateRequest{Identity: "demo@m3o.com"}, &cRsp) assert.NoError(t, err) t.Run("IncorrectCode", func(t *testing.T) { var rsp pb.VerifyResponse - err := h.Verify(context.TODO(), &pb.VerifyRequest{Identity: "demo@m3o.com", Code: "12345"}, &rsp) + err := h.Verify(microAccountCtx(), &pb.VerifyRequest{Identity: "demo@m3o.com", Code: "12345"}, &rsp) assert.Equal(t, handler.ErrInvalidCode, err) }) t.Run("IncorrectEmail", func(t *testing.T) { var rsp pb.VerifyResponse - err := h.Verify(context.TODO(), &pb.VerifyRequest{Identity: "john@m3o.com", Code: cRsp.Code}, &rsp) + err := h.Verify(microAccountCtx(), &pb.VerifyRequest{Identity: "john@m3o.com", Code: cRsp.Code}, &rsp) assert.Equal(t, handler.ErrInvalidCode, err) }) @@ -48,13 +47,13 @@ func TestVerify(t *testing.T) { defer func() { h.Time = ot }() var rsp pb.VerifyResponse - err := h.Verify(context.TODO(), &pb.VerifyRequest{Identity: "demo@m3o.com", Code: cRsp.Code}, &rsp) + err := h.Verify(microAccountCtx(), &pb.VerifyRequest{Identity: "demo@m3o.com", Code: cRsp.Code}, &rsp) assert.Equal(t, handler.ErrExpiredCode, err) }) t.Run("ValidCode", func(t *testing.T) { var rsp pb.VerifyResponse - err := h.Verify(context.TODO(), &pb.VerifyRequest{Identity: "demo@m3o.com", Code: cRsp.Code}, &rsp) + err := h.Verify(microAccountCtx(), &pb.VerifyRequest{Identity: "demo@m3o.com", Code: cRsp.Code}, &rsp) assert.NoError(t, err) }) } diff --git a/codes/main.go b/codes/main.go index f9c57b4..386fbe5 100644 --- a/codes/main.go +++ b/codes/main.go @@ -1,6 +1,7 @@ package main import ( + "database/sql" "time" "github.com/micro/services/codes/handler" @@ -9,8 +10,8 @@ import ( "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" - "gorm.io/driver/postgres" - "gorm.io/gorm" + + _ "github.com/jackc/pgx/v4/stdlib" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/codes?sslmode=disable" @@ -28,16 +29,15 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.Code{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } + h := &handler.Codes{Time: time.Now} + h.DBConn(sqlDB).Migrations(&handler.Code{}) // Register handler - pb.RegisterCodesHandler(srv.Server(), &handler.Codes{DB: db, Time: time.Now}) + pb.RegisterCodesHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil { diff --git a/seen/skip b/seen/skip deleted file mode 100644 index 8b13789..0000000 --- a/seen/skip +++ /dev/null @@ -1 +0,0 @@ -