From d68fdeb5163d81b2b79022b9d45a30da9d89fb65 Mon Sep 17 00:00:00 2001 From: Dominic Wong Date: Wed, 24 Mar 2021 08:40:27 +0000 Subject: [PATCH] Multitenant users api (#75) --- streams/proto/streams.pb.go | 37 ++--- streams/proto/streams.pb.micro.go | 2 +- streams/proto/streams.proto | 4 +- streams/skip | 0 test/integration/users/users_test.go | 235 --------------------------- users/handler/create.go | 27 ++- users/handler/create_test.go | 9 +- users/handler/delete.go | 12 +- users/handler/delete_test.go | 11 +- users/handler/handler.go | 39 ++++- users/handler/handler_test.go | 19 ++- users/handler/list.go | 12 +- users/handler/list_test.go | 7 +- users/handler/login.go | 13 +- users/handler/login_test.go | 5 +- users/handler/logout.go | 12 +- users/handler/logout_test.go | 11 +- users/handler/read.go | 12 +- users/handler/read_by_email.go | 12 +- users/handler/read_by_email_test.go | 11 +- users/handler/read_test.go | 11 +- users/handler/update.go | 22 ++- users/handler/update_test.go | 27 ++- users/handler/validate.go | 12 +- users/handler/validate_test.go | 7 +- users/main.go | 16 +- 26 files changed, 233 insertions(+), 352 deletions(-) delete mode 100644 streams/skip delete mode 100644 test/integration/users/users_test.go diff --git a/streams/proto/streams.pb.go b/streams/proto/streams.pb.go index 257289f..5aaf330 100644 --- a/streams/proto/streams.pb.go +++ b/streams/proto/streams.pb.go @@ -1,16 +1,15 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.23.0 -// protoc v3.13.0 +// protoc-gen-go v1.26.0 +// protoc v3.15.5 // source: proto/streams.proto package streams import ( - proto "github.com/golang/protobuf/proto" - timestamp "github.com/golang/protobuf/ptypes/timestamp" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" ) @@ -22,10 +21,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - type PublishResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -69,9 +64,9 @@ type Message struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Topic string `protobuf:"bytes,1,opt,name=topic,proto3" json:"topic,omitempty"` - Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` - SentAt *timestamp.Timestamp `protobuf:"bytes,3,opt,name=sent_at,json=sentAt,proto3" json:"sent_at,omitempty"` + Topic string `protobuf:"bytes,1,opt,name=topic,proto3" json:"topic,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + SentAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=sent_at,json=sentAt,proto3" json:"sent_at,omitempty"` } func (x *Message) Reset() { @@ -120,7 +115,7 @@ func (x *Message) GetMessage() string { return "" } -func (x *Message) GetSentAt() *timestamp.Timestamp { +func (x *Message) GetSentAt() *timestamppb.Timestamp { if x != nil { return x.SentAt } @@ -317,10 +312,8 @@ var file_proto_streams_proto_rawDesc = []byte{ 0x30, 0x01, 0x12, 0x38, 0x0a, 0x05, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x15, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x54, 0x6f, 0x6b, - 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x31, 0x5a, 0x2f, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x63, 0x72, 0x6f, - 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, - 0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3b, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x62, + 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x11, 0x5a, 0x0f, + 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3b, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } @@ -338,12 +331,12 @@ func file_proto_streams_proto_rawDescGZIP() []byte { var file_proto_streams_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_proto_streams_proto_goTypes = []interface{}{ - (*PublishResponse)(nil), // 0: streams.PublishResponse - (*Message)(nil), // 1: streams.Message - (*SubscribeRequest)(nil), // 2: streams.SubscribeRequest - (*TokenRequest)(nil), // 3: streams.TokenRequest - (*TokenResponse)(nil), // 4: streams.TokenResponse - (*timestamp.Timestamp)(nil), // 5: google.protobuf.Timestamp + (*PublishResponse)(nil), // 0: streams.PublishResponse + (*Message)(nil), // 1: streams.Message + (*SubscribeRequest)(nil), // 2: streams.SubscribeRequest + (*TokenRequest)(nil), // 3: streams.TokenRequest + (*TokenResponse)(nil), // 4: streams.TokenResponse + (*timestamppb.Timestamp)(nil), // 5: google.protobuf.Timestamp } var file_proto_streams_proto_depIdxs = []int32{ 5, // 0: streams.Message.sent_at:type_name -> google.protobuf.Timestamp diff --git a/streams/proto/streams.pb.micro.go b/streams/proto/streams.pb.micro.go index a13fb71..33110a9 100644 --- a/streams/proto/streams.pb.micro.go +++ b/streams/proto/streams.pb.micro.go @@ -6,7 +6,7 @@ package streams import ( fmt "fmt" proto "github.com/golang/protobuf/proto" - _ "github.com/golang/protobuf/ptypes/timestamp" + _ "google.golang.org/protobuf/types/known/timestamppb" math "math" ) diff --git a/streams/proto/streams.proto b/streams/proto/streams.proto index 0cb2b9f..e6c385a 100644 --- a/streams/proto/streams.proto +++ b/streams/proto/streams.proto @@ -1,7 +1,7 @@ syntax = "proto3"; package streams; -option go_package = "github.com/micro/services/streams/proto;streams"; +option go_package = "./proto;streams"; import "google/protobuf/timestamp.proto"; service Streams { @@ -35,4 +35,4 @@ message TokenRequest { message TokenResponse { string token = 1; -} \ No newline at end of file +} diff --git a/streams/skip b/streams/skip deleted file mode 100644 index e69de29..0000000 diff --git a/test/integration/users/users_test.go b/test/integration/users/users_test.go deleted file mode 100644 index 7027150..0000000 --- a/test/integration/users/users_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// +build integration - -package signup - -import ( - "encoding/json" - "errors" - "math/rand" - - "os" - "os/exec" - "strings" - "testing" - "time" - - "github.com/micro/micro/v3/test" -) - -const ( - retryCount = 1 -) - -var letterRunes = []rune("abcdefghijklmnopqrstuvwxyz") - -func randStringRunes(n int) string { - b := make([]rune, n) - for i := range b { - b[i] = letterRunes[rand.Intn(len(letterRunes))] - } - return string(b) -} - -func setupUsersTests(serv test.Server, t *test.T) { - envToConfigKey := map[string][]string{} - - if err := test.Try("Set up config values", t, func() ([]byte, error) { - for envKey, configKeys := range envToConfigKey { - val := os.Getenv(envKey) - if len(val) == 0 { - t.Fatalf("'%v' flag is missing", envKey) - } - for _, configKey := range configKeys { - outp, err := serv.Command().Exec("config", "set", configKey, val) - if err != nil { - return outp, err - } - } - } - return serv.Command().Exec("config", "set", "micro.billing.max_included_services", "3") - }, 10*time.Second); err != nil { - t.Fatal(err) - return - } - - services := []struct { - envVar string - deflt string - }{ - {envVar: "POSTS_SVC", deflt: "../../../users"}, - } - - for _, v := range services { - outp, err := serv.Command().Exec("run", v.deflt) - if err != nil { - t.Fatal(string(outp)) - return - } - } - - if err := test.Try("Find posts and tags", t, func() ([]byte, error) { - outp, err := serv.Command().Exec("services") - if err != nil { - return outp, err - } - list := []string{"users"} - logOutp := []byte{} - fail := false - for _, s := range list { - if !strings.Contains(string(outp), s) { - o, _ := serv.Command().Exec("logs", s) - logOutp = append(logOutp, o...) - fail = true - } - } - if fail { - return append(outp, logOutp...), errors.New("Can't find required services in list") - } - return outp, err - }, 180*time.Second); err != nil { - return - } - - // setup rules - - // Adjust rules before we signup into a non admin account - outp, err := serv.Command().Exec("auth", "create", "rule", "--access=granted", "--scope=''", "--resource=\"service:users:*\"", "users") - if err != nil { - t.Fatalf("Error setting up rules: %v", outp) - return - } - - // copy the config with the admin logged in so we can use it for reading logs - // we dont want to have an open access rule for logs as it's not how it works in live - confPath := serv.Command().Config - outp, err = exec.Command("cp", "-rf", confPath, confPath+".admin").CombinedOutput() - if err != nil { - t.Fatalf("Error copying config: %v", outp) - return - } -} - -func TestUsersService(t *testing.T) { - test.TrySuite(t, testUsers, retryCount) -} - -func testUsers(t *test.T) { - t.Parallel() - - serv := test.NewServer(t, test.WithLogin()) - defer serv.Close() - if err := serv.Run(); err != nil { - return - } - - setupUsersTests(serv, t) - - cmd := serv.Command() - - email := "test@gmail.com" - password := "testPassw" - username := "john" - id := "7" - - if err := test.Try("Save user", t, func() ([]byte, error) { - // Attention! The content must be unquoted, don't add quotes. - outp, err := cmd.Exec("users", "create", "--id="+id, "--email="+email, "--password="+password, "--username=john") - if err != nil { - outp1, _ := cmd.Exec("logs", "users") - return append(outp, outp1...), err - } - return outp, err - }, 15*time.Second); err != nil { - return - } - - outp, err := cmd.Exec("users", "read", "--id="+id) - if err != nil { - t.Fatal(string(outp), err) - return - } - if !strings.Contains(string(outp), email) || - !strings.Contains(string(outp), username) || - !strings.Contains(string(outp), id) { - t.Fatal(string(outp)) - return - } - - // no password - outp, err = cmd.Exec("users", "login", "--email="+email) - if err == nil { - t.Fatal(string(outp)) - return - } - - // wrong password - outp, err = cmd.Exec("users", "login", "--email="+email, "--password=somethingincorrect") - if err == nil { - t.Fatal(string(outp)) - return - } - - outp, err = cmd.Exec("users", "login", "--username="+username, "--password="+password) - if err != nil { - t.Fatal(string(outp), err) - return - } - loginRsp := map[string]interface{}{} - err = json.Unmarshal(outp, &loginRsp) - if err != nil { - t.Fatal(err) - return - } - session, ok := loginRsp["session"].(map[string]interface{}) - if !ok { - t.Fatal(string(outp)) - return - } - sessionID := session["id"].(string) - sessionUsername := session["username"].(string) - if sessionUsername != username { - t.Fatal(string(outp)) - return - } - - if len(sessionID) == 0 { - t.Fatal(string(outp)) - return - } - - outp, err = cmd.Exec("users", "login", "--email="+email, "--password="+password) - if err != nil { - t.Fatal(string(outp), err) - return - } - - outp, err = cmd.Exec("users", "login", "--email="+email, "--password="+password) - if err != nil { - t.Fatal(string(outp), err) - return - } - - outp, err = cmd.Exec("users", "readSession", "--sessionId="+sessionID) - if err != nil { - t.Fatal(string(outp), err) - return - } - - loginRsp = map[string]interface{}{} - err = json.Unmarshal(outp, &loginRsp) - if err != nil { - t.Fatal(err) - return - } - session, ok = loginRsp["session"].(map[string]interface{}) - if !ok { - t.Fatal(string(outp)) - return - } - sessionID = session["id"].(string) - sessionUsername = session["username"].(string) - if sessionUsername != username { - t.Fatal(string(outp)) - return - } -} diff --git a/users/handler/create.go b/users/handler/create.go index f89a849..7f79150 100644 --- a/users/handler/create.go +++ b/users/handler/create.go @@ -2,10 +2,12 @@ package handler import ( "context" + "regexp" "strings" "time" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -14,6 +16,11 @@ import ( // Create a user func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } + // validate the request if len(req.FirstName) == 0 { return ErrMissingFirstName @@ -34,11 +41,15 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat // hash and salt the password using bcrypt phash, err := hashAndSalt(req.Password) if err != nil { - logger.Errorf("Error hasing and salting password: %v", err) + logger.Errorf("Error hashing and salting password: %v", err) return errors.InternalServerError("HASHING_ERROR", "Error hashing password") } - - return u.DB.Transaction(func(tx *gorm.DB) error { + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + return db.Transaction(func(tx *gorm.DB) error { // write the user to the database user := &User{ ID: uuid.New().String(), @@ -47,10 +58,12 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat Email: strings.ToLower(req.Email), Password: phash, } - err = u.DB.Create(user).Error - if err != nil && strings.Contains(err.Error(), "idx_users_email") { - return ErrDuplicateEmail - } else if err != nil { + err = tx.Create(user).Error + + if err != nil { + if match, _ := regexp.MatchString(`idx_[\S]+_users_email`, err.Error()); match { + return ErrDuplicateEmail + } logger.Errorf("Error writing to the database: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } diff --git a/users/handler/create_test.go b/users/handler/create_test.go index ef780e3..a098aec 100644 --- a/users/handler/create_test.go +++ b/users/handler/create_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/users/handler" @@ -61,7 +60,7 @@ func TestCreate(t *testing.T) { h := testHandler(t) for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ FirstName: tc.FirstName, LastName: tc.LastName, Email: tc.Email, @@ -79,7 +78,7 @@ func TestCreate(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &req, &rsp) + err := h.Create(microAccountCtx(), &req, &rsp) assert.NoError(t, err) u := rsp.User @@ -101,7 +100,7 @@ func TestCreate(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &req, &rsp) + err := h.Create(microAccountCtx(), &req, &rsp) assert.Equal(t, handler.ErrDuplicateEmail, err) assert.Nil(t, rsp.User) }) @@ -114,7 +113,7 @@ func TestCreate(t *testing.T) { Email: "johndoe@gmail.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &req, &rsp) + err := h.Create(microAccountCtx(), &req, &rsp) assert.NoError(t, err) u := rsp.User diff --git a/users/handler/delete.go b/users/handler/delete.go index 0b0d7de..a258b00 100644 --- a/users/handler/delete.go +++ b/users/handler/delete.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -11,13 +12,22 @@ import ( // Delete a user func (u *Users) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID } + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // delete the users tokens - return u.DB.Transaction(func(tx *gorm.DB) error { + return db.Transaction(func(tx *gorm.DB) error { if err := tx.Delete(&Token{}, &Token{UserID: req.Id}).Error; err != nil { logger.Errorf("Error writing to the database: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") diff --git a/users/handler/delete_test.go b/users/handler/delete_test.go index 36710b0..a2215f5 100644 --- a/users/handler/delete_test.go +++ b/users/handler/delete_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/users/handler" @@ -13,7 +12,7 @@ func TestDelete(t *testing.T) { h := testHandler(t) t.Run("MissingID", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) @@ -25,7 +24,7 @@ func TestDelete(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &cReq, &cRsp) + err := h.Create(microAccountCtx(), &cReq, &cRsp) assert.NoError(t, err) if cRsp.User == nil { t.Fatal("No user returned") @@ -33,14 +32,14 @@ func TestDelete(t *testing.T) { } t.Run("Valid", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{ + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{ Id: cRsp.User.Id, }, &pb.DeleteResponse{}) assert.NoError(t, err) // check it was actually deleted var rsp pb.ReadResponse - err = h.Read(context.TODO(), &pb.ReadRequest{ + err = h.Read(microAccountCtx(), &pb.ReadRequest{ Ids: []string{cRsp.User.Id}, }, &rsp) assert.NoError(t, err) @@ -48,7 +47,7 @@ func TestDelete(t *testing.T) { }) t.Run("Retry", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{ + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{ Id: cRsp.User.Id, }, &pb.DeleteResponse{}) assert.NoError(t, err) diff --git a/users/handler/handler.go b/users/handler/handler.go index d78a9c8..644e123 100644 --- a/users/handler/handler.go +++ b/users/handler/handler.go @@ -1,13 +1,18 @@ package handler import ( + "context" + "fmt" "regexp" + "strings" "time" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" pb "github.com/micro/services/users/proto" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" + "gorm.io/gorm/schema" ) var ( @@ -58,8 +63,38 @@ type Token struct { } type Users struct { - DB *gorm.DB - Time func() time.Time + Time func() time.Time + Dialector gorm.Dialector + dbMigrations map[string]bool +} + +func NewHandler(t func() time.Time, d gorm.Dialector) *Users { + return &Users{Time: t, Dialector: d, dbMigrations: map[string]bool{}} +} + +func (u *Users) getDBConn(ctx context.Context) (*gorm.DB, error) { + acc, ok := auth.AccountFromContext(ctx) + if !ok { + return nil, fmt.Errorf("missing account from context") + } + db, err := gorm.Open(u.Dialector, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")), + }, + }) + if err != nil { + return nil, err + } + // skip migration if we've already done it + if u.dbMigrations[acc.Issuer] { + return db, nil + } + if err := db.AutoMigrate(&User{}, &Token{}); err != nil { + return nil, err + } + // record success + u.dbMigrations[acc.Issuer] = true + return db, nil } // isEmailValid checks if the email provided passes the required structure and length. diff --git a/users/handler/handler_test.go b/users/handler/handler_test.go index 295ba3c..38a183f 100644 --- a/users/handler/handler_test.go +++ b/users/handler/handler_test.go @@ -1,11 +1,14 @@ package handler_test import ( + "context" "os" "testing" "time" + "github.com/micro/micro/v3/service/auth" "github.com/stretchr/testify/assert" + "gorm.io/gorm/schema" "github.com/micro/services/users/handler" pb "github.com/micro/services/users/proto" @@ -19,13 +22,16 @@ func testHandler(t *testing.T) *handler.Users { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + dial := postgres.Open(addr) + db, err := gorm.Open(dial, &gorm.Config{ + NamingStrategy: schema.NamingStrategy{TablePrefix: "micro_"}, + }) if err != nil { t.Fatalf("Error connecting to database: %v", err) } // clean any data from a previous run - if err := db.Exec("DROP TABLE IF EXISTS users, tokens CASCADE").Error; err != nil { + if err := db.Exec("DROP TABLE IF EXISTS micro_users, micro_tokens CASCADE").Error; err != nil { t.Fatalf("Error cleaning database: %v", err) } @@ -33,8 +39,7 @@ func testHandler(t *testing.T) *handler.Users { if err := db.AutoMigrate(&handler.User{}, &handler.Token{}); err != nil { t.Fatalf("Error migrating database: %v", err) } - - return &handler.Users{DB: db, Time: time.Now} + return handler.NewHandler(time.Now, dial) } func assertUsersMatch(t *testing.T, exp, act *pb.User) { @@ -47,3 +52,9 @@ func assertUsersMatch(t *testing.T, exp, act *pb.User) { assert.Equal(t, exp.LastName, act.LastName) assert.Equal(t, exp.Email, act.Email) } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/users/handler/list.go b/users/handler/list.go index d4946c3..bc2e496 100644 --- a/users/handler/list.go +++ b/users/handler/list.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -10,9 +11,18 @@ import ( // List all users func (u *Users) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // query the database + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } var users []User - if err := u.DB.Model(&User{}).Find(&users).Error; err != nil { + if err := db.Model(&User{}).Find(&users).Error; err != nil { logger.Errorf("Error reading from the database: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } diff --git a/users/handler/list_test.go b/users/handler/list_test.go index f184c9c..f9aaa82 100644 --- a/users/handler/list_test.go +++ b/users/handler/list_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" pb "github.com/micro/services/users/proto" @@ -19,7 +18,7 @@ func TestList(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &cReq1, &cRsp1) + err := h.Create(microAccountCtx(), &cReq1, &cRsp1) assert.NoError(t, err) if cRsp1.User == nil { t.Fatal("No user returned") @@ -33,7 +32,7 @@ func TestList(t *testing.T) { Email: "johndoe@gmail.com", Password: "passwordabc", } - err = h.Create(context.TODO(), &cReq2, &cRsp2) + err = h.Create(microAccountCtx(), &cReq2, &cRsp2) assert.NoError(t, err) if cRsp2.User == nil { t.Fatal("No user returned") @@ -41,7 +40,7 @@ func TestList(t *testing.T) { } var rsp pb.ListResponse - err = h.List(context.TODO(), &pb.ListRequest{}, &rsp) + err = h.List(microAccountCtx(), &pb.ListRequest{}, &rsp) assert.NoError(t, err) if rsp.Users == nil { t.Error("No users returned") diff --git a/users/handler/login.go b/users/handler/login.go index 14b07fc..13d116c 100644 --- a/users/handler/login.go +++ b/users/handler/login.go @@ -4,6 +4,7 @@ import ( "context" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -12,6 +13,10 @@ import ( // Login using email and password returns the users profile and a token func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Email) == 0 { return ErrMissingEmail @@ -20,7 +25,13 @@ func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginRe return ErrInvalidPassword } - return u.DB.Transaction(func(tx *gorm.DB) error { + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + + return db.Transaction(func(tx *gorm.DB) error { // lookup the user var user User if err := tx.Where(&User{Email: req.Email}).First(&user).Error; err == gorm.ErrRecordNotFound { diff --git a/users/handler/login_test.go b/users/handler/login_test.go index ca6ac1a..4fd4afe 100644 --- a/users/handler/login_test.go +++ b/users/handler/login_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/users/handler" @@ -20,7 +19,7 @@ func TestLogin(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &cReq, &cRsp) + err := h.Create(microAccountCtx(), &cReq, &cRsp) assert.NoError(t, err) if cRsp.User == nil { t.Fatal("No user returned") @@ -67,7 +66,7 @@ func TestLogin(t *testing.T) { for _, tc := range tt { t.Run(tc.Name, func(t *testing.T) { var rsp pb.LoginResponse - err := h.Login(context.TODO(), &pb.LoginRequest{ + err := h.Login(microAccountCtx(), &pb.LoginRequest{ Email: tc.Email, Password: tc.Password, }, &rsp) assert.Equal(t, tc.Error, err) diff --git a/users/handler/logout.go b/users/handler/logout.go index d3bafee..3d76a0d 100644 --- a/users/handler/logout.go +++ b/users/handler/logout.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -11,12 +12,21 @@ import ( // Logout expires all tokens for the user func (u *Users) Logout(ctx context.Context, req *pb.LogoutRequest, rsp *pb.LogoutResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID } - return u.DB.Transaction(func(tx *gorm.DB) error { + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + return db.Transaction(func(tx *gorm.DB) error { // lookup the user var user User if err := tx.Where(&User{ID: req.Id}).Preload("Tokens").First(&user).Error; err == gorm.ErrRecordNotFound { diff --git a/users/handler/logout_test.go b/users/handler/logout_test.go index fca369b..e15f6ae 100644 --- a/users/handler/logout_test.go +++ b/users/handler/logout_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/google/uuid" @@ -14,12 +13,12 @@ func TestLogout(t *testing.T) { h := testHandler(t) t.Run("MissingUserID", func(t *testing.T) { - err := h.Logout(context.TODO(), &pb.LogoutRequest{}, &pb.LogoutResponse{}) + err := h.Logout(microAccountCtx(), &pb.LogoutRequest{}, &pb.LogoutResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) t.Run("UserNotFound", func(t *testing.T) { - err := h.Logout(context.TODO(), &pb.LogoutRequest{Id: uuid.New().String()}, &pb.LogoutResponse{}) + err := h.Logout(microAccountCtx(), &pb.LogoutRequest{Id: uuid.New().String()}, &pb.LogoutResponse{}) assert.Equal(t, handler.ErrNotFound, err) }) @@ -32,17 +31,17 @@ func TestLogout(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &cReq, &cRsp) + err := h.Create(microAccountCtx(), &cReq, &cRsp) assert.NoError(t, err) if cRsp.User == nil { t.Fatal("No user returned") return } - err = h.Logout(context.TODO(), &pb.LogoutRequest{Id: cRsp.User.Id}, &pb.LogoutResponse{}) + err = h.Logout(microAccountCtx(), &pb.LogoutRequest{Id: cRsp.User.Id}, &pb.LogoutResponse{}) assert.NoError(t, err) - err = h.Validate(context.TODO(), &pb.ValidateRequest{Token: cRsp.Token}, &pb.ValidateResponse{}) + err = h.Validate(microAccountCtx(), &pb.ValidateRequest{Token: cRsp.Token}, &pb.ValidateResponse{}) assert.Error(t, err) }) } diff --git a/users/handler/read.go b/users/handler/read.go index 9f941aa..c86f0b9 100644 --- a/users/handler/read.go +++ b/users/handler/read.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -10,14 +11,23 @@ import ( // Read users using ID func (u *Users) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Ids) == 0 { return ErrMissingIDs } // query the database + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } var users []User - if err := u.DB.Model(&User{}).Where("id IN (?)", req.Ids).Find(&users).Error; err != nil { + if err := db.Model(&User{}).Where("id IN (?)", req.Ids).Find(&users).Error; err != nil { logger.Errorf("Error reading from the database: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } diff --git a/users/handler/read_by_email.go b/users/handler/read_by_email.go index 773730d..dcdac59 100644 --- a/users/handler/read_by_email.go +++ b/users/handler/read_by_email.go @@ -4,6 +4,7 @@ import ( "context" "strings" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -11,6 +12,10 @@ import ( // Read users using email func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp *pb.ReadByEmailResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Emails) == 0 { return ErrMissingEmails @@ -21,8 +26,13 @@ func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp } // query the database + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } var users []User - if err := u.DB.Model(&User{}).Where("lower(email) IN (?)", emails).Find(&users).Error; err != nil { + if err := db.Model(&User{}).Where("lower(email) IN (?)", emails).Find(&users).Error; err != nil { logger.Errorf("Error reading from the database: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } diff --git a/users/handler/read_by_email_test.go b/users/handler/read_by_email_test.go index 1599262..d468342 100644 --- a/users/handler/read_by_email_test.go +++ b/users/handler/read_by_email_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "strings" "testing" @@ -15,14 +14,14 @@ func TestReadByEmail(t *testing.T) { t.Run("MissingEmails", func(t *testing.T) { var rsp pb.ReadByEmailResponse - err := h.ReadByEmail(context.TODO(), &pb.ReadByEmailRequest{}, &rsp) + err := h.ReadByEmail(microAccountCtx(), &pb.ReadByEmailRequest{}, &rsp) assert.Equal(t, handler.ErrMissingEmails, err) assert.Nil(t, rsp.Users) }) t.Run("NotFound", func(t *testing.T) { var rsp pb.ReadByEmailResponse - err := h.ReadByEmail(context.TODO(), &pb.ReadByEmailRequest{Emails: []string{"foo"}}, &rsp) + err := h.ReadByEmail(microAccountCtx(), &pb.ReadByEmailRequest{Emails: []string{"foo"}}, &rsp) assert.Nil(t, err) if rsp.Users == nil { t.Fatal("Expected the users object to not be nil") @@ -38,7 +37,7 @@ func TestReadByEmail(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &req1, &rsp1) + err := h.Create(microAccountCtx(), &req1, &rsp1) assert.NoError(t, err) if rsp1.User == nil { t.Fatal("No user returned") @@ -52,7 +51,7 @@ func TestReadByEmail(t *testing.T) { Email: "apple@tree.com", Password: "passwordabc", } - err = h.Create(context.TODO(), &req2, &rsp2) + err = h.Create(microAccountCtx(), &req2, &rsp2) assert.NoError(t, err) if rsp2.User == nil { t.Fatal("No user returned") @@ -61,7 +60,7 @@ func TestReadByEmail(t *testing.T) { // test the read var rsp pb.ReadByEmailResponse - err = h.ReadByEmail(context.TODO(), &pb.ReadByEmailRequest{ + err = h.ReadByEmail(microAccountCtx(), &pb.ReadByEmailRequest{ Emails: []string{rsp1.User.Email, strings.ToUpper(rsp2.User.Email)}, }, &rsp) assert.NoError(t, err) diff --git a/users/handler/read_test.go b/users/handler/read_test.go index a10becb..167ac75 100644 --- a/users/handler/read_test.go +++ b/users/handler/read_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/users/handler" @@ -14,14 +13,14 @@ func TestRead(t *testing.T) { t.Run("MissingIDs", func(t *testing.T) { var rsp pb.ReadResponse - err := h.Read(context.TODO(), &pb.ReadRequest{}, &rsp) + err := h.Read(microAccountCtx(), &pb.ReadRequest{}, &rsp) assert.Equal(t, handler.ErrMissingIDs, err) assert.Nil(t, rsp.Users) }) t.Run("NotFound", func(t *testing.T) { var rsp pb.ReadResponse - err := h.Read(context.TODO(), &pb.ReadRequest{Ids: []string{"foo"}}, &rsp) + err := h.Read(microAccountCtx(), &pb.ReadRequest{Ids: []string{"foo"}}, &rsp) assert.Nil(t, err) if rsp.Users == nil { t.Fatal("Expected the users object to not be nil") @@ -37,7 +36,7 @@ func TestRead(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &req1, &rsp1) + err := h.Create(microAccountCtx(), &req1, &rsp1) assert.NoError(t, err) if rsp1.User == nil { t.Fatal("No user returned") @@ -51,7 +50,7 @@ func TestRead(t *testing.T) { Email: "apple@tree.com", Password: "passwordabc", } - err = h.Create(context.TODO(), &req2, &rsp2) + err = h.Create(microAccountCtx(), &req2, &rsp2) assert.NoError(t, err) if rsp2.User == nil { t.Fatal("No user returned") @@ -60,7 +59,7 @@ func TestRead(t *testing.T) { // test the read var rsp pb.ReadResponse - err = h.Read(context.TODO(), &pb.ReadRequest{ + err = h.Read(microAccountCtx(), &pb.ReadRequest{ Ids: []string{rsp1.User.Id, rsp2.User.Id}, }, &rsp) assert.NoError(t, err) diff --git a/users/handler/update.go b/users/handler/update.go index 71669a1..aa37641 100644 --- a/users/handler/update.go +++ b/users/handler/update.go @@ -2,8 +2,10 @@ package handler import ( "context" + "regexp" "strings" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -12,6 +14,10 @@ import ( // Update a user func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID @@ -34,7 +40,12 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat // lookup the user var user User - if err := u.DB.Where(&User{ID: req.Id}).First(&user).Error; err == gorm.ErrRecordNotFound { + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + if err := db.Where(&User{ID: req.Id}).First(&user).Error; err == gorm.ErrRecordNotFound { return ErrNotFound } else if err != nil { logger.Errorf("Error reading from the database: %v", err) @@ -61,10 +72,11 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat } // write the user to the database - err := u.DB.Save(user).Error - if err != nil && strings.Contains(err.Error(), "idx_users_email") { - return ErrDuplicateEmail - } else if err != nil { + err = db.Save(user).Error + if err != nil { + if match, _ := regexp.MatchString(`idx_[\S]+_users_email`, err.Error()); match { + return ErrDuplicateEmail + } logger.Errorf("Error writing to the database: %v", err) return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") } diff --git a/users/handler/update_test.go b/users/handler/update_test.go index dd91775..c0ce6e5 100644 --- a/users/handler/update_test.go +++ b/users/handler/update_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "github.com/micro/services/users/handler" @@ -15,14 +14,14 @@ func TestUpdate(t *testing.T) { t.Run("MissingID", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{}, &rsp) + err := h.Update(microAccountCtx(), &pb.UpdateRequest{}, &rsp) assert.Equal(t, handler.ErrMissingID, err) assert.Nil(t, rsp.User) }) t.Run("NotFound", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{Id: "foo"}, &rsp) + err := h.Update(microAccountCtx(), &pb.UpdateRequest{Id: "foo"}, &rsp) assert.Equal(t, handler.ErrNotFound, err) assert.Nil(t, rsp.User) }) @@ -35,7 +34,7 @@ func TestUpdate(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &cReq1, &cRsp1) + err := h.Create(microAccountCtx(), &cReq1, &cRsp1) assert.NoError(t, err) if cRsp1.User == nil { t.Fatal("No user returned") @@ -49,7 +48,7 @@ func TestUpdate(t *testing.T) { Email: "johndoe@gmail.com", Password: "passwordabc", } - err = h.Create(context.TODO(), &cReq2, &cRsp2) + err = h.Create(microAccountCtx(), &cReq2, &cRsp2) assert.NoError(t, err) if cRsp2.User == nil { t.Fatal("No user returned") @@ -58,7 +57,7 @@ func TestUpdate(t *testing.T) { t.Run("BlankFirstName", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp1.User.Id, FirstName: &wrapperspb.StringValue{}, }, &rsp) assert.Equal(t, handler.ErrMissingFirstName, err) @@ -67,7 +66,7 @@ func TestUpdate(t *testing.T) { t.Run("BlankLastName", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp1.User.Id, LastName: &wrapperspb.StringValue{}, }, &rsp) assert.Equal(t, handler.ErrMissingLastName, err) @@ -76,7 +75,7 @@ func TestUpdate(t *testing.T) { t.Run("BlankLastName", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp1.User.Id, LastName: &wrapperspb.StringValue{}, }, &rsp) assert.Equal(t, handler.ErrMissingLastName, err) @@ -85,7 +84,7 @@ func TestUpdate(t *testing.T) { t.Run("BlankEmail", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{}, }, &rsp) assert.Equal(t, handler.ErrMissingEmail, err) @@ -94,7 +93,7 @@ func TestUpdate(t *testing.T) { t.Run("InvalidEmail", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{Value: "foo.bar"}, }, &rsp) assert.Equal(t, handler.ErrInvalidEmail, err) @@ -103,7 +102,7 @@ func TestUpdate(t *testing.T) { t.Run("EmailAlreadyExists", func(t *testing.T) { var rsp pb.UpdateResponse - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{Value: cRsp2.User.Email}, }, &rsp) assert.Equal(t, handler.ErrDuplicateEmail, err) @@ -118,7 +117,7 @@ func TestUpdate(t *testing.T) { LastName: &wrapperspb.StringValue{Value: "Bar"}, } var uRsp pb.UpdateResponse - err := h.Update(context.TODO(), &uReq, &uRsp) + err := h.Update(microAccountCtx(), &uReq, &uRsp) assert.NoError(t, err) if uRsp.User == nil { t.Error("No user returned") @@ -135,14 +134,14 @@ func TestUpdate(t *testing.T) { Id: cRsp2.User.Id, Password: &wrapperspb.StringValue{Value: "helloworld"}, } - err := h.Update(context.TODO(), &uReq, &pb.UpdateResponse{}) + err := h.Update(microAccountCtx(), &uReq, &pb.UpdateResponse{}) assert.NoError(t, err) lReq := pb.LoginRequest{ Email: cRsp2.User.Email, Password: "helloworld", } - err = h.Login(context.TODO(), &lReq, &pb.LoginResponse{}) + err = h.Login(microAccountCtx(), &lReq, &pb.LoginResponse{}) assert.NoError(t, err) }) } diff --git a/users/handler/validate.go b/users/handler/validate.go index 6efe75a..18beb31 100644 --- a/users/handler/validate.go +++ b/users/handler/validate.go @@ -3,6 +3,7 @@ package handler import ( "context" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/users/proto" @@ -11,12 +12,21 @@ import ( // Validate a token, each time a token is validated it extends its lifetime for another week func (u *Users) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.ValidateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Token) == 0 { return ErrMissingToken } - return u.DB.Transaction(func(tx *gorm.DB) error { + db, err := u.getDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + return db.Transaction(func(tx *gorm.DB) error { // lookup the token var token Token if err := tx.Where(&Token{Key: req.Token}).Preload("User").First(&token).Error; err == gorm.ErrRecordNotFound { diff --git a/users/handler/validate_test.go b/users/handler/validate_test.go index 1a05c71..bbf82b8 100644 --- a/users/handler/validate_test.go +++ b/users/handler/validate_test.go @@ -1,7 +1,6 @@ package handler_test import ( - "context" "testing" "time" @@ -22,7 +21,7 @@ func TestValidate(t *testing.T) { Email: "john@doe.com", Password: "passwordabc", } - err := h.Create(context.TODO(), &cReq1, &cRsp1) + err := h.Create(microAccountCtx(), &cReq1, &cRsp1) assert.NoError(t, err) if cRsp1.User == nil { t.Fatal("No user returned") @@ -36,7 +35,7 @@ func TestValidate(t *testing.T) { Email: "barry@doe.com", Password: "passwordabc", } - err = h.Create(context.TODO(), &cReq2, &cRsp2) + err = h.Create(microAccountCtx(), &cReq2, &cRsp2) assert.NoError(t, err) if cRsp2.User == nil { t.Fatal("No user returned") @@ -88,7 +87,7 @@ func TestValidate(t *testing.T) { } var rsp pb.ValidateResponse - err := h.Validate(context.TODO(), &pb.ValidateRequest{Token: tc.Token}, &rsp) + err := h.Validate(microAccountCtx(), &pb.ValidateRequest{Token: tc.Token}, &rsp) assert.Equal(t, tc.Error, err) if tc.User != nil { diff --git a/users/main.go b/users/main.go index 9a78dd0..24e8770 100644 --- a/users/main.go +++ b/users/main.go @@ -3,14 +3,12 @@ package main import ( "time" - "github.com/micro/services/users/handler" - pb "github.com/micro/services/users/proto" - "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" + "github.com/micro/services/users/handler" + pb "github.com/micro/services/users/proto" "gorm.io/driver/postgres" - "gorm.io/gorm" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/users?sslmode=disable" @@ -28,16 +26,8 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) - if err != nil { - logger.Fatalf("Error connecting to database: %v", err) - } - if err := db.AutoMigrate(&handler.User{}, &handler.Token{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) - } - // Register handler - pb.RegisterUsersHandler(srv.Server(), &handler.Users{DB: db, Time: time.Now}) + pb.RegisterUsersHandler(srv.Server(), handler.NewHandler(time.Now, postgres.Open(addr))) // Run service if err := srv.Run(); err != nil {