futurehouse-client 0.3.18.dev186__tar.gz → 0.3.18.dev195__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/PKG-INFO +1 -1
  2. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/clients/rest_client.py +39 -148
  3. futurehouse_client-0.3.18.dev195/futurehouse_client/utils/auth.py +89 -0
  4. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/PKG-INFO +1 -1
  5. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/tests/test_rest.py +26 -84
  6. futurehouse_client-0.3.18.dev186/futurehouse_client/utils/auth.py +0 -107
  7. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/LICENSE +0 -0
  8. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/README.md +0 -0
  9. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/docs/__init__.py +0 -0
  10. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/docs/client_notebook.ipynb +0 -0
  11. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/__init__.py +0 -0
  12. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/clients/__init__.py +0 -0
  13. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/clients/job_client.py +0 -0
  14. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/__init__.py +0 -0
  15. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/app.py +0 -0
  16. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/client.py +0 -0
  17. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/rest.py +0 -0
  18. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/__init__.py +0 -0
  19. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/general.py +0 -0
  20. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/module_utils.py +0 -0
  21. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/monitoring.py +0 -0
  22. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/SOURCES.txt +0 -0
  23. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/dependency_links.txt +0 -0
  24. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/requires.txt +0 -0
  25. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/top_level.txt +0 -0
  26. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/pyproject.toml +0 -0
  27. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/setup.cfg +0 -0
  28. {futurehouse_client-0.3.18.dev186 → futurehouse_client-0.3.18.dev195}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.18.dev186
3
+ Version: 0.3.18.dev195
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  Classifier: Operating System :: OS Independent
@@ -15,7 +15,7 @@ import uuid
15
15
  from collections.abc import Collection
16
16
  from pathlib import Path
17
17
  from types import ModuleType
18
- from typing import Any, ClassVar, assert_never, cast
18
+ from typing import Any, ClassVar, cast
19
19
  from uuid import UUID
20
20
 
21
21
  import cloudpickle
@@ -45,7 +45,6 @@ from tqdm.asyncio import tqdm
45
45
 
46
46
  from futurehouse_client.clients import JobNames
47
47
  from futurehouse_client.models.app import (
48
- APIKeyPayload,
49
48
  AuthType,
50
49
  JobDeploymentConfig,
51
50
  PQATaskResponse,
@@ -55,11 +54,7 @@ from futurehouse_client.models.app import (
55
54
  TaskResponseVerbose,
56
55
  )
57
56
  from futurehouse_client.models.rest import ExecutionStatus
58
- from futurehouse_client.utils.auth import (
59
- AUTH_ERRORS_TO_RETRY_ON,
60
- AuthError,
61
- refresh_token_on_auth_error,
62
- )
57
+ from futurehouse_client.utils.auth import RefreshingJWT
63
58
  from futurehouse_client.utils.general import gather_with_concurrency
64
59
  from futurehouse_client.utils.module_utils import (
65
60
  OrganizationSelector,
@@ -128,8 +123,6 @@ retry_if_connection_error = retry_if_exception_type((
128
123
  FileUploadError,
129
124
  ))
130
125
 
131
- # 5 minute default for JWTs
132
- JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
133
126
  DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
134
127
 
135
128
 
@@ -163,69 +156,85 @@ class RestClient:
163
156
  self.api_key = api_key
164
157
  self._clients: dict[str, Client | AsyncClient] = {}
165
158
  self.headers = headers or {}
166
- self.auth_jwt = self._run_auth(jwt=jwt)
159
+ self.jwt = jwt
167
160
  self.organizations: list[str] = self._filter_orgs(organization)
168
161
 
169
162
  @property
170
163
  def client(self) -> Client:
171
- """Lazily initialized and cached HTTP client with authentication."""
172
- return cast(Client, self.get_client("application/json", with_auth=True))
164
+ """Authenticated HTTP client for regular API calls."""
165
+ return cast(Client, self.get_client("application/json", authenticated=True))
173
166
 
174
167
  @property
175
168
  def async_client(self) -> AsyncClient:
176
- """Lazily initialized and cached HTTP client with authentication."""
169
+ """Authenticated async HTTP client for regular API calls."""
177
170
  return cast(
178
171
  AsyncClient,
179
- self.get_client("application/json", with_auth=True, with_async=True),
172
+ self.get_client("application/json", authenticated=True, async_client=True),
180
173
  )
181
174
 
182
175
  @property
183
- def auth_client(self) -> Client:
184
- """Lazily initialized and cached HTTP client without authentication."""
185
- return cast(Client, self.get_client("application/json", with_auth=False))
176
+ def unauthenticated_client(self) -> Client:
177
+ """Unauthenticated HTTP client for auth operations to avoid recursion."""
178
+ return cast(Client, self.get_client("application/json", authenticated=False))
186
179
 
187
180
  @property
188
181
  def multipart_client(self) -> Client:
189
- """Lazily initialized and cached HTTP client for multipart uploads."""
190
- return cast(Client, self.get_client(None, with_auth=True))
182
+ """Authenticated HTTP client for multipart uploads."""
183
+ return cast(Client, self.get_client(None, authenticated=True))
191
184
 
192
185
  def get_client(
193
186
  self,
194
187
  content_type: str | None = "application/json",
195
- with_auth: bool = True,
196
- with_async: bool = False,
188
+ authenticated: bool = True,
189
+ async_client: bool = False,
197
190
  ) -> Client | AsyncClient:
198
191
  """Return a cached HTTP client or create one if needed.
199
192
 
200
193
  Args:
201
194
  content_type: The desired content type header. Use None for multipart uploads.
202
- with_auth: Whether the client should include an Authorization header.
203
- with_async: Whether to use an async client.
195
+ authenticated: Whether the client should include authentication.
196
+ async_client: Whether to use an async client.
204
197
 
205
198
  Returns:
206
199
  An HTTP client configured with the appropriate headers.
207
200
  """
208
- # Create a composite key based on content type and auth flag.
209
- key = f"{content_type or 'multipart'}_{with_auth}_{with_async}"
201
+ # Create a composite key based on content type and auth flag
202
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
203
+
210
204
  if key not in self._clients:
211
205
  headers = copy.deepcopy(self.headers)
212
- if with_auth:
213
- headers["Authorization"] = f"Bearer {self.auth_jwt}"
206
+ auth = None
207
+
208
+ if authenticated:
209
+ auth = RefreshingJWT(
210
+ # authenticated=False will always return a synchronous client
211
+ auth_client=cast(
212
+ Client, self.get_client("application/json", authenticated=False)
213
+ ),
214
+ auth_type=self.auth_type,
215
+ api_key=self.api_key,
216
+ jwt=self.jwt,
217
+ )
218
+
214
219
  if content_type:
215
220
  headers["Content-Type"] = content_type
221
+
216
222
  self._clients[key] = (
217
223
  AsyncClient(
218
224
  base_url=self.base_url,
219
225
  headers=headers,
220
226
  timeout=self.REQUEST_TIMEOUT,
227
+ auth=auth,
221
228
  )
222
- if with_async
229
+ if async_client
223
230
  else Client(
224
231
  base_url=self.base_url,
225
232
  headers=headers,
226
233
  timeout=self.REQUEST_TIMEOUT,
234
+ auth=auth,
227
235
  )
228
236
  )
237
+
229
238
  return self._clients[key]
230
239
 
231
240
  def close(self):
@@ -255,32 +264,6 @@ class RestClient:
255
264
  raise ValueError(f"Organization '{organization}' not found.")
256
265
  return filtered_orgs
257
266
 
258
- def _run_auth(self, jwt: str | None = None) -> str:
259
- auth_payload: APIKeyPayload | None
260
- if self.auth_type == AuthType.API_KEY:
261
- auth_payload = APIKeyPayload(api_key=self.api_key)
262
- elif self.auth_type == AuthType.JWT:
263
- auth_payload = None
264
- else:
265
- assert_never(self.auth_type)
266
- try:
267
- # Use the unauthenticated client for login
268
- if auth_payload:
269
- response = self.auth_client.post(
270
- "/auth/login", json=auth_payload.model_dump()
271
- )
272
- response.raise_for_status()
273
- token_data = response.json()
274
- elif jwt:
275
- token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
276
- else:
277
- raise ValueError("JWT token required for JWT authentication.")
278
-
279
- return token_data["access_token"]
280
- except Exception as e:
281
- raise RestClientError(f"Error authenticating: {e!s}") from e
282
-
283
- @refresh_token_on_auth_error()
284
267
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
285
268
  try:
286
269
  response = self.client.get(
@@ -288,19 +271,9 @@ class RestClient:
288
271
  )
289
272
  response.raise_for_status()
290
273
  return response.json()
291
- except HTTPStatusError as e:
292
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
293
- raise AuthError(
294
- e.response.status_code,
295
- f"Authentication failed: {e}",
296
- request=e.request,
297
- response=e.response,
298
- ) from e
299
- raise
300
274
  except Exception as e:
301
275
  raise JobFetchError(f"Error checking job: {e!s}") from e
302
276
 
303
- @refresh_token_on_auth_error()
304
277
  def _fetch_my_orgs(self) -> list[str]:
305
278
  response = self.client.get(f"/v0.1/organizations?filter={True}")
306
279
  response.raise_for_status()
@@ -358,7 +331,6 @@ class RestClient:
358
331
  if not files:
359
332
  raise TaskFetchError(f"No files found in {path}")
360
333
 
361
- @refresh_token_on_auth_error()
362
334
  @retry(
363
335
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
364
336
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -400,19 +372,9 @@ class RestClient:
400
372
  ):
401
373
  return PQATaskResponse(**data)
402
374
  return TaskResponse(**data)
403
- except HTTPStatusError as e:
404
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
405
- raise AuthError(
406
- e.response.status_code,
407
- f"Authentication failed: {e}",
408
- request=e.request,
409
- response=e.response,
410
- ) from e
411
- raise
412
375
  except Exception as e:
413
376
  raise TaskFetchError(f"Error getting task: {e!s}") from e
414
377
 
415
- @refresh_token_on_auth_error()
416
378
  @retry(
417
379
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
418
380
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -456,19 +418,9 @@ class RestClient:
456
418
  ):
457
419
  return PQATaskResponse(**data)
458
420
  return TaskResponse(**data)
459
- except HTTPStatusError as e:
460
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
461
- raise AuthError(
462
- e.response.status_code,
463
- f"Authentication failed: {e}",
464
- request=e.request,
465
- response=e.response,
466
- ) from e
467
- raise
468
421
  except Exception as e:
469
422
  raise TaskFetchError(f"Error getting task: {e!s}") from e
470
423
 
471
- @refresh_token_on_auth_error()
472
424
  @retry(
473
425
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
474
426
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -492,20 +444,10 @@ class RestClient:
492
444
  response.raise_for_status()
493
445
  trajectory_id = response.json()["trajectory_id"]
494
446
  self.trajectory_id = trajectory_id
495
- except HTTPStatusError as e:
496
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
497
- raise AuthError(
498
- e.response.status_code,
499
- f"Authentication failed: {e}",
500
- request=e.request,
501
- response=e.response,
502
- ) from e
503
- raise
504
447
  except Exception as e:
505
448
  raise TaskFetchError(f"Error creating task: {e!s}") from e
506
449
  return trajectory_id
507
450
 
508
- @refresh_token_on_auth_error()
509
451
  @retry(
510
452
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
511
453
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -529,15 +471,6 @@ class RestClient:
529
471
  response.raise_for_status()
530
472
  trajectory_id = response.json()["trajectory_id"]
531
473
  self.trajectory_id = trajectory_id
532
- except HTTPStatusError as e:
533
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
534
- raise AuthError(
535
- e.response.status_code,
536
- f"Authentication failed: {e}",
537
- request=e.request,
538
- response=e.response,
539
- ) from e
540
- raise
541
474
  except Exception as e:
542
475
  raise TaskFetchError(f"Error creating task: {e!s}") from e
543
476
  return trajectory_id
@@ -683,7 +616,6 @@ class RestClient:
683
616
  for task_id in trajectory_ids
684
617
  ]
685
618
 
686
- @refresh_token_on_auth_error()
687
619
  @retry(
688
620
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
689
621
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -695,19 +627,11 @@ class RestClient:
695
627
  build_id = build_id or self.build_id
696
628
  response = self.client.get(f"/v0.1/builds/{build_id}")
697
629
  response.raise_for_status()
698
- except HTTPStatusError as e:
699
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
700
- raise AuthError(
701
- e.response.status_code,
702
- f"Authentication failed: {e}",
703
- request=e.request,
704
- response=e.response,
705
- ) from e
706
- raise
630
+ except Exception as e:
631
+ raise JobFetchError(f"Error getting build status: {e!s}") from e
707
632
  return response.json()
708
633
 
709
634
  # TODO: Refactor later so we don't have to ignore PLR0915
710
- @refresh_token_on_auth_error()
711
635
  @retry(
712
636
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
713
637
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -887,13 +811,6 @@ class RestClient:
887
811
  build_context = response.json()
888
812
  self.build_id = build_context["build_id"]
889
813
  except HTTPStatusError as e:
890
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
891
- raise AuthError(
892
- e.response.status_code,
893
- f"Authentication failed: {e}",
894
- request=e.request,
895
- response=e.response,
896
- ) from e
897
814
  error_detail = response.json()
898
815
  error_message = error_detail.get("detail", str(e))
899
816
  raise JobCreationError(
@@ -974,7 +891,6 @@ class RestClient:
974
891
  except Exception as e:
975
892
  raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
976
893
 
977
- @refresh_token_on_auth_error()
978
894
  def _upload_single_file(
979
895
  self,
980
896
  job_name: str,
@@ -1048,20 +964,10 @@ class RestClient:
1048
964
  )
1049
965
 
1050
966
  logger.info(f"Successfully uploaded {file_name}")
1051
- except HTTPStatusError as e:
1052
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1053
- raise AuthError(
1054
- e.response.status_code,
1055
- f"Authentication failed: {e}",
1056
- request=e.request,
1057
- response=e.response,
1058
- ) from e
1059
- raise
1060
967
  except Exception as e:
1061
968
  logger.exception(f"Error uploading file {file_path}")
1062
969
  raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
1063
970
 
1064
- @refresh_token_on_auth_error()
1065
971
  @retry(
1066
972
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1067
973
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1098,13 +1004,6 @@ class RestClient:
1098
1004
  response.raise_for_status()
1099
1005
  return response.json()
1100
1006
  except HTTPStatusError as e:
1101
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1102
- raise AuthError(
1103
- e.response.status_code,
1104
- f"Authentication failed: {e}",
1105
- request=e.request,
1106
- response=e.response,
1107
- ) from e
1108
1007
  logger.exception(
1109
1008
  f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
1110
1009
  )
@@ -1117,7 +1016,6 @@ class RestClient:
1117
1016
  )
1118
1017
  raise RestClientError(f"Error listing files: {e!s}") from e
1119
1018
 
1120
- @refresh_token_on_auth_error()
1121
1019
  @retry(
1122
1020
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1123
1021
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1165,13 +1063,6 @@ class RestClient:
1165
1063
 
1166
1064
  logger.info(f"File {file_path} downloaded to {destination_path}")
1167
1065
  except HTTPStatusError as e:
1168
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1169
- raise AuthError(
1170
- e.response.status_code,
1171
- f"Authentication failed: {e}",
1172
- request=e.request,
1173
- response=e.response,
1174
- ) from e
1175
1066
  logger.exception(
1176
1067
  f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
1177
1068
  )
@@ -0,0 +1,89 @@
1
+ import logging
2
+ from typing import ClassVar, Final
3
+
4
+ import httpx
5
+
6
+ from futurehouse_client.models.app import APIKeyPayload, AuthType
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ INVALID_REFRESH_TYPE_MSG: Final[str] = (
11
+ "API key auth is required to refresh auth tokens."
12
+ )
13
+ JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
14
+
15
+
16
+ def _run_auth(
17
+ client: httpx.Client,
18
+ auth_type: AuthType = AuthType.API_KEY,
19
+ api_key: str | None = None,
20
+ jwt: str | None = None,
21
+ ) -> str:
22
+ auth_payload: APIKeyPayload | None
23
+ if auth_type == AuthType.API_KEY:
24
+ auth_payload = APIKeyPayload(api_key=api_key)
25
+ elif auth_type == AuthType.JWT:
26
+ auth_payload = None
27
+ try:
28
+ if auth_payload:
29
+ response = client.post("/auth/login", json=auth_payload.model_dump())
30
+ response.raise_for_status()
31
+ token_data = response.json()
32
+ elif jwt:
33
+ token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
34
+ else:
35
+ raise ValueError("JWT token required for JWT authentication.")
36
+
37
+ return token_data["access_token"]
38
+ except Exception as e:
39
+ raise Exception("Failed to authenticate") from e # noqa: TRY002
40
+
41
+
42
+ class RefreshingJWT(httpx.Auth):
43
+ """Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
44
+
45
+ RETRY_STATUSES: ClassVar[set[int]] = {
46
+ httpx.codes.UNAUTHORIZED,
47
+ httpx.codes.FORBIDDEN,
48
+ }
49
+
50
+ def __init__(
51
+ self,
52
+ auth_client: httpx.Client,
53
+ auth_type: AuthType = AuthType.API_KEY,
54
+ api_key: str | None = None,
55
+ jwt: str | None = None,
56
+ ):
57
+ self.auth_type = auth_type
58
+ self.auth_client = auth_client
59
+ self.api_key = api_key
60
+ self._jwt = _run_auth(
61
+ client=auth_client,
62
+ jwt=jwt,
63
+ auth_type=auth_type,
64
+ api_key=api_key,
65
+ )
66
+
67
+ def refresh_token(self):
68
+ if self.auth_type == AuthType.JWT:
69
+ logger.error(INVALID_REFRESH_TYPE_MSG)
70
+ raise ValueError(INVALID_REFRESH_TYPE_MSG)
71
+ self._jwt = _run_auth(
72
+ client=self.auth_client,
73
+ auth_type=self.auth_type,
74
+ api_key=self.api_key,
75
+ )
76
+
77
+ def auth_flow(self, request):
78
+ request.headers["Authorization"] = f"Bearer {self._jwt}"
79
+ response = yield request
80
+
81
+ # If it failed, refresh once and replay the request
82
+ if response.status_code in self.RETRY_STATUSES:
83
+ logger.info(
84
+ "Received %s, refreshing token and retrying …",
85
+ response.status_code,
86
+ )
87
+ self.refresh_token()
88
+ request.headers["Authorization"] = f"Bearer {self._jwt}"
89
+ yield request # second (and final) attempt, again or use a while loop
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.18.dev186
3
+ Version: 0.3.18.dev195
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  Classifier: Operating System :: OS Independent
@@ -2,7 +2,6 @@
2
2
  import asyncio
3
3
  import os
4
4
  import time
5
- from unittest.mock import patch
6
5
 
7
6
  import pytest
8
7
  from futurehouse_client.clients import (
@@ -10,10 +9,9 @@ from futurehouse_client.clients import (
10
9
  PQATaskResponse,
11
10
  TaskResponseVerbose,
12
11
  )
13
- from futurehouse_client.clients.rest_client import RestClient
12
+ from futurehouse_client.clients.rest_client import RestClient, TaskFetchError
14
13
  from futurehouse_client.models.app import Stage, TaskRequest
15
14
  from futurehouse_client.models.rest import ExecutionStatus
16
- from futurehouse_client.utils.auth import AuthError, refresh_token_on_auth_error
17
15
  from pytest_subtests import SubTests
18
16
 
19
17
  ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
@@ -40,7 +38,7 @@ def pub_client():
40
38
 
41
39
 
42
40
  @pytest.fixture
43
- def task_data():
41
+ def task_req():
44
42
  """Create a sample task request."""
45
43
  return TaskRequest(
46
44
  name=JobNames.from_string("dummy"),
@@ -49,7 +47,7 @@ def task_data():
49
47
 
50
48
 
51
49
  @pytest.fixture
52
- def pqa_task_data():
50
+ def pqa_task_req():
53
51
  return TaskRequest(
54
52
  name=JobNames.from_string("crow"),
55
53
  query="How many moons does earth have?",
@@ -58,40 +56,40 @@ def pqa_task_data():
58
56
 
59
57
  @pytest.mark.timeout(300)
60
58
  @pytest.mark.flaky(reruns=3)
61
- def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_data: TaskRequest):
62
- admin_client.create_task(task_data)
59
+ def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_req: TaskRequest):
60
+ admin_client.create_task(task_req)
63
61
  while (task_status := admin_client.get_task().status) in {"queued", "in progress"}:
64
62
  time.sleep(5)
65
63
  assert task_status == "success"
66
64
 
67
65
 
68
66
  def test_insufficient_permissions_request(
69
- pub_client: RestClient, task_data: TaskRequest
67
+ pub_client: RestClient, task_req: TaskRequest
70
68
  ):
71
69
  # Create a new instance so that cached credentials aren't reused
72
- with pytest.raises(AuthError) as exc_info:
73
- pub_client.create_task(task_data)
70
+ with pytest.raises(TaskFetchError) as exc_info:
71
+ pub_client.create_task(task_req)
74
72
 
75
- assert "403 Forbidden" in str(exc_info.value)
73
+ assert "Error creating task" in str(exc_info.value)
76
74
 
77
75
 
78
76
  @pytest.mark.timeout(300)
79
77
  @pytest.mark.asyncio
80
78
  async def test_job_response( # noqa: PLR0915
81
- subtests: SubTests, admin_client: RestClient, pqa_task_data: TaskRequest
79
+ subtests: SubTests, admin_client: RestClient, pqa_task_req: TaskRequest
82
80
  ):
83
- task_id = admin_client.create_task(pqa_task_data)
84
- atask_id = await admin_client.acreate_task(pqa_task_data)
81
+ task_id = admin_client.create_task(pqa_task_req)
82
+ atask_id = await admin_client.acreate_task(pqa_task_req)
85
83
 
86
84
  with subtests.test("Test TaskResponse with queued task"):
87
85
  task_response = admin_client.get_task(task_id)
88
86
  assert task_response.status in {"queued", "in progress"}
89
- assert task_response.job_name == pqa_task_data.name
90
- assert task_response.query == pqa_task_data.query
87
+ assert task_response.job_name == pqa_task_req.name
88
+ assert task_response.query == pqa_task_req.query
91
89
  task_response = await admin_client.aget_task(atask_id)
92
90
  assert task_response.status in {"queued", "in progress"}
93
- assert task_response.job_name == pqa_task_data.name
94
- assert task_response.query == pqa_task_data.query
91
+ assert task_response.job_name == pqa_task_req.name
92
+ assert task_response.query == pqa_task_req.query
95
93
 
96
94
  for _ in range(TEST_MAX_POLLS):
97
95
  task_response = admin_client.get_task(task_id)
@@ -111,8 +109,8 @@ async def test_job_response( # noqa: PLR0915
111
109
  # assert it has general fields
112
110
  assert task_response.status == "success"
113
111
  assert task_response.task_id is not None
114
- assert pqa_task_data.name in task_response.job_name
115
- assert pqa_task_data.query in task_response.query
112
+ assert pqa_task_req.name in task_response.job_name
113
+ assert pqa_task_req.query in task_response.query
116
114
  # assert it has PQA specific fields
117
115
  assert task_response.answer is not None
118
116
  # assert it's not verbose
@@ -125,8 +123,8 @@ async def test_job_response( # noqa: PLR0915
125
123
  # assert it has general fields
126
124
  assert task_response.status == "success"
127
125
  assert task_response.task_id is not None
128
- assert pqa_task_data.name in task_response.job_name
129
- assert pqa_task_data.query in task_response.query
126
+ assert pqa_task_req.name in task_response.job_name
127
+ assert pqa_task_req.query in task_response.query
130
128
  # assert it has PQA specific fields
131
129
  assert task_response.answer is not None
132
130
  # assert it's not verbose
@@ -151,9 +149,9 @@ async def test_job_response( # noqa: PLR0915
151
149
  @pytest.mark.timeout(300)
152
150
  @pytest.mark.flaky(reruns=3)
153
151
  def test_run_until_done_futurehouse_dummy_env_crow(
154
- admin_client: RestClient, task_data: TaskRequest
152
+ admin_client: RestClient, task_req: TaskRequest
155
153
  ):
156
- tasks_to_do = [task_data, task_data]
154
+ tasks_to_do = [task_req, task_req]
157
155
 
158
156
  results = admin_client.run_tasks_until_done(tasks_to_do)
159
157
 
@@ -165,9 +163,9 @@ def test_run_until_done_futurehouse_dummy_env_crow(
165
163
  @pytest.mark.flaky(reruns=3)
166
164
  @pytest.mark.asyncio
167
165
  async def test_arun_until_done_futurehouse_dummy_env_crow(
168
- admin_client: RestClient, task_data: TaskRequest
166
+ admin_client: RestClient, task_req: TaskRequest
169
167
  ):
170
- tasks_to_do = [task_data, task_data]
168
+ tasks_to_do = [task_req, task_req]
171
169
 
172
170
  results = await admin_client.arun_tasks_until_done(tasks_to_do)
173
171
 
@@ -179,9 +177,9 @@ async def test_arun_until_done_futurehouse_dummy_env_crow(
179
177
  @pytest.mark.flaky(reruns=3)
180
178
  @pytest.mark.asyncio
181
179
  async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
182
- admin_client: RestClient, task_data: TaskRequest
180
+ admin_client: RestClient, task_req: TaskRequest
183
181
  ):
184
- tasks_to_do = [task_data, task_data]
182
+ tasks_to_do = [task_req, task_req]
185
183
 
186
184
  results = await admin_client.arun_tasks_until_done(
187
185
  tasks_to_do, verbose=True, timeout=5, progress_bar=True
@@ -202,59 +200,3 @@ async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
202
200
  assert all(not isinstance(task, PQATaskResponse) for task in results), (
203
201
  "Should be verbose."
204
202
  )
205
-
206
-
207
- def test_auth_refresh_flow(admin_client: RestClient):
208
- refresh_calls = 0
209
- func_calls = 0
210
-
211
- def mock_run_auth(*args, **kwargs):
212
- nonlocal refresh_calls
213
- refresh_calls += 1
214
- return f"fresh-token-{refresh_calls}"
215
-
216
- @refresh_token_on_auth_error()
217
- def test_func(self, *args):
218
- nonlocal func_calls
219
- func_calls += 1
220
-
221
- if func_calls == 1:
222
- raise AuthError(401, "Auth failed", None, None)
223
- return "success"
224
-
225
- with patch.object(admin_client, "_run_auth", mock_run_auth):
226
- result = test_func(admin_client)
227
-
228
- assert result == "success"
229
- assert func_calls == 2, "Function should be called twice"
230
- assert refresh_calls == 1, "Auth should be refreshed once"
231
- assert admin_client.auth_jwt == "fresh-token-1"
232
-
233
-
234
- @pytest.mark.asyncio
235
- async def test_async_auth_refresh_flow(admin_client: RestClient):
236
- refresh_calls = 0
237
- func_calls = 0
238
-
239
- def mock_run_auth(*args, **kwargs):
240
- nonlocal refresh_calls
241
- refresh_calls += 1
242
- return f"fresh-token-{refresh_calls}"
243
-
244
- @refresh_token_on_auth_error()
245
- async def test_async_func(self, *args):
246
- nonlocal func_calls
247
- func_calls += 1
248
-
249
- if func_calls == 1:
250
- raise AuthError(401, "Auth failed", None, None)
251
- await asyncio.sleep(1)
252
- return "success"
253
-
254
- with patch.object(admin_client, "_run_auth", mock_run_auth):
255
- result = await test_async_func(admin_client)
256
-
257
- assert result == "success"
258
- assert func_calls == 2, "Function should be called twice"
259
- assert refresh_calls == 1, "Auth should be refreshed once"
260
- assert admin_client.auth_jwt == "fresh-token-1"
@@ -1,107 +0,0 @@
1
- import asyncio
2
- import logging
3
- from collections.abc import Callable, Coroutine
4
- from functools import wraps
5
- from typing import Any, Final, Optional, ParamSpec, TypeVar, overload
6
-
7
- import httpx
8
- from httpx import HTTPStatusError
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- T = TypeVar("T")
13
- P = ParamSpec("P")
14
-
15
- AUTH_ERRORS_TO_RETRY_ON: Final[set[int]] = {
16
- httpx.codes.UNAUTHORIZED,
17
- httpx.codes.FORBIDDEN,
18
- }
19
-
20
-
21
- class AuthError(Exception):
22
- """Raised when authentication fails with 401/403 status."""
23
-
24
- def __init__(self, status_code: int, message: str, request=None, response=None):
25
- self.status_code = status_code
26
- self.request = request
27
- self.response = response
28
- super().__init__(message)
29
-
30
-
31
- def is_auth_error(e: Exception) -> bool:
32
- if isinstance(e, AuthError):
33
- return True
34
- if isinstance(e, HTTPStatusError):
35
- return e.response.status_code in AUTH_ERRORS_TO_RETRY_ON
36
- return False
37
-
38
-
39
- def get_status_code(e: Exception) -> Optional[int]:
40
- if isinstance(e, AuthError):
41
- return e.status_code
42
- if isinstance(e, HTTPStatusError):
43
- return e.response.status_code
44
- return None
45
-
46
-
47
- @overload
48
- def refresh_token_on_auth_error(
49
- func: Callable[P, Coroutine[Any, Any, T]],
50
- ) -> Callable[P, Coroutine[Any, Any, T]]: ...
51
-
52
-
53
- @overload
54
- def refresh_token_on_auth_error(
55
- func: None = None, *, max_retries: int = ...
56
- ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
57
-
58
-
59
- def refresh_token_on_auth_error(func=None, max_retries=1):
60
- """Decorator that refreshes JWT token on 401/403 auth errors."""
61
-
62
- def decorator(fn):
63
- @wraps(fn)
64
- def sync_wrapper(self, *args, **kwargs):
65
- retries = 0
66
- while True:
67
- try:
68
- return fn(self, *args, **kwargs)
69
- except Exception as e:
70
- if is_auth_error(e) and retries < max_retries:
71
- retries += 1
72
- status = get_status_code(e) or "Unknown"
73
- logger.info(
74
- f"Received auth error {status}, "
75
- f"refreshing token and retrying (attempt {retries}/{max_retries})..."
76
- )
77
- self.auth_jwt = self._run_auth()
78
- self._clients = {}
79
- continue
80
- raise
81
-
82
- @wraps(fn)
83
- async def async_wrapper(self, *args, **kwargs):
84
- retries = 0
85
- while True:
86
- try:
87
- return await fn(self, *args, **kwargs)
88
- except Exception as e:
89
- if is_auth_error(e) and retries < max_retries:
90
- retries += 1
91
- status = get_status_code(e) or "Unknown"
92
- logger.info(
93
- f"Received auth error {status}, "
94
- f"refreshing token and retrying (attempt {retries}/{max_retries})..."
95
- )
96
- self.auth_jwt = self._run_auth()
97
- self._clients = {}
98
- continue
99
- raise
100
-
101
- if asyncio.iscoroutinefunction(fn):
102
- return async_wrapper
103
- return sync_wrapper
104
-
105
- if callable(func):
106
- return decorator(func)
107
- return decorator