searcher.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import time
  2. from dataclasses import dataclass
  3. from typing import List, Optional, Tuple
  4. from grpc import (
  5. StreamStreamClientInterceptor,
  6. StreamUnaryClientInterceptor,
  7. UnaryStreamClientInterceptor,
  8. UnaryUnaryClientInterceptor,
  9. intercept_channel,
  10. secure_channel,
  11. ssl_channel_credentials,
  12. )
  13. from grpc.aio import ClientCallDetails
  14. from solders.keypair import Keypair
  15. from .generated.auth_pb2 import (
  16. GenerateAuthChallengeRequest,
  17. GenerateAuthTokensRequest,
  18. GenerateAuthTokensResponse,
  19. RefreshAccessTokenRequest,
  20. RefreshAccessTokenResponse,
  21. Role,
  22. )
  23. from .generated.auth_pb2_grpc import AuthServiceStub
  24. from .generated.searcher_pb2_grpc import SearcherServiceStub
  25. @dataclass
  26. class JwtToken:
  27. # jwt token string
  28. token: str
  29. # time in seconds since epoch when the token expires
  30. expiration: int
  31. class SearcherInterceptor(
  32. UnaryUnaryClientInterceptor,
  33. UnaryStreamClientInterceptor,
  34. StreamUnaryClientInterceptor,
  35. StreamStreamClientInterceptor,
  36. ):
  37. """
  38. The jito_searcher_client interceptor is responsible for authenticating with the block engine.
  39. Authentication happens in a challenge-response handshake.
  40. 1. Request a challenge and provide your public key.
  41. 2. Get challenge and sign a message "{pubkey}-{challenge}".
  42. 3. Get back a refresh token and access token.
  43. When the access token expires, use the refresh token to get a new one.
  44. When the refresh token expires, perform the challenge-response handshake again.
  45. """
  46. def __init__(self, url: str, kp: Keypair):
  47. """
  48. :param url: url of the Block Engine without http or https.
  49. :param kp: block engine authentication keypair
  50. """
  51. self._url = url
  52. self._kp = kp
  53. self._access_token: Optional[JwtToken] = None
  54. self._refresh_token: Optional[JwtToken] = None
  55. def intercept_unary_stream(self, continuation, client_call_details, request):
  56. self.authenticate_if_needed()
  57. client_call_details = self._insert_headers(
  58. [("authorization", f"Bearer {self._access_token.token}")],
  59. client_call_details,
  60. )
  61. return continuation(client_call_details, request)
  62. def intercept_stream_unary(
  63. self, continuation, client_call_details, request_iterator
  64. ):
  65. self.authenticate_if_needed()
  66. client_call_details = self._insert_headers(
  67. [("authorization", f"Bearer {self._access_token.token}")],
  68. client_call_details,
  69. )
  70. return continuation(client_call_details, request_iterator)
  71. def intercept_stream_stream(
  72. self, continuation, client_call_details, request_iterator
  73. ):
  74. self.authenticate_if_needed()
  75. client_call_details = self._insert_headers(
  76. [("authorization", f"Bearer {self._access_token.token}")],
  77. client_call_details,
  78. )
  79. return continuation(client_call_details, request_iterator)
  80. def intercept_unary_unary(self, continuation, client_call_details, request):
  81. self.authenticate_if_needed()
  82. client_call_details = self._insert_headers(
  83. [("authorization", f"Bearer {self._access_token.token}")],
  84. client_call_details,
  85. )
  86. return continuation(client_call_details, request)
  87. @staticmethod
  88. def _insert_headers(
  89. new_metadata: List[Tuple[str, str]], client_call_details
  90. ) -> ClientCallDetails:
  91. metadata = []
  92. if client_call_details.metadata is not None:
  93. metadata = list(client_call_details.metadata)
  94. metadata.extend(new_metadata)
  95. return ClientCallDetails(
  96. client_call_details.method,
  97. client_call_details.timeout,
  98. metadata,
  99. client_call_details.credentials,
  100. False,
  101. )
  102. def authenticate_if_needed(self):
  103. """
  104. Maybe authenticates depending on state of access + refresh tokens
  105. """
  106. now = int(time.time())
  107. if self._access_token is None or self._refresh_token is None or now >= self._refresh_token.expiration:
  108. self.full_authentication()
  109. elif now >= self._access_token.expiration:
  110. self.refresh_authentication()
  111. def refresh_authentication(self):
  112. """
  113. Performs an authentication refresh with the block engine, which involves using the refresh token to get a new
  114. access token.
  115. """
  116. credentials = ssl_channel_credentials()
  117. channel = secure_channel(self._url, credentials)
  118. auth_client = AuthServiceStub(channel)
  119. new_access_token: RefreshAccessTokenResponse = auth_client.RefreshAccessToken(
  120. RefreshAccessTokenRequest(refresh_token=self._refresh_token.token))
  121. self._access_token = JwtToken(token=new_access_token.access_token.value,
  122. expiration=new_access_token.access_token.expires_at_utc.seconds)
  123. def full_authentication(self):
  124. """
  125. Performs full authentication with the block engine
  126. """
  127. credentials = ssl_channel_credentials()
  128. channel = secure_channel(self._url, credentials)
  129. auth_client = AuthServiceStub(channel)
  130. challenge = auth_client.GenerateAuthChallenge(
  131. GenerateAuthChallengeRequest(
  132. role=Role.SEARCHER, pubkey=bytes(self._kp.pubkey())
  133. )
  134. ).challenge
  135. challenge_to_sign = f"{str(self._kp.pubkey())}-{challenge}"
  136. signed = self._kp.sign_message(bytes(challenge_to_sign, "utf8"))
  137. auth_tokens_response: GenerateAuthTokensResponse = (
  138. auth_client.GenerateAuthTokens(
  139. GenerateAuthTokensRequest(
  140. challenge=challenge_to_sign,
  141. client_pubkey=bytes(self._kp.pubkey()),
  142. signed_challenge=bytes(signed),
  143. )
  144. )
  145. )
  146. self._access_token = JwtToken(
  147. token=auth_tokens_response.access_token.value,
  148. expiration=auth_tokens_response.access_token.expires_at_utc.seconds,
  149. )
  150. self._refresh_token = JwtToken(
  151. token=auth_tokens_response.refresh_token.value,
  152. expiration=auth_tokens_response.refresh_token.expires_at_utc.seconds,
  153. )
  154. def get_searcher_client(url: str, kp: Keypair) -> SearcherServiceStub:
  155. """
  156. Returns a Searcher Service client that intercepts requests and authenticates with the block engine.
  157. :param url: url of the block engine without http/https
  158. :param kp: keypair of the block engine
  159. :return: SearcherServiceStub which handles authentication on requests
  160. """
  161. # Authenticate immediately
  162. searcher_interceptor = SearcherInterceptor(url, kp)
  163. searcher_interceptor.authenticate_if_needed()
  164. credentials = ssl_channel_credentials()
  165. channel = secure_channel(url, credentials)
  166. intercepted_channel = intercept_channel(channel, searcher_interceptor)
  167. return SearcherServiceStub(intercepted_channel)