mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-13 19:45:26 +00:00
Multitenant streams api (#72)
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
@@ -11,6 +14,7 @@ import (
|
||||
var (
|
||||
TokenTTL = time.Minute
|
||||
ErrMissingTopic = errors.BadRequest("MISSING_TOPIC", "Missing topic")
|
||||
ErrInvalidTopic = errors.BadRequest("MISSING_TOPIC", "Invalid topic")
|
||||
ErrMissingToken = errors.BadRequest("MISSING_TOKEN", "Missing token")
|
||||
ErrMissingMessage = errors.BadRequest("MISSING_MESSAGE", "Missing message")
|
||||
ErrInvalidToken = errors.Forbidden("INVALID_TOKEN", "Invalid token")
|
||||
@@ -22,6 +26,7 @@ type Token struct {
|
||||
Token string `gorm:"primaryKey"`
|
||||
Topic string
|
||||
ExpiresAt time.Time
|
||||
Namespace string
|
||||
}
|
||||
|
||||
type Streams struct {
|
||||
@@ -29,3 +34,18 @@ type Streams struct {
|
||||
Events events.Stream
|
||||
Time func() time.Time
|
||||
}
|
||||
|
||||
// fmtTopic returns a topic string with namespace prefix and hyphens replaced with dots
|
||||
func fmtTopic(ns, topic string) string {
|
||||
// events topic names can only be alphanumeric and "."
|
||||
return fmt.Sprintf("%s.%s", strings.ReplaceAll(ns, "-", "."), topic)
|
||||
}
|
||||
|
||||
// validateTopicInput validates that topic is alphanumeric
|
||||
func validateTopicInput(topic string) error {
|
||||
reg := regexp.MustCompile("^[a-zA-Z0-9]+$")
|
||||
if len(reg.FindString(topic)) == 0 {
|
||||
return ErrInvalidTopic
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -12,7 +13,11 @@ import (
|
||||
|
||||
func testHandler(t *testing.T) *handler.Streams {
|
||||
// connect to the database
|
||||
db, err := gorm.Open(postgres.Open("postgresql://postgres@localhost:5432/postgres?sslmode=disable"), &gorm.Config{})
|
||||
addr := os.Getenv("POSTGRES_URL")
|
||||
if len(addr) == 0 {
|
||||
addr = "postgresql://postgres@localhost:5432/postgres?sslmode=disable"
|
||||
}
|
||||
db, err := gorm.Open(postgres.Open(addr), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error connecting to database: %v", err)
|
||||
}
|
||||
@@ -22,11 +27,6 @@ func testHandler(t *testing.T) *handler.Streams {
|
||||
t.Fatalf("Error migrating database: %v", err)
|
||||
}
|
||||
|
||||
// clean any data from a previous run
|
||||
if err := db.Exec("TRUNCATE TABLE tokens CASCADE").Error; err != nil {
|
||||
t.Fatalf("Error cleaning database: %v", err)
|
||||
}
|
||||
|
||||
return &handler.Streams{
|
||||
DB: db,
|
||||
Events: new(eventsMock),
|
||||
|
||||
@@ -3,6 +3,8 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
)
|
||||
@@ -12,11 +14,18 @@ func (s *Streams) Publish(ctx context.Context, req *pb.Message, rsp *pb.PublishR
|
||||
if len(req.Topic) == 0 {
|
||||
return ErrMissingTopic
|
||||
}
|
||||
if err := validateTopicInput(req.Topic); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(req.Message) == 0 {
|
||||
return ErrMissingMessage
|
||||
}
|
||||
acc, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
return errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
|
||||
// publish the message
|
||||
logger.Infof("Publishing message to topic: %v", req.Topic)
|
||||
return s.Events.Publish(req.Topic, req.Message)
|
||||
return s.Events.Publish(fmtTopic(acc.Issuer, req.Topic), req.Message)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/services/streams/handler"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -12,30 +14,34 @@ import (
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
msg := "{\"foo\":\"bar\"}"
|
||||
topic := uuid.New().String()
|
||||
topic := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
t.Run("MissingTopic", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
err := h.Publish(context.TODO(), &pb.Message{Message: msg}, &pb.PublishResponse{})
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Publish(ctx, &pb.Message{Message: msg}, &pb.PublishResponse{})
|
||||
assert.Equal(t, handler.ErrMissingTopic, err)
|
||||
assert.Zero(t, h.Events.(*eventsMock).PublishCount)
|
||||
})
|
||||
|
||||
t.Run("MissingMessage", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
err := h.Publish(context.TODO(), &pb.Message{Topic: topic}, &pb.PublishResponse{})
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Publish(ctx, &pb.Message{Topic: topic}, &pb.PublishResponse{})
|
||||
assert.Equal(t, handler.ErrMissingMessage, err)
|
||||
assert.Zero(t, h.Events.(*eventsMock).PublishCount)
|
||||
})
|
||||
|
||||
t.Run("ValidMessage", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
err := h.Publish(context.TODO(), &pb.Message{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Publish(ctx, &pb.Message{
|
||||
Topic: topic, Message: msg,
|
||||
}, &pb.PublishResponse{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, h.Events.(*eventsMock).PublishCount)
|
||||
assert.Equal(t, msg, h.Events.(*eventsMock).PublishMessage)
|
||||
assert.Equal(t, topic, h.Events.(*eventsMock).PublishTopic)
|
||||
// topic is prefixed with acc issuer to implement multitenancy
|
||||
assert.Equal(t, "foo."+topic, h.Events.(*eventsMock).PublishTopic)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/events"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
@@ -13,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, stream pb.Streams_SubscribeStream) error {
|
||||
logger.Infof("Recieved subscribe request. Topic: '%v', Token: '%v'", req.Topic, req.Token)
|
||||
logger.Infof("Received subscribe request. Topic: '%v', Token: '%v'", req.Topic, req.Token)
|
||||
|
||||
// validate the request
|
||||
if len(req.Token) == 0 {
|
||||
@@ -22,10 +23,18 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea
|
||||
if len(req.Topic) == 0 {
|
||||
return ErrMissingTopic
|
||||
}
|
||||
if err := validateTopicInput(req.Topic); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
acc, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
return errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
|
||||
// find the token and check to see if it has expired
|
||||
var token Token
|
||||
if err := s.DB.Where(&Token{Token: req.Token}).First(&token).Error; err == gorm.ErrRecordNotFound {
|
||||
if err := s.DB.Where(&Token{Token: req.Token, Namespace: acc.Issuer}).First(&token).Error; err == gorm.ErrRecordNotFound {
|
||||
return ErrInvalidToken
|
||||
} else if err != nil {
|
||||
logger.Errorf("Error reading token from store: %v", err)
|
||||
@@ -42,12 +51,11 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea
|
||||
|
||||
// start the subscription
|
||||
logger.Infof("Subscribing to %v via queue %v", req.Topic, token.Token)
|
||||
evChan, err := s.Events.Consume(req.Topic, events.WithGroup(token.Token))
|
||||
evChan, err := s.Events.Consume(fmtTopic(acc.Issuer, req.Topic), events.WithGroup(token.Token))
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to events stream: %v", err)
|
||||
return errors.InternalServerError("EVENTS_ERROR", "Error connecting to events stream")
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
for {
|
||||
msg, ok := <-evChan
|
||||
@@ -57,7 +65,7 @@ func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, strea
|
||||
|
||||
logger.Infof("Sending message to subscriber %v", token.Topic)
|
||||
pbMsg := &pb.Message{
|
||||
Topic: msg.Topic,
|
||||
Topic: req.Topic, // use req.Topic not msg.Topic because topic is munged for multitenancy
|
||||
Message: string(msg.Payload),
|
||||
SentAt: timestamppb.New(msg.Timestamp),
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/micro/v3/service/events"
|
||||
"github.com/micro/services/streams/handler"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
@@ -17,7 +19,8 @@ func TestSubscribe(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
s := new(streamMock)
|
||||
|
||||
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
}, s)
|
||||
|
||||
@@ -29,7 +32,8 @@ func TestSubscribe(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
s := new(streamMock)
|
||||
|
||||
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Token: uuid.New().String(),
|
||||
}, s)
|
||||
|
||||
@@ -41,7 +45,8 @@ func TestSubscribe(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
s := new(streamMock)
|
||||
|
||||
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: uuid.New().String(),
|
||||
}, s)
|
||||
@@ -54,7 +59,8 @@ func TestSubscribe(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{
|
||||
Topic: "helloworld",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
@@ -62,7 +68,7 @@ func TestSubscribe(t *testing.T) {
|
||||
ct := h.Time()
|
||||
h.Time = func() time.Time { return ct.Add(handler.TokenTTL * 2) }
|
||||
s := new(streamMock)
|
||||
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
err = h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
@@ -75,13 +81,14 @@ func TestSubscribe(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{
|
||||
Topic: "helloworldx",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := new(streamMock)
|
||||
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
err = h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
@@ -91,23 +98,33 @@ func TestSubscribe(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
defer func() {
|
||||
if i := recover(); i != nil {
|
||||
t.Logf("%+v", i)
|
||||
}
|
||||
}()
|
||||
h := testHandler(t)
|
||||
c := make(chan events.Event)
|
||||
h.Events.(*eventsMock).ConsumeChan = c
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{
|
||||
Topic: "helloworld",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := &streamMock{Messages: []*pb.Message{}}
|
||||
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "helloworld", h.Events.(*eventsMock).ConsumeTopic)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
var subsErr error
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
subsErr = h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
}()
|
||||
|
||||
e1 := events.Event{
|
||||
ID: uuid.New().String(),
|
||||
@@ -138,6 +155,13 @@ func TestSubscribe(t *testing.T) {
|
||||
t.Log("Event2 consumed")
|
||||
}
|
||||
|
||||
close(c)
|
||||
wg.Wait()
|
||||
assert.NoError(t, subsErr)
|
||||
assert.Equal(t, "foo.helloworld", h.Events.(*eventsMock).ConsumeTopic)
|
||||
|
||||
// sleep to wait for the subscribe loop to push the message to the stream
|
||||
//time.Sleep(1 * time.Second)
|
||||
if len(s.Messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %v", len(s.Messages))
|
||||
return
|
||||
@@ -151,6 +175,45 @@ func TestSubscribe(t *testing.T) {
|
||||
assert.Equal(t, string(e2.Payload), s.Messages[1].Message)
|
||||
assert.True(t, e2.Timestamp.Equal(s.Messages[1].SentAt.AsTime()))
|
||||
})
|
||||
|
||||
t.Run("TokenForDifferentIssuer", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{
|
||||
Topic: "tokfordiff",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := new(streamMock)
|
||||
ctx = auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "bar"})
|
||||
err = h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "tokfordiff",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
assert.Equal(t, handler.ErrInvalidToken, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
t.Run("BadTopic", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := new(streamMock)
|
||||
ctx = auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "bar"})
|
||||
err = h.Subscribe(ctx, &pb.SubscribeRequest{
|
||||
Topic: "tok-for-diff",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
assert.Equal(t, handler.ErrInvalidTopic, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
type streamMock struct {
|
||||
|
||||
@@ -4,17 +4,29 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
)
|
||||
|
||||
func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.TokenResponse) error {
|
||||
acc, ok := auth.AccountFromContext(ctx)
|
||||
if !ok {
|
||||
return errors.Unauthorized("UNAUTHORIZED", "Unauthorized")
|
||||
}
|
||||
if len(req.Topic) > 0 {
|
||||
if err := validateTopicInput(req.Topic); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// construct the token and write it to the database
|
||||
t := Token{
|
||||
Token: uuid.New().String(),
|
||||
ExpiresAt: s.Time().Add(TokenTTL),
|
||||
Topic: req.Topic,
|
||||
Namespace: acc.Issuer,
|
||||
}
|
||||
if err := s.DB.Create(&t).Error; err != nil {
|
||||
logger.Errorf("Error creating token in store: %v", err)
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/micro/v3/service/auth"
|
||||
"github.com/micro/services/streams/handler"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@@ -13,15 +15,24 @@ func TestToken(t *testing.T) {
|
||||
|
||||
t.Run("WithoutTopic", func(t *testing.T) {
|
||||
var rsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{}, &rsp)
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, rsp.Token)
|
||||
})
|
||||
|
||||
t.Run("WithTopic", func(t *testing.T) {
|
||||
var rsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{Topic: "helloworld"}, &rsp)
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{Topic: "helloworld"}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, rsp.Token)
|
||||
})
|
||||
|
||||
t.Run("WithBadTopic", func(t *testing.T) {
|
||||
var rsp pb.TokenResponse
|
||||
ctx := auth.ContextWithAccount(context.TODO(), &auth.Account{Issuer: "foo"})
|
||||
err := h.Token(ctx, &pb.TokenRequest{Topic: "helloworld-1"}, &rsp)
|
||||
assert.Equal(t, handler.ErrInvalidTopic, err)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user