Files
services/pkg/gorm/wrapper.go
Dominic Wong 7ae45b522e make multi tenancy based on namespace and ID of account (#85)
* make multi tenancy based on namespace and ID of account not just namespace

* fix tests
2021-04-21 15:09:38 +01:00

454 lines
11 KiB
Go

package gorm
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"sync"
"github.com/micro/micro/v3/service/auth"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Helper struct {
sync.RWMutex
gormConns map[string]*gorm.DB
dbConn *sql.DB
migrations []interface{}
}
func (h *Helper) Migrations(migrations ...interface{}) *Helper {
h.migrations = migrations
return h
}
func (h *Helper) DBConn(conn *sql.DB) *Helper {
h.dbConn = conn
h.gormConns = map[string]*gorm.DB{}
return h
}
func getTenancyKey(acc *auth.Account) string {
owner := acc.Metadata["apikey_owner"]
if len(owner) == 0 {
owner = acc.ID
}
return fmt.Sprintf("%s_%s", acc.Issuer, owner)
}
func (h *Helper) GetDBConn(ctx context.Context) (*gorm.DB, error) {
acc, ok := auth.AccountFromContext(ctx)
if !ok {
return nil, fmt.Errorf("missing account from context")
}
h.RLock()
tenancyKey := getTenancyKey(acc)
if conn, ok := h.gormConns[tenancyKey]; ok {
h.RUnlock()
return conn, nil
}
h.RUnlock()
h.Lock()
// double check
if conn, ok := h.gormConns[tenancyKey]; ok {
h.Unlock()
return conn, nil
}
defer h.Unlock()
ns := schema.NamingStrategy{
TablePrefix: fmt.Sprintf("%s_", strings.ReplaceAll(tenancyKey, "-", "")),
}
db, err := gorm.Open(
newGormDialector(postgres.Config{
Conn: h.dbConn,
}, ns),
&gorm.Config{
NamingStrategy: ns,
})
if err != nil {
return nil, err
}
if len(h.migrations) == 0 {
// record success
h.gormConns[tenancyKey] = db
return db, nil
}
if err := db.AutoMigrate(h.migrations...); err != nil {
return nil, err
}
// record success
h.gormConns[tenancyKey] = db
return db, nil
}
func newGormDialector(config postgres.Config, ns schema.NamingStrategy) gorm.Dialector {
return &postgresDial{
Dialector: postgres.Dialector{Config: &config},
namer: &ns,
}
}
// postgresDial is a postgres dialector that prefixes index names with the table prefix when doing migrations.
// NOTE, it does not support the gorm tag priority option
type postgresDial struct {
postgres.Dialector
namer schema.Namer
}
func (p postgresDial) Migrator(db *gorm.DB) gorm.Migrator {
return gormMigrator{
postgres.Migrator{
migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: p,
CreateIndexAfterCreateTable: true,
}},
},
p.namer,
}
}
type gormMigrator struct {
postgres.Migrator
namer schema.Namer
}
// AutoMigrate
func (m gormMigrator) AutoMigrate(values ...interface{}) error {
for _, value := range m.ReorderModels(values, true) {
tx := m.DB.Session(&gorm.Session{NewDB: true})
if !tx.Migrator().HasTable(value) {
if err := tx.Migrator().CreateTable(value); err != nil {
return err
}
} else {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
columnTypes, _ := m.DB.Migrator().ColumnTypes(value)
for _, field := range stmt.Schema.FieldsByDBName {
var foundColumn gorm.ColumnType
for _, columnType := range columnTypes {
if columnType.Name() == field.DBName {
foundColumn = columnType
break
}
}
if foundColumn == nil {
// not found, add column
if err := tx.Migrator().AddColumn(value, field.DBName); err != nil {
return err
}
} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil {
// found, smart migrate
return err
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.Config.DisableForeignKeyConstraintWhenMigrating {
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
if !tx.Migrator().HasConstraint(value, constraint.Name) {
if err := tx.Migrator().CreateConstraint(value, constraint.Name); err != nil {
return err
}
}
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
if !tx.Migrator().HasConstraint(value, chk.Name) {
if err := tx.Migrator().CreateConstraint(value, chk.Name); err != nil {
return err
}
}
}
}
for _, idx := range m.ParseIndexes(stmt.Schema) {
if !tx.Migrator().HasIndex(value, idx.Name) {
if err := tx.Migrator().CreateIndex(value, idx.Name); err != nil {
return err
}
}
}
return nil
}); err != nil {
return err
}
}
}
return nil
}
func (m gormMigrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := m.LookIndex(stmt.Schema, name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
createIndexSQL += " ON ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type + "(?)"
} else {
createIndexSQL += " ?"
}
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m gormMigrator) LookIndex(sch *schema.Schema, name string) *schema.Index {
if sch != nil {
indexes := m.ParseIndexes(sch)
for _, index := range indexes {
if index.Name == name {
return &index
}
for _, field := range index.Fields {
if field.Name == name {
return &index
}
}
}
}
return nil
}
func (g gormMigrator) parseFieldIndexes(field *schema.Field) (indexes []schema.Index) {
for _, value := range strings.Split(field.Tag.Get("gorm"), ";") {
if value != "" {
v := strings.Split(value, ":")
k := strings.TrimSpace(strings.ToUpper(v[0]))
if k == "INDEX" || k == "UNIQUEINDEX" {
var (
name string
tag = strings.Join(v[1:], ":")
idx = strings.Index(tag, ",")
settings = schema.ParseTagSetting(tag, ",")
length, _ = strconv.Atoi(settings["LENGTH"])
)
if idx == -1 {
idx = len(tag)
}
if idx != -1 {
name = tag[0:idx]
}
if name == "" {
name = g.namer.IndexName(field.Schema.Table, field.Name)
} else {
ns := g.namer.(*schema.NamingStrategy)
name = fmt.Sprintf("%s%s", ns.TablePrefix, name)
}
if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" {
settings["CLASS"] = "UNIQUE"
}
//priority, err := strconv.Atoi(settings["PRIORITY"])
//if err != nil {
// priority = 10
//}
indexes = append(indexes, schema.Index{
Name: name,
Class: settings["CLASS"],
Type: settings["TYPE"],
Where: settings["WHERE"],
Comment: settings["COMMENT"],
Option: settings["OPTION"],
Fields: []schema.IndexOption{{
Field: field,
Expression: settings["EXPRESSION"],
Sort: settings["SORT"],
Collate: settings["COLLATE"],
Length: length,
//priority: priority, // TODO does not support priority
}},
})
}
}
}
return
}
func (m gormMigrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) {
tx := m.DB.Session(&gorm.Session{NewDB: true})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
var (
createTableSQL = "CREATE TABLE ? ("
values = []interface{}{m.CurrentTable(stmt)}
hasPrimaryKeyInDataType bool
)
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[dbName]
createTableSQL += "? ?"
hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY")
values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field))
createTableSQL += ","
}
if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 {
createTableSQL += "PRIMARY KEY ?,"
primaryKeys := []interface{}{}
for _, field := range stmt.Schema.PrimaryFields {
primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName})
}
values = append(values, primaryKeys)
}
for _, idx := range m.ParseIndexes(stmt.Schema) {
if m.CreateIndexAfterCreateTable {
defer func(value interface{}, name string) {
errr = tx.Migrator().CreateIndex(value, name)
}(value, idx.Name)
} else {
if idx.Class != "" {
createTableSQL += idx.Class + " "
}
createTableSQL += "INDEX ? ?"
if idx.Option != "" {
createTableSQL += " " + idx.Option
}
createTableSQL += ","
values = append(values, clause.Expr{SQL: idx.Name}, tx.Migrator().(migrator.BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt))
}
}
for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint)
createTableSQL += sql + ","
values = append(values, vars...)
}
}
}
}
for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
}
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
createTableSQL += ")"
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
createTableSQL += fmt.Sprint(tableOption)
}
errr = tx.Exec(createTableSQL, values...).Error
return errr
}); err != nil {
return err
}
}
return nil
}
type gormIndexOption struct {
schema.IndexOption
priority int
}
func (g gormMigrator) ParseIndexes(sch *schema.Schema) map[string]schema.Index {
var indexes = map[string]schema.Index{}
for _, field := range sch.Fields {
if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" {
for _, index := range g.parseFieldIndexes(field) {
idx := indexes[index.Name]
idx.Name = index.Name
if idx.Class == "" {
idx.Class = index.Class
}
if idx.Type == "" {
idx.Type = index.Type
}
if idx.Where == "" {
idx.Where = index.Where
}
if idx.Comment == "" {
idx.Comment = index.Comment
}
if idx.Option == "" {
idx.Option = index.Option
}
idx.Fields = append(idx.Fields, index.Fields...)
// TODO priority not supported
//sort.Slice(idx.Fields, func(i, j int) bool {
// return idx.Fields[i].priority < idx.Fields[j].priority
//})
indexes[index.Name] = idx
}
}
}
return indexes
}
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}