Multi tenant groups (#77)

* multitenant groups

* switch users service to use new wrapper

* fix tests

* skip pkg dir

* Check for auth
This commit is contained in:
Dominic Wong
2021-03-25 15:53:14 +00:00
committed by GitHub
parent b37cc09835
commit c42aeaa0a9
17 changed files with 592 additions and 125 deletions

View File

@@ -5,8 +5,12 @@ import (
"strings"
"github.com/google/uuid"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/groups/proto"
gorm2 "github.com/micro/services/pkg/gorm"
"gorm.io/gorm"
)
@@ -41,10 +45,14 @@ func (g *Group) Serialize() *pb.Group {
}
type Groups struct {
DB *gorm.DB
gorm2.Helper
}
func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Name) == 0 {
return ErrMissingName
@@ -52,7 +60,13 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
// create the group object
group := &Group{ID: uuid.New().String(), Name: req.Name}
if err := g.DB.Create(group).Error; err != nil {
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if err := db.Create(group).Error; err != nil {
return ErrStore
}
@@ -62,14 +76,23 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
}
func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Ids) == 0 {
return ErrMissingIDs
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// query the database
var groups []Group
if err := g.DB.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
if err := db.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
return ErrStore
}
@@ -83,6 +106,10 @@ func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResp
}
func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Id) == 0 {
return ErrMissingID
@@ -90,8 +117,13 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
if len(req.Name) == 0 {
return ErrMissingName
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return g.DB.Transaction(func(tx *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// find the group
var group Group
if err := tx.Where(&Group{ID: req.Id}).First(&group).Error; err == gorm.ErrRecordNotFound {
@@ -113,13 +145,22 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
}
func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Id) == 0 {
return ErrMissingID
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// delete from the database
if err := g.DB.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
if err := db.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
return nil
} else if err != nil {
return ErrStore
@@ -129,10 +170,19 @@ func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Dele
}
func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if len(req.MemberId) > 0 {
// only list groups the user is a member of
var ms []Membership
q := g.DB.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships")
q := db.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships")
if err := q.Find(&ms).Error; err != nil {
return err
}
@@ -145,7 +195,7 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
// load all groups
var groups []Group
if err := g.DB.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil {
if err := db.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil {
return ErrStore
}
@@ -159,6 +209,10 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
}
func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *pb.AddMemberResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.GroupId) == 0 {
return ErrMissingGroupID
@@ -166,8 +220,13 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
if len(req.MemberId) == 0 {
return ErrMissingMemberID
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return g.DB.Transaction(func(tx *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// check the group exists
var group Group
if err := tx.Where(&Group{ID: req.GroupId}).First(&group).Error; err == gorm.ErrRecordNotFound {
@@ -191,6 +250,10 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
}
func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, rsp *pb.RemoveMemberResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.GroupId) == 0 {
return ErrMissingGroupID
@@ -199,9 +262,14 @@ func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest,
return ErrMissingMemberID
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// delete the membership
m := &Membership{MemberID: req.MemberId, GroupID: req.GroupId}
if err := g.DB.Where(m).Delete(m).Error; err != nil {
if err := db.Where(m).Delete(m).Error; err != nil {
return ErrStore
}

View File

@@ -2,16 +2,16 @@ package handler_test
import (
"context"
"database/sql"
"os"
"sort"
"testing"
"github.com/google/uuid"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/services/groups/handler"
pb "github.com/micro/services/groups/proto"
"github.com/stretchr/testify/assert"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func testHandler(t *testing.T) *handler.Groups {
@@ -20,33 +20,30 @@ func testHandler(t *testing.T) *handler.Groups {
if len(addr) == 0 {
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
}
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
sqlDB, err := sql.Open("pgx", addr)
if err != nil {
t.Fatalf("Error connecting to database: %v", err)
t.Fatalf("Failed to open connection to DB %s", err)
}
// clean any data from a previous run
if err := db.Exec("DROP TABLE IF EXISTS groups, memberships CASCADE").Error; err != nil {
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_groups, micro_memberships CASCADE"); err != nil {
t.Fatalf("Error cleaning database: %v", err)
}
// migrate the database
if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil {
t.Fatalf("Error migrating database: %v", err)
}
return &handler.Groups{DB: db}
h := &handler.Groups{}
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
return h
}
func TestCreate(t *testing.T) {
h := testHandler(t)
t.Run("MissingName", func(t *testing.T) {
err := h.Create(context.TODO(), &pb.CreateRequest{}, &pb.CreateResponse{})
err := h.Create(microAccountCtx(), &pb.CreateRequest{}, &pb.CreateResponse{})
assert.Equal(t, handler.ErrMissingName, err)
})
t.Run("Valid", func(t *testing.T) {
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Doe Family Group",
}, &pb.CreateResponse{})
assert.NoError(t, err)
@@ -57,21 +54,21 @@ func TestUpdate(t *testing.T) {
h := testHandler(t)
t.Run("MissingID", func(t *testing.T) {
err := h.Update(context.TODO(), &pb.UpdateRequest{
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Name: "Doe Family Group",
}, &pb.UpdateResponse{})
assert.Equal(t, handler.ErrMissingID, err)
})
t.Run("MissingName", func(t *testing.T) {
err := h.Update(context.TODO(), &pb.UpdateRequest{
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: uuid.New().String(),
}, &pb.UpdateResponse{})
assert.Equal(t, handler.ErrMissingName, err)
})
t.Run("NotFound", func(t *testing.T) {
err := h.Update(context.TODO(), &pb.UpdateRequest{
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: uuid.New().String(),
Name: "Bar Family Group",
}, &pb.UpdateResponse{})
@@ -81,19 +78,19 @@ func TestUpdate(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
// create a demo group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Doe Family Group",
}, &cRsp)
assert.NoError(t, err)
err = h.Update(context.TODO(), &pb.UpdateRequest{
err = h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp.Group.Id,
Name: "Bar Family Group",
}, &pb.UpdateResponse{})
assert.NoError(t, err)
var rRsp pb.ReadResponse
err = h.Read(context.TODO(), &pb.ReadRequest{
err = h.Read(microAccountCtx(), &pb.ReadRequest{
Ids: []string{cRsp.Group.Id},
}, &rRsp)
assert.NoError(t, err)
@@ -111,12 +108,12 @@ func TestDelete(t *testing.T) {
h := testHandler(t)
t.Run("MissingID", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
assert.Equal(t, handler.ErrMissingID, err)
})
t.Run("NotFound", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
Id: uuid.New().String(),
}, &pb.DeleteResponse{})
assert.NoError(t, err)
@@ -124,19 +121,19 @@ func TestDelete(t *testing.T) {
// create a demo group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Doe Family Group",
}, &cRsp)
assert.NoError(t, err)
t.Run("Valid", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
Id: cRsp.Group.Id,
}, &pb.DeleteResponse{})
assert.NoError(t, err)
var rRsp pb.ReadResponse
err = h.Read(context.TODO(), &pb.ReadRequest{
err = h.Read(microAccountCtx(), &pb.ReadRequest{
Ids: []string{cRsp.Group.Id},
}, &rRsp)
assert.Nil(t, rRsp.Groups[cRsp.Group.Id])
@@ -147,27 +144,27 @@ func TestList(t *testing.T) {
// create two demo groups
var cRsp1 pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Alpha Group",
}, &cRsp1)
assert.NoError(t, err)
var cRsp2 pb.CreateResponse
err = h.Create(context.TODO(), &pb.CreateRequest{
err = h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Bravo Group",
}, &cRsp2)
assert.NoError(t, err)
// add a member to the first group
uid := uuid.New().String()
err = h.AddMember(context.TODO(), &pb.AddMemberRequest{
err = h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: cRsp1.Group.Id, MemberId: uid,
}, &pb.AddMemberResponse{})
assert.NoError(t, err)
t.Run("Unscoped", func(t *testing.T) {
var rsp pb.ListResponse
err = h.List(context.TODO(), &pb.ListRequest{}, &rsp)
err = h.List(microAccountCtx(), &pb.ListRequest{}, &rsp)
assert.NoError(t, err)
assert.Lenf(t, rsp.Groups, 2, "Two groups should be returned")
if len(rsp.Groups) != 2 {
@@ -188,13 +185,12 @@ func TestList(t *testing.T) {
t.Run("Scoped", func(t *testing.T) {
var rsp pb.ListResponse
err = h.List(context.TODO(), &pb.ListRequest{MemberId: uid}, &rsp)
err = h.List(microAccountCtx(), &pb.ListRequest{MemberId: uid}, &rsp)
assert.NoError(t, err)
assert.Lenf(t, rsp.Groups, 1, "One group should be returned")
if len(rsp.Groups) != 1 {
return
}
assert.Equal(t, cRsp1.Group.Id, rsp.Groups[0].Id)
assert.Equal(t, cRsp1.Group.Name, rsp.Groups[0].Name)
assert.Len(t, rsp.Groups[0].MemberIds, 1)
@@ -206,21 +202,21 @@ func TestAddMember(t *testing.T) {
h := testHandler(t)
t.Run("MissingGroupID", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
assert.Equal(t, handler.ErrMissingGroupID, err)
})
t.Run("MissingMemberID", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: uuid.New().String(),
}, &pb.AddMemberResponse{})
assert.Equal(t, handler.ErrMissingMemberID, err)
})
t.Run("GroupNotFound", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: uuid.New().String(),
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
@@ -229,13 +225,13 @@ func TestAddMember(t *testing.T) {
// create a test group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Alpha Group",
}, &cRsp)
assert.NoError(t, err)
t.Run("Valid", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
@@ -243,7 +239,7 @@ func TestAddMember(t *testing.T) {
})
t.Run("Retry", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
@@ -255,14 +251,14 @@ func TestRemoveMember(t *testing.T) {
h := testHandler(t)
t.Run("MissingGroupID", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
MemberId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
assert.Equal(t, handler.ErrMissingGroupID, err)
})
t.Run("MissingMemberID", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
GroupId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
assert.Equal(t, handler.ErrMissingMemberID, err)
@@ -270,13 +266,13 @@ func TestRemoveMember(t *testing.T) {
// create a test group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Alpha Group",
}, &cRsp)
assert.NoError(t, err)
t.Run("Valid", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
@@ -284,10 +280,16 @@ func TestRemoveMember(t *testing.T) {
})
t.Run("Retry", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
assert.NoError(t, err)
})
}
func microAccountCtx() context.Context {
return auth.ContextWithAccount(context.TODO(), &auth.Account{
Issuer: "micro",
})
}

View File

@@ -1,14 +1,13 @@
package main
import (
"github.com/micro/services/groups/handler"
pb "github.com/micro/services/groups/proto"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"database/sql"
"github.com/micro/micro/v3/service"
"github.com/micro/micro/v3/service/config"
"github.com/micro/micro/v3/service/logger"
"github.com/micro/services/groups/handler"
pb "github.com/micro/services/groups/proto"
)
var dbAddress = "postgresql://postgres:postgres@localhost:5432/groups?sslmode=disable"
@@ -26,16 +25,14 @@ func main() {
logger.Fatalf("Error loading config: %v", err)
}
addr := cfg.String(dbAddress)
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
sqlDB, err := sql.Open("pgx", addr)
if err != nil {
logger.Fatalf("Error connecting to database: %v", err)
logger.Fatalf("Failed to open connection to DB %s", err)
}
if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil {
logger.Fatalf("Error migrating database: %v", err)
}
h := &handler.Groups{}
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
// Register handler
pb.RegisterGroupsHandler(srv.Server(), &handler.Groups{DB: db.Debug()})
pb.RegisterGroupsHandler(srv.Server(), h)
// Run service
if err := srv.Run(); err != nil {

444
pkg/gorm/wrapper.go Normal file
View File

@@ -0,0 +1,444 @@
package gorm
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"sync"
"github.com/micro/micro/v3/service/auth"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Helper struct {
sync.RWMutex
gormConns map[string]*gorm.DB
dbConn *sql.DB
migrations []interface{}
}
func (h *Helper) Migrations(migrations ...interface{}) *Helper {
h.migrations = migrations
return h
}
func (h *Helper) DBConn(conn *sql.DB) *Helper {
h.dbConn = conn
h.gormConns = map[string]*gorm.DB{}
return h
}
func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
acc, ok := auth.AccountFromContext(ctx)
if !ok {
return nil, fmt.Errorf("missing account from context")
}
h.RLock()
if conn, ok := h.gormConns[acc.Issuer]; ok {
h.RUnlock()
return conn, nil
}
h.RUnlock()
h.Lock()
// double check
if conn, ok := h.gormConns[acc.Issuer]; ok {
h.Unlock()
return conn, nil
}
defer h.Unlock()
ns := schema.NamingStrategy{
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")),
}
db, err := gorm.Open(
newGormDialector(postgres.Config{
Conn: h.dbConn,
}, ns),
&gorm.Config{
NamingStrategy: ns,
})
if err != nil {
return nil, err
}
if len(h.migrations) == 0 {
// record success
h.gormConns[acc.Issuer] = db
return db, nil
}
if err := db.AutoMigrate(h.migrations...); err != nil {
return nil, err
}
// record success
h.gormConns[acc.Issuer] = db
return db, nil
}
func newGormDialector(config postgres.Config, ns schema.NamingStrategy) gorm.Dialector {
return &postgresDial{
Dialector: postgres.Dialector{Config: &config},
namer: &ns,
}
}
// postgresDial is a postgres dialector that prefixes index names with the table prefix when doing migrations.
// NOTE, it does not support the gorm tag priority option
type postgresDial struct {
postgres.Dialector
namer schema.Namer
}
func (p postgresDial) Migrator(db *gorm.DB) gorm.Migrator {
return gormMigrator{
postgres.Migrator{
migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: p,
CreateIndexAfterCreateTable: true,
}},
},
p.namer,
}
}
type gormMigrator struct {
postgres.Migrator
namer schema.Namer
}
// AutoMigrate
func (m gormMigrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{NewDB: true})
if !tx.Migrator().HasTable(value) {
if err := tx.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, field := range stmt.Schema.FieldsByDBName {
var foundColumn gorm.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
return err
}
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
if !tx.Migrator().HasConstraint(value, constraint.Name) {
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
}
}
}
}
for _, idx := range m.ParseIndexes(stmt.Schema) {
if !tx.Migrator().HasIndex(value, idx.Name) {
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
return nil
}); err != nil {
return err
}
}
}
return nil
}
func (m gormMigrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := m.LookIndex(stmt.Schema, name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
createIndexSQL += " ON ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m gormMigrator) LookIndex(sch *schema.Schema, name string) *schema.Index {
if sch != nil {
indexes := m.ParseIndexes(sch)
for _, index := range indexes {
if index.Name == name {
return &index
}
for _, field := range index.Fields {
if field.Name == name {
return &index
}
}
}
}
return nil
}
func (g gormMigrator) parseFieldIndexes(field *schema.Field) (indexes []schema.Index) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
settings = schema.ParseTagSetting(tag, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
idx = len(tag)
}
if idx != -1 {
name = tag[0:idx]
}
if name == "" {
name = g.namer.IndexName(field.Schema.Table, field.Name)
} else {
ns := g.namer.(*schema.NamingStrategy)
name = fmt.Sprintf("%s%s", ns.TablePrefix, name)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
settings["CLASS"] = "UNIQUE"
}
//priority, err := strconv.Atoi(settings["PRIORITY"])
//if err != nil {
// priority = 10
//}
indexes = append(indexes, schema.Index{
Name: name,
Class: settings["CLASS"],
Type: settings["TYPE"],
Where: settings["WHERE"],
Comment: settings["COMMENT"],
Option: settings["OPTION"],
Fields: []schema.IndexOption{{
Field: field,
Expression: settings["EXPRESSION"],
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
//priority: priority, // TODO does not support priority
}},
})
}
}
}
return
}
func (m gormMigrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{NewDB: true})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range m.ParseIndexes(stmt.Schema) {
if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) {
errr = tx.Migrator().CreateIndex(value, name)
}(value, idx.Name)
} else {
if idx.Class != "" {
createTableSQL += idx.Class + " "
}
createTableSQL += "INDEX ? ?"
if idx.Option != "" {
createTableSQL += " " + idx.Option
}
createTableSQL += ","
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
createTableSQL += fmt.Sprint(tableOption)
}
errr = tx.Exec(createTableSQL, values...).Error
return errr
}); err != nil {
return err
}
}
return nil
}
type gormIndexOption struct {
schema.IndexOption
priority int
}
func (g gormMigrator) ParseIndexes(sch *schema.Schema) map[string]schema.Index {
var indexes = map[string]schema.Index{}
for _, field := range sch.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range g.parseFieldIndexes(field) {
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
}
if idx.Type == "" {
idx.Type = index.Type
}
if idx.Where == "" {
idx.Where = index.Where
}
if idx.Comment == "" {
idx.Comment = index.Comment
}
if idx.Option == "" {
idx.Option = index.Option
}
idx.Fields = append(idx.Fields, index.Fields...)
// TODO priority not supported
//sort.Slice(idx.Fields, func(i, j int) bool {
// return idx.Fields[i].priority < idx.Fields[j].priority
//})
indexes[index.Name] = idx
}
}
}
return indexes
}
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}

0
pkg/skip Normal file
View File

View File

@@ -44,7 +44,7 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat
logger.Errorf("Error hashing and salting password: %v", err)
return errors.InternalServerError("HASHING_ERROR", "Error hashing password")
}
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -20,7 +20,7 @@ func (u *Users) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Delet
if len(req.Id) == 0 {
return ErrMissingID
}
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -1,21 +1,13 @@
package handler
import (
"context"
"database/sql"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors"
gorm2 "github.com/micro/services/pkg/gorm"
pb "github.com/micro/services/users/proto"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"
)
var (
@@ -66,52 +58,12 @@ type Token struct {
}
type Users struct {
sync.RWMutex
Time func() time.Time
dbConn *sql.DB
gormConns map[string]*gorm.DB
gorm2.Helper
Time func() time.Time
}
func NewHandler(t func() time.Time, dbConn *sql.DB) *Users {
return &Users{Time: t, dbConn: dbConn, gormConns: map[string]*gorm.DB{}}
}
func (u *Users) getDBConn(ctx context.Context) (*gorm.DB, error) {
acc, ok := auth.AccountFromContext(ctx)
if !ok {
return nil, fmt.Errorf("missing account from context")
}
u.RLock()
if conn, ok := u.gormConns[acc.Issuer]; ok {
u.RUnlock()
return conn, nil
}
u.RUnlock()
u.Lock()
// double check
if conn, ok := u.gormConns[acc.Issuer]; ok {
u.Unlock()
return conn, nil
}
defer u.Unlock()
db, err := gorm.Open(
postgres.New(postgres.Config{
Conn: u.dbConn,
}),
&gorm.Config{
NamingStrategy: schema.NamingStrategy{
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")),
},
})
if err != nil {
return nil, err
}
if err := db.AutoMigrate(&User{}, &Token{}); err != nil {
return nil, err
}
// record success
u.gormConns[acc.Issuer] = db
return db, nil
func NewHandler(t func() time.Time) *Users {
return &Users{Time: t}
}
// isEmailValid checks if the email provided passes the required structure and length.

View File

@@ -31,7 +31,9 @@ func testHandler(t *testing.T) *handler.Users {
t.Fatalf("Error cleaning database: %v", err)
}
return handler.NewHandler(time.Now, sqlDB)
h := handler.NewHandler(time.Now)
h.DBConn(sqlDB).Migrations(&handler.User{}, &handler.Token{})
return h
}
func assertUsersMatch(t *testing.T, exp, act *pb.User) {

View File

@@ -16,7 +16,7 @@ func (u *Users) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListRespo
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// query the database
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -25,7 +25,7 @@ func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginRe
return ErrInvalidPassword
}
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -21,7 +21,7 @@ func (u *Users) Logout(ctx context.Context, req *pb.LogoutRequest, rsp *pb.Logou
return ErrMissingID
}
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -21,7 +21,7 @@ func (u *Users) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRespo
}
// query the database
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -26,7 +26,7 @@ func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp
}
// query the database
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -40,7 +40,7 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat
// lookup the user
var user User
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -21,7 +21,7 @@ func (u *Users) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.V
return ErrMissingToken
}
db, err := u.getDBConn(ctx)
db, err := u.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")

View File

@@ -34,7 +34,9 @@ func main() {
if err != nil {
logger.Fatalf("Failed to open connection to DB %s", err)
}
pb.RegisterUsersHandler(srv.Server(), handler.NewHandler(time.Now, sqlDB))
h := handler.NewHandler(time.Now)
h.DBConn(sqlDB).Migrations(&handler.User{}, &handler.Token{})
pb.RegisterUsersHandler(srv.Server(), h)
// Run service
if err := srv.Run(); err != nil {