db: More sanitization for table names (#243)

This commit is contained in:
Janos Dobronszki
2021-10-25 14:47:14 +01:00
committed by GitHub
parent edc4a72755
commit 493f520a99

View File

@@ -57,22 +57,35 @@ func correctFieldName(s string) string {
return ret
}
func (e *Db) tableName(ctx context.Context, t string) (string, error) {
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
if t == "" {
t = "default"
}
t = strings.ToLower(t)
t = strings.Replace(t, "-", "_", -1)
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
tableName := tenantId + "_" + t
if !re.Match([]byte(tableName)) {
return "", fmt.Errorf("table name %v is invalid", t)
}
return tableName, nil
}
// Call is a single request handler called via client.Call or the generated client code
func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateResponse) error {
if len(req.Record.AsMap()) == 0 {
return errors.BadRequest("db.create", "missing record")
}
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
if req.Table == "" {
req.Table = "default"
}
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
tableName := tenantId + "_" + req.Table
if !re.Match([]byte(tableName)) {
return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table))
tableName, err := e.tableName(ctx, req.Table)
if err != nil {
return err
}
logger.Infof("Inserting into table '%v'", tableName)
@@ -80,7 +93,7 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe
if err != nil {
return err
}
_, ok = c.Get(tableName)
_, ok := c.Get(tableName)
if !ok {
logger.Infof("Creating table '%v'", tableName)
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
@@ -111,17 +124,9 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe
if len(req.Record.AsMap()) == 0 {
return errors.BadRequest("db.update", "missing record")
}
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
if req.Table == "" {
req.Table = "default"
}
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
tableName := tenantId + "_" + req.Table
if !re.Match([]byte(tableName)) {
return errors.BadRequest("db.update", fmt.Sprintf("table name %v is invalid", req.Table))
tableName, err := e.tableName(ctx, req.Table)
if err != nil {
return err
}
logger.Infof("Updating table '%v'", tableName)
@@ -179,26 +184,16 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
if err != nil {
return err
}
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
if req.Table == "" {
req.Table = "default"
}
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
tableName := tenantId + "_" + req.Table
logger.Infof("Reading table '%v'", tableName)
if !re.Match([]byte(tableName)) {
return errors.BadRequest("db.read", fmt.Sprintf("table name %v is invalid", req.Table))
tableName, err := e.tableName(ctx, req.Table)
if err != nil {
return err
}
db, err := e.GetDBConn(ctx)
if err != nil {
return err
}
_, ok = c.Get(tableName)
_, ok := c.Get(tableName)
if !ok {
logger.Infof("Creating table '%v'", tableName)
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
@@ -295,18 +290,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe
if len(req.Id) == 0 {
return errors.BadRequest("db.delete", "missing id")
}
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
if req.Table == "" {
req.Table = "default"
}
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
tableName := tenantId + "_" + req.Table
if !re.Match([]byte(tableName)) {
return errors.BadRequest("db.delete", fmt.Sprintf("table name %v is invalid", req.Table))
tableName, err := e.tableName(ctx, req.Table)
if err != nil {
return err
}
logger.Infof("Deleting from table '%v'", tableName)
@@ -321,17 +307,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe
}
func (e *Db) Truncate(ctx context.Context, req *db.TruncateRequest, rsp *db.TruncateResponse) error {
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
if req.Table == "" {
req.Table = "default"
}
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
tableName := tenantId + "_" + req.Table
if !re.Match([]byte(tableName)) {
return errors.BadRequest("db.truncate", fmt.Sprintf("table name %v is invalid", req.Table))
tableName, err := e.tableName(ctx, req.Table)
if err != nil {
return err
}
logger.Infof("Truncating table '%v'", tableName)