diff --git a/db/handler/db.go b/db/handler/db.go index 386c36c..98ae89a 100644 --- a/db/handler/db.go +++ b/db/handler/db.go @@ -4,17 +4,31 @@ import ( "context" "encoding/json" "fmt" + "regexp" + "strings" + "time" "github.com/google/uuid" "github.com/micro/micro/v3/service/errors" db "github.com/micro/services/db/proto" gorm2 "github.com/micro/services/pkg/gorm" + "github.com/micro/services/pkg/tenant" + "github.com/patrickmn/go-cache" "gorm.io/datatypes" + "gorm.io/gorm" ) +const idKey = "id" +const stmt = "create table %v(id text not null, data jsonb, primary key(id));" + +var re = regexp.MustCompile("^[a-zA-Z0-9_]*$") +var c = cache.New(5*time.Minute, 10*time.Minute) + type Record struct { ID string Data datatypes.JSON `json:"data"` + // private field, ignored from gorm + table string `gorm:"-"` } type Db struct { @@ -26,23 +40,38 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe if len(req.Record) == 0 { return errors.BadRequest("db.create", "missing record") } + tenantId, ok := tenant.FromContext(ctx) + if !ok { + tenantId = "micro" + } + tenantId = strings.Replace(tenantId, "/", "_", -1) + tableName := tenantId + "_" + req.Table + if !re.Match([]byte(tableName)) { + return errors.BadRequest("db.create", "table name is invalid") + } db, err := e.GetDBConn(ctx) if err != nil { return err } + _, ok = c.Get(req.Table) + if !ok { + db.Exec(fmt.Sprintf(stmt, tableName)) + c.Set(req.Table, true, 0) + } + m := map[string]interface{}{} err = json.Unmarshal([]byte(req.Record), &m) if err != nil { return err } - if _, ok := m["ID"].(string); !ok { - m["ID"] = uuid.New().String() + if _, ok := m[idKey].(string); !ok { + m[idKey] = uuid.New().String() } bs, _ := json.Marshal(m) - err = db.Table(req.Table).Create(Record{ - ID: m["ID"].(string), + err = db.Table(tableName).Create(Record{ + ID: m[idKey].(string), Data: bs, }).Error if err != nil { @@ -50,7 +79,7 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe } // set the response id - rsp.Id = m["ID"].(string) + rsp.Id = m[idKey].(string) return nil } @@ -59,7 +88,11 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe if len(req.Record) == 0 { return errors.BadRequest("db.update", "missing record") } - + tenantId, ok := tenant.FromContext(ctx) + if !ok { + tenantId = "micro" + } + tenantId = strings.Replace(tenantId, "/", "_", -1) db, err := e.GetDBConn(ctx) if err != nil { return err @@ -71,59 +104,37 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe return err } - // do we really need to remarshal this? - data, _ := json.Marshal(m) - // where ID is specified do a single update record update - if id, ok := m["ID"].(string); ok { - // apply the update to a single record - return db.Table(req.Table).First(&Record{ID: id}).Updates(Record{Data: data}).Error + id, ok := m[idKey].(string) + if !ok { + return fmt.Errorf("update failed: missing id") } - // define the db - db = db.Table(req.Table) - - // no ID param so we're expecting a query - if len(req.Query) == 0 { - // apply the updates to all records - return db.Find(&Record{}).Updates(Record{Data: data}).Error - } - - // parse the query - queries, err := Parse(req.Query) - if err != nil { - return err - } - - // get the filters - for _, query := range queries { - typ := "text" - switch query.Value.(type) { - case int64: - typ = "int" - case bool: - typ = "boolean" + db.Transaction(func(tx *gorm.DB) error { + rec := []Record{} + err = tx.Table(tenantId+"_"+req.Table).Where("ID = ?", id).Find(&rec).Error + if err != nil { + return err } - op := "" - switch query.Op { - case itemEquals: - op = "=" - case itemGreaterThan: - op = ">" - case itemGreaterThanEquals: - op = ">=" - case itemLessThan: - op = "<" - case itemLessThanEquals: - op = "<=" - case itemNotEquals: - op = "!=" + if len(rec) == 0 { + return fmt.Errorf("update failed: not found") } - db = db.Where(fmt.Sprintf("(data ->> '%v')::%v %v ?", query.Field, typ, op), query.Value) - } + old := map[string]interface{}{} + err = json.Unmarshal(rec[0].Data, &old) + if err != nil { + return err + } + for k, v := range old { + m[k] = v + } + bs, _ := json.Marshal(m) - // apply updates to the filtered records - return db.Updates(Record{Data: data}).Error + return tx.Table(tenantId + "_" + req.Table).Save(Record{ + ID: m[idKey].(string), + Data: bs, + }).Error + }) + return nil } func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse) error { @@ -132,11 +143,16 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse if err != nil { return err } + tenantId, ok := tenant.FromContext(ctx) + if !ok { + tenantId = "micro" + } + tenantId = strings.Replace(tenantId, "/", "_", -1) db, err := e.GetDBConn(ctx) if err != nil { return err } - db = db.Table(req.Table) + db = db.Table(tenantId + "_" + req.Table) for _, query := range queries { typ := "text" switch query.Value.(type) { @@ -174,7 +190,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse } ma := map[string]interface{}{} json.Unmarshal(m, &ma) - ma["ID"] = rec.ID + ma[idKey] = rec.ID ret = append(ret, ma) } bs, _ := json.Marshal(ret) diff --git a/db/main.go b/db/main.go index 923cfb3..124147d 100644 --- a/db/main.go +++ b/db/main.go @@ -35,13 +35,13 @@ func main() { logger.Fatalf("Failed to open connection to DB %s", err) } h := &handler.Db{} - h.DBConn(sqlDB).Migrations(&handler.Record{}) + h.DBConn(sqlDB) // Register handler pb.RegisterDbHandler(srv.Server(), h) // Register handler - pb.RegisterDbHandler(srv.Server(), new(handler.Db)) + pb.RegisterDbHandler(srv.Server(), &handler.Db{}) // Run service if err := srv.Run(); err != nil { diff --git a/go.mod b/go.mod index 0aa8495..410a24a 100644 --- a/go.mod +++ b/go.mod @@ -41,7 +41,7 @@ require ( gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 // indirect gorm.io/datatypes v1.0.1 gorm.io/driver/postgres v1.0.8 - gorm.io/gorm v1.21.6 + gorm.io/gorm v1.21.10 ) replace google.golang.org/grpc => google.golang.org/grpc v1.26.0 diff --git a/go.sum b/go.sum index 47babd7..a485d39 100644 --- a/go.sum +++ b/go.sum @@ -780,6 +780,8 @@ gorm.io/gorm v1.21.3/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.21.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= gorm.io/gorm v1.21.6 h1:xEFbH7WShsnAM+HeRNv7lOeyqmDAK+dDnf1AMf/cVPQ= gorm.io/gorm v1.21.6/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= +gorm.io/gorm v1.21.10 h1:kBGiBsaqOQ+8f6S2U6mvGFz6aWWyCeIiuaFcaBozp4M= +gorm.io/gorm v1.21.10/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=