diff --git a/groups/handler/handler.go b/groups/handler/handler.go index 45da8ce..df4f55a 100644 --- a/groups/handler/handler.go +++ b/groups/handler/handler.go @@ -5,8 +5,12 @@ import ( "strings" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" + "github.com/micro/micro/v3/service/logger" pb "github.com/micro/services/groups/proto" + gorm2 "github.com/micro/services/pkg/gorm" + "gorm.io/gorm" ) @@ -41,10 +45,14 @@ func (g *Group) Serialize() *pb.Group { } type Groups struct { - DB *gorm.DB + gorm2.Helper } func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Name) == 0 { return ErrMissingName @@ -52,7 +60,13 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea // create the group object group := &Group{ID: uuid.New().String(), Name: req.Name} - if err := g.DB.Create(group).Error; err != nil { + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } + + if err := db.Create(group).Error; err != nil { return ErrStore } @@ -62,14 +76,23 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea } func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Ids) == 0 { return ErrMissingIDs } + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // query the database var groups []Group - if err := g.DB.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil { + if err := db.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil { return ErrStore } @@ -83,6 +106,10 @@ func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResp } func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID @@ -90,8 +117,13 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda if len(req.Name) == 0 { return ErrMissingName } + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } - return g.DB.Transaction(func(tx *gorm.DB) error { + return db.Transaction(func(tx *gorm.DB) error { // find the group var group Group if err := tx.Where(&Group{ID: req.Id}).First(&group).Error; err == gorm.ErrRecordNotFound { @@ -113,13 +145,22 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda } func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.Id) == 0 { return ErrMissingID } + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // delete from the database - if err := g.DB.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound { + if err := db.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound { return nil } else if err != nil { return ErrStore @@ -129,10 +170,19 @@ func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Dele } func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } if len(req.MemberId) > 0 { // only list groups the user is a member of var ms []Membership - q := g.DB.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships") + q := db.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships") if err := q.Find(&ms).Error; err != nil { return err } @@ -145,7 +195,7 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp // load all groups var groups []Group - if err := g.DB.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil { + if err := db.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil { return ErrStore } @@ -159,6 +209,10 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp } func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *pb.AddMemberResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.GroupId) == 0 { return ErrMissingGroupID @@ -166,8 +220,13 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p if len(req.MemberId) == 0 { return ErrMissingMemberID } + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } - return g.DB.Transaction(func(tx *gorm.DB) error { + return db.Transaction(func(tx *gorm.DB) error { // check the group exists var group Group if err := tx.Where(&Group{ID: req.GroupId}).First(&group).Error; err == gorm.ErrRecordNotFound { @@ -191,6 +250,10 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p } func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, rsp *pb.RemoveMemberResponse) error { + _, ok := auth.AccountFromContext(ctx) + if !ok { + errors.Unauthorized("UNAUTHORIZED", "Unauthorized") + } // validate the request if len(req.GroupId) == 0 { return ErrMissingGroupID @@ -199,9 +262,14 @@ func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, return ErrMissingMemberID } + db, err := g.GetDBConn(ctx) + if err != nil { + logger.Errorf("Error connecting to DB: %v", err) + return errors.InternalServerError("DB_ERROR", "Error connecting to DB") + } // delete the membership m := &Membership{MemberID: req.MemberId, GroupID: req.GroupId} - if err := g.DB.Where(m).Delete(m).Error; err != nil { + if err := db.Where(m).Delete(m).Error; err != nil { return ErrStore } diff --git a/groups/handler/handler_test.go b/groups/handler/handler_test.go index d927584..026409b 100644 --- a/groups/handler/handler_test.go +++ b/groups/handler/handler_test.go @@ -2,16 +2,16 @@ package handler_test import ( "context" + "database/sql" "os" "sort" "testing" "github.com/google/uuid" + "github.com/micro/micro/v3/service/auth" "github.com/micro/services/groups/handler" pb "github.com/micro/services/groups/proto" "github.com/stretchr/testify/assert" - "gorm.io/driver/postgres" - "gorm.io/gorm" ) func testHandler(t *testing.T) *handler.Groups { @@ -20,33 +20,30 @@ func testHandler(t *testing.T) *handler.Groups { if len(addr) == 0 { addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable" } - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - t.Fatalf("Error connecting to database: %v", err) + t.Fatalf("Failed to open connection to DB %s", err) } // clean any data from a previous run - if err := db.Exec("DROP TABLE IF EXISTS groups, memberships CASCADE").Error; err != nil { + if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_groups, micro_memberships CASCADE"); err != nil { t.Fatalf("Error cleaning database: %v", err) } - // migrate the database - if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil { - t.Fatalf("Error migrating database: %v", err) - } - - return &handler.Groups{DB: db} + h := &handler.Groups{} + h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{}) + return h } func TestCreate(t *testing.T) { h := testHandler(t) t.Run("MissingName", func(t *testing.T) { - err := h.Create(context.TODO(), &pb.CreateRequest{}, &pb.CreateResponse{}) + err := h.Create(microAccountCtx(), &pb.CreateRequest{}, &pb.CreateResponse{}) assert.Equal(t, handler.ErrMissingName, err) }) t.Run("Valid", func(t *testing.T) { - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Doe Family Group", }, &pb.CreateResponse{}) assert.NoError(t, err) @@ -57,21 +54,21 @@ func TestUpdate(t *testing.T) { h := testHandler(t) t.Run("MissingID", func(t *testing.T) { - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Name: "Doe Family Group", }, &pb.UpdateResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) t.Run("MissingName", func(t *testing.T) { - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: uuid.New().String(), }, &pb.UpdateResponse{}) assert.Equal(t, handler.ErrMissingName, err) }) t.Run("NotFound", func(t *testing.T) { - err := h.Update(context.TODO(), &pb.UpdateRequest{ + err := h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: uuid.New().String(), Name: "Bar Family Group", }, &pb.UpdateResponse{}) @@ -81,19 +78,19 @@ func TestUpdate(t *testing.T) { t.Run("Valid", func(t *testing.T) { // create a demo group var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Doe Family Group", }, &cRsp) assert.NoError(t, err) - err = h.Update(context.TODO(), &pb.UpdateRequest{ + err = h.Update(microAccountCtx(), &pb.UpdateRequest{ Id: cRsp.Group.Id, Name: "Bar Family Group", }, &pb.UpdateResponse{}) assert.NoError(t, err) var rRsp pb.ReadResponse - err = h.Read(context.TODO(), &pb.ReadRequest{ + err = h.Read(microAccountCtx(), &pb.ReadRequest{ Ids: []string{cRsp.Group.Id}, }, &rRsp) assert.NoError(t, err) @@ -111,12 +108,12 @@ func TestDelete(t *testing.T) { h := testHandler(t) t.Run("MissingID", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{}) assert.Equal(t, handler.ErrMissingID, err) }) t.Run("NotFound", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{ + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{ Id: uuid.New().String(), }, &pb.DeleteResponse{}) assert.NoError(t, err) @@ -124,19 +121,19 @@ func TestDelete(t *testing.T) { // create a demo group var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Doe Family Group", }, &cRsp) assert.NoError(t, err) t.Run("Valid", func(t *testing.T) { - err := h.Delete(context.TODO(), &pb.DeleteRequest{ + err := h.Delete(microAccountCtx(), &pb.DeleteRequest{ Id: cRsp.Group.Id, }, &pb.DeleteResponse{}) assert.NoError(t, err) var rRsp pb.ReadResponse - err = h.Read(context.TODO(), &pb.ReadRequest{ + err = h.Read(microAccountCtx(), &pb.ReadRequest{ Ids: []string{cRsp.Group.Id}, }, &rRsp) assert.Nil(t, rRsp.Groups[cRsp.Group.Id]) @@ -147,27 +144,27 @@ func TestList(t *testing.T) { // create two demo groups var cRsp1 pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Alpha Group", }, &cRsp1) assert.NoError(t, err) var cRsp2 pb.CreateResponse - err = h.Create(context.TODO(), &pb.CreateRequest{ + err = h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Bravo Group", }, &cRsp2) assert.NoError(t, err) // add a member to the first group uid := uuid.New().String() - err = h.AddMember(context.TODO(), &pb.AddMemberRequest{ + err = h.AddMember(microAccountCtx(), &pb.AddMemberRequest{ GroupId: cRsp1.Group.Id, MemberId: uid, }, &pb.AddMemberResponse{}) assert.NoError(t, err) t.Run("Unscoped", func(t *testing.T) { var rsp pb.ListResponse - err = h.List(context.TODO(), &pb.ListRequest{}, &rsp) + err = h.List(microAccountCtx(), &pb.ListRequest{}, &rsp) assert.NoError(t, err) assert.Lenf(t, rsp.Groups, 2, "Two groups should be returned") if len(rsp.Groups) != 2 { @@ -188,13 +185,12 @@ func TestList(t *testing.T) { t.Run("Scoped", func(t *testing.T) { var rsp pb.ListResponse - err = h.List(context.TODO(), &pb.ListRequest{MemberId: uid}, &rsp) + err = h.List(microAccountCtx(), &pb.ListRequest{MemberId: uid}, &rsp) assert.NoError(t, err) assert.Lenf(t, rsp.Groups, 1, "One group should be returned") if len(rsp.Groups) != 1 { return } - assert.Equal(t, cRsp1.Group.Id, rsp.Groups[0].Id) assert.Equal(t, cRsp1.Group.Name, rsp.Groups[0].Name) assert.Len(t, rsp.Groups[0].MemberIds, 1) @@ -206,21 +202,21 @@ func TestAddMember(t *testing.T) { h := testHandler(t) t.Run("MissingGroupID", func(t *testing.T) { - err := h.AddMember(context.TODO(), &pb.AddMemberRequest{ + err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{ MemberId: uuid.New().String(), }, &pb.AddMemberResponse{}) assert.Equal(t, handler.ErrMissingGroupID, err) }) t.Run("MissingMemberID", func(t *testing.T) { - err := h.AddMember(context.TODO(), &pb.AddMemberRequest{ + err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{ GroupId: uuid.New().String(), }, &pb.AddMemberResponse{}) assert.Equal(t, handler.ErrMissingMemberID, err) }) t.Run("GroupNotFound", func(t *testing.T) { - err := h.AddMember(context.TODO(), &pb.AddMemberRequest{ + err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{ GroupId: uuid.New().String(), MemberId: uuid.New().String(), }, &pb.AddMemberResponse{}) @@ -229,13 +225,13 @@ func TestAddMember(t *testing.T) { // create a test group var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Alpha Group", }, &cRsp) assert.NoError(t, err) t.Run("Valid", func(t *testing.T) { - err := h.AddMember(context.TODO(), &pb.AddMemberRequest{ + err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{ GroupId: cRsp.Group.Id, MemberId: uuid.New().String(), }, &pb.AddMemberResponse{}) @@ -243,7 +239,7 @@ func TestAddMember(t *testing.T) { }) t.Run("Retry", func(t *testing.T) { - err := h.AddMember(context.TODO(), &pb.AddMemberRequest{ + err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{ GroupId: cRsp.Group.Id, MemberId: uuid.New().String(), }, &pb.AddMemberResponse{}) @@ -255,14 +251,14 @@ func TestRemoveMember(t *testing.T) { h := testHandler(t) t.Run("MissingGroupID", func(t *testing.T) { - err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{ + err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{ MemberId: uuid.New().String(), }, &pb.RemoveMemberResponse{}) assert.Equal(t, handler.ErrMissingGroupID, err) }) t.Run("MissingMemberID", func(t *testing.T) { - err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{ + err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{ GroupId: uuid.New().String(), }, &pb.RemoveMemberResponse{}) assert.Equal(t, handler.ErrMissingMemberID, err) @@ -270,13 +266,13 @@ func TestRemoveMember(t *testing.T) { // create a test group var cRsp pb.CreateResponse - err := h.Create(context.TODO(), &pb.CreateRequest{ + err := h.Create(microAccountCtx(), &pb.CreateRequest{ Name: "Alpha Group", }, &cRsp) assert.NoError(t, err) t.Run("Valid", func(t *testing.T) { - err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{ + err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{ GroupId: cRsp.Group.Id, MemberId: uuid.New().String(), }, &pb.RemoveMemberResponse{}) @@ -284,10 +280,16 @@ func TestRemoveMember(t *testing.T) { }) t.Run("Retry", func(t *testing.T) { - err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{ + err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{ GroupId: cRsp.Group.Id, MemberId: uuid.New().String(), }, &pb.RemoveMemberResponse{}) assert.NoError(t, err) }) } + +func microAccountCtx() context.Context { + return auth.ContextWithAccount(context.TODO(), &auth.Account{ + Issuer: "micro", + }) +} diff --git a/groups/main.go b/groups/main.go index 44663b8..f4243da 100644 --- a/groups/main.go +++ b/groups/main.go @@ -1,14 +1,13 @@ package main import ( - "github.com/micro/services/groups/handler" - pb "github.com/micro/services/groups/proto" - "gorm.io/driver/postgres" - "gorm.io/gorm" + "database/sql" "github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service/config" "github.com/micro/micro/v3/service/logger" + "github.com/micro/services/groups/handler" + pb "github.com/micro/services/groups/proto" ) var dbAddress = "postgresql://postgres:postgres@localhost:5432/groups?sslmode=disable" @@ -26,16 +25,14 @@ func main() { logger.Fatalf("Error loading config: %v", err) } addr := cfg.String(dbAddress) - db, err := gorm.Open(postgres.Open(addr), &gorm.Config{}) + sqlDB, err := sql.Open("pgx", addr) if err != nil { - logger.Fatalf("Error connecting to database: %v", err) + logger.Fatalf("Failed to open connection to DB %s", err) } - if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil { - logger.Fatalf("Error migrating database: %v", err) - } - + h := &handler.Groups{} + h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{}) // Register handler - pb.RegisterGroupsHandler(srv.Server(), &handler.Groups{DB: db.Debug()}) + pb.RegisterGroupsHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil { diff --git a/pkg/gorm/wrapper.go b/pkg/gorm/wrapper.go new file mode 100644 index 0000000..3c10caa --- /dev/null +++ b/pkg/gorm/wrapper.go @@ -0,0 +1,444 @@ +package gorm + +import ( + "context" + "database/sql" + "fmt" + "strconv" + "strings" + "sync" + + "github.com/micro/micro/v3/service/auth" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +type Helper struct { + sync.RWMutex + gormConns map[string]*gorm.DB + dbConn *sql.DB + migrations []interface{} +} + +func (h *Helper) Migrations(migrations ...interface{}) *Helper { + h.migrations = migrations + return h +} + +func (h *Helper) DBConn(conn *sql.DB) *Helper { + h.dbConn = conn + h.gormConns = map[string]*gorm.DB{} + return h +} + +func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) { + acc, ok := auth.AccountFromContext(ctx) + if !ok { + return nil, fmt.Errorf("missing account from context") + } + h.RLock() + if conn, ok := h.gormConns[acc.Issuer]; ok { + h.RUnlock() + return conn, nil + } + h.RUnlock() + h.Lock() + // double check + if conn, ok := h.gormConns[acc.Issuer]; ok { + h.Unlock() + return conn, nil + } + defer h.Unlock() + ns := schema.NamingStrategy{ + TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")), + } + db, err := gorm.Open( + newGormDialector(postgres.Config{ + Conn: h.dbConn, + }, ns), + &gorm.Config{ + NamingStrategy: ns, + }) + if err != nil { + return nil, err + } + if len(h.migrations) == 0 { + // record success + h.gormConns[acc.Issuer] = db + return db, nil + } + + if err := db.AutoMigrate(h.migrations...); err != nil { + return nil, err + } + + // record success + h.gormConns[acc.Issuer] = db + return db, nil +} + +func newGormDialector(config postgres.Config, ns schema.NamingStrategy) gorm.Dialector { + return &postgresDial{ + Dialector: postgres.Dialector{Config: &config}, + namer: &ns, + } +} + +// postgresDial is a postgres dialector that prefixes index names with the table prefix when doing migrations. +// NOTE, it does not support the gorm tag priority option +type postgresDial struct { + postgres.Dialector + namer schema.Namer +} + +func (p postgresDial) Migrator(db *gorm.DB) gorm.Migrator { + return gormMigrator{ + postgres.Migrator{ + migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: p, + CreateIndexAfterCreateTable: true, + }}, + }, + p.namer, + } +} + +type gormMigrator struct { + postgres.Migrator + namer schema.Namer +} + +// AutoMigrate +func (m gormMigrator) AutoMigrate(values ...interface{}) error { + for _, value := range m.ReorderModels(values, true) { + tx := m.DB.Session(&gorm.Session{NewDB: true}) + if !tx.Migrator().HasTable(value) { + if err := tx.Migrator().CreateTable(value); err != nil { + return err + } + } else { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + + for _, field := range stmt.Schema.FieldsByDBName { + var foundColumn gorm.ColumnType + + for _, columnType := range columnTypes { + if columnType.Name() == field.DBName { + foundColumn = columnType + break + } + } + + if foundColumn == nil { + // not found, add column + if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { + return err + } + } else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { + // found, smart migrate + return err + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + if !tx.Migrator().HasConstraint(value, constraint.Name) { + if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil { + return err + } + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + if !tx.Migrator().HasConstraint(value, chk.Name) { + if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil { + return err + } + } + } + } + + for _, idx := range m.ParseIndexes(stmt.Schema) { + if !tx.Migrator().HasIndex(value, idx.Name) { + if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil { + return err + } + } + } + + return nil + }); err != nil { + return err + } + } + } + + return nil +} + +func (m gormMigrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := m.LookIndex(stmt.Schema, name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + createIndexSQL += " ON ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + "(?)" + } else { + createIndexSQL += " ?" + } + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m gormMigrator) LookIndex(sch *schema.Schema, name string) *schema.Index { + if sch != nil { + indexes := m.ParseIndexes(sch) + for _, index := range indexes { + if index.Name == name { + return &index + } + + for _, field := range index.Fields { + if field.Name == name { + return &index + } + } + } + } + + return nil +} + +func (g gormMigrator) parseFieldIndexes(field *schema.Field) (indexes []schema.Index) { + for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { + if value != "" { + v := strings.Split(value, ":") + k := strings.TrimSpace(strings.ToUpper(v[0])) + if k == "INDEX" || k == "UNIQUEINDEX" { + var ( + name string + tag = strings.Join(v[1:], ":") + idx = strings.Index(tag, ",") + settings = schema.ParseTagSetting(tag, ",") + length, _ = strconv.Atoi(settings["LENGTH"]) + ) + + if idx == -1 { + idx = len(tag) + } + + if idx != -1 { + name = tag[0:idx] + } + + if name == "" { + name = g.namer.IndexName(field.Schema.Table, field.Name) + } else { + ns := g.namer.(*schema.NamingStrategy) + name = fmt.Sprintf("%s%s", ns.TablePrefix, name) + } + + if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { + settings["CLASS"] = "UNIQUE" + } + + //priority, err := strconv.Atoi(settings["PRIORITY"]) + //if err != nil { + // priority = 10 + //} + + indexes = append(indexes, schema.Index{ + Name: name, + Class: settings["CLASS"], + Type: settings["TYPE"], + Where: settings["WHERE"], + Comment: settings["COMMENT"], + Option: settings["OPTION"], + Fields: []schema.IndexOption{{ + Field: field, + Expression: settings["EXPRESSION"], + Sort: settings["SORT"], + Collate: settings["COLLATE"], + Length: length, + //priority: priority, // TODO does not support priority + }}, + }) + } + } + } + + return +} + +func (m gormMigrator) CreateTable(values ...interface{}) error { + for _, value := range m.ReorderModels(values, false) { + tx := m.DB.Session(&gorm.Session{NewDB: true}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + var ( + createTableSQL = "CREATE TABLE ? (" + values = []interface{}{m.CurrentTable(stmt)} + hasPrimaryKeyInDataType bool + ) + + for _, dbName := range stmt.Schema.DBNames { + field := stmt.Schema.FieldsByDBName[dbName] + createTableSQL += "? ?" + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) + createTableSQL += "," + } + + if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { + createTableSQL += "PRIMARY KEY ?," + primaryKeys := []interface{}{} + for _, field := range stmt.Schema.PrimaryFields { + primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) + } + + values = append(values, primaryKeys) + } + + for _, idx := range m.ParseIndexes(stmt.Schema) { + if m.CreateIndexAfterCreateTable { + defer func(value interface{}, name string) { + errr = tx.Migrator().CreateIndex(value, name) + }(value, idx.Name) + } else { + if idx.Class != "" { + createTableSQL += idx.Class + " " + } + createTableSQL += "INDEX ? ?" + + if idx.Option != "" { + createTableSQL += " " + idx.Option + } + + createTableSQL += "," + values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) + } + } + + for _, rel := range stmt.Schema.Relationships.Relations { + if !m.DB.DisableForeignKeyConstraintWhenMigrating { + if constraint := rel.ParseConstraint(); constraint != nil { + if constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) + } + } + } + } + + for _, chk := range stmt.Schema.ParseCheckConstraints() { + createTableSQL += "CONSTRAINT ? CHECK (?)," + values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) + } + + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + + createTableSQL += ")" + + if tableOption, ok := m.DB.Get("gorm:table_options"); ok { + createTableSQL += fmt.Sprint(tableOption) + } + + errr = tx.Exec(createTableSQL, values...).Error + return errr + }); err != nil { + return err + } + } + return nil +} + +type gormIndexOption struct { + schema.IndexOption + priority int +} + +func (g gormMigrator) ParseIndexes(sch *schema.Schema) map[string]schema.Index { + var indexes = map[string]schema.Index{} + + for _, field := range sch.Fields { + if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { + for _, index := range g.parseFieldIndexes(field) { + idx := indexes[index.Name] + idx.Name = index.Name + if idx.Class == "" { + idx.Class = index.Class + } + if idx.Type == "" { + idx.Type = index.Type + } + if idx.Where == "" { + idx.Where = index.Where + } + if idx.Comment == "" { + idx.Comment = index.Comment + } + if idx.Option == "" { + idx.Option = index.Option + } + + idx.Fields = append(idx.Fields, index.Fields...) + // TODO priority not supported + //sort.Slice(idx.Fields, func(i, j int) bool { + // return idx.Fields[i].priority < idx.Fields[j].priority + //}) + + indexes[index.Name] = idx + } + } + } + + return indexes +} + +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} diff --git a/pkg/skip b/pkg/skip new file mode 100644 index 0000000..e69de29 diff --git a/users/handler/create.go b/users/handler/create.go index 7f79150..d15b22f 100644 --- a/users/handler/create.go +++ b/users/handler/create.go @@ -44,7 +44,7 @@ func (u *Users) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Creat logger.Errorf("Error hashing and salting password: %v", err) return errors.InternalServerError("HASHING_ERROR", "Error hashing password") } - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/delete.go b/users/handler/delete.go index a258b00..5eb8043 100644 --- a/users/handler/delete.go +++ b/users/handler/delete.go @@ -20,7 +20,7 @@ func (u *Users) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Delet if len(req.Id) == 0 { return ErrMissingID } - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/handler.go b/users/handler/handler.go index f950085..568180f 100644 --- a/users/handler/handler.go +++ b/users/handler/handler.go @@ -1,21 +1,13 @@ package handler import ( - "context" - "database/sql" - "fmt" "regexp" - "strings" - "sync" "time" - "github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/errors" + gorm2 "github.com/micro/services/pkg/gorm" pb "github.com/micro/services/users/proto" "golang.org/x/crypto/bcrypt" - "gorm.io/driver/postgres" - "gorm.io/gorm" - "gorm.io/gorm/schema" ) var ( @@ -66,52 +58,12 @@ type Token struct { } type Users struct { - sync.RWMutex - Time func() time.Time - dbConn *sql.DB - gormConns map[string]*gorm.DB + gorm2.Helper + Time func() time.Time } -func NewHandler(t func() time.Time, dbConn *sql.DB) *Users { - return &Users{Time: t, dbConn: dbConn, gormConns: map[string]*gorm.DB{}} -} - -func (u *Users) getDBConn(ctx context.Context) (*gorm.DB, error) { - acc, ok := auth.AccountFromContext(ctx) - if !ok { - return nil, fmt.Errorf("missing account from context") - } - u.RLock() - if conn, ok := u.gormConns[acc.Issuer]; ok { - u.RUnlock() - return conn, nil - } - u.RUnlock() - u.Lock() - // double check - if conn, ok := u.gormConns[acc.Issuer]; ok { - u.Unlock() - return conn, nil - } - defer u.Unlock() - db, err := gorm.Open( - postgres.New(postgres.Config{ - Conn: u.dbConn, - }), - &gorm.Config{ - NamingStrategy: schema.NamingStrategy{ - TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")), - }, - }) - if err != nil { - return nil, err - } - if err := db.AutoMigrate(&User{}, &Token{}); err != nil { - return nil, err - } - // record success - u.gormConns[acc.Issuer] = db - return db, nil +func NewHandler(t func() time.Time) *Users { + return &Users{Time: t} } // isEmailValid checks if the email provided passes the required structure and length. diff --git a/users/handler/handler_test.go b/users/handler/handler_test.go index e871d5f..91778c8 100644 --- a/users/handler/handler_test.go +++ b/users/handler/handler_test.go @@ -31,7 +31,9 @@ func testHandler(t *testing.T) *handler.Users { t.Fatalf("Error cleaning database: %v", err) } - return handler.NewHandler(time.Now, sqlDB) + h := handler.NewHandler(time.Now) + h.DBConn(sqlDB).Migrations(&handler.User{}, &handler.Token{}) + return h } func assertUsersMatch(t *testing.T, exp, act *pb.User) { diff --git a/users/handler/list.go b/users/handler/list.go index bc2e496..8ecc225 100644 --- a/users/handler/list.go +++ b/users/handler/list.go @@ -16,7 +16,7 @@ func (u *Users) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListRespo errors.Unauthorized("UNAUTHORIZED", "Unauthorized") } // query the database - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/login.go b/users/handler/login.go index 13d116c..c275d9a 100644 --- a/users/handler/login.go +++ b/users/handler/login.go @@ -25,7 +25,7 @@ func (u *Users) Login(ctx context.Context, req *pb.LoginRequest, rsp *pb.LoginRe return ErrInvalidPassword } - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/logout.go b/users/handler/logout.go index 3d76a0d..e3f5b47 100644 --- a/users/handler/logout.go +++ b/users/handler/logout.go @@ -21,7 +21,7 @@ func (u *Users) Logout(ctx context.Context, req *pb.LogoutRequest, rsp *pb.Logou return ErrMissingID } - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/read.go b/users/handler/read.go index c86f0b9..20aa7fc 100644 --- a/users/handler/read.go +++ b/users/handler/read.go @@ -21,7 +21,7 @@ func (u *Users) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRespo } // query the database - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/read_by_email.go b/users/handler/read_by_email.go index dcdac59..adee532 100644 --- a/users/handler/read_by_email.go +++ b/users/handler/read_by_email.go @@ -26,7 +26,7 @@ func (u *Users) ReadByEmail(ctx context.Context, req *pb.ReadByEmailRequest, rsp } // query the database - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/update.go b/users/handler/update.go index aa37641..3c566ab 100644 --- a/users/handler/update.go +++ b/users/handler/update.go @@ -40,7 +40,7 @@ func (u *Users) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Updat // lookup the user var user User - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/handler/validate.go b/users/handler/validate.go index 18beb31..01306e4 100644 --- a/users/handler/validate.go +++ b/users/handler/validate.go @@ -21,7 +21,7 @@ func (u *Users) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.V return ErrMissingToken } - db, err := u.getDBConn(ctx) + db, err := u.GetDBConn(ctx) if err != nil { logger.Errorf("Error connecting to DB: %v", err) return errors.InternalServerError("DB_ERROR", "Error connecting to DB") diff --git a/users/main.go b/users/main.go index 90aeb42..b1d24ab 100644 --- a/users/main.go +++ b/users/main.go @@ -34,7 +34,9 @@ func main() { if err != nil { logger.Fatalf("Failed to open connection to DB %s", err) } - pb.RegisterUsersHandler(srv.Server(), handler.NewHandler(time.Now, sqlDB)) + h := handler.NewHandler(time.Now) + h.DBConn(sqlDB).Migrations(&handler.User{}, &handler.Token{}) + pb.RegisterUsersHandler(srv.Server(), h) // Run service if err := srv.Run(); err != nil {