access_token.rs 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. pub use goauth::scopes::Scope;
  2. /// A module for managing a Google API access token
  3. use {
  4. crate::CredentialType,
  5. goauth::{
  6. auth::{JwtClaims, Token},
  7. credentials::Credentials,
  8. },
  9. log::*,
  10. smpl_jwt::Jwt,
  11. std::{
  12. str::FromStr,
  13. sync::{
  14. atomic::{AtomicBool, Ordering},
  15. {Arc, RwLock},
  16. },
  17. time::Instant,
  18. },
  19. tokio::time,
  20. };
  21. fn load_credentials(filepath: Option<String>) -> Result<Credentials, String> {
  22. let path = match filepath {
  23. Some(f) => f,
  24. None => std::env::var("GOOGLE_APPLICATION_CREDENTIALS").map_err(|_| {
  25. "GOOGLE_APPLICATION_CREDENTIALS environment variable not found".to_string()
  26. })?,
  27. };
  28. Credentials::from_file(&path)
  29. .map_err(|err| format!("Failed to read GCP credentials from {path}: {err}"))
  30. }
  31. fn load_stringified_credentials(credential: String) -> Result<Credentials, String> {
  32. Credentials::from_str(&credential).map_err(|err| format!("{err}"))
  33. }
  34. pub struct AccessTokenInner {
  35. credentials: Credentials,
  36. scope: Scope,
  37. token: RwLock<(Token, Instant)>,
  38. refresh_active: AtomicBool,
  39. }
  40. #[derive(Clone)]
  41. pub struct AccessToken {
  42. inner: Arc<AccessTokenInner>,
  43. }
  44. impl std::ops::Deref for AccessToken {
  45. type Target = AccessTokenInner;
  46. fn deref(&self) -> &Self::Target {
  47. &self.inner
  48. }
  49. }
  50. impl AccessToken {
  51. pub async fn new(scope: Scope, credential_type: CredentialType) -> Result<Self, String> {
  52. let credentials = match credential_type {
  53. CredentialType::Filepath(fp) => load_credentials(fp)?,
  54. CredentialType::Stringified(s) => load_stringified_credentials(s)?,
  55. };
  56. if let Err(err) = credentials.rsa_key() {
  57. Err(format!("Invalid rsa key: {err}"))
  58. } else {
  59. let token = RwLock::new(Self::get_token(&credentials, &scope).await?);
  60. let access_token = Self {
  61. inner: Arc::new(AccessTokenInner {
  62. credentials,
  63. scope,
  64. token,
  65. refresh_active: AtomicBool::new(false),
  66. }),
  67. };
  68. Ok(access_token)
  69. }
  70. }
  71. /// The project that this token grants access to
  72. pub fn project(&self) -> String {
  73. self.credentials.project()
  74. }
  75. async fn get_token(
  76. credentials: &Credentials,
  77. scope: &Scope,
  78. ) -> Result<(Token, Instant), String> {
  79. info!("Requesting token for {scope:?} scope");
  80. let claims = JwtClaims::new(
  81. credentials.iss(),
  82. scope,
  83. credentials.token_uri(),
  84. None,
  85. None,
  86. );
  87. let jwt = Jwt::new(claims, credentials.rsa_key().unwrap(), None);
  88. let token = goauth::get_token(&jwt, credentials)
  89. .await
  90. .map_err(|err| format!("Failed to refresh access token: {err}"))?;
  91. info!("Token expires in {} seconds", token.expires_in());
  92. Ok((token, Instant::now()))
  93. }
  94. /// Call this function regularly to ensure the access token does not expire
  95. pub fn refresh(&self) {
  96. // Check if it's time to try a token refresh
  97. let token_r = self.token.read().unwrap();
  98. if token_r.1.elapsed().as_secs() < token_r.0.expires_in() as u64 / 2 {
  99. debug!("Token is not expired yet");
  100. return;
  101. }
  102. drop(token_r);
  103. // Refresh already is progress
  104. let refresh_progress =
  105. self.refresh_active
  106. .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed);
  107. if refresh_progress.is_err() {
  108. debug!("Token update is already in progress");
  109. return;
  110. }
  111. let this = self.clone();
  112. tokio::spawn(async move {
  113. match time::timeout(
  114. time::Duration::from_secs(5),
  115. Self::get_token(&this.credentials, &this.scope),
  116. )
  117. .await
  118. {
  119. Ok(new_token) => match new_token {
  120. Ok(new_token) => {
  121. let mut token_w = this.token.write().unwrap();
  122. *token_w = new_token;
  123. }
  124. Err(err) => error!("Failed to fetch new token: {err}"),
  125. },
  126. Err(_timeout) => {
  127. warn!("Token refresh timeout")
  128. }
  129. }
  130. this.refresh_active.store(false, Ordering::Relaxed);
  131. info!("Token refreshed");
  132. });
  133. }
  134. /// Return an access token suitable for use in an HTTP authorization header
  135. pub fn get(&self) -> String {
  136. let token_r = self.token.read().unwrap();
  137. format!("{} {}", token_r.0.token_type(), token_r.0.access_token())
  138. }
  139. }