diff --git a/db/handler/db.go b/db/handler/db.go index bbc9d18..43b6e98 100644 --- a/db/handler/db.go +++ b/db/handler/db.go @@ -57,22 +57,35 @@ func correctFieldName(s string) string { return ret } +func (e *Db) tableName(ctx context.Context, t string) (string, error) { + tenantId, ok := tenant.FromContext(ctx) + if !ok { + tenantId = "micro" + } + if t == "" { + t = "default" + } + t = strings.ToLower(t) + t = strings.Replace(t, "-", "_", -1) + tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) + + tableName := tenantId + "_" + t + if !re.Match([]byte(tableName)) { + return "", fmt.Errorf("table name %v is invalid", t) + } + + return tableName, nil +} + // Call is a single request handler called via client.Call or the generated client code func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateResponse) error { if len(req.Record.AsMap()) == 0 { return errors.BadRequest("db.create", "missing record") } - tenantId, ok := tenant.FromContext(ctx) - if !ok { - tenantId = "micro" - } - if req.Table == "" { - req.Table = "default" - } - tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) - tableName := tenantId + "_" + req.Table - if !re.Match([]byte(tableName)) { - return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table)) + + tableName, err := e.tableName(ctx, req.Table) + if err != nil { + return err } logger.Infof("Inserting into table '%v'", tableName) @@ -80,7 +93,7 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe if err != nil { return err } - _, ok = c.Get(tableName) + _, ok := c.Get(tableName) if !ok { logger.Infof("Creating table '%v'", tableName) db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName)) @@ -111,17 +124,9 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe if len(req.Record.AsMap()) == 0 { return errors.BadRequest("db.update", "missing record") } - tenantId, ok := tenant.FromContext(ctx) - if !ok { - tenantId = "micro" - } - if req.Table == "" { - req.Table = "default" - } - tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) - tableName := tenantId + "_" + req.Table - if !re.Match([]byte(tableName)) { - return errors.BadRequest("db.update", fmt.Sprintf("table name %v is invalid", req.Table)) + tableName, err := e.tableName(ctx, req.Table) + if err != nil { + return err } logger.Infof("Updating table '%v'", tableName) @@ -179,26 +184,16 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse if err != nil { return err } - tenantId, ok := tenant.FromContext(ctx) - if !ok { - tenantId = "micro" - } - if req.Table == "" { - req.Table = "default" - } - tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) - tableName := tenantId + "_" + req.Table - logger.Infof("Reading table '%v'", tableName) - - if !re.Match([]byte(tableName)) { - return errors.BadRequest("db.read", fmt.Sprintf("table name %v is invalid", req.Table)) + tableName, err := e.tableName(ctx, req.Table) + if err != nil { + return err } db, err := e.GetDBConn(ctx) if err != nil { return err } - _, ok = c.Get(tableName) + _, ok := c.Get(tableName) if !ok { logger.Infof("Creating table '%v'", tableName) db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName)) @@ -295,18 +290,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe if len(req.Id) == 0 { return errors.BadRequest("db.delete", "missing id") } - - tenantId, ok := tenant.FromContext(ctx) - if !ok { - tenantId = "micro" - } - if req.Table == "" { - req.Table = "default" - } - tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) - tableName := tenantId + "_" + req.Table - if !re.Match([]byte(tableName)) { - return errors.BadRequest("db.delete", fmt.Sprintf("table name %v is invalid", req.Table)) + tableName, err := e.tableName(ctx, req.Table) + if err != nil { + return err } logger.Infof("Deleting from table '%v'", tableName) @@ -321,17 +307,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe } func (e *Db) Truncate(ctx context.Context, req *db.TruncateRequest, rsp *db.TruncateResponse) error { - tenantId, ok := tenant.FromContext(ctx) - if !ok { - tenantId = "micro" - } - if req.Table == "" { - req.Table = "default" - } - tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1) - tableName := tenantId + "_" + req.Table - if !re.Match([]byte(tableName)) { - return errors.BadRequest("db.truncate", fmt.Sprintf("table name %v is invalid", req.Table)) + tableName, err := e.tableName(ctx, req.Table) + if err != nil { + return err } logger.Infof("Truncating table '%v'", tableName)