db: Fix jsonb field access (#265)

This commit is contained in:
Janos Dobronszki
2021-11-11 11:53:13 +00:00
committed by GitHub
parent 8ac424dd15
commit fff15c6f5b
6 changed files with 659 additions and 58 deletions

View File

@@ -23,6 +23,7 @@ import (
const idKey = "id"
const stmt = "create table if not exists %v(id text not null, data jsonb, primary key(id)); alter table %v add created_at timestamptz; alter table %v add updated_at timestamptz"
const truncateStmt = `truncate table "%v"`
const renameTableStmt = `ALTER TABLE "%v" RENAME TO "%v"`
var re = regexp.MustCompile("^[a-zA-Z0-9_]*$")
var c = cache.New(5*time.Minute, 10*time.Minute)
@@ -40,19 +41,28 @@ type Db struct {
gorm2.Helper
}
func correctFieldName(s string) string {
func correctFieldName(s string, isText bool) string {
operator := "->"
if isText {
// https: //stackoverflow.com/questions/27215216/postgres-how-to-convert-a-json-string-to-text
operator = "->>"
}
switch s {
// top level fields can stay top level
case "id": // "created_at", "updated_at", <-- these are not special fields for now
return s
}
if !strings.Contains(s, ".") {
return fmt.Sprintf("data ->> '%v'", s)
return fmt.Sprintf("data %v '%v'", operator, s)
}
paths := strings.Split(s, ".")
ret := "data"
for _, path := range paths {
ret += fmt.Sprintf(" ->> '%v'", path)
for i, path := range paths {
if i == len(paths)-1 && isText {
ret += fmt.Sprintf(" ->> '%v'", path)
break
}
ret += fmt.Sprintf(" -> '%v'", path)
}
return ret
}
@@ -213,7 +223,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
db = db.Where("id = ?", req.Id)
} else {
for _, query := range queries {
logger.Infof("Query field: %v, op: %v, type: %v", query.Field, query.Op, query.Value)
logger.Infof("Query field: %v, op: %v, value: %v", query.Field, query.Op, query.Value)
typ := "text"
switch query.Value.(type) {
case int64:
@@ -236,7 +246,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
case itemNotEquals:
op = "!="
}
queryField := correctFieldName(query.Field)
queryField := correctFieldName(query.Field, typ == "text")
db = db.Where(fmt.Sprintf("(%v)::%v %v ?", queryField, typ, op), query.Value)
}
}
@@ -245,7 +255,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
if req.OrderBy != "" {
orderField = req.OrderBy
}
orderField = correctFieldName(orderField)
orderField = correctFieldName(orderField, false)
ordering := "asc"
if req.Order != "" {
@@ -260,7 +270,7 @@ func (e *Db) Read(ctx context.Context, req *db.ReadRequest, rsp *db.ReadResponse
}
db = db.Order(orderField + " " + ordering).Offset(int(req.Offset)).Limit(int(req.Limit))
err = db.Find(&recs).Error
err = db.Debug().Find(&recs).Error
if err != nil {
return err
}
@@ -343,3 +353,53 @@ func (e *Db) Count(ctx context.Context, req *db.CountRequest, rsp *db.CountRespo
rsp.Count = int32(a)
return nil
}
func (e *Db) RenameTable(ctx context.Context, req *db.RenameTableRequest, rsp *db.RenameTableResponse) error {
if req.From == "" || req.To == "" {
return errors.BadRequest("db.renameTable", "must provide table names")
}
oldtableName, err := e.tableName(ctx, req.From)
if err != nil {
return err
}
newtableName, err := e.tableName(ctx, req.To)
if err != nil {
return err
}
db, err := e.GetDBConn(ctx)
if err != nil {
return err
}
stmt := fmt.Sprintf(renameTableStmt, oldtableName, newtableName)
logger.Info(stmt)
return db.Debug().Exec(stmt).Error
}
func (e *Db) ListTables(ctx context.Context, req *db.ListTablesRequest, rsp *db.ListTablesResponse) error {
tenantId, ok := tenant.FromContext(ctx)
if !ok {
tenantId = "micro"
}
tenantId = strings.Replace(strings.Replace(tenantId, "/", "_", -1), "-", "_", -1)
db, err := e.GetDBConn(ctx)
if err != nil {
return err
}
var tables []string
if err := db.Table("information_schema.tables").Select("table_name").Where("table_schema = ?", "public").Find(&tables).Error; err != nil {
return err
}
rsp.Tables = []string{}
for _, v := range tables {
if strings.HasPrefix(v, tenantId) {
rsp.Tables = append(rsp.Tables, strings.Replace(v, tenantId+"_", "", -1))
}
}
return nil
}

View File

@@ -0,0 +1,234 @@
package handler
import (
"context"
"encoding/json"
"testing"
"database/sql"
"github.com/micro/micro/v3/service/auth"
db "github.com/micro/services/db/proto"
"google.golang.org/protobuf/types/known/structpb"
)
const dbAddr = "postgresql://postgres:postgres@postgres:5432/postgres?sslmode=disable"
func getHandler(t *testing.T) *Db {
sqlDB, err := sql.Open("pgx", dbAddr)
if err != nil {
t.Fatalf("Failed to open connection to DB %s", err)
}
h := &Db{}
h.DBConn(sqlDB)
return h
}
func TestBasic(t *testing.T) {
h := getHandler(t)
ctx := auth.ContextWithAccount(context.Background(), &auth.Account{Issuer: "basic_test", ID: "test"})
rs := []map[string]interface{}{
{
"name": "Jane",
"age": 42,
"isActive": true,
"id": "1",
},
{
"name": "Joe",
"age": 112,
"isActive": false,
"id": "2",
},
}
for _, v := range rs {
record, _ := json.Marshal(v)
rec := &structpb.Struct{}
err := rec.UnmarshalJSON(record)
if err != nil {
t.Fatal(err)
}
err = h.Create(ctx, &db.CreateRequest{
Table: "users",
Record: rec,
}, &db.CreateResponse{})
if err != nil {
t.Fatal(err)
}
}
t.Run("number ==", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "age == 112",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("number <", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "age < 100",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "1" {
t.Fatal(readRsp)
}
})
t.Run("number >", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "age > 100",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("number !=", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "age != 42",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("bool ==", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "isActive == false",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("bool !=", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "isActive != false",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "1" {
t.Fatal(readRsp)
}
})
t.Run("string ==", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "name == 'Jane'",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "1" {
t.Fatal(readRsp)
}
})
t.Run("string !=", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
Query: "name != 'Jane'",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("order number asc", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
OrderBy: "age",
Order: "asc",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 2 || readRsp.Records[0].AsMap()["id"].(string) != "1" || readRsp.Records[1].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("order number desc", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
OrderBy: "age",
Order: "desc",
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 2 || readRsp.Records[0].AsMap()["id"].(string) != "2" || readRsp.Records[1].AsMap()["id"].(string) != "1" {
t.Fatal(readRsp)
}
})
t.Run("order number desc, limit", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
OrderBy: "age",
Order: "desc",
Limit: 1,
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "2" {
t.Fatal(readRsp)
}
})
t.Run("order number desc, limit, offset", func(t *testing.T) {
readRsp := &db.ReadResponse{}
err := h.Read(ctx, &db.ReadRequest{
Table: "users",
OrderBy: "age",
Order: "desc",
Limit: 1,
Offset: 1,
}, readRsp)
if err != nil {
t.Fatal(err)
}
if len(readRsp.Records) != 1 || readRsp.Records[0].AsMap()["id"].(string) != "1" {
t.Fatal(readRsp)
}
})
}

View File

@@ -9,8 +9,8 @@ import (
)
func TestCorrectFieldName(t *testing.T) {
f := correctFieldName("a.b.c")
if f != "data ->> 'a' ->> 'b' ->> 'c'" {
f := correctFieldName("a.b.c", true)
if f != "data -> 'a' -> 'b' ->> 'c'" {
t.Fatal(f)
}
}