mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-11 19:04:35 +00:00
db: More sanitization for table names (#243)
This commit is contained in:
@@ -57,22 +57,35 @@ func correctFieldName(s string) string {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Db) tableName(ctx context.Context, t string) (string, error) {
|
||||||
|
tenantId, ok := tenant.FromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
tenantId = "micro"
|
||||||
|
}
|
||||||
|
if t == "" {
|
||||||
|
t = "default"
|
||||||
|
}
|
||||||
|
t = strings.ToLower(t)
|
||||||
|
t = strings.Replace(t, "-", "_", -1)
|
||||||
|
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||||
|
|
||||||
|
tableName := tenantId + "_" + t
|
||||||
|
if !re.Match([]byte(tableName)) {
|
||||||
|
return "", fmt.Errorf("table name %v is invalid", t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tableName, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Call is a single request handler called via client.Call or the generated client code
|
// Call is a single request handler called via client.Call or the generated client code
|
||||||
func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateResponse) error {
|
func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateResponse) error {
|
||||||
if len(req.Record.AsMap()) == 0 {
|
if len(req.Record.AsMap()) == 0 {
|
||||||
return errors.BadRequest("db.create", "missing record")
|
return errors.BadRequest("db.create", "missing record")
|
||||||
}
|
}
|
||||||
tenantId, ok := tenant.FromContext(ctx)
|
|
||||||
if !ok {
|
tableName, err := e.tableName(ctx, req.Table)
|
||||||
tenantId = "micro"
|
if err != nil {
|
||||||
}
|
return err
|
||||||
if req.Table == "" {
|
|
||||||
req.Table = "default"
|
|
||||||
}
|
|
||||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
|
||||||
tableName := tenantId + "_" + req.Table
|
|
||||||
if !re.Match([]byte(tableName)) {
|
|
||||||
return errors.BadRequest("db.create", fmt.Sprintf("table name %v is invalid", req.Table))
|
|
||||||
}
|
}
|
||||||
logger.Infof("Inserting into table '%v'", tableName)
|
logger.Infof("Inserting into table '%v'", tableName)
|
||||||
|
|
||||||
@@ -80,7 +93,7 @@ func (e *Db) Create(ctx context.Context, req *db.CreateRequest, rsp *db.CreateRe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, ok = c.Get(tableName)
|
_, ok := c.Get(tableName)
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Infof("Creating table '%v'", tableName)
|
logger.Infof("Creating table '%v'", tableName)
|
||||||
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
|
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
|
||||||
@@ -111,17 +124,9 @@ func (e *Db) Update(ctx context.Context, req *db.UpdateRequest, rsp *db.UpdateRe
|
|||||||
if len(req.Record.AsMap()) == 0 {
|
if len(req.Record.AsMap()) == 0 {
|
||||||
return errors.BadRequest("db.update", "missing record")
|
return errors.BadRequest("db.update", "missing record")
|
||||||
}
|
}
|
||||||
tenantId, ok := tenant.FromContext(ctx)
|
tableName, err := e.tableName(ctx, req.Table)
|
||||||
if !ok {
|
if err != nil {
|
||||||
tenantId = "micro"
|
return err
|
||||||
}
|
|
||||||
if req.Table == "" {
|
|
||||||
req.Table = "default"
|
|
||||||
}
|
|
||||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
|
||||||
tableName := tenantId + "_" + req.Table
|
|
||||||
if !re.Match([]byte(tableName)) {
|
|
||||||
return errors.BadRequest("db.update", fmt.Sprintf("table name %v is invalid", req.Table))
|
|
||||||
}
|
}
|
||||||
logger.Infof("Updating table '%v'", tableName)
|
logger.Infof("Updating table '%v'", tableName)
|
||||||
|
|
||||||
@@ -179,26 +184,16 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tenantId, ok := tenant.FromContext(ctx)
|
tableName, err := e.tableName(ctx, req.Table)
|
||||||
if !ok {
|
if err != nil {
|
||||||
tenantId = "micro"
|
return err
|
||||||
}
|
|
||||||
if req.Table == "" {
|
|
||||||
req.Table = "default"
|
|
||||||
}
|
|
||||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
|
||||||
tableName := tenantId + "_" + req.Table
|
|
||||||
logger.Infof("Reading table '%v'", tableName)
|
|
||||||
|
|
||||||
if !re.Match([]byte(tableName)) {
|
|
||||||
return errors.BadRequest("db.read", fmt.Sprintf("table name %v is invalid", req.Table))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := e.GetDBConn(ctx)
|
db, err := e.GetDBConn(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, ok = c.Get(tableName)
|
_, ok := c.Get(tableName)
|
||||||
if !ok {
|
if !ok {
|
||||||
logger.Infof("Creating table '%v'", tableName)
|
logger.Infof("Creating table '%v'", tableName)
|
||||||
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
|
db.Exec(fmt.Sprintf(stmt, tableName, tableName, tableName))
|
||||||
@@ -295,18 +290,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe
|
|||||||
if len(req.Id) == 0 {
|
if len(req.Id) == 0 {
|
||||||
return errors.BadRequest("db.delete", "missing id")
|
return errors.BadRequest("db.delete", "missing id")
|
||||||
}
|
}
|
||||||
|
tableName, err := e.tableName(ctx, req.Table)
|
||||||
tenantId, ok := tenant.FromContext(ctx)
|
if err != nil {
|
||||||
if !ok {
|
return err
|
||||||
tenantId = "micro"
|
|
||||||
}
|
|
||||||
if req.Table == "" {
|
|
||||||
req.Table = "default"
|
|
||||||
}
|
|
||||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
|
||||||
tableName := tenantId + "_" + req.Table
|
|
||||||
if !re.Match([]byte(tableName)) {
|
|
||||||
return errors.BadRequest("db.delete", fmt.Sprintf("table name %v is invalid", req.Table))
|
|
||||||
}
|
}
|
||||||
logger.Infof("Deleting from table '%v'", tableName)
|
logger.Infof("Deleting from table '%v'", tableName)
|
||||||
|
|
||||||
@@ -321,17 +307,9 @@ func (e *Db) Delete(ctx context.Context, req *db.DeleteRequest, rsp *db.DeleteRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Db) Truncate(ctx context.Context, req *db.TruncateRequest, rsp *db.TruncateResponse) error {
|
func (e *Db) Truncate(ctx context.Context, req *db.TruncateRequest, rsp *db.TruncateResponse) error {
|
||||||
tenantId, ok := tenant.FromContext(ctx)
|
tableName, err := e.tableName(ctx, req.Table)
|
||||||
if !ok {
|
if err != nil {
|
||||||
tenantId = "micro"
|
return err
|
||||||
}
|
|
||||||
if req.Table == "" {
|
|
||||||
req.Table = "default"
|
|
||||||
}
|
|
||||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
|
||||||
tableName := tenantId + "_" + req.Table
|
|
||||||
if !re.Match([]byte(tableName)) {
|
|
||||||
return errors.BadRequest("db.truncate", fmt.Sprintf("table name %v is invalid", req.Table))
|
|
||||||
}
|
}
|
||||||
logger.Infof("Truncating table '%v'", tableName)
|
logger.Infof("Truncating table '%v'", tableName)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user