From 5368aec8e41b51531aa9277b01677455b885a919 Mon Sep 17 00:00:00 2001 From: Asim Aslam Date: Sun, 2 May 2021 18:11:22 +0100 Subject: [PATCH] Add tenancy to seen and otp services --- otp/handler/otp.go | 6 +++--- pkg/cache/cache.go | 38 +++++++++++++++++++++++++++++++++----- pkg/tenant/tenant.go | 27 +++++++++++++++++++++++++++ seen/handler/handler.go | 31 +++++++++++++++++++++++-------- 4 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 pkg/tenant/tenant.go diff --git a/otp/handler/otp.go b/otp/handler/otp.go index 7137da7..19a9246 100644 --- a/otp/handler/otp.go +++ b/otp/handler/otp.go @@ -23,7 +23,7 @@ func (e *Otp) Generate(ctx context.Context, req *pb.GenerateRequest, rsp *pb.Gen // check if a key exists for the user var secret string - if err := cache.Get(req.Id, &secret); err != nil { + if err := cache.Context(ctx).Get(req.Id, &secret); err != nil { // generate a key key, err := totp.Generate(totp.GenerateOpts{ Issuer: "Micro", @@ -38,7 +38,7 @@ func (e *Otp) Generate(ctx context.Context, req *pb.GenerateRequest, rsp *pb.Gen secret = key.Secret() - if err := cache.Put(req.Id, secret, time.Now().Add(time.Minute*5)); err != nil { + if err := cache.Context(ctx).Put(req.Id, secret, time.Now().Add(time.Minute*5)); err != nil { logger.Error("Failed to store secret: %v", err) return errors.InternalServerError("otp.generate", "failed to generate code") } @@ -72,7 +72,7 @@ func (e *Otp) Validate(ctx context.Context, req *pb.ValidateRequest, rsp *pb.Val var secret string - if err := cache.Get(req.Id, &secret); err != nil { + if err := cache.Context(ctx).Get(req.Id, &secret); err != nil { logger.Error("Failed to get secret from store: %v", err) return errors.InternalServerError("otp.generate", "failed to validate code") } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index cd5e4ec..b8934e8 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -2,20 +2,26 @@ package cache import ( + "context" "encoding/json" + "fmt" "time" "github.com/micro/micro/v3/service/store" + "github.com/micro/services/pkg/tenant" ) type Cache interface { + // Context returns a tenant scoped Cache + Context(ctx context.Context) Cache Get(key string, val interface{}) error Put(key string, val interface{}, expires time.Time) error Delete(key string) error } type cache struct { - Store store.Store + Store store.Store + Prefix string } var ( @@ -23,7 +29,25 @@ var ( ) func New(st store.Store) Cache { - return &cache{st} + return &cache{Store: st} +} + +func (c *cache) Key(k string) string { + if len(c.Prefix) > 0 { + return fmt.Sprintf("%s/%s", c.Prefix, k) + } + return k +} + +func (c *cache) Context(ctx context.Context) Cache { + t, ok := tenant.FromContext(ctx) + if !ok { + return c + } + return &cache{ + Store: c.Store, + Prefix: t, + } } func (c *cache) Get(key string, val interface{}) error { @@ -31,7 +55,7 @@ func (c *cache) Get(key string, val interface{}) error { c.Store = store.DefaultStore } - recs, err := c.Store.Read(key, store.ReadLimit(1)) + recs, err := c.Store.Read(c.Key(key), store.ReadLimit(1)) if err != nil { return err } @@ -57,7 +81,7 @@ func (c *cache) Put(key string, val interface{}, expires time.Time) error { expiry = time.Duration(0) } return c.Store.Write(&store.Record{ - Key: key, + Key: c.Key(key), Value: b, Expiry: expiry, }) @@ -67,7 +91,11 @@ func (c *cache) Delete(key string) error { if c.Store == nil { c.Store = store.DefaultStore } - return c.Store.Delete(key) + return c.Store.Delete(c.Key(key)) +} + +func Context(ctx context.Context) Cache { + return DefaultCache.Context(ctx) } func Get(key string, val interface{}) error { diff --git a/pkg/tenant/tenant.go b/pkg/tenant/tenant.go new file mode 100644 index 0000000..3d8b3c4 --- /dev/null +++ b/pkg/tenant/tenant.go @@ -0,0 +1,27 @@ +// Package tenant provides multi-tenancy helpers +package tenant + +import ( + "context" + "fmt" + + "github.com/micro/micro/v3/service/auth" +) + +// FromContext returns a tenant from the context +func FromContext(ctx context.Context) (string, bool) { + acc, ok := auth.AccountFromContext(ctx) + if !ok { + return "", false + } + return FromAccount(acc), true +} + +// FromAccount returns a tenant from +func FromAccount(acc *auth.Account) string { + owner := acc.Metadata["apikey_owner"] + if len(owner) == 0 { + owner = acc.ID + } + return fmt.Sprintf("%s/%s", acc.Issuer, owner) +} diff --git a/seen/handler/handler.go b/seen/handler/handler.go index 98c4e82..cb39b44 100644 --- a/seen/handler/handler.go +++ b/seen/handler/handler.go @@ -12,6 +12,7 @@ import ( "github.com/micro/micro/v3/service/errors" "github.com/micro/micro/v3/service/logger" "github.com/micro/micro/v3/service/store" + "github.com/micro/services/pkg/tenant" pb "github.com/micro/services/seen/proto" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -34,8 +35,15 @@ type Record struct { Timestamp time.Time } -func (r *Record) Key() string { - return fmt.Sprintf("%s:%s:%s", r.UserID, r.ResourceType, r.ResourceID) +func (r *Record) Key(ctx context.Context) string { + key := fmt.Sprintf("%s:%s:%s", r.UserID, r.ResourceType, r.ResourceID) + + t, ok := tenant.FromContext(ctx) + if !ok { + return key + } + + return fmt.Sprintf("%s/%s", t, key) } func (r *Record) Marshal() []byte { @@ -76,7 +84,7 @@ func (s *Seen) Set(ctx context.Context, req *pb.SetRequest, rsp *pb.SetResponse) ResourceType: req.ResourceType, } - _, err := store.Read(instance.Key(), store.ReadLimit(1)) + _, err := store.Read(instance.Key(ctx), store.ReadLimit(1)) if err == store.ErrNotFound { instance.ID = uuid.New().String() } else if err != nil { @@ -88,7 +96,7 @@ func (s *Seen) Set(ctx context.Context, req *pb.SetRequest, rsp *pb.SetResponse) instance.Timestamp = req.Timestamp.AsTime() if err := store.Write(&store.Record{ - Key: instance.Key(), + Key: instance.Key(ctx), Value: instance.Marshal(), }); err != nil { logger.Errorf("Error with store: %v", err) @@ -123,7 +131,7 @@ func (s *Seen) Unset(ctx context.Context, req *pb.UnsetRequest, rsp *pb.UnsetRes } // delete the object from the store - if err := store.Delete(instance.Key()); err != nil { + if err := store.Delete(instance.Key(ctx)); err != nil { logger.Errorf("Error with store: %v", err) return ErrStore } @@ -150,8 +158,10 @@ func (s *Seen) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRespon return ErrMissingResourceType } - // create a key prefix - key := fmt.Sprintf("%s:%s:", req.UserId, req.ResourceType) + rec := &Record{ + UserID: req.UserId, + ResourceType: req.ResourceType, + } var recs []*store.Record var err error @@ -159,9 +169,14 @@ func (s *Seen) Read(ctx context.Context, req *pb.ReadRequest, rsp *pb.ReadRespon // get the records for the resource type if len(req.ResourceIds) == 1 { // read the key itself - key = key + req.ResourceIds[0] + rec.ResourceId = req.ResourceIds[0] + // gen key + key = rec.Key(ctx) + // get the record recs, err = store.Read(key, store.ReadLimit(1)) } else { + // create a key prefix + key := rec.Key(ctx) // otherwise read the prefix recs, err = store.Read(key, store.ReadPrefix()) }