futurehouse-client 0.3.18.dev185__tar.gz → 0.3.18.dev195__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/PKG-INFO +1 -1
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/clients/rest_client.py +40 -148
- futurehouse_client-0.3.18.dev195/futurehouse_client/utils/auth.py +89 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/PKG-INFO +1 -1
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/tests/test_rest.py +26 -84
- futurehouse_client-0.3.18.dev185/futurehouse_client/utils/auth.py +0 -107
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/LICENSE +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/README.md +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/docs/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/docs/client_notebook.ipynb +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/clients/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/clients/job_client.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/app.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/client.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/models/rest.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/general.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/module_utils.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/utils/monitoring.py +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/SOURCES.txt +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/dependency_links.txt +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/requires.txt +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client.egg-info/top_level.txt +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/pyproject.toml +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/setup.cfg +0 -0
- {futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.18.
|
3
|
+
Version: 0.3.18.dev195
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -15,7 +15,7 @@ import uuid
|
|
15
15
|
from collections.abc import Collection
|
16
16
|
from pathlib import Path
|
17
17
|
from types import ModuleType
|
18
|
-
from typing import Any, ClassVar,
|
18
|
+
from typing import Any, ClassVar, cast
|
19
19
|
from uuid import UUID
|
20
20
|
|
21
21
|
import cloudpickle
|
@@ -45,7 +45,6 @@ from tqdm.asyncio import tqdm
|
|
45
45
|
|
46
46
|
from futurehouse_client.clients import JobNames
|
47
47
|
from futurehouse_client.models.app import (
|
48
|
-
APIKeyPayload,
|
49
48
|
AuthType,
|
50
49
|
JobDeploymentConfig,
|
51
50
|
PQATaskResponse,
|
@@ -55,11 +54,7 @@ from futurehouse_client.models.app import (
|
|
55
54
|
TaskResponseVerbose,
|
56
55
|
)
|
57
56
|
from futurehouse_client.models.rest import ExecutionStatus
|
58
|
-
from futurehouse_client.utils.auth import
|
59
|
-
AUTH_ERRORS_TO_RETRY_ON,
|
60
|
-
AuthError,
|
61
|
-
refresh_token_on_auth_error,
|
62
|
-
)
|
57
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
63
58
|
from futurehouse_client.utils.general import gather_with_concurrency
|
64
59
|
from futurehouse_client.utils.module_utils import (
|
65
60
|
OrganizationSelector,
|
@@ -128,8 +123,6 @@ retry_if_connection_error = retry_if_exception_type((
|
|
128
123
|
FileUploadError,
|
129
124
|
))
|
130
125
|
|
131
|
-
# 5 minute default for JWTs
|
132
|
-
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
133
126
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
134
127
|
|
135
128
|
|
@@ -163,69 +156,85 @@ class RestClient:
|
|
163
156
|
self.api_key = api_key
|
164
157
|
self._clients: dict[str, Client | AsyncClient] = {}
|
165
158
|
self.headers = headers or {}
|
166
|
-
self.
|
159
|
+
self.jwt = jwt
|
167
160
|
self.organizations: list[str] = self._filter_orgs(organization)
|
168
161
|
|
169
162
|
@property
|
170
163
|
def client(self) -> Client:
|
171
|
-
"""
|
172
|
-
return cast(Client, self.get_client("application/json",
|
164
|
+
"""Authenticated HTTP client for regular API calls."""
|
165
|
+
return cast(Client, self.get_client("application/json", authenticated=True))
|
173
166
|
|
174
167
|
@property
|
175
168
|
def async_client(self) -> AsyncClient:
|
176
|
-
"""
|
169
|
+
"""Authenticated async HTTP client for regular API calls."""
|
177
170
|
return cast(
|
178
171
|
AsyncClient,
|
179
|
-
self.get_client("application/json",
|
172
|
+
self.get_client("application/json", authenticated=True, async_client=True),
|
180
173
|
)
|
181
174
|
|
182
175
|
@property
|
183
|
-
def
|
184
|
-
"""
|
185
|
-
return cast(Client, self.get_client("application/json",
|
176
|
+
def unauthenticated_client(self) -> Client:
|
177
|
+
"""Unauthenticated HTTP client for auth operations to avoid recursion."""
|
178
|
+
return cast(Client, self.get_client("application/json", authenticated=False))
|
186
179
|
|
187
180
|
@property
|
188
181
|
def multipart_client(self) -> Client:
|
189
|
-
"""
|
190
|
-
return cast(Client, self.get_client(None,
|
182
|
+
"""Authenticated HTTP client for multipart uploads."""
|
183
|
+
return cast(Client, self.get_client(None, authenticated=True))
|
191
184
|
|
192
185
|
def get_client(
|
193
186
|
self,
|
194
187
|
content_type: str | None = "application/json",
|
195
|
-
|
196
|
-
|
188
|
+
authenticated: bool = True,
|
189
|
+
async_client: bool = False,
|
197
190
|
) -> Client | AsyncClient:
|
198
191
|
"""Return a cached HTTP client or create one if needed.
|
199
192
|
|
200
193
|
Args:
|
201
194
|
content_type: The desired content type header. Use None for multipart uploads.
|
202
|
-
|
203
|
-
|
195
|
+
authenticated: Whether the client should include authentication.
|
196
|
+
async_client: Whether to use an async client.
|
204
197
|
|
205
198
|
Returns:
|
206
199
|
An HTTP client configured with the appropriate headers.
|
207
200
|
"""
|
208
|
-
# Create a composite key based on content type and auth flag
|
209
|
-
key = f"{content_type or 'multipart'}_{
|
201
|
+
# Create a composite key based on content type and auth flag
|
202
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
203
|
+
|
210
204
|
if key not in self._clients:
|
211
205
|
headers = copy.deepcopy(self.headers)
|
212
|
-
|
213
|
-
|
206
|
+
auth = None
|
207
|
+
|
208
|
+
if authenticated:
|
209
|
+
auth = RefreshingJWT(
|
210
|
+
# authenticated=False will always return a synchronous client
|
211
|
+
auth_client=cast(
|
212
|
+
Client, self.get_client("application/json", authenticated=False)
|
213
|
+
),
|
214
|
+
auth_type=self.auth_type,
|
215
|
+
api_key=self.api_key,
|
216
|
+
jwt=self.jwt,
|
217
|
+
)
|
218
|
+
|
214
219
|
if content_type:
|
215
220
|
headers["Content-Type"] = content_type
|
221
|
+
|
216
222
|
self._clients[key] = (
|
217
223
|
AsyncClient(
|
218
224
|
base_url=self.base_url,
|
219
225
|
headers=headers,
|
220
226
|
timeout=self.REQUEST_TIMEOUT,
|
227
|
+
auth=auth,
|
221
228
|
)
|
222
|
-
if
|
229
|
+
if async_client
|
223
230
|
else Client(
|
224
231
|
base_url=self.base_url,
|
225
232
|
headers=headers,
|
226
233
|
timeout=self.REQUEST_TIMEOUT,
|
234
|
+
auth=auth,
|
227
235
|
)
|
228
236
|
)
|
237
|
+
|
229
238
|
return self._clients[key]
|
230
239
|
|
231
240
|
def close(self):
|
@@ -255,32 +264,6 @@ class RestClient:
|
|
255
264
|
raise ValueError(f"Organization '{organization}' not found.")
|
256
265
|
return filtered_orgs
|
257
266
|
|
258
|
-
def _run_auth(self, jwt: str | None = None) -> str:
|
259
|
-
auth_payload: APIKeyPayload | None
|
260
|
-
if self.auth_type == AuthType.API_KEY:
|
261
|
-
auth_payload = APIKeyPayload(api_key=self.api_key)
|
262
|
-
elif self.auth_type == AuthType.JWT:
|
263
|
-
auth_payload = None
|
264
|
-
else:
|
265
|
-
assert_never(self.auth_type)
|
266
|
-
try:
|
267
|
-
# Use the unauthenticated client for login
|
268
|
-
if auth_payload:
|
269
|
-
response = self.auth_client.post(
|
270
|
-
"/auth/login", json=auth_payload.model_dump()
|
271
|
-
)
|
272
|
-
response.raise_for_status()
|
273
|
-
token_data = response.json()
|
274
|
-
elif jwt:
|
275
|
-
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
276
|
-
else:
|
277
|
-
raise ValueError("JWT token required for JWT authentication.")
|
278
|
-
|
279
|
-
return token_data["access_token"]
|
280
|
-
except Exception as e:
|
281
|
-
raise RestClientError(f"Error authenticating: {e!s}") from e
|
282
|
-
|
283
|
-
@refresh_token_on_auth_error()
|
284
267
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
285
268
|
try:
|
286
269
|
response = self.client.get(
|
@@ -288,19 +271,9 @@ class RestClient:
|
|
288
271
|
)
|
289
272
|
response.raise_for_status()
|
290
273
|
return response.json()
|
291
|
-
except HTTPStatusError as e:
|
292
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
293
|
-
raise AuthError(
|
294
|
-
e.response.status_code,
|
295
|
-
f"Authentication failed: {e}",
|
296
|
-
request=e.request,
|
297
|
-
response=e.response,
|
298
|
-
) from e
|
299
|
-
raise
|
300
274
|
except Exception as e:
|
301
275
|
raise JobFetchError(f"Error checking job: {e!s}") from e
|
302
276
|
|
303
|
-
@refresh_token_on_auth_error()
|
304
277
|
def _fetch_my_orgs(self) -> list[str]:
|
305
278
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
306
279
|
response.raise_for_status()
|
@@ -358,7 +331,6 @@ class RestClient:
|
|
358
331
|
if not files:
|
359
332
|
raise TaskFetchError(f"No files found in {path}")
|
360
333
|
|
361
|
-
@refresh_token_on_auth_error()
|
362
334
|
@retry(
|
363
335
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
364
336
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -385,6 +357,7 @@ class RestClient:
|
|
385
357
|
),
|
386
358
|
self.client.stream("GET", url, params={"history": history}) as response,
|
387
359
|
):
|
360
|
+
response.raise_for_status()
|
388
361
|
json_data = "".join(response.iter_text(chunk_size=1024))
|
389
362
|
data = json.loads(json_data)
|
390
363
|
if "id" not in data:
|
@@ -399,19 +372,9 @@ class RestClient:
|
|
399
372
|
):
|
400
373
|
return PQATaskResponse(**data)
|
401
374
|
return TaskResponse(**data)
|
402
|
-
except HTTPStatusError as e:
|
403
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
404
|
-
raise AuthError(
|
405
|
-
e.response.status_code,
|
406
|
-
f"Authentication failed: {e}",
|
407
|
-
request=e.request,
|
408
|
-
response=e.response,
|
409
|
-
) from e
|
410
|
-
raise
|
411
375
|
except Exception as e:
|
412
376
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
413
377
|
|
414
|
-
@refresh_token_on_auth_error()
|
415
378
|
@retry(
|
416
379
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
417
380
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -455,19 +418,9 @@ class RestClient:
|
|
455
418
|
):
|
456
419
|
return PQATaskResponse(**data)
|
457
420
|
return TaskResponse(**data)
|
458
|
-
except HTTPStatusError as e:
|
459
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
460
|
-
raise AuthError(
|
461
|
-
e.response.status_code,
|
462
|
-
f"Authentication failed: {e}",
|
463
|
-
request=e.request,
|
464
|
-
response=e.response,
|
465
|
-
) from e
|
466
|
-
raise
|
467
421
|
except Exception as e:
|
468
422
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
469
423
|
|
470
|
-
@refresh_token_on_auth_error()
|
471
424
|
@retry(
|
472
425
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
473
426
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -491,20 +444,10 @@ class RestClient:
|
|
491
444
|
response.raise_for_status()
|
492
445
|
trajectory_id = response.json()["trajectory_id"]
|
493
446
|
self.trajectory_id = trajectory_id
|
494
|
-
except HTTPStatusError as e:
|
495
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
496
|
-
raise AuthError(
|
497
|
-
e.response.status_code,
|
498
|
-
f"Authentication failed: {e}",
|
499
|
-
request=e.request,
|
500
|
-
response=e.response,
|
501
|
-
) from e
|
502
|
-
raise
|
503
447
|
except Exception as e:
|
504
448
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
505
449
|
return trajectory_id
|
506
450
|
|
507
|
-
@refresh_token_on_auth_error()
|
508
451
|
@retry(
|
509
452
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
510
453
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -528,15 +471,6 @@ class RestClient:
|
|
528
471
|
response.raise_for_status()
|
529
472
|
trajectory_id = response.json()["trajectory_id"]
|
530
473
|
self.trajectory_id = trajectory_id
|
531
|
-
except HTTPStatusError as e:
|
532
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
533
|
-
raise AuthError(
|
534
|
-
e.response.status_code,
|
535
|
-
f"Authentication failed: {e}",
|
536
|
-
request=e.request,
|
537
|
-
response=e.response,
|
538
|
-
) from e
|
539
|
-
raise
|
540
474
|
except Exception as e:
|
541
475
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
542
476
|
return trajectory_id
|
@@ -682,7 +616,6 @@ class RestClient:
|
|
682
616
|
for task_id in trajectory_ids
|
683
617
|
]
|
684
618
|
|
685
|
-
@refresh_token_on_auth_error()
|
686
619
|
@retry(
|
687
620
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
688
621
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -694,19 +627,11 @@ class RestClient:
|
|
694
627
|
build_id = build_id or self.build_id
|
695
628
|
response = self.client.get(f"/v0.1/builds/{build_id}")
|
696
629
|
response.raise_for_status()
|
697
|
-
except
|
698
|
-
|
699
|
-
raise AuthError(
|
700
|
-
e.response.status_code,
|
701
|
-
f"Authentication failed: {e}",
|
702
|
-
request=e.request,
|
703
|
-
response=e.response,
|
704
|
-
) from e
|
705
|
-
raise
|
630
|
+
except Exception as e:
|
631
|
+
raise JobFetchError(f"Error getting build status: {e!s}") from e
|
706
632
|
return response.json()
|
707
633
|
|
708
634
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
709
|
-
@refresh_token_on_auth_error()
|
710
635
|
@retry(
|
711
636
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
712
637
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -886,13 +811,6 @@ class RestClient:
|
|
886
811
|
build_context = response.json()
|
887
812
|
self.build_id = build_context["build_id"]
|
888
813
|
except HTTPStatusError as e:
|
889
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
890
|
-
raise AuthError(
|
891
|
-
e.response.status_code,
|
892
|
-
f"Authentication failed: {e}",
|
893
|
-
request=e.request,
|
894
|
-
response=e.response,
|
895
|
-
) from e
|
896
814
|
error_detail = response.json()
|
897
815
|
error_message = error_detail.get("detail", str(e))
|
898
816
|
raise JobCreationError(
|
@@ -973,7 +891,6 @@ class RestClient:
|
|
973
891
|
except Exception as e:
|
974
892
|
raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
|
975
893
|
|
976
|
-
@refresh_token_on_auth_error()
|
977
894
|
def _upload_single_file(
|
978
895
|
self,
|
979
896
|
job_name: str,
|
@@ -1047,20 +964,10 @@ class RestClient:
|
|
1047
964
|
)
|
1048
965
|
|
1049
966
|
logger.info(f"Successfully uploaded {file_name}")
|
1050
|
-
except HTTPStatusError as e:
|
1051
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1052
|
-
raise AuthError(
|
1053
|
-
e.response.status_code,
|
1054
|
-
f"Authentication failed: {e}",
|
1055
|
-
request=e.request,
|
1056
|
-
response=e.response,
|
1057
|
-
) from e
|
1058
|
-
raise
|
1059
967
|
except Exception as e:
|
1060
968
|
logger.exception(f"Error uploading file {file_path}")
|
1061
969
|
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
1062
970
|
|
1063
|
-
@refresh_token_on_auth_error()
|
1064
971
|
@retry(
|
1065
972
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1066
973
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1097,13 +1004,6 @@ class RestClient:
|
|
1097
1004
|
response.raise_for_status()
|
1098
1005
|
return response.json()
|
1099
1006
|
except HTTPStatusError as e:
|
1100
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1101
|
-
raise AuthError(
|
1102
|
-
e.response.status_code,
|
1103
|
-
f"Authentication failed: {e}",
|
1104
|
-
request=e.request,
|
1105
|
-
response=e.response,
|
1106
|
-
) from e
|
1107
1007
|
logger.exception(
|
1108
1008
|
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
1109
1009
|
)
|
@@ -1116,7 +1016,6 @@ class RestClient:
|
|
1116
1016
|
)
|
1117
1017
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
1118
1018
|
|
1119
|
-
@refresh_token_on_auth_error()
|
1120
1019
|
@retry(
|
1121
1020
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1122
1021
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1164,13 +1063,6 @@ class RestClient:
|
|
1164
1063
|
|
1165
1064
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
1166
1065
|
except HTTPStatusError as e:
|
1167
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1168
|
-
raise AuthError(
|
1169
|
-
e.response.status_code,
|
1170
|
-
f"Authentication failed: {e}",
|
1171
|
-
request=e.request,
|
1172
|
-
response=e.response,
|
1173
|
-
) from e
|
1174
1066
|
logger.exception(
|
1175
1067
|
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
1176
1068
|
)
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import ClassVar, Final
|
3
|
+
|
4
|
+
import httpx
|
5
|
+
|
6
|
+
from futurehouse_client.models.app import APIKeyPayload, AuthType
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
INVALID_REFRESH_TYPE_MSG: Final[str] = (
|
11
|
+
"API key auth is required to refresh auth tokens."
|
12
|
+
)
|
13
|
+
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
14
|
+
|
15
|
+
|
16
|
+
def _run_auth(
|
17
|
+
client: httpx.Client,
|
18
|
+
auth_type: AuthType = AuthType.API_KEY,
|
19
|
+
api_key: str | None = None,
|
20
|
+
jwt: str | None = None,
|
21
|
+
) -> str:
|
22
|
+
auth_payload: APIKeyPayload | None
|
23
|
+
if auth_type == AuthType.API_KEY:
|
24
|
+
auth_payload = APIKeyPayload(api_key=api_key)
|
25
|
+
elif auth_type == AuthType.JWT:
|
26
|
+
auth_payload = None
|
27
|
+
try:
|
28
|
+
if auth_payload:
|
29
|
+
response = client.post("/auth/login", json=auth_payload.model_dump())
|
30
|
+
response.raise_for_status()
|
31
|
+
token_data = response.json()
|
32
|
+
elif jwt:
|
33
|
+
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
34
|
+
else:
|
35
|
+
raise ValueError("JWT token required for JWT authentication.")
|
36
|
+
|
37
|
+
return token_data["access_token"]
|
38
|
+
except Exception as e:
|
39
|
+
raise Exception("Failed to authenticate") from e # noqa: TRY002
|
40
|
+
|
41
|
+
|
42
|
+
class RefreshingJWT(httpx.Auth):
|
43
|
+
"""Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
|
44
|
+
|
45
|
+
RETRY_STATUSES: ClassVar[set[int]] = {
|
46
|
+
httpx.codes.UNAUTHORIZED,
|
47
|
+
httpx.codes.FORBIDDEN,
|
48
|
+
}
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
auth_client: httpx.Client,
|
53
|
+
auth_type: AuthType = AuthType.API_KEY,
|
54
|
+
api_key: str | None = None,
|
55
|
+
jwt: str | None = None,
|
56
|
+
):
|
57
|
+
self.auth_type = auth_type
|
58
|
+
self.auth_client = auth_client
|
59
|
+
self.api_key = api_key
|
60
|
+
self._jwt = _run_auth(
|
61
|
+
client=auth_client,
|
62
|
+
jwt=jwt,
|
63
|
+
auth_type=auth_type,
|
64
|
+
api_key=api_key,
|
65
|
+
)
|
66
|
+
|
67
|
+
def refresh_token(self):
|
68
|
+
if self.auth_type == AuthType.JWT:
|
69
|
+
logger.error(INVALID_REFRESH_TYPE_MSG)
|
70
|
+
raise ValueError(INVALID_REFRESH_TYPE_MSG)
|
71
|
+
self._jwt = _run_auth(
|
72
|
+
client=self.auth_client,
|
73
|
+
auth_type=self.auth_type,
|
74
|
+
api_key=self.api_key,
|
75
|
+
)
|
76
|
+
|
77
|
+
def auth_flow(self, request):
|
78
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
79
|
+
response = yield request
|
80
|
+
|
81
|
+
# If it failed, refresh once and replay the request
|
82
|
+
if response.status_code in self.RETRY_STATUSES:
|
83
|
+
logger.info(
|
84
|
+
"Received %s, refreshing token and retrying …",
|
85
|
+
response.status_code,
|
86
|
+
)
|
87
|
+
self.refresh_token()
|
88
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
89
|
+
yield request # second (and final) attempt, again or use a while loop
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.18.
|
3
|
+
Version: 0.3.18.dev195
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -2,7 +2,6 @@
|
|
2
2
|
import asyncio
|
3
3
|
import os
|
4
4
|
import time
|
5
|
-
from unittest.mock import patch
|
6
5
|
|
7
6
|
import pytest
|
8
7
|
from futurehouse_client.clients import (
|
@@ -10,10 +9,9 @@ from futurehouse_client.clients import (
|
|
10
9
|
PQATaskResponse,
|
11
10
|
TaskResponseVerbose,
|
12
11
|
)
|
13
|
-
from futurehouse_client.clients.rest_client import RestClient
|
12
|
+
from futurehouse_client.clients.rest_client import RestClient, TaskFetchError
|
14
13
|
from futurehouse_client.models.app import Stage, TaskRequest
|
15
14
|
from futurehouse_client.models.rest import ExecutionStatus
|
16
|
-
from futurehouse_client.utils.auth import AuthError, refresh_token_on_auth_error
|
17
15
|
from pytest_subtests import SubTests
|
18
16
|
|
19
17
|
ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
|
@@ -40,7 +38,7 @@ def pub_client():
|
|
40
38
|
|
41
39
|
|
42
40
|
@pytest.fixture
|
43
|
-
def
|
41
|
+
def task_req():
|
44
42
|
"""Create a sample task request."""
|
45
43
|
return TaskRequest(
|
46
44
|
name=JobNames.from_string("dummy"),
|
@@ -49,7 +47,7 @@ def task_data():
|
|
49
47
|
|
50
48
|
|
51
49
|
@pytest.fixture
|
52
|
-
def
|
50
|
+
def pqa_task_req():
|
53
51
|
return TaskRequest(
|
54
52
|
name=JobNames.from_string("crow"),
|
55
53
|
query="How many moons does earth have?",
|
@@ -58,40 +56,40 @@ def pqa_task_data():
|
|
58
56
|
|
59
57
|
@pytest.mark.timeout(300)
|
60
58
|
@pytest.mark.flaky(reruns=3)
|
61
|
-
def test_futurehouse_dummy_env_crow(admin_client: RestClient,
|
62
|
-
admin_client.create_task(
|
59
|
+
def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_req: TaskRequest):
|
60
|
+
admin_client.create_task(task_req)
|
63
61
|
while (task_status := admin_client.get_task().status) in {"queued", "in progress"}:
|
64
62
|
time.sleep(5)
|
65
63
|
assert task_status == "success"
|
66
64
|
|
67
65
|
|
68
66
|
def test_insufficient_permissions_request(
|
69
|
-
pub_client: RestClient,
|
67
|
+
pub_client: RestClient, task_req: TaskRequest
|
70
68
|
):
|
71
69
|
# Create a new instance so that cached credentials aren't reused
|
72
|
-
with pytest.raises(
|
73
|
-
pub_client.create_task(
|
70
|
+
with pytest.raises(TaskFetchError) as exc_info:
|
71
|
+
pub_client.create_task(task_req)
|
74
72
|
|
75
|
-
assert "
|
73
|
+
assert "Error creating task" in str(exc_info.value)
|
76
74
|
|
77
75
|
|
78
76
|
@pytest.mark.timeout(300)
|
79
77
|
@pytest.mark.asyncio
|
80
78
|
async def test_job_response( # noqa: PLR0915
|
81
|
-
subtests: SubTests, admin_client: RestClient,
|
79
|
+
subtests: SubTests, admin_client: RestClient, pqa_task_req: TaskRequest
|
82
80
|
):
|
83
|
-
task_id = admin_client.create_task(
|
84
|
-
atask_id = await admin_client.acreate_task(
|
81
|
+
task_id = admin_client.create_task(pqa_task_req)
|
82
|
+
atask_id = await admin_client.acreate_task(pqa_task_req)
|
85
83
|
|
86
84
|
with subtests.test("Test TaskResponse with queued task"):
|
87
85
|
task_response = admin_client.get_task(task_id)
|
88
86
|
assert task_response.status in {"queued", "in progress"}
|
89
|
-
assert task_response.job_name ==
|
90
|
-
assert task_response.query ==
|
87
|
+
assert task_response.job_name == pqa_task_req.name
|
88
|
+
assert task_response.query == pqa_task_req.query
|
91
89
|
task_response = await admin_client.aget_task(atask_id)
|
92
90
|
assert task_response.status in {"queued", "in progress"}
|
93
|
-
assert task_response.job_name ==
|
94
|
-
assert task_response.query ==
|
91
|
+
assert task_response.job_name == pqa_task_req.name
|
92
|
+
assert task_response.query == pqa_task_req.query
|
95
93
|
|
96
94
|
for _ in range(TEST_MAX_POLLS):
|
97
95
|
task_response = admin_client.get_task(task_id)
|
@@ -111,8 +109,8 @@ async def test_job_response( # noqa: PLR0915
|
|
111
109
|
# assert it has general fields
|
112
110
|
assert task_response.status == "success"
|
113
111
|
assert task_response.task_id is not None
|
114
|
-
assert
|
115
|
-
assert
|
112
|
+
assert pqa_task_req.name in task_response.job_name
|
113
|
+
assert pqa_task_req.query in task_response.query
|
116
114
|
# assert it has PQA specific fields
|
117
115
|
assert task_response.answer is not None
|
118
116
|
# assert it's not verbose
|
@@ -125,8 +123,8 @@ async def test_job_response( # noqa: PLR0915
|
|
125
123
|
# assert it has general fields
|
126
124
|
assert task_response.status == "success"
|
127
125
|
assert task_response.task_id is not None
|
128
|
-
assert
|
129
|
-
assert
|
126
|
+
assert pqa_task_req.name in task_response.job_name
|
127
|
+
assert pqa_task_req.query in task_response.query
|
130
128
|
# assert it has PQA specific fields
|
131
129
|
assert task_response.answer is not None
|
132
130
|
# assert it's not verbose
|
@@ -151,9 +149,9 @@ async def test_job_response( # noqa: PLR0915
|
|
151
149
|
@pytest.mark.timeout(300)
|
152
150
|
@pytest.mark.flaky(reruns=3)
|
153
151
|
def test_run_until_done_futurehouse_dummy_env_crow(
|
154
|
-
admin_client: RestClient,
|
152
|
+
admin_client: RestClient, task_req: TaskRequest
|
155
153
|
):
|
156
|
-
tasks_to_do = [
|
154
|
+
tasks_to_do = [task_req, task_req]
|
157
155
|
|
158
156
|
results = admin_client.run_tasks_until_done(tasks_to_do)
|
159
157
|
|
@@ -165,9 +163,9 @@ def test_run_until_done_futurehouse_dummy_env_crow(
|
|
165
163
|
@pytest.mark.flaky(reruns=3)
|
166
164
|
@pytest.mark.asyncio
|
167
165
|
async def test_arun_until_done_futurehouse_dummy_env_crow(
|
168
|
-
admin_client: RestClient,
|
166
|
+
admin_client: RestClient, task_req: TaskRequest
|
169
167
|
):
|
170
|
-
tasks_to_do = [
|
168
|
+
tasks_to_do = [task_req, task_req]
|
171
169
|
|
172
170
|
results = await admin_client.arun_tasks_until_done(tasks_to_do)
|
173
171
|
|
@@ -179,9 +177,9 @@ async def test_arun_until_done_futurehouse_dummy_env_crow(
|
|
179
177
|
@pytest.mark.flaky(reruns=3)
|
180
178
|
@pytest.mark.asyncio
|
181
179
|
async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
|
182
|
-
admin_client: RestClient,
|
180
|
+
admin_client: RestClient, task_req: TaskRequest
|
183
181
|
):
|
184
|
-
tasks_to_do = [
|
182
|
+
tasks_to_do = [task_req, task_req]
|
185
183
|
|
186
184
|
results = await admin_client.arun_tasks_until_done(
|
187
185
|
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
@@ -202,59 +200,3 @@ async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
|
|
202
200
|
assert all(not isinstance(task, PQATaskResponse) for task in results), (
|
203
201
|
"Should be verbose."
|
204
202
|
)
|
205
|
-
|
206
|
-
|
207
|
-
def test_auth_refresh_flow(admin_client: RestClient):
|
208
|
-
refresh_calls = 0
|
209
|
-
func_calls = 0
|
210
|
-
|
211
|
-
def mock_run_auth(*args, **kwargs):
|
212
|
-
nonlocal refresh_calls
|
213
|
-
refresh_calls += 1
|
214
|
-
return f"fresh-token-{refresh_calls}"
|
215
|
-
|
216
|
-
@refresh_token_on_auth_error()
|
217
|
-
def test_func(self, *args):
|
218
|
-
nonlocal func_calls
|
219
|
-
func_calls += 1
|
220
|
-
|
221
|
-
if func_calls == 1:
|
222
|
-
raise AuthError(401, "Auth failed", None, None)
|
223
|
-
return "success"
|
224
|
-
|
225
|
-
with patch.object(admin_client, "_run_auth", mock_run_auth):
|
226
|
-
result = test_func(admin_client)
|
227
|
-
|
228
|
-
assert result == "success"
|
229
|
-
assert func_calls == 2, "Function should be called twice"
|
230
|
-
assert refresh_calls == 1, "Auth should be refreshed once"
|
231
|
-
assert admin_client.auth_jwt == "fresh-token-1"
|
232
|
-
|
233
|
-
|
234
|
-
@pytest.mark.asyncio
|
235
|
-
async def test_async_auth_refresh_flow(admin_client: RestClient):
|
236
|
-
refresh_calls = 0
|
237
|
-
func_calls = 0
|
238
|
-
|
239
|
-
def mock_run_auth(*args, **kwargs):
|
240
|
-
nonlocal refresh_calls
|
241
|
-
refresh_calls += 1
|
242
|
-
return f"fresh-token-{refresh_calls}"
|
243
|
-
|
244
|
-
@refresh_token_on_auth_error()
|
245
|
-
async def test_async_func(self, *args):
|
246
|
-
nonlocal func_calls
|
247
|
-
func_calls += 1
|
248
|
-
|
249
|
-
if func_calls == 1:
|
250
|
-
raise AuthError(401, "Auth failed", None, None)
|
251
|
-
await asyncio.sleep(1)
|
252
|
-
return "success"
|
253
|
-
|
254
|
-
with patch.object(admin_client, "_run_auth", mock_run_auth):
|
255
|
-
result = await test_async_func(admin_client)
|
256
|
-
|
257
|
-
assert result == "success"
|
258
|
-
assert func_calls == 2, "Function should be called twice"
|
259
|
-
assert refresh_calls == 1, "Auth should be refreshed once"
|
260
|
-
assert admin_client.auth_jwt == "fresh-token-1"
|
@@ -1,107 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import logging
|
3
|
-
from collections.abc import Callable, Coroutine
|
4
|
-
from functools import wraps
|
5
|
-
from typing import Any, Final, Optional, ParamSpec, TypeVar, overload
|
6
|
-
|
7
|
-
import httpx
|
8
|
-
from httpx import HTTPStatusError
|
9
|
-
|
10
|
-
logger = logging.getLogger(__name__)
|
11
|
-
|
12
|
-
T = TypeVar("T")
|
13
|
-
P = ParamSpec("P")
|
14
|
-
|
15
|
-
AUTH_ERRORS_TO_RETRY_ON: Final[set[int]] = {
|
16
|
-
httpx.codes.UNAUTHORIZED,
|
17
|
-
httpx.codes.FORBIDDEN,
|
18
|
-
}
|
19
|
-
|
20
|
-
|
21
|
-
class AuthError(Exception):
|
22
|
-
"""Raised when authentication fails with 401/403 status."""
|
23
|
-
|
24
|
-
def __init__(self, status_code: int, message: str, request=None, response=None):
|
25
|
-
self.status_code = status_code
|
26
|
-
self.request = request
|
27
|
-
self.response = response
|
28
|
-
super().__init__(message)
|
29
|
-
|
30
|
-
|
31
|
-
def is_auth_error(e: Exception) -> bool:
|
32
|
-
if isinstance(e, AuthError):
|
33
|
-
return True
|
34
|
-
if isinstance(e, HTTPStatusError):
|
35
|
-
return e.response.status_code in AUTH_ERRORS_TO_RETRY_ON
|
36
|
-
return False
|
37
|
-
|
38
|
-
|
39
|
-
def get_status_code(e: Exception) -> Optional[int]:
|
40
|
-
if isinstance(e, AuthError):
|
41
|
-
return e.status_code
|
42
|
-
if isinstance(e, HTTPStatusError):
|
43
|
-
return e.response.status_code
|
44
|
-
return None
|
45
|
-
|
46
|
-
|
47
|
-
@overload
|
48
|
-
def refresh_token_on_auth_error(
|
49
|
-
func: Callable[P, Coroutine[Any, Any, T]],
|
50
|
-
) -> Callable[P, Coroutine[Any, Any, T]]: ...
|
51
|
-
|
52
|
-
|
53
|
-
@overload
|
54
|
-
def refresh_token_on_auth_error(
|
55
|
-
func: None = None, *, max_retries: int = ...
|
56
|
-
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
57
|
-
|
58
|
-
|
59
|
-
def refresh_token_on_auth_error(func=None, max_retries=1):
|
60
|
-
"""Decorator that refreshes JWT token on 401/403 auth errors."""
|
61
|
-
|
62
|
-
def decorator(fn):
|
63
|
-
@wraps(fn)
|
64
|
-
def sync_wrapper(self, *args, **kwargs):
|
65
|
-
retries = 0
|
66
|
-
while True:
|
67
|
-
try:
|
68
|
-
return fn(self, *args, **kwargs)
|
69
|
-
except Exception as e:
|
70
|
-
if is_auth_error(e) and retries < max_retries:
|
71
|
-
retries += 1
|
72
|
-
status = get_status_code(e) or "Unknown"
|
73
|
-
logger.info(
|
74
|
-
f"Received auth error {status}, "
|
75
|
-
f"refreshing token and retrying (attempt {retries}/{max_retries})..."
|
76
|
-
)
|
77
|
-
self.auth_jwt = self._run_auth()
|
78
|
-
self._clients = {}
|
79
|
-
continue
|
80
|
-
raise
|
81
|
-
|
82
|
-
@wraps(fn)
|
83
|
-
async def async_wrapper(self, *args, **kwargs):
|
84
|
-
retries = 0
|
85
|
-
while True:
|
86
|
-
try:
|
87
|
-
return await fn(self, *args, **kwargs)
|
88
|
-
except Exception as e:
|
89
|
-
if is_auth_error(e) and retries < max_retries:
|
90
|
-
retries += 1
|
91
|
-
status = get_status_code(e) or "Unknown"
|
92
|
-
logger.info(
|
93
|
-
f"Received auth error {status}, "
|
94
|
-
f"refreshing token and retrying (attempt {retries}/{max_retries})..."
|
95
|
-
)
|
96
|
-
self.auth_jwt = self._run_auth()
|
97
|
-
self._clients = {}
|
98
|
-
continue
|
99
|
-
raise
|
100
|
-
|
101
|
-
if asyncio.iscoroutinefunction(fn):
|
102
|
-
return async_wrapper
|
103
|
-
return sync_wrapper
|
104
|
-
|
105
|
-
if callable(func):
|
106
|
-
return decorator(func)
|
107
|
-
return decorator
|
File without changes
|
File without changes
|
File without changes
|
{futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/docs/client_notebook.ipynb
RENAMED
File without changes
|
{futurehouse_client-0.3.18.dev185 → futurehouse_client-0.3.18.dev195}/futurehouse_client/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|