use store in groups service (#95)

* use store in groups service

* fix message

* fix main
This commit is contained in:
Asim Aslam
2021-05-06 10:22:06 +01:00
committed by GitHub
parent b4c1b48b56
commit e14a246604
3 changed files with 216 additions and 153 deletions

View File

@@ -2,16 +2,14 @@ package handler
import ( import (
"context" "context"
"strings" "fmt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/store"
pb "github.com/micro/services/groups/proto" pb "github.com/micro/services/groups/proto"
gorm2 "github.com/micro/services/pkg/gorm" "github.com/micro/services/pkg/tenant"
"gorm.io/gorm"
) )
var ( var (
@@ -25,28 +23,48 @@ var (
) )
type Group struct { type Group struct {
ID string ID string
Name string Name string
Memberships []Membership Members []string
} }
type Membership struct { type Member struct {
MemberID string `gorm:"uniqueIndex:idx_membership"` ID string
GroupID string `gorm:"uniqueIndex:idx_membership"` Group string
Group Group }
func (g *Group) Key(ctx context.Context) string {
key := fmt.Sprintf("group:%s", g.ID)
t, ok := tenant.FromContext(ctx)
if !ok {
return key
}
return fmt.Sprintf("%s/%s", t, key)
}
func (m *Member) Key(ctx context.Context) string {
key := fmt.Sprintf("member:%s:%s", m.ID, m.Group)
t, ok := tenant.FromContext(ctx)
if !ok {
return key
}
return fmt.Sprintf("%s/%s", t, key)
} }
func (g *Group) Serialize() *pb.Group { func (g *Group) Serialize() *pb.Group {
memberIDs := make([]string, len(g.Memberships)) memberIDs := make([]string, len(g.Members))
for i, m := range g.Memberships { for i, m := range g.Members {
memberIDs[i] = m.MemberID memberIDs[i] = m
} }
return &pb.Group{Id: g.ID, Name: g.Name, MemberIds: memberIDs} return &pb.Group{Id: g.ID, Name: g.Name, MemberIds: memberIDs}
} }
type Groups struct { type Groups struct{}
gorm2.Helper
}
func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error { func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error {
_, ok := auth.AccountFromContext(ctx) _, ok := auth.AccountFromContext(ctx)
@@ -60,18 +78,15 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
// create the group object // create the group object
group := &Group{ID: uuid.New().String(), Name: req.Name} group := &Group{ID: uuid.New().String(), Name: req.Name}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if err := db.Create(group).Error; err != nil { // write the group record
if err := store.Write(store.NewRecord(group.Key(ctx), group)); err != nil {
return ErrStore return ErrStore
} }
// return the group // return the group
rsp.Group = group.Serialize() rsp.Group = group.Serialize()
return nil return nil
} }
@@ -85,21 +100,24 @@ func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResp
return ErrMissingIDs return ErrMissingIDs
} }
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// query the database
var groups []Group
if err := db.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
return ErrStore
}
// serialize the response // serialize the response
rsp.Groups = make(map[string]*pb.Group, len(groups)) rsp.Groups = make(map[string]*pb.Group)
for _, g := range groups {
rsp.Groups[g.ID] = g.Serialize() for _, id := range req.Ids {
group := &Group{
ID: id,
}
recs, err := store.Read(group.Key(ctx), store.ReadLimit(1))
if err != nil {
return ErrStore
}
if len(recs) == 0 {
continue
}
if err := recs[0].Decode(&group); err != nil {
continue
}
rsp.Groups[group.ID] = group.Serialize()
} }
return nil return nil
@@ -117,31 +135,31 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
if len(req.Name) == 0 { if len(req.Name) == 0 {
return ErrMissingName return ErrMissingName
} }
db, err := g.GetDBConn(ctx)
if err != nil { group := &Group{ID: req.Id}
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB") recs, err := store.Read(group.Key(ctx), store.ReadLimit(1))
if err == store.ErrNotFound {
return ErrNotFound
} else if err != nil {
return ErrStore
} }
return db.Transaction(func(tx *gorm.DB) error { // decode the record
// find the group recs[0].Decode(&group)
var group Group
if err := tx.Where(&Group{ID: req.Id}).First(&group).Error; err == gorm.ErrRecordNotFound {
return ErrNotFound
} else if err != nil {
return ErrStore
}
// update the group // set the name
group.Name = req.Name group.Name = req.Name
if err := tx.Save(&group).Error; err != nil {
return ErrStore
}
// serialize the response // save the record
rsp.Group = group.Serialize() if err := store.Write(store.NewRecord(group.Key(ctx), group)); err != nil {
return nil return ErrStore
}) }
// serialize the response
rsp.Group = group.Serialize()
return nil
} }
func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error { func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error {
@@ -154,18 +172,35 @@ func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Dele
return ErrMissingID return ErrMissingID
} }
db, err := g.GetDBConn(ctx) group := &Group{ID: req.Id}
if err != nil {
logger.Errorf("Error connecting to DB: %v", err) // get the group
return errors.InternalServerError("DB_ERROR", "Error connecting to DB") recs, err := store.Read(group.Key(ctx), store.ReadLimit(1))
} if err == store.ErrNotFound {
// delete from the database
if err := db.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
return nil return nil
} else if err != nil { } else if err != nil {
return ErrStore return ErrStore
} }
// decode the record
recs[0].Decode(&group)
// delete the record
if err := store.Delete(group.Key(ctx)); err == store.ErrNotFound {
return nil
} else if err != nil {
return ErrStore
}
// delete all the members
for _, memberId := range group.Members {
m := &Member{
ID: memberId,
}
// delete the member
store.Delete(m.Key(ctx))
}
return nil return nil
} }
@@ -174,35 +209,47 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
if !ok { if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized") errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
} }
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if len(req.MemberId) > 0 { if len(req.MemberId) > 0 {
// only list groups the user is a member of // only list groups the user is a member of
var ms []Membership m := &Member{ID: req.MemberId}
q := db.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships") recs, err := store.Read(m.Key(ctx), store.ReadPrefix())
if err := q.Find(&ms).Error; err != nil { if err != nil {
return err return ErrStore
} }
rsp.Groups = make([]*pb.Group, len(ms))
for i, m := range ms { for _, rec := range recs {
rsp.Groups[i] = m.Group.Serialize() m := &Member{ID: req.MemberId}
rec.Decode(&m)
// get the group
group := &Group{ID: m.Group}
grecs, err := store.Read(group.Key(ctx), store.ReadLimit(1))
if err != nil {
return ErrStore
}
grecs[0].Decode(&group)
rsp.Groups = append(rsp.Groups, group.Serialize())
} }
return nil return nil
} }
// load all groups group := &Group{}
var groups []Group
if err := db.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil { // read all the prefixes
recs, err := store.Read(group.Key(ctx), store.ReadPrefix())
if err != nil {
return ErrStore return ErrStore
} }
// serialize the response // serialize and return response
rsp.Groups = make([]*pb.Group, len(groups)) for _, rec := range recs {
for i, g := range groups { group := new(Group)
rsp.Groups[i] = g.Serialize() rec.Decode(&group)
rsp.Groups = append(rsp.Groups, group.Serialize())
} }
return nil return nil
@@ -220,33 +267,54 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
if len(req.MemberId) == 0 { if len(req.MemberId) == 0 {
return ErrMissingMemberID return ErrMissingMemberID
} }
db, err := g.GetDBConn(ctx)
if err != nil { // read the group
logger.Errorf("Error connecting to DB: %v", err) group := &Group{ID: req.GroupId}
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
recs, err := store.Read(group.Key(ctx), store.ReadLimit(1))
if err == store.ErrNotFound {
return ErrNotFound
} else if err != nil {
return ErrStore
} }
return db.Transaction(func(tx *gorm.DB) error { // decode the record
// check the group exists recs[0].Decode(group)
var group Group
if err := tx.Where(&Group{ID: req.GroupId}).First(&group).Error; err == gorm.ErrRecordNotFound {
return ErrNotFound
} else if err != nil {
return err
}
// create the membership var seen bool
m := &Membership{MemberID: req.MemberId, GroupID: req.GroupId} for _, member := range group.Members {
err := tx.Create(m).Error if member == req.MemberId {
// check for membership already existing (unique index violation) seen = true
if err != nil && strings.Contains(err.Error(), "fk_groups_memberships") { break
return nil
} else if err != nil {
return ErrStore
} }
}
// already a member
if seen {
return nil return nil
}) }
// add the member
group.Members = append(group.Members, req.MemberId)
// save the record
if err := store.Write(store.NewRecord(group.Key(ctx), group)); err != nil {
return ErrStore
}
// add the member record
m := &Member{
ID: req.MemberId,
Group: group.ID,
}
// write the record
if err := store.Write(store.NewRecord(m.Key(ctx), m)); err != nil {
return ErrStore
}
return nil
} }
func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, rsp *pb.RemoveMemberResponse) error { func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, rsp *pb.RemoveMemberResponse) error {
@@ -262,14 +330,44 @@ func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest,
return ErrMissingMemberID return ErrMissingMemberID
} }
db, err := g.GetDBConn(ctx) // read the group
if err != nil { group := &Group{ID: req.GroupId}
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB") // read the gruop
recs, err := store.Read(group.Key(ctx), store.ReadLimit(1))
if err == store.ErrNotFound {
return ErrNotFound
} else if err != nil {
return ErrStore
} }
// delete the membership
m := &Membership{MemberID: req.MemberId, GroupID: req.GroupId} // decode the record
if err := db.Where(m).Delete(m).Error; err != nil { recs[0].Decode(&group)
// new member id list
var members []string
for _, member := range group.Members {
if member == req.MemberId {
continue
}
members = append(members, member)
}
// update the member
group.Members = members
// save the record
if err := store.Write(store.NewRecord(group.Key(ctx), group)); err != nil {
return ErrStore
}
// delete the member
m := &Member{
ID: req.MemberId,
Group: group.ID,
}
if err := store.Delete(m.Key(ctx)); err != nil {
return ErrStore return ErrStore
} }

View File

@@ -2,37 +2,21 @@ package handler_test
import ( import (
"context" "context"
"database/sql"
"os"
"sort" "sort"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/micro/micro/v3/service/auth" "github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/store"
"github.com/micro/micro/v3/service/store/memory"
"github.com/micro/services/groups/handler" "github.com/micro/services/groups/handler"
pb "github.com/micro/services/groups/proto" pb "github.com/micro/services/groups/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func testHandler(t *testing.T) *handler.Groups { func testHandler(t *testing.T) *handler.Groups {
// connect to the database store.DefaultStore = memory.NewStore()
addr := os.Getenv("POSTGRES_URL") return &handler.Groups{}
if len(addr) == 0 {
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
}
sqlDB, err := sql.Open("pgx", addr)
if err != nil {
t.Fatalf("Failed to open connection to DB %s", err)
}
// clean any data from a previous run
if _, err := sqlDB.Exec(`DROP TABLE IF EXISTS "micro_someID_groups", "micro_someID_memberships" CASCADE`); err != nil {
t.Fatalf("Error cleaning database: %v", err)
}
h := &handler.Groups{}
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
return h
} }
func TestCreate(t *testing.T) { func TestCreate(t *testing.T) {
h := testHandler(t) h := testHandler(t)

View File

@@ -1,19 +1,12 @@
package main package main
import ( import (
"database/sql"
"github.com/micro/micro/v3/service" "github.com/micro/micro/v3/service"
"github.com/micro/micro/v3/service/config"
"github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/logger"
"github.com/micro/services/groups/handler" "github.com/micro/services/groups/handler"
pb "github.com/micro/services/groups/proto" pb "github.com/micro/services/groups/proto"
_ "github.com/jackc/pgx/v4/stdlib"
) )
var dbAddress = "postgresql://postgres:postgres@localhost:5432/groups?sslmode=disable"
func main() { func main() {
// Create service // Create service
srv := service.New( srv := service.New(
@@ -21,20 +14,8 @@ func main() {
service.Version("latest"), service.Version("latest"),
) )
// Connect to the database
cfg, err := config.Get("groups.database")
if err != nil {
logger.Fatalf("Error loading config: %v", err)
}
addr := cfg.String(dbAddress)
sqlDB, err := sql.Open("pgx", addr)
if err != nil {
logger.Fatalf("Failed to open connection to DB %s", err)
}
h := &handler.Groups{}
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
// Register handler // Register handler
pb.RegisterGroupsHandler(srv.Server(), h) pb.RegisterGroupsHandler(srv.Server(), new(handler.Groups))
// Run service // Run service
if err := srv.Run(); err != nil { if err := srv.Run(); err != nil {