From 2f3e1bbf3c01971b8a176b8e750a189cbfa44aa1 Mon Sep 17 00:00:00 2001 From: Silas Davis <silas@monax.io> Date: Wed, 14 Feb 2018 23:38:51 +0000 Subject: [PATCH] Allow callbacks to close subscriptions. In particular we can stop trying to write to dropped websocket connections Signed-off-by: Silas Davis <silas@monax.io> --- event/cache_test.go | 24 +++++++++++++----------- event/convention.go | 16 ++++++++++++---- event/convention_test.go | 3 ++- execution/events/events.go | 3 ++- execution/evm/events/events.go | 6 ++++-- rpc/service.go | 14 ++++++++------ rpc/tm/methods.go | 20 ++++++++++++++------ rpc/tm/server.go | 2 +- rpc/v0/subscriptions.go | 3 ++- rpc/v0/websocket_service.go | 3 ++- 10 files changed, 60 insertions(+), 34 deletions(-) diff --git a/event/cache_test.go b/event/cache_test.go index b4957a71..877d0438 100644 --- a/event/cache_test.go +++ b/event/cache_test.go @@ -18,30 +18,32 @@ func TestEventCache_Flush(t *testing.T) { flushed := false em := NewEmitter(loggers.NewNoopInfoTraceLogger()) - SubscribeCallback(ctx, em, "nothingness", NewQueryBuilder(), func(message interface{}) { + SubscribeCallback(ctx, em, "nothingness", NewQueryBuilder(), func(message interface{}) bool { // Check against sending a buffer of zeroed messages if message == nil { errCh <- fmt.Errorf("recevied empty message but none sent") } + return false }) evc := NewEventCache(em) evc.Flush() // Check after reset evc.Flush() - SubscribeCallback(ctx, em, "somethingness", NewQueryBuilder().AndEquals("foo", "bar"), func(interface{}) { - if flushed { - errCh <- nil - } else { - errCh <- fmt.Errorf("callback was run before messages were flushed") - } - }) + SubscribeCallback(ctx, em, "somethingness", NewQueryBuilder().AndEquals("foo", "bar"), + func(interface{}) bool { + if flushed { + errCh <- nil + return true + } else { + errCh <- fmt.Errorf("callback was run before messages were flushed") + return false + } + }) numMessages := 3 tags := map[string]interface{}{"foo": "bar"} for i := 0; i < numMessages; i++ { - evc.Publish(ctx, "something", tags) - evc.Publish(ctx, "something", tags) - evc.Publish(ctx, "something", tags) + evc.Publish(ctx, fmt.Sprintf("something_%v", i), tags) } flushed = true evc.Flush() diff --git a/event/convention.go b/event/convention.go index 1fc69abf..16832d5f 100644 --- a/event/convention.go +++ b/event/convention.go @@ -39,16 +39,24 @@ func PublishWithEventID(publisher Publisher, eventID string, eventData interface // Subscribe to messages matching query and launch a goroutine to run a callback for each one. The goroutine will exit // when the context is done or the subscription is removed. func SubscribeCallback(ctx context.Context, subscribable Subscribable, subscriber string, query Queryable, - callback func(message interface{})) error { + callback func(message interface{}) bool) error { out := make(chan interface{}) go func() { for { msg, ok := <-out if !ok { + // Channel closed, no need to unsubscribe or drain + return + } + if !callback(msg) { + // Callback is requesting stop so unsubscribe and drain channel + subscribable.Unsubscribe(context.Background(), subscriber, query) + // Not draining channel can starve other subscribers + for range out { + } return } - callback(msg) } }() err := subscribable.Subscribe(ctx, subscriber, query, out) @@ -62,13 +70,13 @@ func SubscribeCallback(ctx context.Context, subscribable Subscribable, subscribe func PublishAll(ctx context.Context, subscribable Subscribable, subscriber string, query Queryable, publisher Publisher, extraTags map[string]interface{}) error { - return SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) { + return SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) bool { tags := make(map[string]interface{}) for k, v := range extraTags { tags[k] = v } - // Help! I can't tell which tags the original publisher used - so I can't forward them on publisher.Publish(ctx, message, tags) + return true }) } diff --git a/event/convention_test.go b/event/convention_test.go index 71e4a718..31e9571b 100644 --- a/event/convention_test.go +++ b/event/convention_test.go @@ -13,8 +13,9 @@ func TestSubscribeCallback(t *testing.T) { ctx := context.Background() em := NewEmitter(loggers.NewNoopInfoTraceLogger()) ch := make(chan interface{}) - SubscribeCallback(ctx, em, "TestSubscribeCallback", MatchAllQueryable(), func(msg interface{}) { + SubscribeCallback(ctx, em, "TestSubscribeCallback", MatchAllQueryable(), func(msg interface{}) bool { ch <- msg + return true }) sent := "FROTHY" diff --git a/execution/events/events.go b/execution/events/events.go index 0ee74ea4..63577bd8 100644 --- a/execution/events/events.go +++ b/execution/events/events.go @@ -67,12 +67,13 @@ func SubscribeAccountOutputSendTx(ctx context.Context, subscribable event.Subscr query := sendTxQuery.And(event.QueryForEventID(EventStringAccountOutput(address))). AndEquals(event.TxHashKey, hex.EncodeUpperToString(txHash)) - return event.SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) { + return event.SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) bool { if eventDataCall, ok := message.(*EventDataTx); ok { if sendTx, ok := eventDataCall.Tx.(*txs.SendTx); ok { ch <- sendTx } } + return true }) } diff --git a/execution/evm/events/events.go b/execution/evm/events/events.go index f018528e..64c8ec44 100644 --- a/execution/evm/events/events.go +++ b/execution/evm/events/events.go @@ -68,11 +68,12 @@ func SubscribeAccountCall(ctx context.Context, subscribable event.Subscribable, query = query.AndEquals(event.TxHashKey, hex.EncodeUpperToString(txHash)) } - return event.SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) { + return event.SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) bool { eventDataCall, ok := message.(*EventDataCall) if ok { ch <- eventDataCall } + return true }) } @@ -81,11 +82,12 @@ func SubscribeLogEvent(ctx context.Context, subscribable event.Subscribable, sub query := event.QueryForEventID(EventStringLogEvent(address)) - return event.SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) { + return event.SubscribeCallback(ctx, subscribable, subscriber, query, func(message interface{}) bool { eventDataLog, ok := message.(*EventDataLog) if ok { ch <- eventDataLog } + return true }) } diff --git a/rpc/service.go b/rpc/service.go index cdf139c7..a9001d51 100644 --- a/rpc/service.go +++ b/rpc/service.go @@ -38,7 +38,7 @@ const MaxBlockLookback = 100 type SubscribableService interface { // Events - Subscribe(ctx context.Context, subscriptionID string, eventID string, callback func(*ResultEvent)) error + Subscribe(ctx context.Context, subscriptionID string, eventID string, callback func(*ResultEvent) bool) error Unsubscribe(ctx context.Context, subscriptionID string) error } @@ -134,22 +134,24 @@ func (s *service) ListUnconfirmedTxs(maxTxs int) (*ResultListUnconfirmedTxs, err } func (s *service) Subscribe(ctx context.Context, subscriptionID string, eventID string, - callback func(resultEvent *ResultEvent)) error { + callback func(resultEvent *ResultEvent) bool) error { queryBuilder := event.QueryForEventID(eventID) logging.InfoMsg(s.logger, "Subscribing to events", "query", queryBuilder.String(), - "subscription_id", subscriptionID) + "subscription_id", subscriptionID, + "event_id", eventID) return event.SubscribeCallback(ctx, s.subscribable, subscriptionID, queryBuilder, - func(message interface{}) { + func(message interface{}) bool { resultEvent, err := NewResultEvent(eventID, message) if err != nil { logging.InfoMsg(s.logger, "Received event that could not be mapped to ResultEvent", structure.ErrorKey, err, + "subscription_id", subscriptionID, "event_id", eventID) - return + return true } - callback(resultEvent) + return callback(resultEvent) }) } diff --git a/rpc/tm/methods.go b/rpc/tm/methods.go index 2686f35a..7da5d22b 100644 --- a/rpc/tm/methods.go +++ b/rpc/tm/methods.go @@ -1,14 +1,15 @@ package tm import ( - "fmt" - "context" + "fmt" "time" acm "github.com/hyperledger/burrow/account" "github.com/hyperledger/burrow/event" "github.com/hyperledger/burrow/execution" + "github.com/hyperledger/burrow/logging" + logging_types "github.com/hyperledger/burrow/logging/types" "github.com/hyperledger/burrow/rpc" "github.com/hyperledger/burrow/txs" gorpc "github.com/tendermint/tendermint/rpc/lib/server" @@ -57,7 +58,8 @@ const ( const SubscriptionTimeoutSeconds = 5 * time.Second -func GetRoutes(service rpc.Service) map[string]*gorpc.RPCFunc { +func GetRoutes(service rpc.Service, logger logging_types.InfoTraceLogger) map[string]*gorpc.RPCFunc { + logger = logging.WithScope(logger, "GetRoutes") return map[string]*gorpc.RPCFunc{ // Transact BroadcastTx: gorpc.NewRPCFunc(func(tx txs.Wrapper) (*rpc.ResultBroadcastTx, error) { @@ -101,9 +103,15 @@ func GetRoutes(service rpc.Service) map[string]*gorpc.RPCFunc { } ctx, cancel := context.WithTimeout(context.Background(), SubscriptionTimeoutSeconds*time.Second) defer cancel() - err = service.Subscribe(ctx, subscriptionID, eventID, func(resultEvent *rpc.ResultEvent) { - wsCtx.TryWriteRPCResponse(rpctypes.NewRPCSuccessResponse(EventResponseID(wsCtx.Request.ID, eventID), - resultEvent)) + err = service.Subscribe(ctx, subscriptionID, eventID, func(resultEvent *rpc.ResultEvent) bool { + keepAlive := wsCtx.TryWriteRPCResponse(rpctypes.NewRPCSuccessResponse( + EventResponseID(wsCtx.Request.ID, eventID), resultEvent)) + if !keepAlive { + logging.InfoMsg(logger, "dropping subscription because could not write to websocket", + "subscription_id", subscriptionID, + "event_id", eventID) + } + return keepAlive }) if err != nil { return nil, err diff --git a/rpc/tm/server.go b/rpc/tm/server.go index 13193f34..25f5f5cd 100644 --- a/rpc/tm/server.go +++ b/rpc/tm/server.go @@ -30,7 +30,7 @@ func StartServer(service rpc.Service, pattern, listenAddress string, emitter eve logger logging_types.InfoTraceLogger) (net.Listener, error) { logger = logger.With(structure.ComponentKey, "RPC_TM") - routes := GetRoutes(service) + routes := GetRoutes(service, logger) mux := http.NewServeMux() wm := rpcserver.NewWebsocketManager(routes, rpcserver.EventSubscriber(tendermint.SubscribableAsEventBus(emitter))) mux.HandleFunc(pattern, wm.WebsocketHandler) diff --git a/rpc/v0/subscriptions.go b/rpc/v0/subscriptions.go index 80a3d372..dbfab6f7 100644 --- a/rpc/v0/subscriptions.go +++ b/rpc/v0/subscriptions.go @@ -106,10 +106,11 @@ func (subs *Subscriptions) Add(eventId string) (string, error) { return "", err } cache := newSubscriptionsCache() - err = subs.service.Subscribe(context.Background(), subId, eventId, func(resultEvent *rpc.ResultEvent) { + err = subs.service.Subscribe(context.Background(), subId, eventId, func(resultEvent *rpc.ResultEvent) bool { cache.mtx.Lock() defer cache.mtx.Unlock() cache.events = append(cache.events, resultEvent) + return true }) if err != nil { return "", err diff --git a/rpc/v0/websocket_service.go b/rpc/v0/websocket_service.go index 0e0e036f..7917d84c 100644 --- a/rpc/v0/websocket_service.go +++ b/rpc/v0/websocket_service.go @@ -125,8 +125,9 @@ func (ws *WebsocketService) EventSubscribe(request *rpc.RPCRequest, return nil, rpc.INTERNAL_ERROR, err } - err = ws.service.Subscribe(context.Background(), subId, eventId, func(resultEvent *rpc.ResultEvent) { + err = ws.service.Subscribe(context.Background(), subId, eventId, func(resultEvent *rpc.ResultEvent) bool { ws.writeResponse(subId, resultEvent, session) + return true }) if err != nil { return nil, rpc.INTERNAL_ERROR, err -- GitLab