Browse Source

Node: Minor tweaks and spy improvement (#3974)

* Node: Minor tweaks and spy improvement

* Add tests
bruce-riley 1 year ago
parent
commit
0e2ba6270c

+ 2 - 31
node/cmd/spy/spy.go

@@ -340,12 +340,6 @@ func runSpy(cmd *cobra.Command, args []string) {
 	// Outbound gossip message queue
 	// Outbound gossip message queue
 	sendC := make(chan []byte)
 	sendC := make(chan []byte)
 
 
-	// Inbound observations
-	obsvC := make(chan *common.MsgWithTimeStamp[gossipv1.SignedObservation], 1024)
-
-	// Inbound observation requests
-	obsvReqC := make(chan *gossipv1.ObservationRequest, 1024)
-
 	// Inbound signed VAAs
 	// Inbound signed VAAs
 	signedInC := make(chan *gossipv1.SignedVAAWithQuorum, 1024)
 	signedInC := make(chan *gossipv1.SignedVAAWithQuorum, 1024)
 
 
@@ -370,29 +364,6 @@ func runSpy(cmd *cobra.Command, args []string) {
 		}
 		}
 	}
 	}
 
 
-	// Ignore observations
-	go func() {
-		for {
-			select {
-			case <-rootCtx.Done():
-				return
-			case <-obsvC:
-			}
-		}
-	}()
-
-	// Ignore observation requests
-	// Note: without this, the whole program hangs on observation requests
-	go func() {
-		for {
-			select {
-			case <-rootCtx.Done():
-				return
-			case <-obsvReqC:
-			}
-		}
-	}()
-
 	// Log signed VAAs
 	// Log signed VAAs
 	go func() {
 	go func() {
 		for {
 		for {
@@ -422,8 +393,8 @@ func runSpy(cmd *cobra.Command, args []string) {
 		components.Port = *p2pPort
 		components.Port = *p2pPort
 		if err := supervisor.Run(ctx,
 		if err := supervisor.Run(ctx,
 			"p2p",
 			"p2p",
-			p2p.Run(obsvC,
-				obsvReqC,
+			p2p.Run(nil, // Ignore incoming observations.
+				nil, // Ignore observation requests.
 				nil,
 				nil,
 				sendC,
 				sendC,
 				signedInC,
 				signedInC,

+ 1 - 16
node/pkg/accountant/submit_obs.go

@@ -67,7 +67,7 @@ func (acct *Accountant) handleBatch(ctx context.Context, subChan chan *common.Me
 	ctx, cancel := context.WithTimeout(ctx, delayInMS)
 	ctx, cancel := context.WithTimeout(ctx, delayInMS)
 	defer cancel()
 	defer cancel()
 
 
-	msgs, err := readFromChannel[*common.MessagePublication](ctx, subChan, batchSize)
+	msgs, err := common.ReadFromChannelWithTimeout[*common.MessagePublication](ctx, subChan, batchSize)
 	if err != nil && !errors.Is(err, context.DeadlineExceeded) {
 	if err != nil && !errors.Is(err, context.DeadlineExceeded) {
 		return fmt.Errorf("failed to read messages from channel for %s: %w", tag, err)
 		return fmt.Errorf("failed to read messages from channel for %s: %w", tag, err)
 	}
 	}
@@ -95,21 +95,6 @@ func (acct *Accountant) handleBatch(ctx context.Context, subChan chan *common.Me
 	return nil
 	return nil
 }
 }
 
 
-// readFromChannel reads events from the channel until a timeout occurs or the batch is full, and returns them.
-func readFromChannel[T any](ctx context.Context, ch <-chan T, count int) ([]T, error) {
-	out := make([]T, 0, count)
-	for len(out) < count {
-		select {
-		case <-ctx.Done():
-			return out, ctx.Err()
-		case msg := <-ch:
-			out = append(out, msg)
-		}
-	}
-
-	return out, nil
-}
-
 // removeCompleted drops any messages that are no longer in the pending transfer map. This is to handle the case where the contract reports
 // removeCompleted drops any messages that are no longer in the pending transfer map. This is to handle the case where the contract reports
 // that a transfer is committed while it is in the channel. There is no point in submitting the observation once the transfer is committed.
 // that a transfer is committed while it is in the channel. There is no point in submitting the observation once the transfer is committed.
 func (acct *Accountant) removeCompleted(msgs []*common.MessagePublication) []*common.MessagePublication {
 func (acct *Accountant) removeCompleted(msgs []*common.MessagePublication) []*common.MessagePublication {

+ 20 - 0
node/pkg/common/channel_utils.go

@@ -0,0 +1,20 @@
+package common
+
+import (
+	"context"
+)
+
+// ReadFromChannelWithTimeout reads events from the channel until a timeout occurs or the max maxCount is reached.
+func ReadFromChannelWithTimeout[T any](ctx context.Context, ch <-chan T, maxCount int) ([]T, error) {
+	out := make([]T, 0, maxCount)
+	for len(out) < maxCount {
+		select {
+		case <-ctx.Done():
+			return out, ctx.Err()
+		case msg := <-ch:
+			out = append(out, msg)
+		}
+	}
+
+	return out, nil
+}

+ 80 - 0
node/pkg/common/channel_utils_test.go

@@ -0,0 +1,80 @@
+package common
+
+import (
+	"context"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+const myDelay = time.Millisecond * 100
+const myMaxSize = 2
+const myQueueSize = myMaxSize * 10
+
+func TestReadFromChannelWithTimeout_NoData(t *testing.T) {
+	ctx := context.Background()
+	myChan := make(chan int, myQueueSize)
+
+	// No data should timeout.
+	timeout, cancel := context.WithTimeout(ctx, myDelay)
+	defer cancel()
+	observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
+	assert.Equal(t, err, context.DeadlineExceeded)
+	assert.Equal(t, 0, len(observations))
+}
+
+func TestReadFromChannelWithTimeout_SomeData(t *testing.T) {
+	ctx := context.Background()
+	myChan := make(chan int, myQueueSize)
+	myChan <- 1
+
+	// Some data but not enough to fill a message should timeout and return the data.
+	timeout, cancel := context.WithTimeout(ctx, myDelay)
+	defer cancel()
+	observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
+	assert.Equal(t, err, context.DeadlineExceeded)
+	require.Equal(t, 1, len(observations))
+	assert.Equal(t, 1, observations[0])
+}
+
+func TestReadFromChannelWithTimeout_JustEnoughData(t *testing.T) {
+	ctx := context.Background()
+	myChan := make(chan int, myQueueSize)
+	myChan <- 1
+	myChan <- 2
+
+	// Just enough data should return the data and no error.
+	timeout, cancel := context.WithTimeout(ctx, myDelay)
+	defer cancel()
+	observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
+	assert.NoError(t, err)
+	require.Equal(t, 2, len(observations))
+	assert.Equal(t, 1, observations[0])
+	assert.Equal(t, 2, observations[1])
+}
+
+func TestReadFromChannelWithTimeout_TooMuchData(t *testing.T) {
+	ctx := context.Background()
+	myChan := make(chan int, myQueueSize)
+	myChan <- 1
+	myChan <- 2
+	myChan <- 3
+
+	// If there is more data than will fit, it should immediately return a full message, then timeout and return the remainder.
+	timeout, cancel := context.WithTimeout(ctx, myDelay)
+	defer cancel()
+	observations, err := ReadFromChannelWithTimeout[int](timeout, myChan, myMaxSize)
+	assert.NoError(t, err)
+	require.Equal(t, 2, len(observations))
+	assert.Equal(t, 1, observations[0])
+	assert.Equal(t, 2, observations[1])
+
+	timeout2, cancel2 := context.WithTimeout(ctx, myDelay)
+	defer cancel2()
+	observations, err = ReadFromChannelWithTimeout[int](timeout2, myChan, myMaxSize)
+	assert.Equal(t, err, context.DeadlineExceeded)
+	require.Equal(t, 1, len(observations))
+	assert.Equal(t, 3, observations[0])
+}

+ 8 - 3
node/pkg/common/guardianset.go

@@ -54,8 +54,8 @@ type GuardianSet struct {
 	// On-chain set index
 	// On-chain set index
 	Index uint32
 	Index uint32
 
 
-	// Quorum value for this set of keys
-	Quorum int
+	// quorum value for this set of keys
+	quorum int
 
 
 	// A map from address to index. Testing showed that, on average, a map is almost three times faster than a sequential search of the key slice.
 	// A map from address to index. Testing showed that, on average, a map is almost three times faster than a sequential search of the key slice.
 	// Testing also showed that the map was twice as fast as using a sorted slice and `slices.BinarySearchFunc`. That being said, on a 4GHz CPU,
 	// Testing also showed that the map was twice as fast as using a sorted slice and `slices.BinarySearchFunc`. That being said, on a 4GHz CPU,
@@ -63,6 +63,11 @@ type GuardianSet struct {
 	keyMap map[common.Address]int
 	keyMap map[common.Address]int
 }
 }
 
 
+// Quorum returns the current quorum value.
+func (gs *GuardianSet) Quorum() int {
+	return gs.quorum
+}
+
 func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet {
 func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet {
 	keyMap := map[common.Address]int{}
 	keyMap := map[common.Address]int{}
 	for idx, key := range keys {
 	for idx, key := range keys {
@@ -71,7 +76,7 @@ func NewGuardianSet(keys []common.Address, index uint32) *GuardianSet {
 	return &GuardianSet{
 	return &GuardianSet{
 		Keys:   keys,
 		Keys:   keys,
 		Index:  index,
 		Index:  index,
-		Quorum: vaa.CalculateQuorum(len(keys)),
+		quorum: vaa.CalculateQuorum(len(keys)),
 		keyMap: keyMap,
 		keyMap: keyMap,
 	}
 	}
 }
 }

+ 1 - 1
node/pkg/common/guardianset_test.go

@@ -34,7 +34,7 @@ func TestNewGuardianSet(t *testing.T) {
 	gs := NewGuardianSet(keys, 1)
 	gs := NewGuardianSet(keys, 1)
 	assert.True(t, reflect.DeepEqual(keys, gs.Keys))
 	assert.True(t, reflect.DeepEqual(keys, gs.Keys))
 	assert.Equal(t, uint32(1), gs.Index)
 	assert.Equal(t, uint32(1), gs.Index)
-	assert.Equal(t, vaa.CalculateQuorum(len(keys)), gs.Quorum)
+	assert.Equal(t, vaa.CalculateQuorum(len(keys)), gs.Quorum())
 }
 }
 
 
 func TestKeyIndex(t *testing.T) {
 func TestKeyIndex(t *testing.T) {

+ 53 - 45
node/pkg/p2p/p2p.go

@@ -590,7 +590,9 @@ func Run(
 					}
 					}
 
 
 					// Send to local observation request queue (the loopback message is ignored)
 					// Send to local observation request queue (the loopback message is ignored)
-					obsvReqC <- msg
+					if obsvReqC != nil {
+						obsvReqC <- msg
+					}
 
 
 					err = th.Publish(ctx, b)
 					err = th.Publish(ctx, b)
 					p2pMessagesSent.Inc()
 					p2pMessagesSent.Inc()
@@ -699,59 +701,65 @@ func Run(
 					}()
 					}()
 				}
 				}
 			case *gossipv1.GossipMessage_SignedObservation:
 			case *gossipv1.GossipMessage_SignedObservation:
-				if err := common.PostMsgWithTimestamp[gossipv1.SignedObservation](m.SignedObservation, obsvC); err == nil {
-					p2pMessagesReceived.WithLabelValues("observation").Inc()
-				} else {
-					if components.WarnChannelOverflow {
-						logger.Warn("Ignoring SignedObservation because obsvC full", zap.String("hash", hex.EncodeToString(m.SignedObservation.Hash)))
+				if obsvC != nil {
+					if err := common.PostMsgWithTimestamp[gossipv1.SignedObservation](m.SignedObservation, obsvC); err == nil {
+						p2pMessagesReceived.WithLabelValues("observation").Inc()
+					} else {
+						if components.WarnChannelOverflow {
+							logger.Warn("Ignoring SignedObservation because obsvC full", zap.String("hash", hex.EncodeToString(m.SignedObservation.Hash)))
+						}
+						p2pReceiveChannelOverflow.WithLabelValues("observation").Inc()
 					}
 					}
-					p2pReceiveChannelOverflow.WithLabelValues("observation").Inc()
 				}
 				}
 			case *gossipv1.GossipMessage_SignedVaaWithQuorum:
 			case *gossipv1.GossipMessage_SignedVaaWithQuorum:
-				select {
-				case signedInC <- m.SignedVaaWithQuorum:
-					p2pMessagesReceived.WithLabelValues("signed_vaa_with_quorum").Inc()
-				default:
-					if components.WarnChannelOverflow {
-						// TODO do not log this in production
-						var hexStr string
-						if vaa, err := vaa.Unmarshal(m.SignedVaaWithQuorum.Vaa); err == nil {
-							hexStr = vaa.HexDigest()
+				if signedInC != nil {
+					select {
+					case signedInC <- m.SignedVaaWithQuorum:
+						p2pMessagesReceived.WithLabelValues("signed_vaa_with_quorum").Inc()
+					default:
+						if components.WarnChannelOverflow {
+							// TODO do not log this in production
+							var hexStr string
+							if vaa, err := vaa.Unmarshal(m.SignedVaaWithQuorum.Vaa); err == nil {
+								hexStr = vaa.HexDigest()
+							}
+							logger.Warn("Ignoring SignedVaaWithQuorum because signedInC full", zap.String("hash", hexStr))
 						}
 						}
-						logger.Warn("Ignoring SignedVaaWithQuorum because signedInC full", zap.String("hash", hexStr))
+						p2pReceiveChannelOverflow.WithLabelValues("signed_vaa_with_quorum").Inc()
 					}
 					}
-					p2pReceiveChannelOverflow.WithLabelValues("signed_vaa_with_quorum").Inc()
 				}
 				}
 			case *gossipv1.GossipMessage_SignedObservationRequest:
 			case *gossipv1.GossipMessage_SignedObservationRequest:
-				s := m.SignedObservationRequest
-				gs := gst.Get()
-				if gs == nil {
-					if logger.Level().Enabled(zapcore.DebugLevel) {
-						logger.Debug("dropping SignedObservationRequest - no guardian set", zap.Any("value", s), zap.String("from", envelope.GetFrom().String()))
-					}
-					break
-				}
-				r, err := processSignedObservationRequest(s, gs)
-				if err != nil {
-					p2pMessagesReceived.WithLabelValues("invalid_signed_observation_request").Inc()
-					if logger.Level().Enabled(zapcore.DebugLevel) {
-						logger.Debug("invalid signed observation request received",
-							zap.Error(err),
-							zap.Any("payload", msg.Message),
-							zap.Any("value", s),
-							zap.Binary("raw", envelope.Data),
-							zap.String("from", envelope.GetFrom().String()))
-					}
-				} else {
-					if logger.Level().Enabled(zapcore.DebugLevel) {
-						logger.Debug("valid signed observation request received", zap.Any("value", r), zap.String("from", envelope.GetFrom().String()))
+				if obsvReqC != nil {
+					s := m.SignedObservationRequest
+					gs := gst.Get()
+					if gs == nil {
+						if logger.Level().Enabled(zapcore.DebugLevel) {
+							logger.Debug("dropping SignedObservationRequest - no guardian set", zap.Any("value", s), zap.String("from", envelope.GetFrom().String()))
+						}
+						break
 					}
 					}
+					r, err := processSignedObservationRequest(s, gs)
+					if err != nil {
+						p2pMessagesReceived.WithLabelValues("invalid_signed_observation_request").Inc()
+						if logger.Level().Enabled(zapcore.DebugLevel) {
+							logger.Debug("invalid signed observation request received",
+								zap.Error(err),
+								zap.Any("payload", msg.Message),
+								zap.Any("value", s),
+								zap.Binary("raw", envelope.Data),
+								zap.String("from", envelope.GetFrom().String()))
+						}
+					} else {
+						if logger.Level().Enabled(zapcore.DebugLevel) {
+							logger.Debug("valid signed observation request received", zap.Any("value", r), zap.String("from", envelope.GetFrom().String()))
+						}
 
 
-					select {
-					case obsvReqC <- r:
-						p2pMessagesReceived.WithLabelValues("signed_observation_request").Inc()
-					default:
-						p2pReceiveChannelOverflow.WithLabelValues("signed_observation_request").Inc()
+						select {
+						case obsvReqC <- r:
+							p2pMessagesReceived.WithLabelValues("signed_observation_request").Inc()
+						default:
+							p2pReceiveChannelOverflow.WithLabelValues("signed_observation_request").Inc()
+						}
 					}
 					}
 				}
 				}
 			case *gossipv1.GossipMessage_SignedChainGovernorConfig:
 			case *gossipv1.GossipMessage_SignedChainGovernorConfig:

+ 1 - 2
node/pkg/processor/broadcast.go

@@ -8,7 +8,6 @@ import (
 	"github.com/prometheus/client_golang/prometheus/promauto"
 	"github.com/prometheus/client_golang/prometheus/promauto"
 
 
 	ethcommon "github.com/ethereum/go-ethereum/common"
 	ethcommon "github.com/ethereum/go-ethereum/common"
-	"github.com/ethereum/go-ethereum/crypto"
 	"google.golang.org/protobuf/proto"
 	"google.golang.org/protobuf/proto"
 
 
 	node_common "github.com/certusone/wormhole/node/pkg/common"
 	node_common "github.com/certusone/wormhole/node/pkg/common"
@@ -43,7 +42,7 @@ func (p *Processor) broadcastSignature(
 ) {
 ) {
 	digest := o.SigningDigest()
 	digest := o.SigningDigest()
 	obsv := gossipv1.SignedObservation{
 	obsv := gossipv1.SignedObservation{
-		Addr:      crypto.PubkeyToAddress(p.gk.PublicKey).Bytes(),
+		Addr:      p.ourAddr.Bytes(),
 		Hash:      digest.Bytes(),
 		Hash:      digest.Bytes(),
 		Signature: signature,
 		Signature: signature,
 		TxHash:    txhash,
 		TxHash:    txhash,

+ 4 - 4
node/pkg/processor/cleanup.go

@@ -115,7 +115,7 @@ func (p *Processor) handleCleanup(ctx context.Context) {
 			}
 			}
 
 
 			hasSigs := len(s.signatures)
 			hasSigs := len(s.signatures)
-			quorum := hasSigs >= gs.Quorum
+			quorum := hasSigs >= gs.Quorum()
 
 
 			var chain vaa.ChainID
 			var chain vaa.ChainID
 			if s.ourObservation != nil {
 			if s.ourObservation != nil {
@@ -128,7 +128,7 @@ func (p *Processor) handleCleanup(ctx context.Context) {
 					zap.String("digest", hash),
 					zap.String("digest", hash),
 					zap.Duration("delta", delta),
 					zap.Duration("delta", delta),
 					zap.Int("have_sigs", hasSigs),
 					zap.Int("have_sigs", hasSigs),
-					zap.Int("required_sigs", gs.Quorum),
+					zap.Int("required_sigs", gs.Quorum()),
 					zap.Bool("quorum", quorum),
 					zap.Bool("quorum", quorum),
 					zap.Stringer("emitter_chain", chain),
 					zap.Stringer("emitter_chain", chain),
 				)
 				)
@@ -245,8 +245,8 @@ func (p *Processor) handleCleanup(ctx context.Context) {
 						zap.String("digest", hash),
 						zap.String("digest", hash),
 						zap.Duration("delta", delta),
 						zap.Duration("delta", delta),
 						zap.Int("have_sigs", hasSigs),
 						zap.Int("have_sigs", hasSigs),
-						zap.Int("required_sigs", p.gs.Quorum),
-						zap.Bool("quorum", hasSigs >= p.gs.Quorum),
+						zap.Int("required_sigs", p.gs.Quorum()),
+						zap.Bool("quorum", hasSigs >= p.gs.Quorum()),
 					)
 					)
 				}
 				}
 				delete(p.state.signatures, hash)
 				delete(p.state.signatures, hash)

+ 4 - 4
node/pkg/processor/observation.go

@@ -228,7 +228,7 @@ func (p *Processor) handleObservation(ctx context.Context, obs *node_common.MsgW
 		// Hence, if len(s.signatures) < quorum, then there is definitely no quorum and we can return early to save additional computation,
 		// Hence, if len(s.signatures) < quorum, then there is definitely no quorum and we can return early to save additional computation,
 		// but if len(s.signatures) >= quorum, there is not necessarily quorum for the active guardian set.
 		// but if len(s.signatures) >= quorum, there is not necessarily quorum for the active guardian set.
 		// We will later check for quorum again after assembling the VAA for a particular guardian set.
 		// We will later check for quorum again after assembling the VAA for a particular guardian set.
-		if len(s.signatures) < gs.Quorum {
+		if len(s.signatures) < gs.Quorum() {
 			// no quorum yet, we're done here
 			// no quorum yet, we're done here
 			if p.logger.Level().Enabled(zapcore.DebugLevel) {
 			if p.logger.Level().Enabled(zapcore.DebugLevel) {
 				p.logger.Debug("quorum not yet met",
 				p.logger.Debug("quorum not yet met",
@@ -250,13 +250,13 @@ func (p *Processor) handleObservation(ctx context.Context, obs *node_common.MsgW
 				zap.Any("set", gs.KeysAsHexStrings()),
 				zap.Any("set", gs.KeysAsHexStrings()),
 				zap.Uint32("index", gs.Index),
 				zap.Uint32("index", gs.Index),
 				zap.Bools("aggregation", agg),
 				zap.Bools("aggregation", agg),
-				zap.Int("required_sigs", gs.Quorum),
+				zap.Int("required_sigs", gs.Quorum()),
 				zap.Int("have_sigs", len(sigsVaaFormat)),
 				zap.Int("have_sigs", len(sigsVaaFormat)),
-				zap.Bool("quorum", len(sigsVaaFormat) >= gs.Quorum),
+				zap.Bool("quorum", len(sigsVaaFormat) >= gs.Quorum()),
 			)
 			)
 		}
 		}
 
 
-		if len(sigsVaaFormat) >= gs.Quorum {
+		if len(sigsVaaFormat) >= gs.Quorum() {
 			// we have reached quorum *with the active guardian set*
 			// we have reached quorum *with the active guardian set*
 			s.ourObservation.HandleQuorum(sigsVaaFormat, hash, p)
 			s.ourObservation.HandleQuorum(sigsVaaFormat, hash, p)
 		} else {
 		} else {

+ 1 - 1
node/pkg/processor/processor.go

@@ -223,7 +223,7 @@ func (p *Processor) Run(ctx context.Context) error {
 			p.logger.Info("guardian set updated",
 			p.logger.Info("guardian set updated",
 				zap.Strings("set", p.gs.KeysAsHexStrings()),
 				zap.Strings("set", p.gs.KeysAsHexStrings()),
 				zap.Uint32("index", p.gs.Index),
 				zap.Uint32("index", p.gs.Index),
-				zap.Int("quorum", p.gs.Quorum),
+				zap.Int("quorum", p.gs.Quorum()),
 			)
 			)
 			p.gst.Set(p.gs)
 			p.gst.Set(p.gs)
 		case k := <-p.msgC:
 		case k := <-p.msgC: