mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-16 13:04:34 +00:00
Multi tenant groups (#77)
* multitenant groups * switch users service to use new wrapper * fix tests * skip pkg dir * Check for auth
This commit is contained in:
444
pkg/gorm/wrapper.go
Normal file
444
pkg/gorm/wrapper.go
Normal file
@@ -0,0 +1,444 @@
|
||||
package gorm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Helper struct {
|
||||
sync.RWMutex
|
||||
gormConns map[string]*gorm.DB
|
||||
dbConn *sql.DB
|
||||
migrations []interface{}
|
||||
}
|
||||
|
||||
func (h *Helper) Migrations(migrations ...interface{}) *Helper {
|
||||
h.migrations = migrations
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Helper) DBConn(conn *sql.DB) *Helper {
|
||||
h.dbConn = conn
|
||||
h.gormConns = map[string]*gorm.DB{}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
|
||||
acc, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing account from context")
|
||||
}
|
||||
h.RLock()
|
||||
if conn, ok := h.gormConns[acc.Issuer]; ok {
|
||||
h.RUnlock()
|
||||
return conn, nil
|
||||
}
|
||||
h.RUnlock()
|
||||
h.Lock()
|
||||
// double check
|
||||
if conn, ok := h.gormConns[acc.Issuer]; ok {
|
||||
h.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
defer h.Unlock()
|
||||
ns := schema.NamingStrategy{
|
||||
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(acc.Issuer, "-", "")),
|
||||
}
|
||||
db, err := gorm.Open(
|
||||
newGormDialector(postgres.Config{
|
||||
Conn: h.dbConn,
|
||||
}, ns),
|
||||
&gorm.Config{
|
||||
NamingStrategy: ns,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(h.migrations) == 0 {
|
||||
// record success
|
||||
h.gormConns[acc.Issuer] = db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(h.migrations...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// record success
|
||||
h.gormConns[acc.Issuer] = db
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func newGormDialector(config postgres.Config, ns schema.NamingStrategy) gorm.Dialector {
|
||||
return &postgresDial{
|
||||
Dialector: postgres.Dialector{Config: &config},
|
||||
namer: &ns,
|
||||
}
|
||||
}
|
||||
|
||||
// postgresDial is a postgres dialector that prefixes index names with the table prefix when doing migrations.
|
||||
// NOTE, it does not support the gorm tag priority option
|
||||
type postgresDial struct {
|
||||
postgres.Dialector
|
||||
namer schema.Namer
|
||||
}
|
||||
|
||||
func (p postgresDial) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return gormMigrator{
|
||||
postgres.Migrator{
|
||||
migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: p,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}},
|
||||
},
|
||||
p.namer,
|
||||
}
|
||||
}
|
||||
|
||||
type gormMigrator struct {
|
||||
postgres.Migrator
|
||||
namer schema.Namer
|
||||
}
|
||||
|
||||
// AutoMigrate
|
||||
func (m gormMigrator) AutoMigrate(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, true) {
|
||||
tx := m.DB.Session(&gorm.Session{NewDB: true})
|
||||
if !tx.Migrator().HasTable(value) {
|
||||
if err := tx.Migrator().CreateTable(value); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
|
||||
|
||||
for _, field := range stmt.Schema.FieldsByDBName {
|
||||
var foundColumn gorm.ColumnType
|
||||
|
||||
for _, columnType := range columnTypes {
|
||||
if columnType.Name() == field.DBName {
|
||||
foundColumn = columnType
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundColumn == nil {
|
||||
// not found, add column
|
||||
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
|
||||
// found, smart migrate
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
if !tx.Migrator().HasConstraint(value, constraint.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
if !tx.Migrator().HasConstraint(value, chk.Name) {
|
||||
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range m.ParseIndexes(stmt.Schema) {
|
||||
if !tx.Migrator().HasIndex(value, idx.Name) {
|
||||
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m gormMigrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if idx := m.LookIndex(stmt.Schema, name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
createIndexSQL += " ON ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type + "(?)"
|
||||
} else {
|
||||
createIndexSQL += " ?"
|
||||
}
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m gormMigrator) LookIndex(sch *schema.Schema, name string) *schema.Index {
|
||||
if sch != nil {
|
||||
indexes := m.ParseIndexes(sch)
|
||||
for _, index := range indexes {
|
||||
if index.Name == name {
|
||||
return &index
|
||||
}
|
||||
|
||||
for _, field := range index.Fields {
|
||||
if field.Name == name {
|
||||
return &index
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g gormMigrator) parseFieldIndexes(field *schema.Field) (indexes []schema.Index) {
|
||||
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
|
||||
if value != "" {
|
||||
v := strings.Split(value, ":")
|
||||
k := strings.TrimSpace(strings.ToUpper(v[0]))
|
||||
if k == "INDEX" || k == "UNIQUEINDEX" {
|
||||
var (
|
||||
name string
|
||||
tag = strings.Join(v[1:], ":")
|
||||
idx = strings.Index(tag, ",")
|
||||
settings = schema.ParseTagSetting(tag, ",")
|
||||
length, _ = strconv.Atoi(settings["LENGTH"])
|
||||
)
|
||||
|
||||
if idx == -1 {
|
||||
idx = len(tag)
|
||||
}
|
||||
|
||||
if idx != -1 {
|
||||
name = tag[0:idx]
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
name = g.namer.IndexName(field.Schema.Table, field.Name)
|
||||
} else {
|
||||
ns := g.namer.(*schema.NamingStrategy)
|
||||
name = fmt.Sprintf("%s%s", ns.TablePrefix, name)
|
||||
}
|
||||
|
||||
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
|
||||
settings["CLASS"] = "UNIQUE"
|
||||
}
|
||||
|
||||
//priority, err := strconv.Atoi(settings["PRIORITY"])
|
||||
//if err != nil {
|
||||
// priority = 10
|
||||
//}
|
||||
|
||||
indexes = append(indexes, schema.Index{
|
||||
Name: name,
|
||||
Class: settings["CLASS"],
|
||||
Type: settings["TYPE"],
|
||||
Where: settings["WHERE"],
|
||||
Comment: settings["COMMENT"],
|
||||
Option: settings["OPTION"],
|
||||
Fields: []schema.IndexOption{{
|
||||
Field: field,
|
||||
Expression: settings["EXPRESSION"],
|
||||
Sort: settings["SORT"],
|
||||
Collate: settings["COLLATE"],
|
||||
Length: length,
|
||||
//priority: priority, // TODO does not support priority
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (m gormMigrator) CreateTable(values ...interface{}) error {
|
||||
for _, value := range m.ReorderModels(values, false) {
|
||||
tx := m.DB.Session(&gorm.Session{NewDB: true})
|
||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||
var (
|
||||
createTableSQL = "CREATE TABLE ? ("
|
||||
values = []interface{}{m.CurrentTable(stmt)}
|
||||
hasPrimaryKeyInDataType bool
|
||||
)
|
||||
|
||||
for _, dbName := range stmt.Schema.DBNames {
|
||||
field := stmt.Schema.FieldsByDBName[dbName]
|
||||
createTableSQL += "? ?"
|
||||
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
|
||||
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
|
||||
createTableSQL += ","
|
||||
}
|
||||
|
||||
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
|
||||
createTableSQL += "PRIMARY KEY ?,"
|
||||
primaryKeys := []interface{}{}
|
||||
for _, field := range stmt.Schema.PrimaryFields {
|
||||
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
values = append(values, primaryKeys)
|
||||
}
|
||||
|
||||
for _, idx := range m.ParseIndexes(stmt.Schema) {
|
||||
if m.CreateIndexAfterCreateTable {
|
||||
defer func(value interface{}, name string) {
|
||||
errr = tx.Migrator().CreateIndex(value, name)
|
||||
}(value, idx.Name)
|
||||
} else {
|
||||
if idx.Class != "" {
|
||||
createTableSQL += idx.Class + " "
|
||||
}
|
||||
createTableSQL += "INDEX ? ?"
|
||||
|
||||
if idx.Option != "" {
|
||||
createTableSQL += " " + idx.Option
|
||||
}
|
||||
|
||||
createTableSQL += ","
|
||||
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
||||
if constraint.Schema == stmt.Schema {
|
||||
sql, vars := buildConstraint(constraint)
|
||||
createTableSQL += sql + ","
|
||||
values = append(values, vars...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||
}
|
||||
|
||||
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
|
||||
|
||||
createTableSQL += ")"
|
||||
|
||||
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
|
||||
createTableSQL += fmt.Sprint(tableOption)
|
||||
}
|
||||
|
||||
errr = tx.Exec(createTableSQL, values...).Error
|
||||
return errr
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type gormIndexOption struct {
|
||||
schema.IndexOption
|
||||
priority int
|
||||
}
|
||||
|
||||
func (g gormMigrator) ParseIndexes(sch *schema.Schema) map[string]schema.Index {
|
||||
var indexes = map[string]schema.Index{}
|
||||
|
||||
for _, field := range sch.Fields {
|
||||
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
|
||||
for _, index := range g.parseFieldIndexes(field) {
|
||||
idx := indexes[index.Name]
|
||||
idx.Name = index.Name
|
||||
if idx.Class == "" {
|
||||
idx.Class = index.Class
|
||||
}
|
||||
if idx.Type == "" {
|
||||
idx.Type = index.Type
|
||||
}
|
||||
if idx.Where == "" {
|
||||
idx.Where = index.Where
|
||||
}
|
||||
if idx.Comment == "" {
|
||||
idx.Comment = index.Comment
|
||||
}
|
||||
if idx.Option == "" {
|
||||
idx.Option = index.Option
|
||||
}
|
||||
|
||||
idx.Fields = append(idx.Fields, index.Fields...)
|
||||
// TODO priority not supported
|
||||
//sort.Slice(idx.Fields, func(i, j int) bool {
|
||||
// return idx.Fields[i].priority < idx.Fields[j].priority
|
||||
//})
|
||||
|
||||
indexes[index.Name] = idx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return indexes
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user