This commit is contained in:
ben-toogood
2021-02-08 14:42:47 +00:00
committed by GitHub
parent f4ad5d6bd2
commit 6e49584049
16 changed files with 1202 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
package handler
import (
"time"
"github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/events"
"gorm.io/gorm"
)
var (
TokenTTL = time.Minute
ErrMissingTopic = errors.BadRequest("MISSING_TOPIC", "Missing topic")
ErrMissingToken = errors.BadRequest("MISSING_TOKEN", "Missing token")
ErrMissingMessage = errors.BadRequest("MISSING_MESSAGE", "Missing message")
ErrInvalidToken = errors.Forbidden("INVALID_TOKEN", "Invalid token")
ErrExpiredToken = errors.Forbidden("EXPIRED_TOKEN", "Token expired")
ErrForbiddenTopic = errors.Forbidden("FORBIDDEN_TOPIC", "Token has not have permission to subscribe to this topic")
)
type Token struct {
Token string `gorm:"primaryKey"`
Topic string
ExpiresAt time.Time
}
type Streams struct {
DB *gorm.DB
Events events.Stream
Time func() time.Time
}

View File

@@ -0,0 +1,58 @@
package handler_test
import (
"testing"
"time"
"github.com/micro/micro/v3/service/events"
"github.com/micro/services/streams/handler"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
func testHandler(t *testing.T) *handler.Streams {
// connect to the database
db, err := gorm.Open(postgres.Open("postgresql://postgres@localhost:5432/postgres?sslmode=disable"), &gorm.Config{})
if err != nil {
t.Fatalf("Error connecting to database: %v", err)
}
// migrate the database
if err := db.AutoMigrate(&handler.Token{}); err != nil {
t.Fatalf("Error migrating database: %v", err)
}
// clean any data from a previous run
if err := db.Exec("TRUNCATE TABLE tokens CASCADE").Error; err != nil {
t.Fatalf("Error cleaning database: %v", err)
}
return &handler.Streams{
DB: db,
Events: new(eventsMock),
Time: func() time.Time {
return time.Unix(1612787045, 0)
},
}
}
type eventsMock struct {
PublishCount int
PublishTopic string
PublishMessage interface{}
ConsumeTopic string
ConsumeChan <-chan events.Event
}
func (e *eventsMock) Publish(topic string, msg interface{}, opts ...events.PublishOption) error {
e.PublishCount++
e.PublishTopic = topic
e.PublishMessage = msg
return nil
}
func (e *eventsMock) Consume(topic string, opts ...events.ConsumeOption) (<-chan events.Event, error) {
e.ConsumeTopic = topic
return e.ConsumeChan, nil
}

View File

@@ -0,0 +1,20 @@
package handler
import (
"context"
pb "github.com/micro/services/streams/proto"
)
func (s *Streams) Publish(ctx context.Context, req *pb.Message, rsp *pb.PublishResponse) error {
// validate the request
if len(req.Topic) == 0 {
return ErrMissingTopic
}
if len(req.Message) == 0 {
return ErrMissingMessage
}
// publish the message
return s.Events.Publish(req.Topic, req.Message)
}

View File

@@ -0,0 +1,41 @@
package handler_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/micro/services/streams/handler"
pb "github.com/micro/services/streams/proto"
"github.com/stretchr/testify/assert"
)
func TestPublish(t *testing.T) {
msg := "{\"foo\":\"bar\"}"
topic := uuid.New().String()
t.Run("MissingTopic", func(t *testing.T) {
h := testHandler(t)
err := h.Publish(context.TODO(), &pb.Message{Message: msg}, &pb.PublishResponse{})
assert.Equal(t, handler.ErrMissingTopic, err)
assert.Zero(t, h.Events.(*eventsMock).PublishCount)
})
t.Run("MissingMessage", func(t *testing.T) {
h := testHandler(t)
err := h.Publish(context.TODO(), &pb.Message{Topic: topic}, &pb.PublishResponse{})
assert.Equal(t, handler.ErrMissingMessage, err)
assert.Zero(t, h.Events.(*eventsMock).PublishCount)
})
t.Run("ValidMessage", func(t *testing.T) {
h := testHandler(t)
err := h.Publish(context.TODO(), &pb.Message{
Topic: topic, Message: msg,
}, &pb.PublishResponse{})
assert.NoError(t, err)
assert.Equal(t, 1, h.Events.(*eventsMock).PublishCount)
assert.Equal(t, msg, h.Events.(*eventsMock).PublishMessage)
assert.Equal(t, topic, h.Events.(*eventsMock).PublishTopic)
})
}

View File

@@ -0,0 +1,64 @@
package handler
import (
"context"
"github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/events"
"github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/streams/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
)
func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, stream pb.Streams_SubscribeStream) error {
// validate the request
if len(req.Token) == 0 {
return ErrMissingToken
}
if len(req.Topic) == 0 {
return ErrMissingTopic
}
// find the token and check to see if it has expired
var token Token
if err := s.DB.Where(&Token{Token: req.Token}).First(&token).Error; err == gorm.ErrRecordNotFound {
return ErrInvalidToken
} else if err != nil {
logger.Errorf("Error reading token from store: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error reading token from database")
}
if token.ExpiresAt.Before(s.Time()) {
return ErrExpiredToken
}
// if the token was scoped to a channel, ensure the channel is the one being requested
if len(token.Topic) > 0 && token.Topic != req.Topic {
return ErrForbiddenTopic
}
// start the subscription
evChan, err := s.Events.Consume(req.Topic, events.WithGroup(token.Token))
if err != nil {
logger.Errorf("Error connecting to events stream: %v", err)
return errors.InternalServerError("EVENTS_ERROR", "Error connecting to events stream")
}
go func() {
defer stream.Close()
for {
msg, ok := <-evChan
if !ok {
return
}
if err := stream.Send(&pb.Message{
Topic: msg.Topic,
Message: string(msg.Payload),
SentAt: timestamppb.New(msg.Timestamp),
}); err != nil {
return
}
}
}()
return nil
}

View File

@@ -0,0 +1,164 @@
package handler_test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/micro/micro/v3/service/events"
"github.com/micro/services/streams/handler"
pb "github.com/micro/services/streams/proto"
"github.com/stretchr/testify/assert"
)
func TestSubscribe(t *testing.T) {
t.Run("MissingToken", func(t *testing.T) {
h := testHandler(t)
s := new(streamMock)
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
Topic: "helloworld",
}, s)
assert.Equal(t, handler.ErrMissingToken, err)
assert.Empty(t, s.Messages)
})
t.Run("MissingTopic", func(t *testing.T) {
h := testHandler(t)
s := new(streamMock)
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
Token: uuid.New().String(),
}, s)
assert.Equal(t, handler.ErrMissingTopic, err)
assert.Empty(t, s.Messages)
})
t.Run("InvalidToken", func(t *testing.T) {
h := testHandler(t)
s := new(streamMock)
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
Topic: "helloworld",
Token: uuid.New().String(),
}, s)
assert.Equal(t, handler.ErrInvalidToken, err)
assert.Empty(t, s.Messages)
})
t.Run("ExpiredToken", func(t *testing.T) {
h := testHandler(t)
var tRsp pb.TokenResponse
err := h.Token(context.TODO(), &pb.TokenRequest{
Topic: "helloworld",
}, &tRsp)
assert.NoError(t, err)
ct := h.Time()
h.Time = func() time.Time { return ct.Add(handler.TokenTTL * 2) }
s := new(streamMock)
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
Topic: "helloworld",
Token: tRsp.Token,
}, s)
assert.Equal(t, handler.ErrExpiredToken, err)
assert.Empty(t, s.Messages)
})
t.Run("ForbiddenTopic", func(t *testing.T) {
h := testHandler(t)
var tRsp pb.TokenResponse
err := h.Token(context.TODO(), &pb.TokenRequest{
Topic: "helloworldx",
}, &tRsp)
assert.NoError(t, err)
s := new(streamMock)
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
Topic: "helloworld",
Token: tRsp.Token,
}, s)
assert.Equal(t, handler.ErrForbiddenTopic, err)
assert.Empty(t, s.Messages)
})
t.Run("Valid", func(t *testing.T) {
h := testHandler(t)
c := make(chan events.Event)
h.Events.(*eventsMock).ConsumeChan = c
var tRsp pb.TokenResponse
err := h.Token(context.TODO(), &pb.TokenRequest{
Topic: "helloworld",
}, &tRsp)
assert.NoError(t, err)
s := &streamMock{Messages: []*pb.Message{}}
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
Topic: "helloworld",
Token: tRsp.Token,
}, s)
assert.NoError(t, err)
assert.Equal(t, "helloworld", h.Events.(*eventsMock).ConsumeTopic)
e1 := events.Event{
ID: uuid.New().String(),
Topic: "helloworld",
Timestamp: h.Time().Add(time.Second * -2),
Payload: []byte("abc"),
}
e2 := events.Event{
ID: uuid.New().String(),
Topic: "helloworld",
Timestamp: h.Time().Add(time.Second * -1),
Payload: []byte("123"),
}
timeout := time.NewTimer(time.Millisecond * 100).C
select {
case <-timeout:
t.Fatal("Events not consumed from stream")
return
case c <- e1:
t.Log("Event1 consumed")
}
select {
case <-timeout:
t.Fatal("Events not consumed from stream")
return
case c <- e2:
t.Log("Event2 consumed")
}
if len(s.Messages) != 2 {
t.Fatalf("Expected 2 messages, got %v", len(s.Messages))
return
}
assert.Equal(t, e1.Topic, s.Messages[0].Topic)
assert.Equal(t, string(e1.Payload), s.Messages[0].Message)
assert.True(t, e1.Timestamp.Equal(s.Messages[0].SentAt.AsTime()))
assert.Equal(t, e2.Topic, s.Messages[1].Topic)
assert.Equal(t, string(e2.Payload), s.Messages[1].Message)
assert.True(t, e2.Timestamp.Equal(s.Messages[1].SentAt.AsTime()))
})
}
type streamMock struct {
Messages []*pb.Message
pb.Streams_SubscribeStream
}
func (x *streamMock) Send(m *pb.Message) error {
x.Messages = append(x.Messages, m)
return nil
}

27
streams/handler/token.go Normal file
View File

@@ -0,0 +1,27 @@
package handler
import (
"context"
"github.com/google/uuid"
"github.com/micro/micro/v3/service/errors"
"github.com/micro/micro/v3/service/logger"
pb "github.com/micro/services/streams/proto"
)
func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.TokenResponse) error {
// construct the token and write it to the database
t := Token{
Token: uuid.New().String(),
ExpiresAt: s.Time().Add(TokenTTL),
Topic: req.Topic,
}
if err := s.DB.Create(&t).Error; err != nil {
logger.Errorf("Error creating token in store: %v", err)
return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database")
}
// return the token in the response
rsp.Token = t.Token
return nil
}

View File

@@ -0,0 +1,27 @@
package handler_test
import (
"context"
"testing"
pb "github.com/micro/services/streams/proto"
"github.com/stretchr/testify/assert"
)
func TestToken(t *testing.T) {
h := testHandler(t)
t.Run("WithoutTopic", func(t *testing.T) {
var rsp pb.TokenResponse
err := h.Token(context.TODO(), &pb.TokenRequest{}, &rsp)
assert.NoError(t, err)
assert.NotEmpty(t, rsp.Token)
})
t.Run("WithTopic", func(t *testing.T) {
var rsp pb.TokenResponse
err := h.Token(context.TODO(), &pb.TokenRequest{Topic: "helloworld"}, &rsp)
assert.NoError(t, err)
assert.NotEmpty(t, rsp.Token)
})
}