mirror of
https://github.com/kevin-DL/services.git
synced 2026-01-16 04:54:42 +00:00
Streams (#68)
This commit is contained in:
31
streams/handler/handler.go
Normal file
31
streams/handler/handler.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/events"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
TokenTTL = time.Minute
|
||||
ErrMissingTopic = errors.BadRequest("MISSING_TOPIC", "Missing topic")
|
||||
ErrMissingToken = errors.BadRequest("MISSING_TOKEN", "Missing token")
|
||||
ErrMissingMessage = errors.BadRequest("MISSING_MESSAGE", "Missing message")
|
||||
ErrInvalidToken = errors.Forbidden("INVALID_TOKEN", "Invalid token")
|
||||
ErrExpiredToken = errors.Forbidden("EXPIRED_TOKEN", "Token expired")
|
||||
ErrForbiddenTopic = errors.Forbidden("FORBIDDEN_TOPIC", "Token has not have permission to subscribe to this topic")
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Token string `gorm:"primaryKey"`
|
||||
Topic string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type Streams struct {
|
||||
DB *gorm.DB
|
||||
Events events.Stream
|
||||
Time func() time.Time
|
||||
}
|
||||
58
streams/handler/handler_test.go
Normal file
58
streams/handler/handler_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/micro/micro/v3/service/events"
|
||||
"github.com/micro/services/streams/handler"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func testHandler(t *testing.T) *handler.Streams {
|
||||
// connect to the database
|
||||
db, err := gorm.Open(postgres.Open("postgresql://postgres@localhost:5432/postgres?sslmode=disable"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("Error connecting to database: %v", err)
|
||||
}
|
||||
|
||||
// migrate the database
|
||||
if err := db.AutoMigrate(&handler.Token{}); err != nil {
|
||||
t.Fatalf("Error migrating database: %v", err)
|
||||
}
|
||||
|
||||
// clean any data from a previous run
|
||||
if err := db.Exec("TRUNCATE TABLE tokens CASCADE").Error; err != nil {
|
||||
t.Fatalf("Error cleaning database: %v", err)
|
||||
}
|
||||
|
||||
return &handler.Streams{
|
||||
DB: db,
|
||||
Events: new(eventsMock),
|
||||
Time: func() time.Time {
|
||||
return time.Unix(1612787045, 0)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type eventsMock struct {
|
||||
PublishCount int
|
||||
PublishTopic string
|
||||
PublishMessage interface{}
|
||||
|
||||
ConsumeTopic string
|
||||
ConsumeChan <-chan events.Event
|
||||
}
|
||||
|
||||
func (e *eventsMock) Publish(topic string, msg interface{}, opts ...events.PublishOption) error {
|
||||
e.PublishCount++
|
||||
e.PublishTopic = topic
|
||||
e.PublishMessage = msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *eventsMock) Consume(topic string, opts ...events.ConsumeOption) (<-chan events.Event, error) {
|
||||
e.ConsumeTopic = topic
|
||||
return e.ConsumeChan, nil
|
||||
}
|
||||
20
streams/handler/publish.go
Normal file
20
streams/handler/publish.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
)
|
||||
|
||||
func (s *Streams) Publish(ctx context.Context, req *pb.Message, rsp *pb.PublishResponse) error {
|
||||
// validate the request
|
||||
if len(req.Topic) == 0 {
|
||||
return ErrMissingTopic
|
||||
}
|
||||
if len(req.Message) == 0 {
|
||||
return ErrMissingMessage
|
||||
}
|
||||
|
||||
// publish the message
|
||||
return s.Events.Publish(req.Topic, req.Message)
|
||||
}
|
||||
41
streams/handler/publish_test.go
Normal file
41
streams/handler/publish_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/services/streams/handler"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
msg := "{\"foo\":\"bar\"}"
|
||||
topic := uuid.New().String()
|
||||
|
||||
t.Run("MissingTopic", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
err := h.Publish(context.TODO(), &pb.Message{Message: msg}, &pb.PublishResponse{})
|
||||
assert.Equal(t, handler.ErrMissingTopic, err)
|
||||
assert.Zero(t, h.Events.(*eventsMock).PublishCount)
|
||||
})
|
||||
|
||||
t.Run("MissingMessage", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
err := h.Publish(context.TODO(), &pb.Message{Topic: topic}, &pb.PublishResponse{})
|
||||
assert.Equal(t, handler.ErrMissingMessage, err)
|
||||
assert.Zero(t, h.Events.(*eventsMock).PublishCount)
|
||||
})
|
||||
|
||||
t.Run("ValidMessage", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
err := h.Publish(context.TODO(), &pb.Message{
|
||||
Topic: topic, Message: msg,
|
||||
}, &pb.PublishResponse{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, h.Events.(*eventsMock).PublishCount)
|
||||
assert.Equal(t, msg, h.Events.(*eventsMock).PublishMessage)
|
||||
assert.Equal(t, topic, h.Events.(*eventsMock).PublishTopic)
|
||||
})
|
||||
}
|
||||
64
streams/handler/subscribe.go
Normal file
64
streams/handler/subscribe.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/events"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (s *Streams) Subscribe(ctx context.Context, req *pb.SubscribeRequest, stream pb.Streams_SubscribeStream) error {
|
||||
// validate the request
|
||||
if len(req.Token) == 0 {
|
||||
return ErrMissingToken
|
||||
}
|
||||
if len(req.Topic) == 0 {
|
||||
return ErrMissingTopic
|
||||
}
|
||||
|
||||
// find the token and check to see if it has expired
|
||||
var token Token
|
||||
if err := s.DB.Where(&Token{Token: req.Token}).First(&token).Error; err == gorm.ErrRecordNotFound {
|
||||
return ErrInvalidToken
|
||||
} else if err != nil {
|
||||
logger.Errorf("Error reading token from store: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error reading token from database")
|
||||
}
|
||||
if token.ExpiresAt.Before(s.Time()) {
|
||||
return ErrExpiredToken
|
||||
}
|
||||
|
||||
// if the token was scoped to a channel, ensure the channel is the one being requested
|
||||
if len(token.Topic) > 0 && token.Topic != req.Topic {
|
||||
return ErrForbiddenTopic
|
||||
}
|
||||
|
||||
// start the subscription
|
||||
evChan, err := s.Events.Consume(req.Topic, events.WithGroup(token.Token))
|
||||
if err != nil {
|
||||
logger.Errorf("Error connecting to events stream: %v", err)
|
||||
return errors.InternalServerError("EVENTS_ERROR", "Error connecting to events stream")
|
||||
}
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
for {
|
||||
msg, ok := <-evChan
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := stream.Send(&pb.Message{
|
||||
Topic: msg.Topic,
|
||||
Message: string(msg.Payload),
|
||||
SentAt: timestamppb.New(msg.Timestamp),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
164
streams/handler/subscribe_test.go
Normal file
164
streams/handler/subscribe_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/events"
|
||||
"github.com/micro/services/streams/handler"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSubscribe(t *testing.T) {
|
||||
t.Run("MissingToken", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
s := new(streamMock)
|
||||
|
||||
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
}, s)
|
||||
|
||||
assert.Equal(t, handler.ErrMissingToken, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
t.Run("MissingTopic", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
s := new(streamMock)
|
||||
|
||||
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Token: uuid.New().String(),
|
||||
}, s)
|
||||
|
||||
assert.Equal(t, handler.ErrMissingTopic, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
t.Run("InvalidToken", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
s := new(streamMock)
|
||||
|
||||
err := h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: uuid.New().String(),
|
||||
}, s)
|
||||
|
||||
assert.Equal(t, handler.ErrInvalidToken, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
t.Run("ExpiredToken", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{
|
||||
Topic: "helloworld",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ct := h.Time()
|
||||
h.Time = func() time.Time { return ct.Add(handler.TokenTTL * 2) }
|
||||
s := new(streamMock)
|
||||
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
|
||||
assert.Equal(t, handler.ErrExpiredToken, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
t.Run("ForbiddenTopic", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{
|
||||
Topic: "helloworldx",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := new(streamMock)
|
||||
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
|
||||
assert.Equal(t, handler.ErrForbiddenTopic, err)
|
||||
assert.Empty(t, s.Messages)
|
||||
})
|
||||
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
c := make(chan events.Event)
|
||||
h.Events.(*eventsMock).ConsumeChan = c
|
||||
|
||||
var tRsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{
|
||||
Topic: "helloworld",
|
||||
}, &tRsp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := &streamMock{Messages: []*pb.Message{}}
|
||||
err = h.Subscribe(context.TODO(), &pb.SubscribeRequest{
|
||||
Topic: "helloworld",
|
||||
Token: tRsp.Token,
|
||||
}, s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "helloworld", h.Events.(*eventsMock).ConsumeTopic)
|
||||
|
||||
e1 := events.Event{
|
||||
ID: uuid.New().String(),
|
||||
Topic: "helloworld",
|
||||
Timestamp: h.Time().Add(time.Second * -2),
|
||||
Payload: []byte("abc"),
|
||||
}
|
||||
e2 := events.Event{
|
||||
ID: uuid.New().String(),
|
||||
Topic: "helloworld",
|
||||
Timestamp: h.Time().Add(time.Second * -1),
|
||||
Payload: []byte("123"),
|
||||
}
|
||||
|
||||
timeout := time.NewTimer(time.Millisecond * 100).C
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Events not consumed from stream")
|
||||
return
|
||||
case c <- e1:
|
||||
t.Log("Event1 consumed")
|
||||
}
|
||||
select {
|
||||
case <-timeout:
|
||||
t.Fatal("Events not consumed from stream")
|
||||
return
|
||||
case c <- e2:
|
||||
t.Log("Event2 consumed")
|
||||
}
|
||||
|
||||
if len(s.Messages) != 2 {
|
||||
t.Fatalf("Expected 2 messages, got %v", len(s.Messages))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, e1.Topic, s.Messages[0].Topic)
|
||||
assert.Equal(t, string(e1.Payload), s.Messages[0].Message)
|
||||
assert.True(t, e1.Timestamp.Equal(s.Messages[0].SentAt.AsTime()))
|
||||
|
||||
assert.Equal(t, e2.Topic, s.Messages[1].Topic)
|
||||
assert.Equal(t, string(e2.Payload), s.Messages[1].Message)
|
||||
assert.True(t, e2.Timestamp.Equal(s.Messages[1].SentAt.AsTime()))
|
||||
})
|
||||
}
|
||||
|
||||
type streamMock struct {
|
||||
Messages []*pb.Message
|
||||
pb.Streams_SubscribeStream
|
||||
}
|
||||
|
||||
func (x *streamMock) Send(m *pb.Message) error {
|
||||
x.Messages = append(x.Messages, m)
|
||||
return nil
|
||||
}
|
||||
27
streams/handler/token.go
Normal file
27
streams/handler/token.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/micro/micro/v3/service/errors"
|
||||
"github.com/micro/micro/v3/service/logger"
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
)
|
||||
|
||||
func (s *Streams) Token(ctx context.Context, req *pb.TokenRequest, rsp *pb.TokenResponse) error {
|
||||
// construct the token and write it to the database
|
||||
t := Token{
|
||||
Token: uuid.New().String(),
|
||||
ExpiresAt: s.Time().Add(TokenTTL),
|
||||
Topic: req.Topic,
|
||||
}
|
||||
if err := s.DB.Create(&t).Error; err != nil {
|
||||
logger.Errorf("Error creating token in store: %v", err)
|
||||
return errors.InternalServerError("DATABASE_ERROR", "Error writing token to database")
|
||||
}
|
||||
|
||||
// return the token in the response
|
||||
rsp.Token = t.Token
|
||||
return nil
|
||||
}
|
||||
27
streams/handler/token_test.go
Normal file
27
streams/handler/token_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
pb "github.com/micro/services/streams/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestToken(t *testing.T) {
|
||||
h := testHandler(t)
|
||||
|
||||
t.Run("WithoutTopic", func(t *testing.T) {
|
||||
var rsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, rsp.Token)
|
||||
})
|
||||
|
||||
t.Run("WithTopic", func(t *testing.T) {
|
||||
var rsp pb.TokenResponse
|
||||
err := h.Token(context.TODO(), &pb.TokenRequest{Topic: "helloworld"}, &rsp)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, rsp.Token)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user