mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-11 10:54:28 +00:00
Multi tenant groups (#77)
* multitenant groups * switch users service to use new wrapper * fix tests * skip pkg dir * Check for auth
This commit is contained in:
@@ -5,8 +5,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/groups/proto"
|
||||
gorm2 "github.com/micro/services/pkg/gorm"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -41,10 +45,14 @@ func (g *Group) Serialize() *pb.Group {
|
||||
}
|
||||
|
||||
type Groups struct {
|
||||
DB *gorm.DB
|
||||
gorm2.Helper
|
||||
}
|
||||
|
||||
func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// validate the request
|
||||
if len(req.Name) == 0 {
|
||||
return ErrMissingName
|
||||
@@ -52,7 +60,13 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
|
||||
|
||||
// create the group object
|
||||
group := &Group{ID: uuid.New().String(), Name: req.Name}
|
||||
if err := g.DB.Create(group).Error; err != nil {
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
|
||||
if err := db.Create(group).Error; err != nil {
|
||||
return ErrStore
|
||||
}
|
||||
|
||||
@@ -62,14 +76,23 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
|
||||
}
|
||||
|
||||
func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// validate the request
|
||||
if len(req.Ids) == 0 {
|
||||
return ErrMissingIDs
|
||||
}
|
||||
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
// query the database
|
||||
var groups []Group
|
||||
if err := g.DB.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
|
||||
if err := db.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
|
||||
return ErrStore
|
||||
}
|
||||
|
||||
@@ -83,6 +106,10 @@ func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResp
|
||||
}
|
||||
|
||||
func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// validate the request
|
||||
if len(req.Id) == 0 {
|
||||
return ErrMissingID
|
||||
@@ -90,8 +117,13 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
|
||||
if len(req.Name) == 0 {
|
||||
return ErrMissingName
|
||||
}
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
|
||||
return g.DB.Transaction(func(tx *gorm.DB) error {
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// find the group
|
||||
var group Group
|
||||
if err := tx.Where(&Group{ID: req.Id}).First(&group).Error; err == gorm.ErrRecordNotFound {
|
||||
@@ -113,13 +145,22 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
|
||||
}
|
||||
|
||||
func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// validate the request
|
||||
if len(req.Id) == 0 {
|
||||
return ErrMissingID
|
||||
}
|
||||
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
// delete from the database
|
||||
if err := g.DB.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
|
||||
if err := db.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return ErrStore
|
||||
@@ -129,10 +170,19 @@ func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Dele
|
||||
}
|
||||
|
||||
func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
if len(req.MemberId) > 0 {
|
||||
// only list groups the user is a member of
|
||||
var ms []Membership
|
||||
q := g.DB.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships")
|
||||
q := db.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships")
|
||||
if err := q.Find(&ms).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -145,7 +195,7 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
|
||||
|
||||
// load all groups
|
||||
var groups []Group
|
||||
if err := g.DB.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil {
|
||||
if err := db.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil {
|
||||
return ErrStore
|
||||
}
|
||||
|
||||
@@ -159,6 +209,10 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
|
||||
}
|
||||
|
||||
func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *pb.AddMemberResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// validate the request
|
||||
if len(req.GroupId) == 0 {
|
||||
return ErrMissingGroupID
|
||||
@@ -166,8 +220,13 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
|
||||
if len(req.MemberId) == 0 {
|
||||
return ErrMissingMemberID
|
||||
}
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
|
||||
return g.DB.Transaction(func(tx *gorm.DB) error {
|
||||
return db.Transaction(func(tx *gorm.DB) error {
|
||||
// check the group exists
|
||||
var group Group
|
||||
if err := tx.Where(&Group{ID: req.GroupId}).First(&group).Error; err == gorm.ErrRecordNotFound {
|
||||
@@ -191,6 +250,10 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
|
||||
}
|
||||
|
||||
func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, rsp *pb.RemoveMemberResponse) error {
|
||||
_, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// validate the request
|
||||
if len(req.GroupId) == 0 {
|
||||
return ErrMissingGroupID
|
||||
@@ -199,9 +262,14 @@ func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest,
|
||||
return ErrMissingMemberID
|
||||
}
|
||||
|
||||
db, err := g.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
}
|
||||
// delete the membership
|
||||
m := &Membership{MemberID: req.MemberId, GroupID: req.GroupId}
|
||||
if err := g.DB.Where(m).Delete(m).Error; err != nil {
|
||||
if err := db.Where(m).Delete(m).Error; err != nil {
|
||||
return ErrStore
|
||||
}
|
||||
|
||||
|
||||
@@ -2,16 +2,16 @@ package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/services/groups/handler"
|
||||
pb "github.com/micro/services/groups/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func testHandler(t *testing.T) *handler.Groups {
|
||||
@@ -20,33 +20,30 @@ func testHandler(t *testing.T) *handler.Groups {
|
||||
if len(addr) == 0 {
|
||||
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
|
||||
}
|
||||
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
|
||||
sqlDB, err := sql.Open("pgx", addr)
|
||||
if err != nil {
|
||||
t.Fatalf("Error connecting to database: %v", err)
|
||||
t.Fatalf("Failed to open connection to DB %s", err)
|
||||
}
|
||||
|
||||
// clean any data from a previous run
|
||||
if err := db.Exec("DROP TABLE IF EXISTS groups, memberships CASCADE").Error; err != nil {
|
||||
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_groups, micro_memberships CASCADE"); err != nil {
|
||||
t.Fatalf("Error cleaning database: %v", err)
|
||||
}
|
||||
|
||||
// migrate the database
|
||||
if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil {
|
||||
t.Fatalf("Error migrating database: %v", err)
|
||||
}
|
||||
|
||||
return &handler.Groups{DB: db}
|
||||
h := &handler.Groups{}
|
||||
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
|
||||
return h
|
||||
}
|
||||
func TestCreate(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{}, &pb.CreateResponse{})
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{}, &pb.CreateResponse{})
|
||||
assert.Equal(t, handler.ErrMissingName, err)
|
||||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Doe Family Group",
|
||||
}, &pb.CreateResponse{})
|
||||
assert.NoError(t, err)
|
||||
@@ -57,21 +54,21 @@ func TestUpdate(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
t.Run("MissingID", func(t *testing.T) {
|
||||
err := h.Update(context.TODO(), &pb.UpdateRequest{
|
||||
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
|
||||
Name: "Doe Family Group",
|
||||
}, &pb.UpdateResponse{})
|
||||
assert.Equal(t, handler.ErrMissingID, err)
|
||||
})
|
||||
|
||||
t.Run("MissingName", func(t *testing.T) {
|
||||
err := h.Update(context.TODO(), &pb.UpdateRequest{
|
||||
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
|
||||
Id: uuid.New().String(),
|
||||
}, &pb.UpdateResponse{})
|
||||
assert.Equal(t, handler.ErrMissingName, err)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
err := h.Update(context.TODO(), &pb.UpdateRequest{
|
||||
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
|
||||
Id: uuid.New().String(),
|
||||
Name: "Bar Family Group",
|
||||
}, &pb.UpdateResponse{})
|
||||
@@ -81,19 +78,19 @@ func TestUpdate(t *testing.T) {
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
// create a demo group
|
||||
var cRsp pb.CreateResponse
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Doe Family Group",
|
||||
}, &cRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = h.Update(context.TODO(), &pb.UpdateRequest{
|
||||
err = h.Update(microAccountCtx(), &pb.UpdateRequest{
|
||||
Id: cRsp.Group.Id,
|
||||
Name: "Bar Family Group",
|
||||
}, &pb.UpdateResponse{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var rRsp pb.ReadResponse
|
||||
err = h.Read(context.TODO(), &pb.ReadRequest{
|
||||
err = h.Read(microAccountCtx(), &pb.ReadRequest{
|
||||
Ids: []string{cRsp.Group.Id},
|
||||
}, &rRsp)
|
||||
assert.NoError(t, err)
|
||||
@@ -111,12 +108,12 @@ func TestDelete(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
t.Run("MissingID", func(t *testing.T) {
|
||||
err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
|
||||
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
|
||||
assert.Equal(t, handler.ErrMissingID, err)
|
||||
})
|
||||
|
||||
t.Run("NotFound", func(t *testing.T) {
|
||||
err := h.Delete(context.TODO(), &pb.DeleteRequest{
|
||||
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
|
||||
Id: uuid.New().String(),
|
||||
}, &pb.DeleteResponse{})
|
||||
assert.NoError(t, err)
|
||||
@@ -124,19 +121,19 @@ func TestDelete(t *testing.T) {
|
||||
|
||||
// create a demo group
|
||||
var cRsp pb.CreateResponse
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Doe Family Group",
|
||||
}, &cRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
err := h.Delete(context.TODO(), &pb.DeleteRequest{
|
||||
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
|
||||
Id: cRsp.Group.Id,
|
||||
}, &pb.DeleteResponse{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
var rRsp pb.ReadResponse
|
||||
err = h.Read(context.TODO(), &pb.ReadRequest{
|
||||
err = h.Read(microAccountCtx(), &pb.ReadRequest{
|
||||
Ids: []string{cRsp.Group.Id},
|
||||
}, &rRsp)
|
||||
assert.Nil(t, rRsp.Groups[cRsp.Group.Id])
|
||||
@@ -147,27 +144,27 @@ func TestList(t *testing.T) {
|
||||
|
||||
// create two demo groups
|
||||
var cRsp1 pb.CreateResponse
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Alpha Group",
|
||||
}, &cRsp1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var cRsp2 pb.CreateResponse
|
||||
err = h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err = h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Bravo Group",
|
||||
}, &cRsp2)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// add a member to the first group
|
||||
uid := uuid.New().String()
|
||||
err = h.AddMember(context.TODO(), &pb.AddMemberRequest{
|
||||
err = h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
|
||||
GroupId: cRsp1.Group.Id, MemberId: uid,
|
||||
}, &pb.AddMemberResponse{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("Unscoped", func(t *testing.T) {
|
||||
var rsp pb.ListResponse
|
||||
err = h.List(context.TODO(), &pb.ListRequest{}, &rsp)
|
||||
err = h.List(microAccountCtx(), &pb.ListRequest{}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
assert.Lenf(t, rsp.Groups, 2, "Two groups should be returned")
|
||||
if len(rsp.Groups) != 2 {
|
||||
@@ -188,13 +185,12 @@ func TestList(t *testing.T) {
|
||||
|
||||
t.Run("Scoped", func(t *testing.T) {
|
||||
var rsp pb.ListResponse
|
||||
err = h.List(context.TODO(), &pb.ListRequest{MemberId: uid}, &rsp)
|
||||
err = h.List(microAccountCtx(), &pb.ListRequest{MemberId: uid}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
assert.Lenf(t, rsp.Groups, 1, "One group should be returned")
|
||||
if len(rsp.Groups) != 1 {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, cRsp1.Group.Id, rsp.Groups[0].Id)
|
||||
assert.Equal(t, cRsp1.Group.Name, rsp.Groups[0].Name)
|
||||
assert.Len(t, rsp.Groups[0].MemberIds, 1)
|
||||
@@ -206,21 +202,21 @@ func TestAddMember(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
t.Run("MissingGroupID", func(t *testing.T) {
|
||||
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
|
||||
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.AddMemberResponse{})
|
||||
assert.Equal(t, handler.ErrMissingGroupID, err)
|
||||
})
|
||||
|
||||
t.Run("MissingMemberID", func(t *testing.T) {
|
||||
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
|
||||
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
|
||||
GroupId: uuid.New().String(),
|
||||
}, &pb.AddMemberResponse{})
|
||||
assert.Equal(t, handler.ErrMissingMemberID, err)
|
||||
})
|
||||
|
||||
t.Run("GroupNotFound", func(t *testing.T) {
|
||||
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
|
||||
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
|
||||
GroupId: uuid.New().String(),
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.AddMemberResponse{})
|
||||
@@ -229,13 +225,13 @@ func TestAddMember(t *testing.T) {
|
||||
|
||||
// create a test group
|
||||
var cRsp pb.CreateResponse
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Alpha Group",
|
||||
}, &cRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
|
||||
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
|
||||
GroupId: cRsp.Group.Id,
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.AddMemberResponse{})
|
||||
@@ -243,7 +239,7 @@ func TestAddMember(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Retry", func(t *testing.T) {
|
||||
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
|
||||
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
|
||||
GroupId: cRsp.Group.Id,
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.AddMemberResponse{})
|
||||
@@ -255,14 +251,14 @@ func TestRemoveMember(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
t.Run("MissingGroupID", func(t *testing.T) {
|
||||
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
|
||||
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.RemoveMemberResponse{})
|
||||
assert.Equal(t, handler.ErrMissingGroupID, err)
|
||||
})
|
||||
|
||||
t.Run("MissingMemberID", func(t *testing.T) {
|
||||
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
|
||||
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
|
||||
GroupId: uuid.New().String(),
|
||||
}, &pb.RemoveMemberResponse{})
|
||||
assert.Equal(t, handler.ErrMissingMemberID, err)
|
||||
@@ -270,13 +266,13 @@ func TestRemoveMember(t *testing.T) {
|
||||
|
||||
// create a test group
|
||||
var cRsp pb.CreateResponse
|
||||
err := h.Create(context.TODO(), &pb.CreateRequest{
|
||||
err := h.Create(microAccountCtx(), &pb.CreateRequest{
|
||||
Name: "Alpha Group",
|
||||
}, &cRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
|
||||
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
|
||||
GroupId: cRsp.Group.Id,
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.RemoveMemberResponse{})
|
||||
@@ -284,10 +280,16 @@ func TestRemoveMember(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Retry", func(t *testing.T) {
|
||||
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
|
||||
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
|
||||
GroupId: cRsp.Group.Id,
|
||||
MemberId: uuid.New().String(),
|
||||
}, &pb.RemoveMemberResponse{})
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func microAccountCtx() context.Context {
|
||||
return auth.ContextWithAccount(context.TODO(), &auth.Account{
|
||||
Issuer: "micro",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/micro/services/groups/handler"
|
||||
pb "github.com/micro/services/groups/proto"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"database/sql"
|
||||
|
||||
"github.com/micro/micro/v3/service"
|
||||
"github.com/micro/micro/v3/service/config"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
"github.com/micro/services/groups/handler"
|
||||
pb "github.com/micro/services/groups/proto"
|
||||
)
|
||||
|
||||
var dbAddress = "postgresql://postgres:postgres@localhost:5432/groups?sslmode=disable"
|
||||
@@ -26,16 +25,14 @@ func main() {
|
||||
logger.Fatalf("Error loading config: %v", err)
|
||||
}
|
||||
addr := cfg.String(dbAddress)
|
||||
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
|
||||
sqlDB, err := sql.Open("pgx", addr)
|
||||
if err != nil {
|
||||
logger.Fatalf("Error connecting to database: %v", err)
|
||||
logger.Fatalf("Failed to open connection to DB %s", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil {
|
||||
logger.Fatalf("Error migrating database: %v", err)
|
||||
}
|
||||
|
||||
h := &handler.Groups{}
|
||||
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
|
||||
// Register handler
|
||||
pb.RegisterGroupsHandler(srv.Server(), &handler.Groups{DB: db.Debug()})
|
||||
pb.RegisterGroupsHandler(srv.Server(), h)
|
||||
|
||||
// Run service
|
||||
if err := srv.Run(); err != nil {
|
||||
|
||||
444
pkg/gorm/wrapper.go
Normal file
444
pkg/gorm/wrapper.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Helper struct {
|
||||
sync.RWMutex
|
||||
gormConns map[string]*gorm.DB
|
||||
dbConn *sql.DB
|
||||
migrations []interface{}
|
||||
}
|
||||
|
||||
func (h *Helper) Migrations(migrations ...interface{}) *Helper {
|
||||
h.migrations = migrations
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Helper) DBConn(conn *sql.DB) *Helper {
|
||||
h.dbConn = conn
|
||||
h.gormConns = map[string]*gorm.DB{}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
|
||||
acc, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing account from context")
|
||||
}
|
||||
h.RLock()
|
||||
if conn, ok := h.gormConns[acc.Issuer]; ok {
|
||||
h.RUnlock()
|
||||
return conn, nil
|
||||
}
|
||||
h.RUnlock()
|
||||
h.Lock()
|
||||
// double check
|
||||
if conn, ok := h.gormConns[acc.Issuer]; ok {
|
||||
h.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
defer h.Unlock()
|
||||
ns := schema.NamingStrategy{
|
||||
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")),
|
||||
}
|
||||
db, err := gorm.Open(
|
||||
newGormDialector(postgres.Config{
|
||||
Conn: h.dbConn,
|
||||
}, ns),
|
||||
&gorm.Config{
|
||||
NamingStrategy: ns,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(h.migrations) == 0 {
|
||||
// record success
|
||||
h.gormConns[acc.Issuer] = db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(h.migrations...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// record success
|
||||
h.gormConns[acc.Issuer] = db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func newGormDialector(config postgres.Config, ns schema.NamingStrategy) gorm.Dialector {
|
||||
return &postgresDial{
|
||||
Dialector: postgres.Dialector{Config: &config},
|
||||
namer: &ns,
|
||||
}
|
||||
}
|
||||
|
||||
// postgresDial is a postgres dialector that prefixes index names with the table prefix when doing migrations.
|
||||
// NOTE, it does not support the gorm tag priority option
|
||||
type postgresDial struct {
|
||||
postgres.Dialector
|
||||
namer schema.Namer
|
||||
}
|
||||
|
||||
func (p postgresDial) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return gormMigrator{
|
||||
postgres.Migrator{
|
||||
migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: p,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}},
|
||||
},
|
||||
p.namer,
|
||||
}
|
||||
}
|
||||
|
||||
type gormMigrator struct {
|
||||
postgres.Migrator
|
||||
namer schema.Namer
|
||||
}
|
||||
|
||||
// AutoMigrate
|
||||
func (m gormMigrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
tx := m.DB.Session(&gorm.Session{NewDB: true})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
if err := tx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
var foundColumn gorm.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() == field.DBName {
|
||||
foundColumn = columnType
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
if !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range m.ParseIndexes(stmt.Schema) {
|
||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m gormMigrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := m.LookIndex(stmt.Schema, name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
createIndexSQL += " ON ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type + "(?)"
|
||||
} else {
|
||||
createIndexSQL += " ?"
|
||||
}
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m gormMigrator) LookIndex(sch *schema.Schema, name string) *schema.Index {
|
||||
if sch != nil {
|
||||
indexes := m.ParseIndexes(sch)
|
||||
for _, index := range indexes {
|
||||
if index.Name == name {
|
||||
return &index
|
||||
}
|
||||
|
||||
for _, field := range index.Fields {
|
||||
if field.Name == name {
|
||||
return &index
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g gormMigrator) parseFieldIndexes(field *schema.Field) (indexes []schema.Index) {
|
||||
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
|
||||
if value != "" {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if k == "INDEX" || k == "UNIQUEINDEX" {
|
||||
var (
|
||||
name string
|
||||
tag = strings.Join(v[1:], ":")
|
||||
idx = strings.Index(tag, ",")
|
||||
settings = schema.ParseTagSetting(tag, ",")
|
||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||
)
|
||||
|
||||
if idx == -1 {
|
||||
idx = len(tag)
|
||||
}
|
||||
|
||||
if idx != -1 {
|
||||
name = tag[0:idx]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
name = g.namer.IndexName(field.Schema.Table, field.Name)
|
||||
} else {
|
||||
ns := g.namer.(*schema.NamingStrategy)
|
||||
name = fmt.Sprintf("%s%s", ns.TablePrefix, name)
|
||||
}
|
||||
|
||||
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
|
||||
settings["CLASS"] = "UNIQUE"
|
||||
}
|
||||
|
||||
//priority, err := strconv.Atoi(settings["PRIORITY"])
|
||||
//if err != nil {
|
||||
// priority = 10
|
||||
//}
|
||||
|
||||
indexes = append(indexes, schema.Index{
|
||||
Name: name,
|
||||
Class: settings["CLASS"],
|
||||
Type: settings["TYPE"],
|
||||
Where: settings["WHERE"],
|
||||
Comment: settings["COMMENT"],
|
||||
Option: settings["OPTION"],
|
||||
Fields: []schema.IndexOption{{
|
||||
Field: field,
|
||||
Expression: settings["EXPRESSION"],
|
||||
Sort: settings["SORT"],
|
||||
Collate: settings["COLLATE"],
|
||||
Length: length,
|
||||
//priority: priority, // TODO does not support priority
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m gormMigrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{NewDB: true})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
hasPrimaryKeyInDataType bool
|
||||
)
|
||||
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
createTableSQL += "? ?"
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
|
||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||
createTableSQL += ","
|
||||
}
|
||||
|
||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||
createTableSQL += "PRIMARY KEY ?,"
|
||||
primaryKeys := []interface{}{}
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
values = append(values, primaryKeys)
|
||||
}
|
||||
|
||||
for _, idx := range m.ParseIndexes(stmt.Schema) {
|
||||
if m.CreateIndexAfterCreateTable {
|
||||
defer func(value interface{}, name string) {
|
||||
errr = tx.Migrator().CreateIndex(value, name)
|
||||
}(value, idx.Name)
|
||||
} else {
|
||||
if idx.Class != "" {
|
||||
createTableSQL += idx.Class + " "
|
||||
}
|
||||
createTableSQL += "INDEX ? ?"
|
||||
|
||||
if idx.Option != "" {
|
||||
createTableSQL += " " + idx.Option
|
||||
}
|
||||
|
||||
createTableSQL += ","
|
||||
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := buildConstraint(constraint)
|
||||
createTableSQL += sql + ","
|
||||
values = append(values, vars...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||
}
|
||||
|
||||
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
|
||||
|
||||
createTableSQL += ")"
|
||||
|
||||
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
|
||||
createTableSQL += fmt.Sprint(tableOption)
|
||||
}
|
||||
|
||||
errr = tx.Exec(createTableSQL, values...).Error
|
||||
return errr
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type gormIndexOption struct {
|
||||
schema.IndexOption
|
||||
priority int
|
||||
}
|
||||
|
||||
func (g gormMigrator) ParseIndexes(sch *schema.Schema) map[string]schema.Index {
|
||||
var indexes = map[string]schema.Index{}
|
||||
|
||||
for _, field := range sch.Fields {
|
||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||
for _, index := range g.parseFieldIndexes(field) {
|
||||
idx := indexes[index.Name]
|
||||
idx.Name = index.Name
|
||||
if idx.Class == "" {
|
||||
idx.Class = index.Class
|
||||
}
|
||||
if idx.Type == "" {
|
||||
idx.Type = index.Type
|
||||
}
|
||||
if idx.Where == "" {
|
||||
idx.Where = index.Where
|
||||
}
|
||||
if idx.Comment == "" {
|
||||
idx.Comment = index.Comment
|
||||
}
|
||||
if idx.Option == "" {
|
||||
idx.Option = index.Option
|
||||
}
|
||||
|
||||
idx.Fields = append(idx.Fields, index.Fields...)
|
||||
// TODO priority not supported
|
||||
//sort.Slice(idx.Fields, func(i, j int) bool {
|
||||
// return idx.Fields[i].priority < idx.Fields[j].priority
|
||||
//})
|
||||
|
||||
indexes[index.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return indexes
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat
|
||||
logger.Errorf("Error hashing and salting password: %v", err)
|
||||
return errors.InternalServerError("HASHING_ERROR", "Error hashing password")
|
||||
}
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -20,7 +20,7 @@ func (u *Users) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Delet
|
||||
if len(req.Id) == 0 {
|
||||
return ErrMissingID
|
||||
}
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -1,21 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
gorm2 "github.com/micro/services/pkg/gorm"
|
||||
pb "github.com/micro/services/users/proto"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -66,52 +58,12 @@ type Token struct {
|
||||
}
|
||||
|
||||
type Users struct {
|
||||
sync.RWMutex
|
||||
Time func() time.Time
|
||||
dbConn *sql.DB
|
||||
gormConns map[string]*gorm.DB
|
||||
gorm2.Helper
|
||||
Time func() time.Time
|
||||
}
|
||||
|
||||
func NewHandler(t func() time.Time, dbConn *sql.DB) *Users {
|
||||
return &Users{Time: t, dbConn: dbConn, gormConns: map[string]*gorm.DB{}}
|
||||
}
|
||||
|
||||
func (u *Users) getDBConn(ctx context.Context) (*gorm.DB, error) {
|
||||
acc, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing account from context")
|
||||
}
|
||||
u.RLock()
|
||||
if conn, ok := u.gormConns[acc.Issuer]; ok {
|
||||
u.RUnlock()
|
||||
return conn, nil
|
||||
}
|
||||
u.RUnlock()
|
||||
u.Lock()
|
||||
// double check
|
||||
if conn, ok := u.gormConns[acc.Issuer]; ok {
|
||||
u.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
defer u.Unlock()
|
||||
db, err := gorm.Open(
|
||||
postgres.New(postgres.Config{
|
||||
Conn: u.dbConn,
|
||||
}),
|
||||
&gorm.Config{
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := db.AutoMigrate(&User{}, &Token{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// record success
|
||||
u.gormConns[acc.Issuer] = db
|
||||
return db, nil
|
||||
func NewHandler(t func() time.Time) *Users {
|
||||
return &Users{Time: t}
|
||||
}
|
||||
|
||||
// isEmailValid checks if the email provided passes the required structure and length.
|
||||
|
||||
@@ -31,7 +31,9 @@ func testHandler(t *testing.T) *handler.Users {
|
||||
t.Fatalf("Error cleaning database: %v", err)
|
||||
}
|
||||
|
||||
return handler.NewHandler(time.Now, sqlDB)
|
||||
h := handler.NewHandler(time.Now)
|
||||
h.DBConn(sqlDB).Migrations(&handler.User{}, &handler.Token{})
|
||||
return h
|
||||
}
|
||||
|
||||
func assertUsersMatch(t *testing.T, exp, act *pb.User) {
|
||||
|
||||
@@ -16,7 +16,7 @@ func (u *Users) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListRespo
|
||||
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
// query the database
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -25,7 +25,7 @@ func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginRe
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -21,7 +21,7 @@ func (u *Users) Logout(ctx context.Context, req *pb.LogoutRequest, rsp *pb.Logou
|
||||
return ErrMissingID
|
||||
}
|
||||
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -21,7 +21,7 @@ func (u *Users) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRespo
|
||||
}
|
||||
|
||||
// query the database
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -26,7 +26,7 @@ func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp
|
||||
}
|
||||
|
||||
// query the database
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -40,7 +40,7 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat
|
||||
|
||||
// lookup the user
|
||||
var user User
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -21,7 +21,7 @@ func (u *Users) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.V
|
||||
return ErrMissingToken
|
||||
}
|
||||
|
||||
db, err := u.getDBConn(ctx)
|
||||
db, err := u.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to DB: %v", err)
|
||||
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
|
||||
|
||||
@@ -34,7 +34,9 @@ func main() {
|
||||
if err != nil {
|
||||
logger.Fatalf("Failed to open connection to DB %s", err)
|
||||
}
|
||||
pb.RegisterUsersHandler(srv.Server(), handler.NewHandler(time.Now, sqlDB))
|
||||
h := handler.NewHandler(time.Now)
|
||||
h.DBConn(sqlDB).Migrations(&handler.User{}, &handler.Token{})
|
||||
pb.RegisterUsersHandler(srv.Server(), h)
|
||||
|
||||
// Run service
|
||||
if err := srv.Run(); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user