mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-11 19:04:35 +00:00
db: More sanitization for table names (#243)
This commit is contained in:
@@ -57,22 +57,35 @@ func correctFieldName(s string) string {
|
||||
return ret
|
||||
}
|
||||
|
||||
func (e *Db) tableName(ctx context.Context, t string) (string, error) {
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
if t == "" {
|
||||
t = "default"
|
||||
}
|
||||
t = strings.ToLower(t)
|
||||
t = strings.Replace(t, "-", "_", -1)
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
|
||||
tableName := tenantId + "_" + t
|
||||
if !re.Match([]byte(tableName)) {
|
||||
return "", fmt.Errorf("table name %v is invalid", t)
|
||||
}
|
||||
|
||||
return tableName, nil
|
||||
}
|
||||
|
||||
// Call is a single request handler called via client.Call or the generated client code
|
||||
func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateResponse) error {
|
||||
if len(req.Record.AsMap()) == 0 {
|
||||
return errors.BadRequest("db.create", "missing record")
|
||||
}
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
if req.Table == "" {
|
||||
req.Table = "default"
|
||||
}
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
tableName := tenantId + "_" + req.Table
|
||||
if !re.Match([]byte(tableName)) {
|
||||
return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table))
|
||||
|
||||
tableName, err := e.tableName(ctx, req.Table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("Inserting into table '%v'", tableName)
|
||||
|
||||
@@ -80,7 +93,7 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, ok = c.Get(tableName)
|
||||
_, ok := c.Get(tableName)
|
||||
if !ok {
|
||||
logger.Infof("Creating table '%v'", tableName)
|
||||
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
|
||||
@@ -111,17 +124,9 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe
|
||||
if len(req.Record.AsMap()) == 0 {
|
||||
return errors.BadRequest("db.update", "missing record")
|
||||
}
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
if req.Table == "" {
|
||||
req.Table = "default"
|
||||
}
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
tableName := tenantId + "_" + req.Table
|
||||
if !re.Match([]byte(tableName)) {
|
||||
return errors.BadRequest("db.update", fmt.Sprintf("table name %v is invalid", req.Table))
|
||||
tableName, err := e.tableName(ctx, req.Table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("Updating table '%v'", tableName)
|
||||
|
||||
@@ -179,26 +184,16 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
if req.Table == "" {
|
||||
req.Table = "default"
|
||||
}
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
tableName := tenantId + "_" + req.Table
|
||||
logger.Infof("Reading table '%v'", tableName)
|
||||
|
||||
if !re.Match([]byte(tableName)) {
|
||||
return errors.BadRequest("db.read", fmt.Sprintf("table name %v is invalid", req.Table))
|
||||
tableName, err := e.tableName(ctx, req.Table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := e.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, ok = c.Get(tableName)
|
||||
_, ok := c.Get(tableName)
|
||||
if !ok {
|
||||
logger.Infof("Creating table '%v'", tableName)
|
||||
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
|
||||
@@ -295,18 +290,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe
|
||||
if len(req.Id) == 0 {
|
||||
return errors.BadRequest("db.delete", "missing id")
|
||||
}
|
||||
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
if req.Table == "" {
|
||||
req.Table = "default"
|
||||
}
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
tableName := tenantId + "_" + req.Table
|
||||
if !re.Match([]byte(tableName)) {
|
||||
return errors.BadRequest("db.delete", fmt.Sprintf("table name %v is invalid", req.Table))
|
||||
tableName, err := e.tableName(ctx, req.Table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("Deleting from table '%v'", tableName)
|
||||
|
||||
@@ -321,17 +307,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe
|
||||
}
|
||||
|
||||
func (e *Db) Truncate(ctx context.Context, req *db.TruncateRequest, rsp *db.TruncateResponse) error {
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
if req.Table == "" {
|
||||
req.Table = "default"
|
||||
}
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
tableName := tenantId + "_" + req.Table
|
||||
if !re.Match([]byte(tableName)) {
|
||||
return errors.BadRequest("db.truncate", fmt.Sprintf("table name %v is invalid", req.Table))
|
||||
tableName, err := e.tableName(ctx, req.Table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("Truncating table '%v'", tableName)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user