diff --git a/packages/go-kosu/abci/client.go b/packages/go-kosu/abci/client.go index 268082d7..570cf455 100644 --- a/packages/go-kosu/abci/client.go +++ b/packages/go-kosu/abci/client.go @@ -101,18 +101,10 @@ func (c *Client) Subscribe(ctx context.Context, q string) (<-chan rpctypes.Resul return nil, nil, err } - closer := func() { - _ = c.Client.Unsubscribe(ctx, "kosu", q) - } - + closer := func() { _ = c.Client.Unsubscribe(ctx, "kosu", q) } return ch, closer, nil } -// Unsubscribe unsubscribes given subscriber from query. -func (c *Client) Unsubscribe(ctx context.Context, query string) error { - return c.Client.Unsubscribe(ctx, "kosu", query) -} - // QueryRoundInfo performs a ABCIQuery to "/roundinfo" func (c *Client) QueryRoundInfo() (*types.RoundInfo, error) { var pb types.RoundInfo diff --git a/packages/go-kosu/rpc/cmd.go b/packages/go-kosu/rpc/cmd.go index a1a4aa83..580bbef9 100644 --- a/packages/go-kosu/rpc/cmd.go +++ b/packages/go-kosu/rpc/cmd.go @@ -51,11 +51,11 @@ func NewCommand() *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - client, err := abci.NewHTTPClient(url, key) + fn := func() (*abci.Client, error) { return abci.NewHTTPClient(url, key) } + srv, err := NewServer(fn) if err != nil { return err } - srv := NewServer(client) wg := sync.WaitGroup{} if http { diff --git a/packages/go-kosu/rpc/rpc.go b/packages/go-kosu/rpc/rpc.go index 7972de80..981dea11 100644 --- a/packages/go-kosu/rpc/rpc.go +++ b/packages/go-kosu/rpc/rpc.go @@ -11,13 +11,23 @@ const ( Version = "1.0" ) +// ClientFactory is a function that returns a new abci.Client +// It is invoked by the Server each time we subscribe to an event. +type ClientFactory func() (*abci.Client, error) + // NewServer returns a new Server which holds the registered rpc service -func NewServer(abci *abci.Client) *rpc.Server { - srv := rpc.NewServer() - if err := srv.RegisterName("kosu", &Service{abci: abci}); err != nil { - panic(err) +func NewServer(fn ClientFactory) (*rpc.Server, error) { + server := rpc.NewServer() + service, err := NewService(fn) + if err != nil { + return nil, err + } + + if err := server.RegisterName("kosu", service); err != nil { + return nil, err } - return srv + + return server, nil } // DialInProc wraps rpc.DialInProc constructor diff --git a/packages/go-kosu/rpc/rpc_test.go b/packages/go-kosu/rpc/rpc_test.go index a4b3cafc..56cc654c 100644 --- a/packages/go-kosu/rpc/rpc_test.go +++ b/packages/go-kosu/rpc/rpc_test.go @@ -52,7 +52,10 @@ func TestRPC(t *testing.T) { require.NoError(t, err) defer appClient.Stop() // nolint:errcheck - rpcClient := rpc.DialInProc(NewServer(appClient)) + srv, err := NewServer(app.NewClient) + require.NoError(t, err) + + rpcClient := rpc.DialInProc(srv) defer rpcClient.Close() test.run(t, app, appClient, rpcClient) diff --git a/packages/go-kosu/rpc/service.go b/packages/go-kosu/rpc/service.go index 84ffb6d7..4e45ebbc 100644 --- a/packages/go-kosu/rpc/service.go +++ b/packages/go-kosu/rpc/service.go @@ -22,14 +22,21 @@ import ( // Service is a RPC service type Service struct { - abci *abci.Client + abci *abci.Client + newClient ClientFactory } // NewService returns a new service given a abci client -func NewService(abci *abci.Client) *Service { - return &Service{ - abci: abci, +func NewService(fn ClientFactory) (*Service, error) { + c, err := fn() + if err != nil { + return nil, err } + + return &Service{ + abci: c, + newClient: fn, + }, nil } func eventDecoder(event tmtypes.TMEventData) (interface{}, error) { @@ -48,24 +55,28 @@ func eventDecoder(event tmtypes.TMEventData) (interface{}, error) { } func (s *Service) subscribeTM(ctx context.Context, query string) (*rpc.Subscription, error) { + client, err := s.newClient() + if err != nil { + return nil, err + } + notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return nil, rpc.ErrNotificationsUnsupported } - events, closer, err := s.abci.Subscribe(ctx, query) + events, closer, err := client.Subscribe(ctx, query) if err != nil { return nil, err } rpcSub := notifier.CreateSubscription() go func() { - defer s.abci.Unsubscribe(ctx, query) // nolint + defer closer() for { select { case <-rpcSub.Err(): - closer() return case <-notifier.Closed(): return @@ -330,13 +341,18 @@ for { ``` */ func (s *Service) NewOrders(ctx context.Context) (*rpc.Subscription, error) { + client, err := s.newClient() + if err != nil { + return nil, err + } + notifier, supported := rpc.NotifierFromContext(ctx) if !supported { return nil, rpc.ErrNotificationsUnsupported } query := "tm.event='NewBlock'" - events, closer, err := s.abci.Subscribe(ctx, query) + events, closer, err := client.Subscribe(ctx, query) if err != nil { return nil, err } @@ -344,8 +360,8 @@ func (s *Service) NewOrders(ctx context.Context) (*rpc.Subscription, error) { rpcSub := notifier.CreateSubscription() blocks := make(chan *tmtypes.Block, 1024) go func() { - defer s.abci.Unsubscribe(ctx, query) // nolint defer close(blocks) + defer closer() for { select { @@ -353,7 +369,6 @@ func (s *Service) NewOrders(ctx context.Context) (*rpc.Subscription, error) { log.Printf("ctx.Err() = %+v\n", ctx.Err()) return case <-rpcSub.Err(): - closer() return case <-notifier.Closed(): return diff --git a/packages/go-kosu/tests/order_test.go b/packages/go-kosu/tests/order_test.go index b054df2e..03dac11a 100644 --- a/packages/go-kosu/tests/order_test.go +++ b/packages/go-kosu/tests/order_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "testing" + "github.com/ParadigmFoundation/kosu-monorepo/packages/go-kosu/abci" "github.com/ParadigmFoundation/kosu-monorepo/packages/go-kosu/abci/types" "github.com/ParadigmFoundation/kosu-monorepo/packages/go-kosu/rpc" @@ -78,9 +79,11 @@ func (suite *IntegrationTestSuite) TestOrders() { suite.Run("RPCEvents", func() { tx := NewOrderTx(suite.T()) - rpcClient := rpc.DialInProc( - rpc.NewServer(suite.Client()), - ) + + srv, err := rpc.NewServer(func() (*abci.Client, error) { return suite.Client(), nil }) + suite.Require().NoError(err) + + rpcClient := rpc.DialInProc(srv) ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/packages/go-kosu/tests/support.go b/packages/go-kosu/tests/support.go index 6ccf769b..e83e7fa1 100644 --- a/packages/go-kosu/tests/support.go +++ b/packages/go-kosu/tests/support.go @@ -3,6 +3,7 @@ package tests import ( "io/ioutil" "os" + "strings" "testing" "time" @@ -20,6 +21,10 @@ func StartServer(t *testing.T, db db.DB) (*abci.App, func()) { for { app, closer, err := startServer(t, db) if err != nil { + if strings.Contains(err.Error(), "address already in use") { + t.Fatal(err) + } + closer() time.Sleep(100 * time.Millisecond) continue