Multi tenant groups (#77)

* multitenant groups

* switch users service to use new wrapper

* fix tests

* skip pkg dir

* Check for auth
This commit is contained in:
Dominic Wong
2021-03-25 15:53:14 +00:00
committed by GitHub
parent b37cc09835
commit c42aeaa0a9
17 changed files with 592 additions and 125 deletions

View File

@@ -5,8 +5,12 @@ import (
"strings"
"github.com/google/uuid"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/groups/proto"
gorm2 "github.com/micro/services/pkg/gorm"
"gorm.io/gorm"
)
@@ -41,10 +45,14 @@ func (g *Group) Serialize() *pb.Group {
}
type Groups struct {
DB *gorm.DB
gorm2.Helper
}
func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.CreateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Name) == 0 {
return ErrMissingName
@@ -52,7 +60,13 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
// create the group object
group := &Group{ID: uuid.New().String(), Name: req.Name}
if err := g.DB.Create(group).Error; err != nil {
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if err := db.Create(group).Error; err != nil {
return ErrStore
}
@@ -62,14 +76,23 @@ func (g *Groups) Create(ctx context.Context, req *pb.CreateRequest, rsp *pb.Crea
}
func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Ids) == 0 {
return ErrMissingIDs
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// query the database
var groups []Group
if err := g.DB.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
if err := db.Model(&Group{}).Preload("Memberships").Where("id IN (?)", req.Ids).Find(&groups).Error; err != nil {
return ErrStore
}
@@ -83,6 +106,10 @@ func (g *Groups) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadResp
}
func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.UpdateResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Id) == 0 {
return ErrMissingID
@@ -90,8 +117,13 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
if len(req.Name) == 0 {
return ErrMissingName
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return g.DB.Transaction(func(tx *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// find the group
var group Group
if err := tx.Where(&Group{ID: req.Id}).First(&group).Error; err == gorm.ErrRecordNotFound {
@@ -113,13 +145,22 @@ func (g *Groups) Update(ctx context.Context, req *pb.UpdateRequest, rsp *pb.Upda
}
func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.DeleteResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.Id) == 0 {
return ErrMissingID
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// delete from the database
if err := g.DB.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
if err := db.Delete(&Group{ID: req.Id}).Error; err == gorm.ErrRecordNotFound {
return nil
} else if err != nil {
return ErrStore
@@ -129,10 +170,19 @@ func (g *Groups) Delete(ctx context.Context, req *pb.DeleteRequest, rsp *pb.Dele
}
func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
if len(req.MemberId) > 0 {
// only list groups the user is a member of
var ms []Membership
q := g.DB.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships")
q := db.Where(&Membership{MemberID: req.MemberId}).Preload("Group.Memberships")
if err := q.Find(&ms).Error; err != nil {
return err
}
@@ -145,7 +195,7 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
// load all groups
var groups []Group
if err := g.DB.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil {
if err := db.Model(&Group{}).Preload("Memberships").Find(&groups).Error; err != nil {
return ErrStore
}
@@ -159,6 +209,10 @@ func (g *Groups) List(ctx context.Context, req *pb.ListRequest, rsp *pb.ListResp
}
func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *pb.AddMemberResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.GroupId) == 0 {
return ErrMissingGroupID
@@ -166,8 +220,13 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
if len(req.MemberId) == 0 {
return ErrMissingMemberID
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
return g.DB.Transaction(func(tx *gorm.DB) error {
return db.Transaction(func(tx *gorm.DB) error {
// check the group exists
var group Group
if err := tx.Where(&Group{ID: req.GroupId}).First(&group).Error; err == gorm.ErrRecordNotFound {
@@ -191,6 +250,10 @@ func (g *Groups) AddMember(ctx context.Context, req *pb.AddMemberRequest, rsp *p
}
func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest, rsp *pb.RemoveMemberResponse) error {
_, ok := auth.AccountFromContext(ctx)
if !ok {
errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
}
// validate the request
if len(req.GroupId) == 0 {
return ErrMissingGroupID
@@ -199,9 +262,14 @@ func (g *Groups) RemoveMember(ctx context.Context, req *pb.RemoveMemberRequest,
return ErrMissingMemberID
}
db, err := g.GetDBConn(ctx)
if err != nil {
logger.Errorf("Error connecting to DB: %v", err)
return errors.InternalServerError("DB_ERROR", "Error connecting to DB")
}
// delete the membership
m := &Membership{MemberID: req.MemberId, GroupID: req.GroupId}
if err := g.DB.Where(m).Delete(m).Error; err != nil {
if err := db.Where(m).Delete(m).Error; err != nil {
return ErrStore
}

View File

@@ -2,16 +2,16 @@ package handler_test
import (
"context"
"database/sql"
"os"
"sort"
"testing"
"github.com/google/uuid"
"github.com/micro/micro/v3/service/auth"
"github.com/micro/services/groups/handler"
pb "github.com/micro/services/groups/proto"
"github.com/stretchr/testify/assert"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func testHandler(t *testing.T) *handler.Groups {
@@ -20,33 +20,30 @@ func testHandler(t *testing.T) *handler.Groups {
if len(addr) == 0 {
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
}
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
sqlDB, err := sql.Open("pgx", addr)
if err != nil {
t.Fatalf("Error connecting to database: %v", err)
t.Fatalf("Failed to open connection to DB %s", err)
}
// clean any data from a previous run
if err := db.Exec("DROP TABLE IF EXISTS groups, memberships CASCADE").Error; err != nil {
if _, err := sqlDB.Exec("DROP TABLE IF EXISTS micro_groups, micro_memberships CASCADE"); err != nil {
t.Fatalf("Error cleaning database: %v", err)
}
// migrate the database
if err := db.AutoMigrate(&handler.Group{}, &handler.Membership{}); err != nil {
t.Fatalf("Error migrating database: %v", err)
}
return &handler.Groups{DB: db}
h := &handler.Groups{}
h.DBConn(sqlDB).Migrations(&handler.Group{}, &handler.Membership{})
return h
}
func TestCreate(t *testing.T) {
h := testHandler(t)
t.Run("MissingName", func(t *testing.T) {
err := h.Create(context.TODO(), &pb.CreateRequest{}, &pb.CreateResponse{})
err := h.Create(microAccountCtx(), &pb.CreateRequest{}, &pb.CreateResponse{})
assert.Equal(t, handler.ErrMissingName, err)
})
t.Run("Valid", func(t *testing.T) {
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Doe Family Group",
}, &pb.CreateResponse{})
assert.NoError(t, err)
@@ -57,21 +54,21 @@ func TestUpdate(t *testing.T) {
h := testHandler(t)
t.Run("MissingID", func(t *testing.T) {
err := h.Update(context.TODO(), &pb.UpdateRequest{
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Name: "Doe Family Group",
}, &pb.UpdateResponse{})
assert.Equal(t, handler.ErrMissingID, err)
})
t.Run("MissingName", func(t *testing.T) {
err := h.Update(context.TODO(), &pb.UpdateRequest{
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: uuid.New().String(),
}, &pb.UpdateResponse{})
assert.Equal(t, handler.ErrMissingName, err)
})
t.Run("NotFound", func(t *testing.T) {
err := h.Update(context.TODO(), &pb.UpdateRequest{
err := h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: uuid.New().String(),
Name: "Bar Family Group",
}, &pb.UpdateResponse{})
@@ -81,19 +78,19 @@ func TestUpdate(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
// create a demo group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Doe Family Group",
}, &cRsp)
assert.NoError(t, err)
err = h.Update(context.TODO(), &pb.UpdateRequest{
err = h.Update(microAccountCtx(), &pb.UpdateRequest{
Id: cRsp.Group.Id,
Name: "Bar Family Group",
}, &pb.UpdateResponse{})
assert.NoError(t, err)
var rRsp pb.ReadResponse
err = h.Read(context.TODO(), &pb.ReadRequest{
err = h.Read(microAccountCtx(), &pb.ReadRequest{
Ids: []string{cRsp.Group.Id},
}, &rRsp)
assert.NoError(t, err)
@@ -111,12 +108,12 @@ func TestDelete(t *testing.T) {
h := testHandler(t)
t.Run("MissingID", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{}, &pb.DeleteResponse{})
assert.Equal(t, handler.ErrMissingID, err)
})
t.Run("NotFound", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
Id: uuid.New().String(),
}, &pb.DeleteResponse{})
assert.NoError(t, err)
@@ -124,19 +121,19 @@ func TestDelete(t *testing.T) {
// create a demo group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Doe Family Group",
}, &cRsp)
assert.NoError(t, err)
t.Run("Valid", func(t *testing.T) {
err := h.Delete(context.TODO(), &pb.DeleteRequest{
err := h.Delete(microAccountCtx(), &pb.DeleteRequest{
Id: cRsp.Group.Id,
}, &pb.DeleteResponse{})
assert.NoError(t, err)
var rRsp pb.ReadResponse
err = h.Read(context.TODO(), &pb.ReadRequest{
err = h.Read(microAccountCtx(), &pb.ReadRequest{
Ids: []string{cRsp.Group.Id},
}, &rRsp)
assert.Nil(t, rRsp.Groups[cRsp.Group.Id])
@@ -147,27 +144,27 @@ func TestList(t *testing.T) {
// create two demo groups
var cRsp1 pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Alpha Group",
}, &cRsp1)
assert.NoError(t, err)
var cRsp2 pb.CreateResponse
err = h.Create(context.TODO(), &pb.CreateRequest{
err = h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Bravo Group",
}, &cRsp2)
assert.NoError(t, err)
// add a member to the first group
uid := uuid.New().String()
err = h.AddMember(context.TODO(), &pb.AddMemberRequest{
err = h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: cRsp1.Group.Id, MemberId: uid,
}, &pb.AddMemberResponse{})
assert.NoError(t, err)
t.Run("Unscoped", func(t *testing.T) {
var rsp pb.ListResponse
err = h.List(context.TODO(), &pb.ListRequest{}, &rsp)
err = h.List(microAccountCtx(), &pb.ListRequest{}, &rsp)
assert.NoError(t, err)
assert.Lenf(t, rsp.Groups, 2, "Two groups should be returned")
if len(rsp.Groups) != 2 {
@@ -188,13 +185,12 @@ func TestList(t *testing.T) {
t.Run("Scoped", func(t *testing.T) {
var rsp pb.ListResponse
err = h.List(context.TODO(), &pb.ListRequest{MemberId: uid}, &rsp)
err = h.List(microAccountCtx(), &pb.ListRequest{MemberId: uid}, &rsp)
assert.NoError(t, err)
assert.Lenf(t, rsp.Groups, 1, "One group should be returned")
if len(rsp.Groups) != 1 {
return
}
assert.Equal(t, cRsp1.Group.Id, rsp.Groups[0].Id)
assert.Equal(t, cRsp1.Group.Name, rsp.Groups[0].Name)
assert.Len(t, rsp.Groups[0].MemberIds, 1)
@@ -206,21 +202,21 @@ func TestAddMember(t *testing.T) {
h := testHandler(t)
t.Run("MissingGroupID", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
assert.Equal(t, handler.ErrMissingGroupID, err)
})
t.Run("MissingMemberID", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: uuid.New().String(),
}, &pb.AddMemberResponse{})
assert.Equal(t, handler.ErrMissingMemberID, err)
})
t.Run("GroupNotFound", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: uuid.New().String(),
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
@@ -229,13 +225,13 @@ func TestAddMember(t *testing.T) {
// create a test group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Alpha Group",
}, &cRsp)
assert.NoError(t, err)
t.Run("Valid", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
@@ -243,7 +239,7 @@ func TestAddMember(t *testing.T) {
})
t.Run("Retry", func(t *testing.T) {
err := h.AddMember(context.TODO(), &pb.AddMemberRequest{
err := h.AddMember(microAccountCtx(), &pb.AddMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.AddMemberResponse{})
@@ -255,14 +251,14 @@ func TestRemoveMember(t *testing.T) {
h := testHandler(t)
t.Run("MissingGroupID", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
MemberId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
assert.Equal(t, handler.ErrMissingGroupID, err)
})
t.Run("MissingMemberID", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
GroupId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
assert.Equal(t, handler.ErrMissingMemberID, err)
@@ -270,13 +266,13 @@ func TestRemoveMember(t *testing.T) {
// create a test group
var cRsp pb.CreateResponse
err := h.Create(context.TODO(), &pb.CreateRequest{
err := h.Create(microAccountCtx(), &pb.CreateRequest{
Name: "Alpha Group",
}, &cRsp)
assert.NoError(t, err)
t.Run("Valid", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
@@ -284,10 +280,16 @@ func TestRemoveMember(t *testing.T) {
})
t.Run("Retry", func(t *testing.T) {
err := h.RemoveMember(context.TODO(), &pb.RemoveMemberRequest{
err := h.RemoveMember(microAccountCtx(), &pb.RemoveMemberRequest{
GroupId: cRsp.Group.Id,
MemberId: uuid.New().String(),
}, &pb.RemoveMemberResponse{})
assert.NoError(t, err)
})
}
func microAccountCtx() context.Context {
return auth.ContextWithAccount(context.TODO(), &auth.Account{
Issuer: "micro",
})
}