Multitenant users api (#75)

This commit is contained in:
Dominic Wong
2021-03-24 08:40:27 +00:00
committed by GitHub
parent a711e10961
commit d68fdeb516
26 changed files with 233 additions and 352 deletions

View File

@@ -1,16 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// versions: // versions:
// protoc-gen-go v1.23.0 // protoc-gen-go v1.26.0
// protoc v3.13.0 // protoc v3.15.5
// source: proto/streams.proto // source: proto/streams.proto
package streams package streams
import ( import (
proto "github.com/golang/protobuf/proto"
timestamp "github.com/golang/protobuf/ptypes/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl" protoimpl "google.golang.org/protobuf/runtime/protoimpl"
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
reflect "reflect" reflect "reflect"
sync "sync" sync "sync"
) )
@@ -22,10 +21,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type PublishResponse struct { type PublishResponse struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@@ -69,9 +64,9 @@ type Message struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
Topic string `protobuf:"bytes,1,opt,name=topic,proto3" json:"topic,omitempty"` Topic string `protobuf:"bytes,1,opt,name=topic,proto3" json:"topic,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
SentAt *timestamp.Timestamp `protobuf:"bytes,3,opt,name=sent_at,json=sentAt,proto3" json:"sent_at,omitempty"` SentAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=sent_at,json=sentAt,proto3" json:"sent_at,omitempty"`
} }
func (x *Message) Reset() { func (x *Message) Reset() {
@@ -120,7 +115,7 @@ func (x *Message) GetMessage() string {
return "" return ""
} }
func (x *Message) GetSentAt() *timestamp.Timestamp { func (x *Message) GetSentAt() *timestamppb.Timestamp {
if x != nil { if x != nil {
return x.SentAt return x.SentAt
} }
@@ -317,10 +312,8 @@ var file_proto_streams_proto_rawDesc = []byte{
0x30, 0x01, 0x12, 0x38, 0x0a, 0x05, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x15, 0x2e, 0x73, 0x74, 0x30, 0x01, 0x12, 0x38, 0x0a, 0x05, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x15, 0x2e, 0x73, 0x74,
0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x1a, 0x16, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x54, 0x6f, 0x6b, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x2e, 0x54, 0x6f, 0x6b,
0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x31, 0x5a, 0x2f, 0x65, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x11, 0x5a, 0x0f,
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6d, 0x69, 0x63, 0x72, 0x6f, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3b, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x62,
0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d,
0x73, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3b, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x73, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
@@ -338,12 +331,12 @@ func file_proto_streams_proto_rawDescGZIP() []byte {
var file_proto_streams_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_proto_streams_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_proto_streams_proto_goTypes = []interface{}{ var file_proto_streams_proto_goTypes = []interface{}{
(*PublishResponse)(nil), // 0: streams.PublishResponse (*PublishResponse)(nil), // 0: streams.PublishResponse
(*Message)(nil), // 1: streams.Message (*Message)(nil), // 1: streams.Message
(*SubscribeRequest)(nil), // 2: streams.SubscribeRequest (*SubscribeRequest)(nil), // 2: streams.SubscribeRequest
(*TokenRequest)(nil), // 3: streams.TokenRequest (*TokenRequest)(nil), // 3: streams.TokenRequest
(*TokenResponse)(nil), // 4: streams.TokenResponse (*TokenResponse)(nil), // 4: streams.TokenResponse
(*timestamp.Timestamp)(nil), // 5: google.protobuf.Timestamp (*timestamppb.Timestamp)(nil), // 5: google.protobuf.Timestamp
} }
var file_proto_streams_proto_depIdxs = []int32{ var file_proto_streams_proto_depIdxs = []int32{
5, // 0: streams.Message.sent_at:type_name -> google.protobuf.Timestamp 5, // 0: streams.Message.sent_at:type_name -> google.protobuf.Timestamp

View File

@@ -6,7 +6,7 @@ package streams
import ( import (
fmt "fmt" fmt "fmt"
proto "github.com/golang/protobuf/proto" proto "github.com/golang/protobuf/proto"
_ "github.com/golang/protobuf/ptypes/timestamp" _ "google.golang.org/protobuf/types/known/timestamppb"
math "math" math "math"
) )

View File

@@ -1,7 +1,7 @@
syntax = "proto3"; syntax = "proto3";
package streams; package streams;
option go_package = "github.com/micro/services/streams/proto;streams"; option go_package = "./proto;streams";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";
service Streams { service Streams {
@@ -35,4 +35,4 @@ message TokenRequest {
message TokenResponse { message TokenResponse {
string token = 1; string token = 1;
} }

View File

View File

@@ -1,235 +0,0 @@
// +build integration
package signup
import (
"encoding/json"
"errors"
"math/rand"
"os"
"os/exec"
"strings"
"testing"
"time"
"github.com/micro/micro/v3/test"
)
const (
retryCount = 1
)
var letterRunes = []rune("abcdefghijklmnopqrstuvwxyz")
func randStringRunes(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}
func setupUsersTests(serv test.Server, t *test.T) {
envToConfigKey := map[string][]string{}
if err := test.Try("Set up config values", t, func() ([]byte, error) {
for envKey, configKeys := range envToConfigKey {
val := os.Getenv(envKey)
if len(val) == 0 {
t.Fatalf("'%v' flag is missing", envKey)
}
for _, configKey := range configKeys {
outp, err := serv.Command().Exec("config", "set", configKey, val)
if err != nil {
return outp, err
}
}
}
return serv.Command().Exec("config", "set", "micro.billing.max_included_services", "3")
}, 10*time.Second); err != nil {
t.Fatal(err)
return
}
services := []struct {
envVar string
deflt string
}{
{envVar: "POSTS_SVC", deflt: "../../../users"},
}
for _, v := range services {
outp, err := serv.Command().Exec("run", v.deflt)
if err != nil {
t.Fatal(string(outp))
return
}
}
if err := test.Try("Find posts and tags", t, func() ([]byte, error) {
outp, err := serv.Command().Exec("services")
if err != nil {
return outp, err
}
list := []string{"users"}
logOutp := []byte{}
fail := false
for _, s := range list {
if !strings.Contains(string(outp), s) {
o, _ := serv.Command().Exec("logs", s)
logOutp = append(logOutp, o...)
fail = true
}
}
if fail {
return append(outp, logOutp...), errors.New("Can't find required services in list")
}
return outp, err
}, 180*time.Second); err != nil {
return
}
// setup rules
// Adjust rules before we signup into a non admin account
outp, err := serv.Command().Exec("auth", "create", "rule", "--access=granted", "--scope=''", "--resource=\"service:users:*\"", "users")
if err != nil {
t.Fatalf("Error setting up rules: %v", outp)
return
}
// copy the config with the admin logged in so we can use it for reading logs
// we dont want to have an open access rule for logs as it's not how it works in live
confPath := serv.Command().Config
outp, err = exec.Command("cp", "-rf", confPath, confPath+".admin").CombinedOutput()
if err != nil {
t.Fatalf("Error copying config: %v", outp)
return
}
}
func TestUsersService(t *testing.T) {
test.TrySuite(t, testUsers, retryCount)
}
func testUsers(t *test.T) {
t.Parallel()
serv := test.NewServer(t, test.WithLogin())
defer serv.Close()
if err := serv.Run(); err != nil {
return
}
setupUsersTests(serv, t)
cmd := serv.Command()
email := "test@gmail.com"
password := "testPassw"
username := "john"
id := "7"
if err := test.Try("Save user", t, func() ([]byte, error) {
// Attention! The content must be unquoted, don't add quotes.
outp, err := cmd.Exec("users", "create", "--id="+id, "--email="+email, "--password="+password, "--username=john")
if err != nil {
outp1, _ := cmd.Exec("logs", "users")
return append(outp, outp1...), err
}
return outp, err
}, 15*time.Second); err != nil {
return
}
outp, err := cmd.Exec("users", "read", "--id="+id)
if err != nil {
t.Fatal(string(outp), err)
return
}
if !strings.Contains(string(outp), email) ||
!strings.Contains(string(outp), username) ||
!strings.Contains(string(outp), id) {
t.Fatal(string(outp))
return
}
// no password
outp, err = cmd.Exec("users", "login", "--email="+email)
if err == nil {
t.Fatal(string(outp))
return
}
// wrong password
outp, err = cmd.Exec("users", "login", "--email="+email, "--password=somethingincorrect")
if err == nil {
t.Fatal(string(outp))
return
}
outp, err = cmd.Exec("users", "login", "--username="+username, "--password="+password)
if err != nil {
t.Fatal(string(outp), err)
return
}
loginRsp := map[string]interface{}{}
err = json.Unmarshal(outp, &loginRsp)
if err != nil {
t.Fatal(err)
return
}
session, ok := loginRsp["session"].(map[string]interface{})
if !ok {
t.Fatal(string(outp))
return
}
sessionID := session["id"].(string)
sessionUsername := session["username"].(string)
if sessionUsername != username {
t.Fatal(string(outp))
return
}
if len(sessionID) == 0 {
t.Fatal(string(outp))
return
}
outp, err = cmd.Exec("users", "login", "--email="+email, "--password="+password)
if err != nil {
t.Fatal(string(outp), err)
return
}
outp, err = cmd.Exec("users", "login", "--email="+email, "--password="+password)
if err != nil {
t.Fatal(string(outp), err)
return
}
outp, err = cmd.Exec("users", "readSession", "--sessionId="+sessionID)
if err != nil {
t.Fatal(string(outp), err)
return
}
loginRsp = map[string]interface{}{}
err = json.Unmarshal(outp, &loginRsp)
if err != nil {
t.Fatal(err)
return
}
session, ok = loginRsp["session"].(map[string]interface{})
if !ok {
t.Fatal(string(outp))
return
}
sessionID = session["id"].(string)
sessionUsername = session["username"].(string)
if sessionUsername != username {
t.Fatal(string(outp))
return
}
}

View File

@@ -2,10 +2,12 @@ package handler
import ( import (
"context" "context"
"regexp"
"strings" "strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -14,6 +16,11 @@ import (
// Create a user // Create a user
func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error { func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.FirstName) == 0 { if len(req.FirstName) == 0 {
return ErrMissingFirstName return ErrMissingFirstName
@@ -34,11 +41,15 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat
// hash and salt the password using bcrypt // hash and salt the password using bcrypt
phash, err := hashAndSalt(req.Password) phash, err := hashAndSalt(req.Password)
if err != nil { if err != nil {
logger.Errorf("Error hasing and salting password: %v", err) logger.Errorf("Error hashing and salting password: %v", err)
return errors.InternalServerError("HASHING_ERROR", "Error hashing password") return errors.InternalServerError("HASHING_ERROR", "Error hashing password")
} }
db, err := u.getDBConn(ctx)
return u.DB.Transaction(func(tx *gorm.DB) error { if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return db.Transaction(func(tx *gorm.DB) error {
// write the user to the database // write the user to the database
user := &User{ user := &User{
ID: uuid.New().String(), ID: uuid.New().String(),
@@ -47,10 +58,12 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat
Email: strings.ToLower(req.Email), Email: strings.ToLower(req.Email),
Password: phash, Password: phash,
} }
err = u.DB.Create(user).Error err = tx.Create(user).Error
if err != nil && strings.Contains(err.Error(), "idx_users_email") {
return ErrDuplicateEmail if err != nil {
} else if err != nil { if match, _ := regexp.MatchString(`idx_[\S]+_users_email`, err.Error()); match {
return ErrDuplicateEmail
}
logger.Errorf("Error writing to the database: %v", err) logger.Errorf("Error writing to the database: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database")
} }

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"github.com/micro/services/users/handler" "github.com/micro/services/users/handler"
@@ -61,7 +60,7 @@ func TestCreate(t *testing.T) {
h := testHandler(t) h := testHandler(t)
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
err := h.Create(context.TODO(), &pb.CreateRequest{ err := h.Create(microAccountCtx(), &pb.CreateRequest{
FirstName: tc.FirstName, FirstName: tc.FirstName,
LastName: tc.LastName, LastName: tc.LastName,
Email: tc.Email, Email: tc.Email,
@@ -79,7 +78,7 @@ func TestCreate(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &req, &rsp) err := h.Create(microAccountCtx(), &req, &rsp)
assert.NoError(t, err) assert.NoError(t, err)
u := rsp.User u := rsp.User
@@ -101,7 +100,7 @@ func TestCreate(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &req, &rsp) err := h.Create(microAccountCtx(), &req, &rsp)
assert.Equal(t, handler.ErrDuplicateEmail, err) assert.Equal(t, handler.ErrDuplicateEmail, err)
assert.Nil(t, rsp.User) assert.Nil(t, rsp.User)
}) })
@@ -114,7 +113,7 @@ func TestCreate(t *testing.T) {
Email: "johndoe@gmail.com", Email: "johndoe@gmail.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &req, &rsp) err := h.Create(microAccountCtx(), &req, &rsp)
assert.NoError(t, err) assert.NoError(t, err)
u := rsp.User u := rsp.User

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -11,13 +12,22 @@ import (
// Delete a user // Delete a user
func (u *Users) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error { func (u *Users) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Id) == 0 { if len(req.Id) == 0 {
return ErrMissingID return ErrMissingID
} }
db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// delete the users tokens // delete the users tokens
return u.DB.Transaction(func(tx *gorm.DB) error { return db.Transaction(func(tx *gorm.DB) error {
if err := tx.Delete(&Token{}, &Token{UserID: req.Id}).Error; err != nil { if err := tx.Delete(&Token{}, &Token{UserID: req.Id}).Error; err != nil {
logger.Errorf("Error writing to the database: %v", err) logger.Errorf("Error writing to the database: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database")

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"github.com/micro/services/users/handler" "github.com/micro/services/users/handler"
@@ -13,7 +12,7 @@ func TestDelete(t *testing.T) {
h := testHandler(t) h := testHandler(t)
t.Run("MissingID", func(t *testing.T) { t.Run("MissingID", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
assert.Equal(t, handler.ErrMissingID, err) assert.Equal(t, handler.ErrMissingID, err)
}) })
@@ -25,7 +24,7 @@ func TestDelete(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &cReq, &cRsp) err := h.Create(microAccountCtx(), &cReq, &cRsp)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp.User == nil { if cRsp.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -33,14 +32,14 @@ func TestDelete(t *testing.T) {
} }
t.Run("Valid", func(t *testing.T) { t.Run("Valid", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{ err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
Id: cRsp.User.Id, Id: cRsp.User.Id,
}, &pb.DeleteResponse{}) }, &pb.DeleteResponse{})
assert.NoError(t, err) assert.NoError(t, err)
// check it was actually deleted // check it was actually deleted
var rsp pb.ReadResponse var rsp pb.ReadResponse
err = h.Read(context.TODO(), &pb.ReadRequest{ err = h.Read(microAccountCtx(), &pb.ReadRequest{
Ids: []string{cRsp.User.Id}, Ids: []string{cRsp.User.Id},
}, &rsp) }, &rsp)
assert.NoError(t, err) assert.NoError(t, err)
@@ -48,7 +47,7 @@ func TestDelete(t *testing.T) {
}) })
t.Run("Retry", func(t *testing.T) { t.Run("Retry", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{ err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
Id: cRsp.User.Id, Id: cRsp.User.Id,
}, &pb.DeleteResponse{}) }, &pb.DeleteResponse{})
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -1,13 +1,18 @@
package handler package handler
import ( import (
"context"
"fmt"
"regexp" "regexp"
"strings"
"time" "time"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/schema"
) )
var ( var (
@@ -58,8 +63,38 @@ type Token struct {
} }
type Users struct { type Users struct {
DB *gorm.DB Time func() time.Time
Time func() time.Time Dialector gorm.Dialector
dbMigrations map[string]bool
}
func NewHandler(t func() time.Time, d gorm.Dialector) *Users {
return &Users{Time: t, Dialector: d, dbMigrations: map[string]bool{}}
}
func (u *Users) getDBConn(ctx context.Context) (*gorm.DB, error) {
acc, ok := auth.AccountFromContext(ctx)
if !ok {
return nil, fmt.Errorf("missing account from context")
}
db, err := gorm.Open(u.Dialector, &gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")),
},
})
if err != nil {
return nil, err
}
// skip migration if we've already done it
if u.dbMigrations[acc.Issuer] {
return db, nil
}
if err := db.AutoMigrate(&User{}, &Token{}); err != nil {
return nil, err
}
// record success
u.dbMigrations[acc.Issuer] = true
return db, nil
} }
// isEmailValid checks if the email provided passes the required structure and length. // isEmailValid checks if the email provided passes the required structure and length.

View File

@@ -1,11 +1,14 @@
package handler_test package handler_test
import ( import (
"context"
"os" "os"
"testing" "testing"
"time" "time"
"github.com/micro/micro/v3/service/auth"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"gorm.io/gorm/schema"
"github.com/micro/services/users/handler" "github.com/micro/services/users/handler"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -19,13 +22,16 @@ func testHandler(t *testing.T) *handler.Users {
if len(addr) == 0 { if len(addr) == 0 {
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
} }
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) dial := postgres.Open(addr)
db, err := gorm.Open(dial, &gorm.Config{
NamingStrategy: schema.NamingStrategy{TablePrefix: "micro_"},
})
if err != nil { if err != nil {
t.Fatalf("Error connecting to database: %v", err) t.Fatalf("Error connecting to database: %v", err)
} }
// clean any data from a previous run // clean any data from a previous run
if err := db.Exec("DROP TABLE IF EXISTS users, tokens CASCADE").Error; err != nil { if err := db.Exec("DROP TABLE IF EXISTS micro_users, micro_tokens CASCADE").Error; err != nil {
t.Fatalf("Error cleaning database: %v", err) t.Fatalf("Error cleaning database: %v", err)
} }
@@ -33,8 +39,7 @@ func testHandler(t *testing.T) *handler.Users {
if err := db.AutoMigrate(&handler.User{}, &handler.Token{}); err != nil { if err := db.AutoMigrate(&handler.User{}, &handler.Token{}); err != nil {
t.Fatalf("Error migrating database: %v", err) t.Fatalf("Error migrating database: %v", err)
} }
return handler.NewHandler(time.Now, dial)
return &handler.Users{DB: db, Time: time.Now}
} }
func assertUsersMatch(t *testing.T, exp, act *pb.User) { func assertUsersMatch(t *testing.T, exp, act *pb.User) {
@@ -47,3 +52,9 @@ func assertUsersMatch(t *testing.T, exp, act *pb.User) {
assert.Equal(t, exp.LastName, act.LastName) assert.Equal(t, exp.LastName, act.LastName)
assert.Equal(t, exp.Email, act.Email) assert.Equal(t, exp.Email, act.Email)
} }
func microAccountCtx() context.Context {
return auth.ContextWithAccount(context.TODO(), &auth.Account{
Issuer: "micro",
})
}

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -10,9 +11,18 @@ import (
// List all users // List all users
func (u *Users) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error { func (u *Users) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// query the database // query the database
db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
var users []User var users []User
if err := u.DB.Model(&User{}).Find(&users).Error; err != nil { if err := db.Model(&User{}).Find(&users).Error; err != nil {
logger.Errorf("Error reading from the database: %v", err) logger.Errorf("Error reading from the database: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database")
} }

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -19,7 +18,7 @@ func TestList(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &cReq1, &cRsp1) err := h.Create(microAccountCtx(), &cReq1, &cRsp1)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp1.User == nil { if cRsp1.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -33,7 +32,7 @@ func TestList(t *testing.T) {
Email: "johndoe@gmail.com", Email: "johndoe@gmail.com",
Password: "passwordabc", Password: "passwordabc",
} }
err = h.Create(context.TODO(), &cReq2, &cRsp2) err = h.Create(microAccountCtx(), &cReq2, &cRsp2)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp2.User == nil { if cRsp2.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -41,7 +40,7 @@ func TestList(t *testing.T) {
} }
var rsp pb.ListResponse var rsp pb.ListResponse
err = h.List(context.TODO(), &pb.ListRequest{}, &rsp) err = h.List(microAccountCtx(), &pb.ListRequest{}, &rsp)
assert.NoError(t, err) assert.NoError(t, err)
if rsp.Users == nil { if rsp.Users == nil {
t.Error("No users returned") t.Error("No users returned")

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -12,6 +13,10 @@ import (
// Login using email and password returns the users profile and a token // Login using email and password returns the users profile and a token
func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginResponse) error { func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Email) == 0 { if len(req.Email) == 0 {
return ErrMissingEmail return ErrMissingEmail
@@ -20,7 +25,13 @@ func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginRe
return ErrInvalidPassword return ErrInvalidPassword
} }
return u.DB.Transaction(func(tx *gorm.DB) error { db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return db.Transaction(func(tx *gorm.DB) error {
// lookup the user // lookup the user
var user User var user User
if err := tx.Where(&User{Email: req.Email}).First(&user).Error; err == gorm.ErrRecordNotFound { if err := tx.Where(&User{Email: req.Email}).First(&user).Error; err == gorm.ErrRecordNotFound {

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"github.com/micro/services/users/handler" "github.com/micro/services/users/handler"
@@ -20,7 +19,7 @@ func TestLogin(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &cReq, &cRsp) err := h.Create(microAccountCtx(), &cReq, &cRsp)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp.User == nil { if cRsp.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -67,7 +66,7 @@ func TestLogin(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
var rsp pb.LoginResponse var rsp pb.LoginResponse
err := h.Login(context.TODO(), &pb.LoginRequest{ err := h.Login(microAccountCtx(), &pb.LoginRequest{
Email: tc.Email, Password: tc.Password, Email: tc.Email, Password: tc.Password,
}, &rsp) }, &rsp)
assert.Equal(t, tc.Error, err) assert.Equal(t, tc.Error, err)

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -11,12 +12,21 @@ import (
// Logout expires all tokens for the user // Logout expires all tokens for the user
func (u *Users) Logout(ctx context.Context, req *pb.LogoutRequest, rsp *pb.LogoutResponse) error { func (u *Users) Logout(ctx context.Context, req *pb.LogoutRequest, rsp *pb.LogoutResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Id) == 0 { if len(req.Id) == 0 {
return ErrMissingID return ErrMissingID
} }
return u.DB.Transaction(func(tx *gorm.DB) error { db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return db.Transaction(func(tx *gorm.DB) error {
// lookup the user // lookup the user
var user User var user User
if err := tx.Where(&User{ID: req.Id}).Preload("Tokens").First(&user).Error; err == gorm.ErrRecordNotFound { if err := tx.Where(&User{ID: req.Id}).Preload("Tokens").First(&user).Error; err == gorm.ErrRecordNotFound {

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
@@ -14,12 +13,12 @@ func TestLogout(t *testing.T) {
h := testHandler(t) h := testHandler(t)
t.Run("MissingUserID", func(t *testing.T) { t.Run("MissingUserID", func(t *testing.T) {
err := h.Logout(context.TODO(), &pb.LogoutRequest{}, &pb.LogoutResponse{}) err := h.Logout(microAccountCtx(), &pb.LogoutRequest{}, &pb.LogoutResponse{})
assert.Equal(t, handler.ErrMissingID, err) assert.Equal(t, handler.ErrMissingID, err)
}) })
t.Run("UserNotFound", func(t *testing.T) { t.Run("UserNotFound", func(t *testing.T) {
err := h.Logout(context.TODO(), &pb.LogoutRequest{Id: uuid.New().String()}, &pb.LogoutResponse{}) err := h.Logout(microAccountCtx(), &pb.LogoutRequest{Id: uuid.New().String()}, &pb.LogoutResponse{})
assert.Equal(t, handler.ErrNotFound, err) assert.Equal(t, handler.ErrNotFound, err)
}) })
@@ -32,17 +31,17 @@ func TestLogout(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &cReq, &cRsp) err := h.Create(microAccountCtx(), &cReq, &cRsp)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp.User == nil { if cRsp.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
return return
} }
err = h.Logout(context.TODO(), &pb.LogoutRequest{Id: cRsp.User.Id}, &pb.LogoutResponse{}) err = h.Logout(microAccountCtx(), &pb.LogoutRequest{Id: cRsp.User.Id}, &pb.LogoutResponse{})
assert.NoError(t, err) assert.NoError(t, err)
err = h.Validate(context.TODO(), &pb.ValidateRequest{Token: cRsp.Token}, &pb.ValidateResponse{}) err = h.Validate(microAccountCtx(), &pb.ValidateRequest{Token: cRsp.Token}, &pb.ValidateResponse{})
assert.Error(t, err) assert.Error(t, err)
}) })
} }

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -10,14 +11,23 @@ import (
// Read users using ID // Read users using ID
func (u *Users) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error { func (u *Users) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Ids) == 0 { if len(req.Ids) == 0 {
return ErrMissingIDs return ErrMissingIDs
} }
// query the database // query the database
db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
var users []User var users []User
if err := u.DB.Model(&User{}).Where("id IN (?)", req.Ids).Find(&users).Error; err != nil { if err := db.Model(&User{}).Where("id IN (?)", req.Ids).Find(&users).Error; err != nil {
logger.Errorf("Error reading from the database: %v", err) logger.Errorf("Error reading from the database: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database")
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"strings" "strings"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -11,6 +12,10 @@ import (
// Read users using email // Read users using email
func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp *pb.ReadByEmailResponse) error { func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp *pb.ReadByEmailResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Emails) == 0 { if len(req.Emails) == 0 {
return ErrMissingEmails return ErrMissingEmails
@@ -21,8 +26,13 @@ func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp
} }
// query the database // query the database
db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
var users []User var users []User
if err := u.DB.Model(&User{}).Where("lower(email) IN (?)", emails).Find(&users).Error; err != nil { if err := db.Model(&User{}).Where("lower(email) IN (?)", emails).Find(&users).Error; err != nil {
logger.Errorf("Error reading from the database: %v", err) logger.Errorf("Error reading from the database: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database")
} }

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"strings" "strings"
"testing" "testing"
@@ -15,14 +14,14 @@ func TestReadByEmail(t *testing.T) {
t.Run("MissingEmails", func(t *testing.T) { t.Run("MissingEmails", func(t *testing.T) {
var rsp pb.ReadByEmailResponse var rsp pb.ReadByEmailResponse
err := h.ReadByEmail(context.TODO(), &pb.ReadByEmailRequest{}, &rsp) err := h.ReadByEmail(microAccountCtx(), &pb.ReadByEmailRequest{}, &rsp)
assert.Equal(t, handler.ErrMissingEmails, err) assert.Equal(t, handler.ErrMissingEmails, err)
assert.Nil(t, rsp.Users) assert.Nil(t, rsp.Users)
}) })
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
var rsp pb.ReadByEmailResponse var rsp pb.ReadByEmailResponse
err := h.ReadByEmail(context.TODO(), &pb.ReadByEmailRequest{Emails: []string{"foo"}}, &rsp) err := h.ReadByEmail(microAccountCtx(), &pb.ReadByEmailRequest{Emails: []string{"foo"}}, &rsp)
assert.Nil(t, err) assert.Nil(t, err)
if rsp.Users == nil { if rsp.Users == nil {
t.Fatal("Expected the users object to not be nil") t.Fatal("Expected the users object to not be nil")
@@ -38,7 +37,7 @@ func TestReadByEmail(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &req1, &rsp1) err := h.Create(microAccountCtx(), &req1, &rsp1)
assert.NoError(t, err) assert.NoError(t, err)
if rsp1.User == nil { if rsp1.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -52,7 +51,7 @@ func TestReadByEmail(t *testing.T) {
Email: "apple@tree.com", Email: "apple@tree.com",
Password: "passwordabc", Password: "passwordabc",
} }
err = h.Create(context.TODO(), &req2, &rsp2) err = h.Create(microAccountCtx(), &req2, &rsp2)
assert.NoError(t, err) assert.NoError(t, err)
if rsp2.User == nil { if rsp2.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -61,7 +60,7 @@ func TestReadByEmail(t *testing.T) {
// test the read // test the read
var rsp pb.ReadByEmailResponse var rsp pb.ReadByEmailResponse
err = h.ReadByEmail(context.TODO(), &pb.ReadByEmailRequest{ err = h.ReadByEmail(microAccountCtx(), &pb.ReadByEmailRequest{
Emails: []string{rsp1.User.Email, strings.ToUpper(rsp2.User.Email)}, Emails: []string{rsp1.User.Email, strings.ToUpper(rsp2.User.Email)},
}, &rsp) }, &rsp)
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"github.com/micro/services/users/handler" "github.com/micro/services/users/handler"
@@ -14,14 +13,14 @@ func TestRead(t *testing.T) {
t.Run("MissingIDs", func(t *testing.T) { t.Run("MissingIDs", func(t *testing.T) {
var rsp pb.ReadResponse var rsp pb.ReadResponse
err := h.Read(context.TODO(), &pb.ReadRequest{}, &rsp) err := h.Read(microAccountCtx(), &pb.ReadRequest{}, &rsp)
assert.Equal(t, handler.ErrMissingIDs, err) assert.Equal(t, handler.ErrMissingIDs, err)
assert.Nil(t, rsp.Users) assert.Nil(t, rsp.Users)
}) })
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
var rsp pb.ReadResponse var rsp pb.ReadResponse
err := h.Read(context.TODO(), &pb.ReadRequest{Ids: []string{"foo"}}, &rsp) err := h.Read(microAccountCtx(), &pb.ReadRequest{Ids: []string{"foo"}}, &rsp)
assert.Nil(t, err) assert.Nil(t, err)
if rsp.Users == nil { if rsp.Users == nil {
t.Fatal("Expected the users object to not be nil") t.Fatal("Expected the users object to not be nil")
@@ -37,7 +36,7 @@ func TestRead(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &req1, &rsp1) err := h.Create(microAccountCtx(), &req1, &rsp1)
assert.NoError(t, err) assert.NoError(t, err)
if rsp1.User == nil { if rsp1.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -51,7 +50,7 @@ func TestRead(t *testing.T) {
Email: "apple@tree.com", Email: "apple@tree.com",
Password: "passwordabc", Password: "passwordabc",
} }
err = h.Create(context.TODO(), &req2, &rsp2) err = h.Create(microAccountCtx(), &req2, &rsp2)
assert.NoError(t, err) assert.NoError(t, err)
if rsp2.User == nil { if rsp2.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -60,7 +59,7 @@ func TestRead(t *testing.T) {
// test the read // test the read
var rsp pb.ReadResponse var rsp pb.ReadResponse
err = h.Read(context.TODO(), &pb.ReadRequest{ err = h.Read(microAccountCtx(), &pb.ReadRequest{
Ids: []string{rsp1.User.Id, rsp2.User.Id}, Ids: []string{rsp1.User.Id, rsp2.User.Id},
}, &rsp) }, &rsp)
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -2,8 +2,10 @@ package handler
import ( import (
"context" "context"
"regexp"
"strings" "strings"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -12,6 +14,10 @@ import (
// Update a user // Update a user
func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error { func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Id) == 0 { if len(req.Id) == 0 {
return ErrMissingID return ErrMissingID
@@ -34,7 +40,12 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat
// lookup the user // lookup the user
var user User var user User
if err := u.DB.Where(&User{ID: req.Id}).First(&user).Error; err == gorm.ErrRecordNotFound { db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if err := db.Where(&User{ID: req.Id}).First(&user).Error; err == gorm.ErrRecordNotFound {
return ErrNotFound return ErrNotFound
} else if err != nil { } else if err != nil {
logger.Errorf("Error reading from the database: %v", err) logger.Errorf("Error reading from the database: %v", err)
@@ -61,10 +72,11 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat
} }
// write the user to the database // write the user to the database
err := u.DB.Save(user).Error err = db.Save(user).Error
if err != nil && strings.Contains(err.Error(), "idx_users_email") { if err != nil {
return ErrDuplicateEmail if match, _ := regexp.MatchString(`idx_[\S]+_users_email`, err.Error()); match {
} else if err != nil { return ErrDuplicateEmail
}
logger.Errorf("Error writing to the database: %v", err) logger.Errorf("Error writing to the database: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database") return errors.InternalServerError("DATABASE_ERROR", "Error connecting to the database")
} }

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"github.com/micro/services/users/handler" "github.com/micro/services/users/handler"
@@ -15,14 +14,14 @@ func TestUpdate(t *testing.T) {
t.Run("MissingID", func(t *testing.T) { t.Run("MissingID", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{}, &rsp) err := h.Update(microAccountCtx(), &pb.UpdateRequest{}, &rsp)
assert.Equal(t, handler.ErrMissingID, err) assert.Equal(t, handler.ErrMissingID, err)
assert.Nil(t, rsp.User) assert.Nil(t, rsp.User)
}) })
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{Id: "foo"}, &rsp) err := h.Update(microAccountCtx(), &pb.UpdateRequest{Id: "foo"}, &rsp)
assert.Equal(t, handler.ErrNotFound, err) assert.Equal(t, handler.ErrNotFound, err)
assert.Nil(t, rsp.User) assert.Nil(t, rsp.User)
}) })
@@ -35,7 +34,7 @@ func TestUpdate(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &cReq1, &cRsp1) err := h.Create(microAccountCtx(), &cReq1, &cRsp1)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp1.User == nil { if cRsp1.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -49,7 +48,7 @@ func TestUpdate(t *testing.T) {
Email: "johndoe@gmail.com", Email: "johndoe@gmail.com",
Password: "passwordabc", Password: "passwordabc",
} }
err = h.Create(context.TODO(), &cReq2, &cRsp2) err = h.Create(microAccountCtx(), &cReq2, &cRsp2)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp2.User == nil { if cRsp2.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -58,7 +57,7 @@ func TestUpdate(t *testing.T) {
t.Run("BlankFirstName", func(t *testing.T) { t.Run("BlankFirstName", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{ err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp1.User.Id, FirstName: &wrapperspb.StringValue{}, Id: cRsp1.User.Id, FirstName: &wrapperspb.StringValue{},
}, &rsp) }, &rsp)
assert.Equal(t, handler.ErrMissingFirstName, err) assert.Equal(t, handler.ErrMissingFirstName, err)
@@ -67,7 +66,7 @@ func TestUpdate(t *testing.T) {
t.Run("BlankLastName", func(t *testing.T) { t.Run("BlankLastName", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{ err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp1.User.Id, LastName: &wrapperspb.StringValue{}, Id: cRsp1.User.Id, LastName: &wrapperspb.StringValue{},
}, &rsp) }, &rsp)
assert.Equal(t, handler.ErrMissingLastName, err) assert.Equal(t, handler.ErrMissingLastName, err)
@@ -76,7 +75,7 @@ func TestUpdate(t *testing.T) {
t.Run("BlankLastName", func(t *testing.T) { t.Run("BlankLastName", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{ err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp1.User.Id, LastName: &wrapperspb.StringValue{}, Id: cRsp1.User.Id, LastName: &wrapperspb.StringValue{},
}, &rsp) }, &rsp)
assert.Equal(t, handler.ErrMissingLastName, err) assert.Equal(t, handler.ErrMissingLastName, err)
@@ -85,7 +84,7 @@ func TestUpdate(t *testing.T) {
t.Run("BlankEmail", func(t *testing.T) { t.Run("BlankEmail", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{ err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{}, Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{},
}, &rsp) }, &rsp)
assert.Equal(t, handler.ErrMissingEmail, err) assert.Equal(t, handler.ErrMissingEmail, err)
@@ -94,7 +93,7 @@ func TestUpdate(t *testing.T) {
t.Run("InvalidEmail", func(t *testing.T) { t.Run("InvalidEmail", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{ err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{Value: "foo.bar"}, Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{Value: "foo.bar"},
}, &rsp) }, &rsp)
assert.Equal(t, handler.ErrInvalidEmail, err) assert.Equal(t, handler.ErrInvalidEmail, err)
@@ -103,7 +102,7 @@ func TestUpdate(t *testing.T) {
t.Run("EmailAlreadyExists", func(t *testing.T) { t.Run("EmailAlreadyExists", func(t *testing.T) {
var rsp pb.UpdateResponse var rsp pb.UpdateResponse
err := h.Update(context.TODO(), &pb.UpdateRequest{ err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{Value: cRsp2.User.Email}, Id: cRsp1.User.Id, Email: &wrapperspb.StringValue{Value: cRsp2.User.Email},
}, &rsp) }, &rsp)
assert.Equal(t, handler.ErrDuplicateEmail, err) assert.Equal(t, handler.ErrDuplicateEmail, err)
@@ -118,7 +117,7 @@ func TestUpdate(t *testing.T) {
LastName: &wrapperspb.StringValue{Value: "Bar"}, LastName: &wrapperspb.StringValue{Value: "Bar"},
} }
var uRsp pb.UpdateResponse var uRsp pb.UpdateResponse
err := h.Update(context.TODO(), &uReq, &uRsp) err := h.Update(microAccountCtx(), &uReq, &uRsp)
assert.NoError(t, err) assert.NoError(t, err)
if uRsp.User == nil { if uRsp.User == nil {
t.Error("No user returned") t.Error("No user returned")
@@ -135,14 +134,14 @@ func TestUpdate(t *testing.T) {
Id: cRsp2.User.Id, Id: cRsp2.User.Id,
Password: &wrapperspb.StringValue{Value: "helloworld"}, Password: &wrapperspb.StringValue{Value: "helloworld"},
} }
err := h.Update(context.TODO(), &uReq, &pb.UpdateResponse{}) err := h.Update(microAccountCtx(), &uReq, &pb.UpdateResponse{})
assert.NoError(t, err) assert.NoError(t, err)
lReq := pb.LoginRequest{ lReq := pb.LoginRequest{
Email: cRsp2.User.Email, Email: cRsp2.User.Email,
Password: "helloworld", Password: "helloworld",
} }
err = h.Login(context.TODO(), &lReq, &pb.LoginResponse{}) err = h.Login(microAccountCtx(), &lReq, &pb.LoginResponse{})
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }

View File

@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/users/proto" pb "github.com/micro/services/users/proto"
@@ -11,12 +12,21 @@ import (
// Validate a token, each time a token is validated it extends its lifetime for another week // Validate a token, each time a token is validated it extends its lifetime for another week
func (u *Users) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.ValidateResponse) error { func (u *Users) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.ValidateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request // validate the request
if len(req.Token) == 0 { if len(req.Token) == 0 {
return ErrMissingToken return ErrMissingToken
} }
return u.DB.Transaction(func(tx *gorm.DB) error { db, err := u.getDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return db.Transaction(func(tx *gorm.DB) error {
// lookup the token // lookup the token
var token Token var token Token
if err := tx.Where(&Token{Key: req.Token}).Preload("User").First(&token).Error; err == gorm.ErrRecordNotFound { if err := tx.Where(&Token{Key: req.Token}).Preload("User").First(&token).Error; err == gorm.ErrRecordNotFound {

View File

@@ -1,7 +1,6 @@
package handler_test package handler_test
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -22,7 +21,7 @@ func TestValidate(t *testing.T) {
Email: "john@doe.com", Email: "john@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err := h.Create(context.TODO(), &cReq1, &cRsp1) err := h.Create(microAccountCtx(), &cReq1, &cRsp1)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp1.User == nil { if cRsp1.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -36,7 +35,7 @@ func TestValidate(t *testing.T) {
Email: "barry@doe.com", Email: "barry@doe.com",
Password: "passwordabc", Password: "passwordabc",
} }
err = h.Create(context.TODO(), &cReq2, &cRsp2) err = h.Create(microAccountCtx(), &cReq2, &cRsp2)
assert.NoError(t, err) assert.NoError(t, err)
if cRsp2.User == nil { if cRsp2.User == nil {
t.Fatal("No user returned") t.Fatal("No user returned")
@@ -88,7 +87,7 @@ func TestValidate(t *testing.T) {
} }
var rsp pb.ValidateResponse var rsp pb.ValidateResponse
err := h.Validate(context.TODO(), &pb.ValidateRequest{Token: tc.Token}, &rsp) err := h.Validate(microAccountCtx(), &pb.ValidateRequest{Token: tc.Token}, &rsp)
assert.Equal(t, tc.Error, err) assert.Equal(t, tc.Error, err)
if tc.User != nil { if tc.User != nil {

View File

@@ -3,14 +3,12 @@ package main
import ( import (
"time" "time"
"github.com/micro/services/users/handler"
pb "github.com/micro/services/users/proto"
"github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service"
"github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/config"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
"github.com/micro/services/users/handler"
pb "github.com/micro/services/users/proto"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm"
) )
var dbAddress = "postgresql://postgres:postgres@localhost:5432/users?sslmode=disable" var dbAddress = "postgresql://postgres:postgres@localhost:5432/users?sslmode=disable"
@@ -28,16 +26,8 @@ func main() {
logger.Fatalf("Error loading config: %v", err) logger.Fatalf("Error loading config: %v", err)
} }
addr := cfg.String(dbAddress) addr := cfg.String(dbAddress)
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
if err != nil {
logger.Fatalf("Error connecting to database: %v", err)
}
if err := db.AutoMigrate(&handler.User{}, &handler.Token{}); err != nil {
logger.Fatalf("Error migrating database: %v", err)
}
// Register handler // Register handler
pb.RegisterUsersHandler(srv.Server(), &handler.Users{DB: db, Time: time.Now}) pb.RegisterUsersHandler(srv.Server(), handler.NewHandler(time.Now, postgres.Open(addr)))
// Run service // Run service
if err := srv.Run(); err != nil { if err := srv.Run(); err != nil {