mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-16 04:54:42 +00:00
db: Fix jsonb field access (#265)
This commit is contained in:
@@ -23,6 +23,7 @@ import (
|
||||
const idKey = "id"
|
||||
const stmt = "create table if not exists %v(id text not null, data jsonb, primary key(id)); alter table %v add created_at timestamptz; alter table %v add updated_at timestamptz"
|
||||
const truncateStmt = `truncate table "%v"`
|
||||
const renameTableStmt = `ALTER TABLE "%v" RENAME TO "%v"`
|
||||
|
||||
var re = regexp.MustCompile("^[a-zA-Z0-9_]*$")
|
||||
var c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
@@ -40,19 +41,28 @@ type Db struct {
|
||||
gorm2.Helper
|
||||
}
|
||||
|
||||
func correctFieldName(s string) string {
|
||||
func correctFieldName(s string, isText bool) string {
|
||||
operator := "->"
|
||||
if isText {
|
||||
// https: //stackoverflow.com/questions/27215216/postgres-how-to-convert-a-json-string-to-text
|
||||
operator = "->>"
|
||||
}
|
||||
switch s {
|
||||
// top level fields can stay top level
|
||||
case "id": // "created_at", "updated_at", <-- these are not special fields for now
|
||||
return s
|
||||
}
|
||||
if !strings.Contains(s, ".") {
|
||||
return fmt.Sprintf("data ->> '%v'", s)
|
||||
return fmt.Sprintf("data %v '%v'", operator, s)
|
||||
}
|
||||
paths := strings.Split(s, ".")
|
||||
ret := "data"
|
||||
for _, path := range paths {
|
||||
ret += fmt.Sprintf(" ->> '%v'", path)
|
||||
for i, path := range paths {
|
||||
if i == len(paths)-1 && isText {
|
||||
ret += fmt.Sprintf(" ->> '%v'", path)
|
||||
break
|
||||
}
|
||||
ret += fmt.Sprintf(" -> '%v'", path)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
@@ -213,7 +223,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
|
||||
db = db.Where("id = ?", req.Id)
|
||||
} else {
|
||||
for _, query := range queries {
|
||||
logger.Infof("Query field: %v, op: %v, type: %v", query.Field, query.Op, query.Value)
|
||||
logger.Infof("Query field: %v, op: %v, value: %v", query.Field, query.Op, query.Value)
|
||||
typ := "text"
|
||||
switch query.Value.(type) {
|
||||
case int64:
|
||||
@@ -236,7 +246,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
|
||||
case itemNotEquals:
|
||||
op = "!="
|
||||
}
|
||||
queryField := correctFieldName(query.Field)
|
||||
queryField := correctFieldName(query.Field, typ == "text")
|
||||
db = db.Where(fmt.Sprintf("(%v)::%v %v ?", queryField, typ, op), query.Value)
|
||||
}
|
||||
}
|
||||
@@ -245,7 +255,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
|
||||
if req.OrderBy != "" {
|
||||
orderField = req.OrderBy
|
||||
}
|
||||
orderField = correctFieldName(orderField)
|
||||
orderField = correctFieldName(orderField, false)
|
||||
|
||||
ordering := "asc"
|
||||
if req.Order != "" {
|
||||
@@ -260,7 +270,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
|
||||
}
|
||||
|
||||
db = db.Order(orderField + " " + ordering).Offset(int(req.Offset)).Limit(int(req.Limit))
|
||||
err = db.Find(&recs).Error
|
||||
err = db.Debug().Find(&recs).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -343,3 +353,53 @@ func (e *Db) Count(ctx context.Context, req *db.CountRequest, rsp *db.CountRespo
|
||||
rsp.Count = int32(a)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Db) RenameTable(ctx context.Context, req *db.RenameTableRequest, rsp *db.RenameTableResponse) error {
|
||||
if req.From == "" || req.To == "" {
|
||||
return errors.BadRequest("db.renameTable", "must provide table names")
|
||||
}
|
||||
|
||||
oldtableName, err := e.tableName(ctx, req.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newtableName, err := e.tableName(ctx, req.To)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := e.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stmt := fmt.Sprintf(renameTableStmt, oldtableName, newtableName)
|
||||
logger.Info(stmt)
|
||||
return db.Debug().Exec(stmt).Error
|
||||
}
|
||||
|
||||
func (e *Db) ListTables(ctx context.Context, req *db.ListTablesRequest, rsp *db.ListTablesResponse) error {
|
||||
tenantId, ok := tenant.FromContext(ctx)
|
||||
if !ok {
|
||||
tenantId = "micro"
|
||||
}
|
||||
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
|
||||
|
||||
db, err := e.GetDBConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var tables []string
|
||||
if err := db.Table("information_schema.tables").Select("table_name").Where("table_schema = ?", "public").Find(&tables).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
rsp.Tables = []string{}
|
||||
for _, v := range tables {
|
||||
if strings.HasPrefix(v, tenantId) {
|
||||
rsp.Tables = append(rsp.Tables, strings.Replace(v, tenantId+"_", "", -1))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user