pendingmessage.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package common
  2. import (
  3. "bytes"
  4. "cmp"
  5. "container/heap"
  6. "encoding/binary"
  7. "errors"
  8. "fmt"
  9. "slices"
  10. "sync"
  11. "time"
  12. "github.com/wormhole-foundation/wormhole/sdk/vaa"
  13. )
  14. const (
  15. // marshaledPendingMsgLenMin is the minimum length of a marshaled pending message.
  16. // It includes 8 bytes for the timestamp.
  17. marshaledPendingMsgLenMin = 8 + marshaledMsgLenMin
  18. )
  19. // PendingMessage is a wrapper type around a [MessagePublication] that includes the time for which it
  20. // should be released.
  21. type PendingMessage struct {
  22. ReleaseTime time.Time
  23. Msg MessagePublication
  24. }
  25. func (p PendingMessage) Compare(other PendingMessage) int {
  26. return cmp.Compare(p.ReleaseTime.Unix(), other.ReleaseTime.Unix())
  27. }
  28. // MarshalBinary implements BinaryMarshaler for [PendingMessage].
  29. func (p *PendingMessage) MarshalBinary() ([]byte, error) {
  30. msgPubBz, err := p.Msg.MarshalBinary()
  31. if err != nil {
  32. return nil, fmt.Errorf("marshal pending message: %w", err)
  33. }
  34. buf := new(bytes.Buffer)
  35. // Compare with [PendingTransfer.Marshal].
  36. // #nosec G115 -- int64 and uint64 have the same number of bytes, and Unix time won't be negative.
  37. vaa.MustWrite(buf, binary.BigEndian, uint64(p.ReleaseTime.Unix()))
  38. buf.Write(msgPubBz)
  39. return buf.Bytes(), nil
  40. }
  41. // UnmarshalBinary implements BinaryUnmarshaler for [PendingMessage].
  42. func (p *PendingMessage) UnmarshalBinary(data []byte) error {
  43. if len(data) < marshaledPendingMsgLenMin {
  44. return ErrInputSize{Msg: "pending message too short", Want: marshaledPendingMsgLenMin, Got: len(data)}
  45. }
  46. // Compare with [UnmarshalPendingTransfer].
  47. p.ReleaseTime = time.Unix(
  48. // #nosec G115 -- int64 and uint64 have the same number of bytes, and Unix time won't be negative.
  49. int64(binary.BigEndian.Uint64(data[0:8])),
  50. 0,
  51. )
  52. err := p.Msg.UnmarshalBinary(data[8:])
  53. if err != nil {
  54. return fmt.Errorf("unmarshal pending message: %w", err)
  55. }
  56. return nil
  57. }
  58. // A pendingMessageHeap is a min-heap of [PendingMessage] and uses the heap interface
  59. // by implementing the methods below.
  60. // As a result:
  61. // - The heap is always sorted by timestamp.
  62. // - the oldest (smallest) timestamp is always the last element.
  63. // This allows us to pop from the heap in order to get the oldest timestamp. If
  64. // that value greater than whatever time threshold we specify, we know that
  65. // there are no other messages that need to be released because their
  66. // timestamps must be greater. This should allow for constant-time lookups when
  67. // looking for messages to release.
  68. //
  69. // See: https://pkg.go.dev/container/heap#Interface
  70. type pendingMessageHeap []*PendingMessage
  71. func (h pendingMessageHeap) Len() int {
  72. return len(h)
  73. }
  74. func (h pendingMessageHeap) Less(i, j int) bool {
  75. return h[i].ReleaseTime.Before(h[j].ReleaseTime)
  76. }
  77. func (h pendingMessageHeap) Swap(i, j int) {
  78. h[i], h[j] = h[j], h[i]
  79. }
  80. // Push dangerously pushes a value to the heap.
  81. func (h *pendingMessageHeap) Push(x any) {
  82. // Push and Pop use pointer receivers because they modify the slice's length,
  83. // not just its contents.
  84. item, ok := x.(*PendingMessage)
  85. if !ok {
  86. panic("PendingMessageHeap: cannot push non-*PendingMessage")
  87. }
  88. // Null check
  89. if item == nil {
  90. panic("PendingMessageHeap: cannot push nil *PendingMessage")
  91. }
  92. *h = append(*h, item)
  93. }
  94. // Pops dangerously pops a value from the heap.
  95. func (h *pendingMessageHeap) Pop() any {
  96. old := *h
  97. n := len(old)
  98. if n == 0 {
  99. panic("PendingMessageHeap: cannot Pop from empty heap")
  100. }
  101. last := old[n-1]
  102. *h = old[0 : n-1]
  103. return last
  104. }
  105. // PendingMessageQueue is a thread-safe min-heap that sorts [PendingMessage] in descending order of Timestamp.
  106. // It also prevents duplicate [MessagePublication]s from being added to the queue.
  107. type PendingMessageQueue struct {
  108. // pendingMessageHeap exposes dangerous APIs as a necessary consequence of implementing [heap.Interface].
  109. // Wrap it and expose only a safe API.
  110. heap pendingMessageHeap
  111. mu sync.RWMutex
  112. }
  113. func NewPendingMessageQueue() *PendingMessageQueue {
  114. q := &PendingMessageQueue{heap: pendingMessageHeap{}}
  115. heap.Init(&q.heap)
  116. return q
  117. }
  118. // Push adds an element to the heap. If the pending message's message ID is invalid, or if it already exists in the queue, nothing is added.
  119. func (q *PendingMessageQueue) Push(pMsg *PendingMessage) {
  120. // noop if the message is nil or already in the queue.
  121. if pMsg == nil {
  122. return
  123. }
  124. if len(pMsg.Msg.MessageID()) < MinMsgIdLen {
  125. return
  126. }
  127. // FetchMessagePublication acquires and releases a read lock, so we don't need to write lock
  128. // until we're inside the if statement.
  129. if q.FetchMessagePublication(pMsg.Msg.MessageID()) == nil {
  130. q.mu.Lock()
  131. heap.Push(&q.heap, pMsg)
  132. defer q.mu.Unlock()
  133. }
  134. }
  135. // Pop removes the last element from the heap and returns its value.
  136. // Returns nil if the heap is empty or if the value is not a *[PendingMessage].
  137. func (q *PendingMessageQueue) Pop() *PendingMessage {
  138. if q.heap.Len() == 0 {
  139. return nil
  140. }
  141. q.mu.Lock()
  142. defer q.mu.Unlock()
  143. last, ok := heap.Pop(&q.heap).(*PendingMessage)
  144. if !ok {
  145. return nil
  146. }
  147. return last
  148. }
  149. func (q *PendingMessageQueue) Len() int {
  150. return q.heap.Len()
  151. }
  152. // Peek returns the element at the top of the heap without removing it.
  153. func (q *PendingMessageQueue) Peek() *PendingMessage {
  154. if q.heap.Len() == 0 {
  155. return nil
  156. }
  157. // container/heap stores the "next" element at the first offset.
  158. last := *q.heap[0]
  159. return &last
  160. }
  161. // RemoveItem removes target MessagePublication with the message ID from the heap. Returns the element that was removed or nil
  162. // if the item was not found. No error is returned if the item was not found.
  163. func (q *PendingMessageQueue) RemoveItem(msgID []byte) (*PendingMessage, error) {
  164. if msgID == nil {
  165. return nil, errors.New("pendingmessage: nil argument for RemoveItem")
  166. }
  167. q.mu.Lock()
  168. defer q.mu.Unlock()
  169. var removed *PendingMessage
  170. for i, item := range q.heap {
  171. // Assumption: MsgIDs are unique across MessagePublications.
  172. if bytes.Equal(item.Msg.MessageID(), msgID) {
  173. pMsg, ok := heap.Remove(&q.heap, i).(*PendingMessage)
  174. if !ok {
  175. return nil, errors.New("pendingmessage: item removed from heap is not PendingMessage")
  176. }
  177. removed = pMsg
  178. break
  179. }
  180. }
  181. return removed, nil
  182. }
  183. // Contains determines whether the queue contains a [PendingMessage].
  184. func (q *PendingMessageQueue) Contains(pMsg *PendingMessage) bool {
  185. if pMsg == nil {
  186. return false
  187. }
  188. q.mu.RLock()
  189. defer q.mu.RUnlock()
  190. return slices.Contains(q.heap, pMsg)
  191. }
  192. // FetchMessagePublication returns a [MessagePublication] with the given ID if it exists in the queue, and nil
  193. // otherwise.
  194. func (q *PendingMessageQueue) FetchMessagePublication(msgID []byte) (msgPub *MessagePublication) {
  195. if len(msgID) == 0 {
  196. return nil
  197. }
  198. q.mu.RLock()
  199. defer q.mu.RUnlock()
  200. // Relies on MessageIDString to be unique.
  201. for _, pMsg := range q.heap {
  202. if bytes.Equal(pMsg.Msg.MessageID(), msgID) {
  203. return &pMsg.Msg
  204. }
  205. }
  206. return nil
  207. }