From 86c64d2a9e4353c633031acc5978800ea1161029 Mon Sep 17 00:00:00 2001
From: Silas Davis <silas@monax.io>
Date: Thu, 14 Jun 2018 12:32:44 +0100
Subject: [PATCH] Add panic handling to GRPC and use safe getters

Signed-off-by: Silas Davis <silas@monax.io>
---
 core/kernel.go                   |  5 ++--
 rpc/burrow/transaction_server.go | 47 ++++++++++++++++++--------------
 rpc/grpc.go                      | 46 +++++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 23 deletions(-)
 create mode 100644 rpc/grpc.go

diff --git a/core/kernel.go b/core/kernel.go
index cd1cf3fd..f5a8549f 100644
--- a/core/kernel.go
+++ b/core/kernel.go
@@ -47,7 +47,6 @@ import (
 	tm_config "github.com/tendermint/tendermint/config"
 	tm_types "github.com/tendermint/tendermint/types"
 	dbm "github.com/tendermint/tmlibs/db"
-	"google.golang.org/grpc"
 )
 
 const (
@@ -206,7 +205,7 @@ func NewKernel(ctx context.Context, keyClient keys.KeyClient, privValidator tm_t
 			},
 		},
 		{
-			Name:    "GRPC service",
+			Name:    "GRPC",
 			Enabled: rpcConfig.GRPC.Enabled,
 			Launch: func() (process.Process, error) {
 				listen, err := net.Listen("tcp", rpcConfig.GRPC.ListenAddress)
@@ -214,7 +213,7 @@ func NewKernel(ctx context.Context, keyClient keys.KeyClient, privValidator tm_t
 					return nil, err
 				}
 
-				grpcServer := grpc.NewServer()
+				grpcServer := rpc.NewGRPCServer(logger)
 				var ks keys.KeyStore
 				if keyStore != nil {
 					ks = *keyStore
diff --git a/rpc/burrow/transaction_server.go b/rpc/burrow/transaction_server.go
index e8cd8347..d2c5dc61 100644
--- a/rpc/burrow/transaction_server.go
+++ b/rpc/burrow/transaction_server.go
@@ -4,6 +4,7 @@ import (
 	acm "github.com/hyperledger/burrow/account"
 	"github.com/hyperledger/burrow/account/state"
 	"github.com/hyperledger/burrow/crypto"
+	"github.com/hyperledger/burrow/execution"
 	"github.com/hyperledger/burrow/execution/evm/events"
 	"github.com/hyperledger/burrow/rpc"
 	"github.com/hyperledger/burrow/txs"
@@ -25,7 +26,7 @@ func NewTransactionServer(service *rpc.Service, reader state.Reader, txCodec txs
 }
 
 func (ts *transactionServer) BroadcastTx(ctx context.Context, param *TxParam) (*TxReceipt, error) {
-	receipt, err := ts.service.Transactor().BroadcastTxRaw(param.Tx)
+	receipt, err := ts.service.Transactor().BroadcastTxRaw(param.GetTx())
 	if err != nil {
 		return nil, err
 	}
@@ -33,15 +34,15 @@ func (ts *transactionServer) BroadcastTx(ctx context.Context, param *TxParam) (*
 }
 
 func (ts *transactionServer) Call(ctx context.Context, param *CallParam) (*CallResult, error) {
-	fromAddress, err := crypto.AddressFromBytes(param.From)
+	fromAddress, err := crypto.AddressFromBytes(param.GetFrom())
 	if err != nil {
 		return nil, err
 	}
-	address, err := crypto.AddressFromBytes(param.Address)
+	address, err := crypto.AddressFromBytes(param.GetAddress())
 	if err != nil {
 		return nil, err
 	}
-	call, err := ts.service.Transactor().Call(ts.reader, fromAddress, address, param.Data)
+	call, err := ts.service.Transactor().Call(ts.reader, fromAddress, address, param.GetData())
 	return &CallResult{
 		Return:  call.Return,
 		GasUsed: call.GasUsed,
@@ -49,11 +50,11 @@ func (ts *transactionServer) Call(ctx context.Context, param *CallParam) (*CallR
 }
 
 func (ts *transactionServer) CallCode(ctx context.Context, param *CallCodeParam) (*CallResult, error) {
-	fromAddress, err := crypto.AddressFromBytes(param.From)
+	fromAddress, err := crypto.AddressFromBytes(param.GetFrom())
 	if err != nil {
 		return nil, err
 	}
-	call, err := ts.service.Transactor().CallCode(ts.reader, fromAddress, param.Code, param.Data)
+	call, err := ts.service.Transactor().CallCode(ts.reader, fromAddress, param.GetCode(), param.GetData())
 	return &CallResult{
 		Return:  call.Return,
 		GasUsed: call.GasUsed,
@@ -61,15 +62,16 @@ func (ts *transactionServer) CallCode(ctx context.Context, param *CallCodeParam)
 }
 
 func (ts *transactionServer) Transact(ctx context.Context, param *TransactParam) (*TxReceipt, error) {
-	inputAccount, err := ts.service.SigningAccount(param.InputAccount.Address, param.InputAccount.PrivateKey)
+	inputAccount, err := ts.inputAccount(param.GetInputAccount())
 	if err != nil {
 		return nil, err
 	}
-	address, err := crypto.MaybeAddressFromBytes(param.Address)
+	address, err := crypto.MaybeAddressFromBytes(param.GetAddress())
 	if err != nil {
 		return nil, err
 	}
-	receipt, err := ts.service.Transactor().Transact(inputAccount, address, param.Data, param.GasLimit, param.Fee)
+	receipt, err := ts.service.Transactor().Transact(inputAccount, address, param.GetData(), param.GetGasLimit(),
+		param.GetFee())
 	if err != nil {
 		return nil, err
 	}
@@ -77,15 +79,16 @@ func (ts *transactionServer) Transact(ctx context.Context, param *TransactParam)
 }
 
 func (ts *transactionServer) TransactAndHold(ctx context.Context, param *TransactParam) (*EventDataCall, error) {
-	inputAccount, err := ts.service.SigningAccount(param.InputAccount.Address, param.InputAccount.PrivateKey)
+	inputAccount, err := ts.inputAccount(param.GetInputAccount())
 	if err != nil {
 		return nil, err
 	}
-	address, err := crypto.MaybeAddressFromBytes(param.Address)
+	address, err := crypto.MaybeAddressFromBytes(param.GetAddress())
 	if err != nil {
 		return nil, err
 	}
-	edt, err := ts.service.Transactor().TransactAndHold(ctx, inputAccount, address, param.Data, param.GasLimit, param.Fee)
+	edt, err := ts.service.Transactor().TransactAndHold(ctx, inputAccount, address, param.GetData(),
+		param.GetGasLimit(), param.GetFee())
 	if err != nil {
 		return nil, err
 	}
@@ -93,15 +96,15 @@ func (ts *transactionServer) TransactAndHold(ctx context.Context, param *Transac
 }
 
 func (ts *transactionServer) Send(ctx context.Context, param *SendParam) (*TxReceipt, error) {
-	inputAccount, err := ts.service.SigningAccount(param.InputAccount.Address, param.InputAccount.PrivateKey)
+	inputAccount, err := ts.inputAccount(param.GetInputAccount())
 	if err != nil {
 		return nil, err
 	}
-	toAddress, err := crypto.AddressFromBytes(param.ToAddress)
+	toAddress, err := crypto.AddressFromBytes(param.GetToAddress())
 	if err != nil {
 		return nil, err
 	}
-	receipt, err := ts.service.Transactor().Send(inputAccount, toAddress, param.Amount)
+	receipt, err := ts.service.Transactor().Send(inputAccount, toAddress, param.GetAmount())
 	if err != nil {
 		return nil, err
 	}
@@ -109,15 +112,15 @@ func (ts *transactionServer) Send(ctx context.Context, param *SendParam) (*TxRec
 }
 
 func (ts *transactionServer) SendAndHold(ctx context.Context, param *SendParam) (*TxReceipt, error) {
-	inputAccount, err := ts.service.SigningAccount(param.InputAccount.Address, param.InputAccount.PrivateKey)
+	inputAccount, err := ts.inputAccount(param.GetInputAccount())
 	if err != nil {
 		return nil, err
 	}
-	toAddress, err := crypto.AddressFromBytes(param.ToAddress)
+	toAddress, err := crypto.AddressFromBytes(param.GetToAddress())
 	if err != nil {
 		return nil, err
 	}
-	receipt, err := ts.service.Transactor().SendAndHold(ctx, inputAccount, toAddress, param.Amount)
+	receipt, err := ts.service.Transactor().SendAndHold(ctx, inputAccount, toAddress, param.GetAmount())
 	if err != nil {
 		return nil, err
 	}
@@ -125,11 +128,11 @@ func (ts *transactionServer) SendAndHold(ctx context.Context, param *SendParam)
 }
 
 func (ts *transactionServer) SignTx(ctx context.Context, param *SignTxParam) (*SignedTx, error) {
-	txEnv, err := ts.txCodec.DecodeTx(param.Tx)
+	txEnv, err := ts.txCodec.DecodeTx(param.GetTx())
 	if err != nil {
 		return nil, err
 	}
-	signers, err := signersFromPrivateAccounts(param.PrivateAccounts)
+	signers, err := signersFromPrivateAccounts(param.GetPrivateAccounts())
 	if err != nil {
 		return nil, err
 	}
@@ -146,6 +149,10 @@ func (ts *transactionServer) SignTx(ctx context.Context, param *SignTxParam) (*S
 	}, nil
 }
 
+func (ts *transactionServer) inputAccount(inAcc *InputAccount) (*execution.SequentialSigningAccount, error) {
+	return ts.service.SigningAccount(inAcc.GetAddress(), inAcc.GetPrivateKey())
+}
+
 func eventDataCall(edt *events.EventDataCall) *EventDataCall {
 	return &EventDataCall{
 		Origin:     edt.Origin.Bytes(),
diff --git a/rpc/grpc.go b/rpc/grpc.go
new file mode 100644
index 00000000..49a42fd0
--- /dev/null
+++ b/rpc/grpc.go
@@ -0,0 +1,46 @@
+package rpc
+
+import (
+	"fmt"
+
+	"github.com/hyperledger/burrow/logging"
+	"github.com/hyperledger/burrow/logging/structure"
+	"golang.org/x/net/context"
+	"google.golang.org/grpc"
+)
+
+func NewGRPCServer(logger *logging.Logger) *grpc.Server {
+	return grpc.NewServer(grpc.UnaryInterceptor(unaryInterceptor(logger)),
+		grpc.StreamInterceptor(streamInterceptor(logger.WithScope("NewGRPCServer"))))
+}
+
+func unaryInterceptor(logger *logging.Logger) grpc.UnaryServerInterceptor {
+	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
+		handler grpc.UnaryHandler) (resp interface{}, err error) {
+
+		defer func() {
+			if r := recover(); r != nil {
+				logger.InfoMsg("panic in GRPC unary call", "method", info.FullMethod,
+					structure.ErrorKey, fmt.Sprintf("%v", r))
+				err = fmt.Errorf("panic in GRPC unary call %s: %v", info.FullMethod, r)
+			}
+		}()
+		return handler(ctx, req)
+	}
+}
+
+func streamInterceptor(logger *logging.Logger) grpc.StreamServerInterceptor {
+	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo,
+		handler grpc.StreamHandler) (err error) {
+
+		defer func() {
+			if r := recover(); r != nil {
+				logger.InfoMsg("panic in GRPC stream", "method", info.FullMethod,
+					"is_client_stream", info.IsClientStream, "is_server_stream", info.IsServerStream,
+					structure.ErrorKey, fmt.Sprintf("%v", r))
+				err = fmt.Errorf("panic in GRPC stream %s: %v", info.FullMethod, r)
+			}
+		}()
+		return handler(srv, ss)
+	}
+}
-- 
GitLab