From b0cbf847ac26f7fd5691c1bd000828c29743d308 Mon Sep 17 00:00:00 2001 From: Janos Dobronszki Date: Wed, 9 Jun 2021 16:31:31 +0100 Subject: [PATCH] DB: Table name fix (#145) --- db/handler/db.go | 48 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/db/handler/db.go b/db/handler/db.go index dfcdda8..2c956fe 100644 --- a/db/handler/db.go +++ b/db/handler/db.go @@ -45,20 +45,23 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe if !ok { tenantId = "micro" } - tenantId = strings.Replace(tenantId, "/", "_", -1) + if req.Table == "" { + req.Table = "default" + } + tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) tableName := tenantId + "_" + req.Table if !re.Match([]byte(tableName)) { - return errors.BadRequest("db.create", "table name is invalid") + return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table)) } db, err := e.GetDBConn(ctx) if err != nil { return err } - _, ok = c.Get(req.Table) + _, ok = c.Get(tableName) if !ok { db.Exec(fmt.Sprintf(stmt, tableName)) - c.Set(req.Table, true, 0) + c.Set(tableName, true, 0) } m := req.Record.AsMap() @@ -89,7 +92,14 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe if !ok { tenantId = "micro" } - tenantId = strings.Replace(tenantId, "/", "_", -1) + if req.Table == "" { + req.Table = "default" + } + tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) + tableName := tenantId + "_" + req.Table + if !re.Match([]byte(tableName)) { + return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table)) + } db, err := e.GetDBConn(ctx) if err != nil { return err @@ -105,7 +115,7 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe db.Transaction(func(tx *gorm.DB) error { rec := []Record{} - err = tx.Table(tenantId+"_"+req.Table).Where("ID = ?", id).Find(&rec).Error + err = tx.Table(tableName).Where("ID = ?", id).Find(&rec).Error if err != nil { return err } @@ -122,7 +132,7 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe } bs, _ := json.Marshal(m) - return tx.Table(tenantId + "_" + req.Table).Save(Record{ + return tx.Table(tableName).Save(Record{ ID: m[idKey].(string), Data: bs, }).Error @@ -140,12 +150,20 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse if !ok { tenantId = "micro" } - tenantId = strings.Replace(tenantId, "/", "_", -1) + if req.Table == "" { + req.Table = "default" + } + tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) + tableName := tenantId + "_" + req.Table + if !re.Match([]byte(tableName)) { + return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table)) + } + db, err := e.GetDBConn(ctx) if err != nil { return err } - db = db.Table(tenantId + "_" + req.Table) + db = db.Table(tableName) for _, query := range queries { typ := "text" switch query.Value.(type) { @@ -206,15 +224,21 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe if !ok { tenantId = "micro" } - - tenantId = strings.Replace(tenantId, "/", "_", -1) + if req.Table == "" { + req.Table = "default" + } + tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) + tableName := tenantId + "_" + req.Table + if !re.Match([]byte(tableName)) { + return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table)) + } db, err := e.GetDBConn(ctx) if err != nil { return err } - return db.Table(tenantId + "_" + req.Table).Delete(Record{ + return db.Table(tableName).Delete(Record{ ID: req.Id, }).Error }