futurehouse-client 0.3.18.dev186__py3-none-any.whl → 0.3.19.dev111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,11 @@
1
1
  from .clients.job_client import JobClient, JobNames
2
- from .clients.rest_client import PQATaskResponse, TaskResponse, TaskResponseVerbose
3
2
  from .clients.rest_client import RestClient as FutureHouseClient
3
+ from .clients.rest_client import TaskResponse, TaskResponseVerbose
4
4
 
5
5
  __all__ = [
6
6
  "FutureHouseClient",
7
7
  "JobClient",
8
8
  "JobNames",
9
- "PQATaskResponse",
10
9
  "TaskResponse",
11
10
  "TaskResponseVerbose",
12
11
  ]
@@ -1,12 +1,11 @@
1
1
  from .job_client import JobClient, JobNames
2
- from .rest_client import PQATaskResponse, TaskResponse, TaskResponseVerbose
3
2
  from .rest_client import RestClient as FutureHouseClient
3
+ from .rest_client import TaskResponse, TaskResponseVerbose
4
4
 
5
5
  __all__ = [
6
6
  "FutureHouseClient",
7
7
  "JobClient",
8
8
  "JobNames",
9
- "PQATaskResponse",
10
9
  "TaskResponse",
11
10
  "TaskResponseVerbose",
12
11
  ]
@@ -8,7 +8,13 @@ from aviary.env import Frame
8
8
  from pydantic import BaseModel
9
9
  from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
10
10
 
11
- from futurehouse_client.models.app import Stage
11
+ from futurehouse_client.models.app import (
12
+ FinchTaskResponse,
13
+ PhoenixTaskResponse,
14
+ PQATaskResponse,
15
+ Stage,
16
+ TaskResponse,
17
+ )
12
18
  from futurehouse_client.models.rest import (
13
19
  FinalEnvironmentRequest,
14
20
  StoreAgentStatePostRequest,
@@ -31,6 +37,19 @@ class JobNames(StrEnum):
31
37
  DUMMY = "job-futurehouse-dummy-env"
32
38
  PHOENIX = "job-futurehouse-phoenix"
33
39
  FINCH = "job-futurehouse-data-analysis-crow-high"
40
+ CHIMP = "job-futurehouse-chimp"
41
+
42
+ @classmethod
43
+ def _get_response_mapping(cls) -> dict[str, type[TaskResponse]]:
44
+ return {
45
+ cls.CROW: PQATaskResponse,
46
+ cls.FALCON: PQATaskResponse,
47
+ cls.OWL: PQATaskResponse,
48
+ cls.CHIMP: PQATaskResponse,
49
+ cls.PHOENIX: PhoenixTaskResponse,
50
+ cls.FINCH: FinchTaskResponse,
51
+ cls.DUMMY: TaskResponse,
52
+ }
34
53
 
35
54
  @classmethod
36
55
  def from_stage(cls, job_name: str, stage: Stage | None = None) -> str:
@@ -52,6 +71,13 @@ class JobNames(StrEnum):
52
71
  f"Invalid job name: {job_name}. \nOptions are: {', '.join([name.name for name in cls])}"
53
72
  ) from e
54
73
 
74
+ @staticmethod
75
+ def get_response_object_from_job(job_name: str) -> type[TaskResponse]:
76
+ return JobNames._get_response_mapping()[job_name]
77
+
78
+ def get_response_object(self) -> type[TaskResponse]:
79
+ return self._get_response_mapping()[self.name]
80
+
55
81
 
56
82
  class JobClient:
57
83
  REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
@@ -13,9 +13,10 @@ import tempfile
13
13
  import time
14
14
  import uuid
15
15
  from collections.abc import Collection
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
17
  from pathlib import Path
17
18
  from types import ModuleType
18
- from typing import Any, ClassVar, assert_never, cast
19
+ from typing import Any, ClassVar, cast
19
20
  from uuid import UUID
20
21
 
21
22
  import cloudpickle
@@ -31,6 +32,7 @@ from httpx import (
31
32
  ReadError,
32
33
  ReadTimeout,
33
34
  RemoteProtocolError,
35
+ codes,
34
36
  )
35
37
  from ldp.agent import AgentConfig
36
38
  from requests.exceptions import RequestException, Timeout
@@ -45,21 +47,15 @@ from tqdm.asyncio import tqdm
45
47
 
46
48
  from futurehouse_client.clients import JobNames
47
49
  from futurehouse_client.models.app import (
48
- APIKeyPayload,
49
50
  AuthType,
50
51
  JobDeploymentConfig,
51
- PQATaskResponse,
52
52
  Stage,
53
53
  TaskRequest,
54
54
  TaskResponse,
55
55
  TaskResponseVerbose,
56
56
  )
57
57
  from futurehouse_client.models.rest import ExecutionStatus
58
- from futurehouse_client.utils.auth import (
59
- AUTH_ERRORS_TO_RETRY_ON,
60
- AuthError,
61
- refresh_token_on_auth_error,
62
- )
58
+ from futurehouse_client.utils.auth import RefreshingJWT
63
59
  from futurehouse_client.utils.general import gather_with_concurrency
64
60
  from futurehouse_client.utils.module_utils import (
65
61
  OrganizationSelector,
@@ -128,8 +124,6 @@ retry_if_connection_error = retry_if_exception_type((
128
124
  FileUploadError,
129
125
  ))
130
126
 
131
- # 5 minute default for JWTs
132
- JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
133
127
  DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
134
128
 
135
129
 
@@ -140,6 +134,9 @@ class RestClient:
140
134
  MAX_RETRY_WAIT: ClassVar[int] = 10
141
135
  DEFAULT_POLLING_TIME: ClassVar[int] = 5 # seconds
142
136
  CHUNK_SIZE: ClassVar[int] = 16 * 1024 * 1024 # 16MB chunks
137
+ ASSEMBLY_POLLING_INTERVAL: ClassVar[int] = 10 # seconds
138
+ MAX_ASSEMBLY_WAIT_TIME: ClassVar[int] = 1800 # 30 minutes
139
+ MAX_CONCURRENT_CHUNKS: ClassVar[int] = 12 # Maximum concurrent chunk uploads
143
140
 
144
141
  def __init__(
145
142
  self,
@@ -163,69 +160,87 @@ class RestClient:
163
160
  self.api_key = api_key
164
161
  self._clients: dict[str, Client | AsyncClient] = {}
165
162
  self.headers = headers or {}
166
- self.auth_jwt = self._run_auth(jwt=jwt)
163
+ self.jwt = jwt
167
164
  self.organizations: list[str] = self._filter_orgs(organization)
168
165
 
169
166
  @property
170
167
  def client(self) -> Client:
171
- """Lazily initialized and cached HTTP client with authentication."""
172
- return cast(Client, self.get_client("application/json", with_auth=True))
168
+ """Authenticated HTTP client for regular API calls."""
169
+ return cast(Client, self.get_client("application/json", authenticated=True))
173
170
 
174
171
  @property
175
172
  def async_client(self) -> AsyncClient:
176
- """Lazily initialized and cached HTTP client with authentication."""
173
+ """Authenticated async HTTP client for regular API calls."""
177
174
  return cast(
178
175
  AsyncClient,
179
- self.get_client("application/json", with_auth=True, with_async=True),
176
+ self.get_client("application/json", authenticated=True, async_client=True),
180
177
  )
181
178
 
182
179
  @property
183
- def auth_client(self) -> Client:
184
- """Lazily initialized and cached HTTP client without authentication."""
185
- return cast(Client, self.get_client("application/json", with_auth=False))
180
+ def unauthenticated_client(self) -> Client:
181
+ """Unauthenticated HTTP client for auth operations."""
182
+ return cast(Client, self.get_client("application/json", authenticated=False))
186
183
 
187
184
  @property
188
185
  def multipart_client(self) -> Client:
189
- """Lazily initialized and cached HTTP client for multipart uploads."""
190
- return cast(Client, self.get_client(None, with_auth=True))
186
+ """Authenticated HTTP client for multipart uploads."""
187
+ return cast(Client, self.get_client(None, authenticated=True))
191
188
 
192
189
  def get_client(
193
190
  self,
194
191
  content_type: str | None = "application/json",
195
- with_auth: bool = True,
196
- with_async: bool = False,
192
+ authenticated: bool = True,
193
+ async_client: bool = False,
197
194
  ) -> Client | AsyncClient:
198
195
  """Return a cached HTTP client or create one if needed.
199
196
 
200
197
  Args:
201
198
  content_type: The desired content type header. Use None for multipart uploads.
202
- with_auth: Whether the client should include an Authorization header.
203
- with_async: Whether to use an async client.
199
+ authenticated: Whether the client should include authentication.
200
+ async_client: Whether to use an async client.
204
201
 
205
202
  Returns:
206
203
  An HTTP client configured with the appropriate headers.
207
204
  """
208
- # Create a composite key based on content type and auth flag.
209
- key = f"{content_type or 'multipart'}_{with_auth}_{with_async}"
205
+ # Create a composite key based on content type and auth flag
206
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
207
+
210
208
  if key not in self._clients:
211
209
  headers = copy.deepcopy(self.headers)
212
- if with_auth:
213
- headers["Authorization"] = f"Bearer {self.auth_jwt}"
210
+ auth = None
211
+
212
+ if authenticated:
213
+ auth = RefreshingJWT(
214
+ # authenticated=False will always return a synchronous client
215
+ auth_client=cast(
216
+ Client, self.get_client("application/json", authenticated=False)
217
+ ),
218
+ auth_type=self.auth_type,
219
+ api_key=self.api_key,
220
+ jwt=self.jwt,
221
+ )
222
+
214
223
  if content_type:
215
224
  headers["Content-Type"] = content_type
225
+
226
+ headers["x-client"] = "sdk"
227
+
216
228
  self._clients[key] = (
217
229
  AsyncClient(
218
230
  base_url=self.base_url,
219
231
  headers=headers,
220
232
  timeout=self.REQUEST_TIMEOUT,
233
+ auth=auth,
221
234
  )
222
- if with_async
235
+ if async_client
223
236
  else Client(
224
237
  base_url=self.base_url,
225
238
  headers=headers,
226
239
  timeout=self.REQUEST_TIMEOUT,
240
+ auth=auth,
227
241
  )
228
242
  )
243
+
229
244
  return self._clients[key]
230
245
 
231
246
  def close(self):
@@ -255,32 +270,6 @@ class RestClient:
255
270
  raise ValueError(f"Organization '{organization}' not found.")
256
271
  return filtered_orgs
257
272
 
258
- def _run_auth(self, jwt: str | None = None) -> str:
259
- auth_payload: APIKeyPayload | None
260
- if self.auth_type == AuthType.API_KEY:
261
- auth_payload = APIKeyPayload(api_key=self.api_key)
262
- elif self.auth_type == AuthType.JWT:
263
- auth_payload = None
264
- else:
265
- assert_never(self.auth_type)
266
- try:
267
- # Use the unauthenticated client for login
268
- if auth_payload:
269
- response = self.auth_client.post(
270
- "/auth/login", json=auth_payload.model_dump()
271
- )
272
- response.raise_for_status()
273
- token_data = response.json()
274
- elif jwt:
275
- token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
276
- else:
277
- raise ValueError("JWT token required for JWT authentication.")
278
-
279
- return token_data["access_token"]
280
- except Exception as e:
281
- raise RestClientError(f"Error authenticating: {e!s}") from e
282
-
283
- @refresh_token_on_auth_error()
284
273
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
285
274
  try:
286
275
  response = self.client.get(
@@ -288,25 +277,113 @@ class RestClient:
288
277
  )
289
278
  response.raise_for_status()
290
279
  return response.json()
291
- except HTTPStatusError as e:
292
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
293
- raise AuthError(
294
- e.response.status_code,
295
- f"Authentication failed: {e}",
296
- request=e.request,
297
- response=e.response,
298
- ) from e
299
- raise
300
280
  except Exception as e:
301
281
  raise JobFetchError(f"Error checking job: {e!s}") from e
302
282
 
303
- @refresh_token_on_auth_error()
304
283
  def _fetch_my_orgs(self) -> list[str]:
305
284
  response = self.client.get(f"/v0.1/organizations?filter={True}")
306
285
  response.raise_for_status()
307
286
  orgs = response.json()
308
287
  return [org["name"] for org in orgs]
309
288
 
289
+ def _check_assembly_status(
290
+ self, job_name: str, upload_id: str, file_name: str
291
+ ) -> dict[str, Any]:
292
+ """Check the assembly status of an uploaded file.
293
+
294
+ Args:
295
+ job_name: The name of the futurehouse job
296
+ upload_id: The upload ID
297
+ file_name: The name of the file
298
+
299
+ Returns:
300
+ Dict containing status information
301
+
302
+ Raises:
303
+ RestClientError: If there's an error checking status
304
+ """
305
+ try:
306
+ url = f"/v0.1/crows/{job_name}/assembly-status/{upload_id}/{file_name}"
307
+ response = self.client.get(url)
308
+ response.raise_for_status()
309
+ return response.json()
310
+ except Exception as e:
311
+ raise RestClientError(f"Error checking assembly status: {e}") from e
312
+
313
+ def _wait_for_all_assemblies_completion(
314
+ self,
315
+ job_name: str,
316
+ upload_id: str,
317
+ file_names: list[str],
318
+ timeout: int = MAX_ASSEMBLY_WAIT_TIME,
319
+ ) -> bool:
320
+ """Wait for all file assemblies to complete.
321
+
322
+ Args:
323
+ job_name: The name of the futurehouse job
324
+ upload_id: The upload ID
325
+ file_names: List of file names to wait for
326
+ timeout: Maximum time to wait in seconds
327
+
328
+ Returns:
329
+ True if all assemblies succeeded, False if any failed or timed out
330
+
331
+ Raises:
332
+ RestClientError: If any assembly fails
333
+ """
334
+ if not file_names:
335
+ return True
336
+
337
+ start_time = time.time()
338
+ logger.info(f"Waiting for assembly of {len(file_names)} file(s) to complete...")
339
+
340
+ completed_files: set[str] = set()
341
+
342
+ while (time.time() - start_time) < timeout and len(completed_files) < len(
343
+ file_names
344
+ ):
345
+ for file_name in file_names:
346
+ if file_name in completed_files:
347
+ continue
348
+
349
+ try:
350
+ status_data = self._check_assembly_status(
351
+ job_name, upload_id, file_name
352
+ )
353
+ status = status_data.get("status")
354
+
355
+ if status == ExecutionStatus.SUCCESS.value:
356
+ logger.info(f"Assembly completed for {file_name}")
357
+ completed_files.add(file_name)
358
+ elif status == ExecutionStatus.FAIL.value:
359
+ error_msg = status_data.get("error", "Unknown assembly error")
360
+ raise RestClientError(
361
+ f"Assembly failed for {file_name}: {error_msg}"
362
+ )
363
+ elif status == ExecutionStatus.IN_PROGRESS.value:
364
+ logger.debug(f"Assembly in progress for {file_name}...")
365
+
366
+ except RestClientError:
367
+ raise # Re-raise assembly errors
368
+ except Exception as e:
369
+ logger.warning(
370
+ f"Error checking assembly status for {file_name}: {e}"
371
+ )
372
+
373
+ # Don't sleep if all files are complete
374
+ if len(completed_files) < len(file_names):
375
+ time.sleep(self.ASSEMBLY_POLLING_INTERVAL)
376
+
377
+ if len(completed_files) < len(file_names):
378
+ remaining_files = set(file_names) - completed_files
379
+ logger.warning(
380
+ f"Assembly timeout for files: {remaining_files} after {timeout} seconds"
381
+ )
382
+ return False
383
+
384
+ logger.info(f"All {len(file_names)} file assemblies completed successfully")
385
+ return True
386
+
310
387
  @staticmethod
311
388
  def _validate_module_path(path: Path) -> None:
312
389
  """Validates that the given path exists and is a directory.
@@ -358,7 +435,6 @@ class RestClient:
358
435
  if not files:
359
436
  raise TaskFetchError(f"No files found in {path}")
360
437
 
361
- @refresh_token_on_auth_error()
362
438
  @retry(
363
439
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
364
440
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -394,25 +470,12 @@ class RestClient:
394
470
 
395
471
  if verbose:
396
472
  return verbose_response
397
- if any(
398
- JobNames.from_string(job_name) in verbose_response.job_name
399
- for job_name in ["crow", "falcon", "owl", "dummy"]
400
- ):
401
- return PQATaskResponse(**data)
402
- return TaskResponse(**data)
403
- except HTTPStatusError as e:
404
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
405
- raise AuthError(
406
- e.response.status_code,
407
- f"Authentication failed: {e}",
408
- request=e.request,
409
- response=e.response,
410
- ) from e
411
- raise
473
+ return JobNames.get_response_object_from_job(verbose_response.job_name)(
474
+ **data
475
+ )
412
476
  except Exception as e:
413
477
  raise TaskFetchError(f"Error getting task: {e!s}") from e
414
478
 
415
- @refresh_token_on_auth_error()
416
479
  @retry(
417
480
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
418
481
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -450,25 +513,12 @@ class RestClient:
450
513
 
451
514
  if verbose:
452
515
  return verbose_response
453
- if any(
454
- JobNames.from_string(job_name) in verbose_response.job_name
455
- for job_name in ["crow", "falcon", "owl", "dummy"]
456
- ):
457
- return PQATaskResponse(**data)
458
- return TaskResponse(**data)
459
- except HTTPStatusError as e:
460
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
461
- raise AuthError(
462
- e.response.status_code,
463
- f"Authentication failed: {e}",
464
- request=e.request,
465
- response=e.response,
466
- ) from e
467
- raise
516
+ return JobNames.get_response_object_from_job(verbose_response.job_name)(
517
+ **data
518
+ )
468
519
  except Exception as e:
469
520
  raise TaskFetchError(f"Error getting task: {e!s}") from e
470
521
 
471
- @refresh_token_on_auth_error()
472
522
  @retry(
473
523
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
474
524
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -492,20 +542,10 @@ class RestClient:
492
542
  response.raise_for_status()
493
543
  trajectory_id = response.json()["trajectory_id"]
494
544
  self.trajectory_id = trajectory_id
495
- except HTTPStatusError as e:
496
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
497
- raise AuthError(
498
- e.response.status_code,
499
- f"Authentication failed: {e}",
500
- request=e.request,
501
- response=e.response,
502
- ) from e
503
- raise
504
545
  except Exception as e:
505
546
  raise TaskFetchError(f"Error creating task: {e!s}") from e
506
547
  return trajectory_id
507
548
 
508
- @refresh_token_on_auth_error()
509
549
  @retry(
510
550
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
511
551
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -529,15 +569,6 @@ class RestClient:
529
569
  response.raise_for_status()
530
570
  trajectory_id = response.json()["trajectory_id"]
531
571
  self.trajectory_id = trajectory_id
532
- except HTTPStatusError as e:
533
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
534
- raise AuthError(
535
- e.response.status_code,
536
- f"Authentication failed: {e}",
537
- request=e.request,
538
- response=e.response,
539
- ) from e
540
- raise
541
572
  except Exception as e:
542
573
  raise TaskFetchError(f"Error creating task: {e!s}") from e
543
574
  return trajectory_id
@@ -683,7 +714,6 @@ class RestClient:
683
714
  for task_id in trajectory_ids
684
715
  ]
685
716
 
686
- @refresh_token_on_auth_error()
687
717
  @retry(
688
718
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
689
719
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -695,19 +725,11 @@ class RestClient:
695
725
  build_id = build_id or self.build_id
696
726
  response = self.client.get(f"/v0.1/builds/{build_id}")
697
727
  response.raise_for_status()
698
- except HTTPStatusError as e:
699
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
700
- raise AuthError(
701
- e.response.status_code,
702
- f"Authentication failed: {e}",
703
- request=e.request,
704
- response=e.response,
705
- ) from e
706
- raise
728
+ except Exception as e:
729
+ raise JobFetchError(f"Error getting build status: {e!s}") from e
707
730
  return response.json()
708
731
 
709
732
  # TODO: Refactor later so we don't have to ignore PLR0915
710
- @refresh_token_on_auth_error()
711
733
  @retry(
712
734
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
713
735
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -887,13 +909,6 @@ class RestClient:
887
909
  build_context = response.json()
888
910
  self.build_id = build_context["build_id"]
889
911
  except HTTPStatusError as e:
890
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
891
- raise AuthError(
892
- e.response.status_code,
893
- f"Authentication failed: {e}",
894
- request=e.request,
895
- response=e.response,
896
- ) from e
897
912
  error_detail = response.json()
898
913
  error_message = error_detail.get("detail", str(e))
899
914
  raise JobCreationError(
@@ -903,6 +918,8 @@ class RestClient:
903
918
  raise JobCreationError(f"Error generating docker image: {e!s}") from e
904
919
  return build_context
905
920
 
921
+ # TODO: we should have have an async upload_file, check_assembly_status,
922
+ # wait_for_assembly_completion, upload_directory, upload_single_file
906
923
  @retry(
907
924
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
908
925
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -913,6 +930,8 @@ class RestClient:
913
930
  job_name: str,
914
931
  file_path: str | os.PathLike,
915
932
  upload_id: str | None = None,
933
+ wait_for_assembly: bool = True,
934
+ assembly_timeout: int = MAX_ASSEMBLY_WAIT_TIME,
916
935
  ) -> str:
917
936
  """Upload a file or directory to a futurehouse job bucket.
918
937
 
@@ -920,29 +939,47 @@ class RestClient:
920
939
  job_name: The name of the futurehouse job to upload to.
921
940
  file_path: The local path to the file or directory to upload.
922
941
  upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
942
+ wait_for_assembly: After file chunking, wait for the assembly to be processed.
943
+ assembly_timeout: Maximum time to wait for assembly in seconds.
923
944
 
924
945
  Returns:
925
946
  The upload ID used for the upload.
926
947
 
927
948
  Raises:
928
949
  FileUploadError: If there's an error uploading the file.
950
+ RestClientError: If assembly fails or times out.
929
951
  """
930
952
  file_path = Path(file_path)
931
953
  if not file_path.exists():
932
954
  raise FileNotFoundError(f"File or directory not found: {file_path}")
933
955
 
934
956
  upload_id = upload_id or str(uuid.uuid4())
957
+ uploaded_files: list[str] = []
935
958
 
936
959
  if file_path.is_dir():
937
960
  # Process directory recursively
938
- self._upload_directory(job_name, file_path, upload_id)
961
+ uploaded_files = self._upload_directory(job_name, file_path, upload_id)
939
962
  else:
940
963
  # Process single file
941
964
  self._upload_single_file(job_name, file_path, upload_id)
965
+ uploaded_files = [file_path.name]
966
+
967
+ # Wait for all assemblies if requested and we have files
968
+ if wait_for_assembly and uploaded_files:
969
+ success = self._wait_for_all_assemblies_completion(
970
+ job_name, upload_id, uploaded_files, assembly_timeout
971
+ )
972
+ if not success:
973
+ raise RestClientError(
974
+ f"Assembly failed or timed out for one or more files: {uploaded_files}"
975
+ )
976
+
942
977
  logger.info(f"Successfully uploaded {file_path} to {upload_id}")
943
978
  return upload_id
944
979
 
945
- def _upload_directory(self, job_name: str, dir_path: Path, upload_id: str) -> None:
980
+ def _upload_directory(
981
+ self, job_name: str, dir_path: Path, upload_id: str
982
+ ) -> list[str]:
946
983
  """Upload all files in a directory recursively.
947
984
 
948
985
  Args:
@@ -950,12 +987,17 @@ class RestClient:
950
987
  dir_path: The path to the directory to upload.
951
988
  upload_id: The upload ID to use.
952
989
 
990
+ Returns:
991
+ List of uploaded file names.
992
+
953
993
  Raises:
954
994
  FileUploadError: If there's an error uploading any file.
955
995
  """
956
996
  # Skip common directories that shouldn't be uploaded
957
997
  if any(ignore in dir_path.parts for ignore in FILE_UPLOAD_IGNORE_PARTS):
958
- return
998
+ return []
999
+
1000
+ uploaded_files: list[str] = []
959
1001
 
960
1002
  try:
961
1003
  # Upload all files in the directory recursively
@@ -965,24 +1007,27 @@ class RestClient:
965
1007
  ):
966
1008
  # Use path relative to the original directory as file name
967
1009
  rel_path = path.relative_to(dir_path)
1010
+ file_name = str(rel_path)
968
1011
  self._upload_single_file(
969
1012
  job_name,
970
1013
  path,
971
1014
  upload_id,
972
- file_name=str(rel_path),
1015
+ file_name=file_name,
973
1016
  )
1017
+ uploaded_files.append(file_name)
974
1018
  except Exception as e:
975
1019
  raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
976
1020
 
977
- @refresh_token_on_auth_error()
1021
+ return uploaded_files
1022
+
978
1023
  def _upload_single_file(
979
1024
  self,
980
1025
  job_name: str,
981
1026
  file_path: Path,
982
1027
  upload_id: str,
983
1028
  file_name: str | None = None,
984
- ) -> None:
985
- """Upload a single file in chunks.
1029
+ ) -> str | None:
1030
+ """Upload a single file in chunks using parallel uploads.
986
1031
 
987
1032
  Args:
988
1033
  job_name: The key of the crow to upload to.
@@ -990,6 +1035,9 @@ class RestClient:
990
1035
  upload_id: The upload ID to use.
991
1036
  file_name: Optional name to use for the file. If not provided, the file's name will be used.
992
1037
 
1038
+ Returns:
1039
+ The status URL if this was the last chunk, None otherwise.
1040
+
993
1041
  Raises:
994
1042
  FileUploadError: If there's an error uploading the file.
995
1043
  """
@@ -999,16 +1047,190 @@ class RestClient:
999
1047
  # Skip empty files
1000
1048
  if file_size == 0:
1001
1049
  logger.warning(f"Skipping upload of empty file: {file_path}")
1002
- return
1050
+ return None
1003
1051
 
1004
1052
  total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
1005
1053
 
1006
1054
  logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
1007
1055
 
1056
+ status_url = None
1057
+
1008
1058
  try:
1009
- with open(file_path, "rb") as f:
1010
- for chunk_index in range(total_chunks):
1011
- # Read the chunk from the file
1059
+ # Upload all chunks except the last one in parallel
1060
+ if total_chunks > 1:
1061
+ self._upload_chunks_parallel(
1062
+ job_name,
1063
+ file_path,
1064
+ file_name,
1065
+ upload_id,
1066
+ total_chunks - 1,
1067
+ total_chunks,
1068
+ )
1069
+
1070
+ # Upload the last chunk separately (handles assembly)
1071
+ status_url = self._upload_final_chunk(
1072
+ job_name,
1073
+ file_path,
1074
+ file_name,
1075
+ upload_id,
1076
+ total_chunks - 1,
1077
+ total_chunks,
1078
+ )
1079
+
1080
+ logger.info(f"Successfully uploaded {file_name}")
1081
+ except Exception as e:
1082
+ logger.exception(f"Error uploading file {file_path}")
1083
+ raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
1084
+ return status_url
1085
+
1086
+ def _upload_chunks_parallel(
1087
+ self,
1088
+ job_name: str,
1089
+ file_path: Path,
1090
+ file_name: str,
1091
+ upload_id: str,
1092
+ num_regular_chunks: int,
1093
+ total_chunks: int,
1094
+ ) -> None:
1095
+ """Upload chunks in parallel batches.
1096
+
1097
+ Args:
1098
+ job_name: The key of the crow to upload to.
1099
+ file_path: The path to the file to upload.
1100
+ file_name: The name to use for the file.
1101
+ upload_id: The upload ID to use.
1102
+ num_regular_chunks: Number of regular chunks (excluding final chunk).
1103
+ total_chunks: Total number of chunks.
1104
+
1105
+ Raises:
1106
+ FileUploadError: If there's an error uploading any chunk.
1107
+ """
1108
+ if num_regular_chunks <= 0:
1109
+ return
1110
+
1111
+ # Process chunks in batches
1112
+ for batch_start in range(0, num_regular_chunks, self.MAX_CONCURRENT_CHUNKS):
1113
+ batch_end = min(
1114
+ batch_start + self.MAX_CONCURRENT_CHUNKS, num_regular_chunks
1115
+ )
1116
+
1117
+ # Upload chunks in this batch concurrently
1118
+ with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_CHUNKS) as executor:
1119
+ futures = {
1120
+ executor.submit(
1121
+ self._upload_single_chunk,
1122
+ job_name,
1123
+ file_path,
1124
+ file_name,
1125
+ upload_id,
1126
+ chunk_index,
1127
+ total_chunks,
1128
+ ): chunk_index
1129
+ for chunk_index in range(batch_start, batch_end)
1130
+ }
1131
+
1132
+ for future in as_completed(futures):
1133
+ chunk_index = futures[future]
1134
+ try:
1135
+ future.result()
1136
+ logger.debug(
1137
+ f"Uploaded chunk {chunk_index + 1}/{total_chunks} of {file_name}"
1138
+ )
1139
+ except Exception as e:
1140
+ logger.error(f"Error uploading chunk {chunk_index}: {e}")
1141
+ raise FileUploadError(
1142
+ f"Error uploading chunk {chunk_index} of {file_name}: {e}"
1143
+ ) from e
1144
+
1145
+ def _upload_single_chunk(
1146
+ self,
1147
+ job_name: str,
1148
+ file_path: Path,
1149
+ file_name: str,
1150
+ upload_id: str,
1151
+ chunk_index: int,
1152
+ total_chunks: int,
1153
+ ) -> None:
1154
+ """Upload a single chunk.
1155
+
1156
+ Args:
1157
+ job_name: The key of the crow to upload to.
1158
+ file_path: The path to the file to upload.
1159
+ file_name: The name to use for the file.
1160
+ upload_id: The upload ID to use.
1161
+ chunk_index: The index of this chunk.
1162
+ total_chunks: Total number of chunks.
1163
+
1164
+ Raises:
1165
+ Exception: If there's an error uploading the chunk.
1166
+ """
1167
+ with open(file_path, "rb") as f:
1168
+ # Read the chunk from the file
1169
+ f.seek(chunk_index * self.CHUNK_SIZE)
1170
+ chunk_data = f.read(self.CHUNK_SIZE)
1171
+
1172
+ # Prepare and send the chunk
1173
+ with tempfile.NamedTemporaryFile() as temp_file:
1174
+ temp_file.write(chunk_data)
1175
+ temp_file.flush()
1176
+
1177
+ # Create form data
1178
+ with open(temp_file.name, "rb") as chunk_file_obj:
1179
+ files = {
1180
+ "chunk": (
1181
+ file_name,
1182
+ chunk_file_obj,
1183
+ "application/octet-stream",
1184
+ )
1185
+ }
1186
+ data = {
1187
+ "file_name": file_name,
1188
+ "chunk_index": chunk_index,
1189
+ "total_chunks": total_chunks,
1190
+ "upload_id": upload_id,
1191
+ }
1192
+
1193
+ # Send the chunk
1194
+ response = self.multipart_client.post(
1195
+ f"/v0.1/crows/{job_name}/upload-chunk",
1196
+ files=files,
1197
+ data=data,
1198
+ )
1199
+ response.raise_for_status()
1200
+
1201
+ def _upload_final_chunk(
1202
+ self,
1203
+ job_name: str,
1204
+ file_path: Path,
1205
+ file_name: str,
1206
+ upload_id: str,
1207
+ chunk_index: int,
1208
+ total_chunks: int,
1209
+ ) -> str | None:
1210
+ """Upload the final chunk with retry logic for missing chunks.
1211
+
1212
+ Args:
1213
+ job_name: The key of the crow to upload to.
1214
+ file_path: The path to the file to upload.
1215
+ file_name: The name to use for the file.
1216
+ upload_id: The upload ID to use.
1217
+ chunk_index: The index of the final chunk.
1218
+ total_chunks: Total number of chunks.
1219
+
1220
+ Returns:
1221
+ The status URL from the response.
1222
+
1223
+ Raises:
1224
+ FileUploadError: If there's an error uploading the final chunk.
1225
+ """
1226
+ retries = 0
1227
+ max_retries = 3
1228
+ retry_delay = 2.0 # seconds
1229
+
1230
+ while retries < max_retries:
1231
+ try:
1232
+ with open(file_path, "rb") as f:
1233
+ # Read the final chunk from the file
1012
1234
  f.seek(chunk_index * self.CHUNK_SIZE)
1013
1235
  chunk_data = f.read(self.CHUNK_SIZE)
1014
1236
 
@@ -1033,35 +1255,47 @@ class RestClient:
1033
1255
  "upload_id": upload_id,
1034
1256
  }
1035
1257
 
1036
- # Send the chunk
1258
+ # Send the final chunk
1037
1259
  response = self.multipart_client.post(
1038
1260
  f"/v0.1/crows/{job_name}/upload-chunk",
1039
1261
  files=files,
1040
1262
  data=data,
1041
1263
  )
1264
+
1265
+ # Handle missing chunks (status 409)
1266
+ if response.status_code == codes.CONFLICT:
1267
+ retries += 1
1268
+ if retries < max_retries:
1269
+ logger.warning(
1270
+ f"Missing chunks detected for {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries})"
1271
+ )
1272
+ time.sleep(retry_delay)
1273
+ continue
1274
+
1042
1275
  response.raise_for_status()
1276
+ response_data = response.json()
1277
+ status_url = response_data.get("status_url")
1043
1278
 
1044
- # Call progress callback if provided
1279
+ logger.debug(
1280
+ f"Uploaded final chunk {chunk_index + 1}/{total_chunks} of {file_name}"
1281
+ )
1282
+ return status_url
1045
1283
 
1046
- logger.debug(
1047
- f"Uploaded chunk {chunk_index + 1}/{total_chunks} of {file_name}"
1048
- )
1284
+ except Exception as e:
1285
+ if retries >= max_retries - 1:
1286
+ raise FileUploadError(
1287
+ f"Error uploading final chunk of {file_name}: {e}"
1288
+ ) from e
1289
+ retries += 1
1290
+ logger.warning(
1291
+ f"Error uploading final chunk of {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries}): {e}"
1292
+ )
1293
+ time.sleep(retry_delay)
1049
1294
 
1050
- logger.info(f"Successfully uploaded {file_name}")
1051
- except HTTPStatusError as e:
1052
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1053
- raise AuthError(
1054
- e.response.status_code,
1055
- f"Authentication failed: {e}",
1056
- request=e.request,
1057
- response=e.response,
1058
- ) from e
1059
- raise
1060
- except Exception as e:
1061
- logger.exception(f"Error uploading file {file_path}")
1062
- raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
1295
+ raise FileUploadError(
1296
+ f"Failed to upload final chunk of {file_name} after {max_retries} retries"
1297
+ )
1063
1298
 
1064
- @refresh_token_on_auth_error()
1065
1299
  @retry(
1066
1300
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1067
1301
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1098,13 +1332,6 @@ class RestClient:
1098
1332
  response.raise_for_status()
1099
1333
  return response.json()
1100
1334
  except HTTPStatusError as e:
1101
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1102
- raise AuthError(
1103
- e.response.status_code,
1104
- f"Authentication failed: {e}",
1105
- request=e.request,
1106
- response=e.response,
1107
- ) from e
1108
1335
  logger.exception(
1109
1336
  f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
1110
1337
  )
@@ -1117,7 +1344,6 @@ class RestClient:
1117
1344
  )
1118
1345
  raise RestClientError(f"Error listing files: {e!s}") from e
1119
1346
 
1120
- @refresh_token_on_auth_error()
1121
1347
  @retry(
1122
1348
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1123
1349
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1165,13 +1391,6 @@ class RestClient:
1165
1391
 
1166
1392
  logger.info(f"File {file_path} downloaded to {destination_path}")
1167
1393
  except HTTPStatusError as e:
1168
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1169
- raise AuthError(
1170
- e.response.status_code,
1171
- f"Authentication failed: {e}",
1172
- request=e.request,
1173
- response=e.response,
1174
- ) from e
1175
1394
  logger.exception(
1176
1395
  f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
1177
1396
  )
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import json
2
3
  import os
3
4
  import re
@@ -675,7 +676,8 @@ class TaskResponse(BaseModel):
675
676
 
676
677
  @model_validator(mode="before")
677
678
  @classmethod
678
- def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
679
+ def validate_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
680
+ data = copy.deepcopy(original_data) # Avoid mutating the original data
679
681
  # Extract fields from environment frame state
680
682
  if not isinstance(data, dict):
681
683
  return data
@@ -690,7 +692,72 @@ class TaskResponse(BaseModel):
690
692
  return data
691
693
 
692
694
 
695
+ class PhoenixTaskResponse(TaskResponse):
696
+ """
697
+ Response scheme for tasks executed with Phoenix.
698
+
699
+ Additional fields:
700
+ answer: Final answer from Phoenix
701
+ """
702
+
703
+ model_config = ConfigDict(extra="ignore")
704
+ answer: str | None = None
705
+
706
+ @model_validator(mode="before")
707
+ @classmethod
708
+ def validate_phoenix_fields(
709
+ cls, original_data: Mapping[str, Any]
710
+ ) -> Mapping[str, Any]:
711
+ data = copy.deepcopy(original_data)
712
+ if not isinstance(data, dict):
713
+ return data
714
+ if not (env_frame := data.get("environment_frame", {})):
715
+ return data
716
+ state = env_frame.get("state", {}).get("state", {})
717
+ data["answer"] = state.get("answer")
718
+ return data
719
+
720
+
721
+ class FinchTaskResponse(TaskResponse):
722
+ """
723
+ Response scheme for tasks executed with Finch.
724
+
725
+ Additional fields:
726
+ answer: Final answer from Finch
727
+ notebook: a dictionary with `cells` and `metadata` regarding the notebook content
728
+ """
729
+
730
+ model_config = ConfigDict(extra="ignore")
731
+ answer: str | None = None
732
+ notebook: dict[str, Any] | None = None
733
+
734
+ @model_validator(mode="before")
735
+ @classmethod
736
+ def validate_finch_fields(
737
+ cls, original_data: Mapping[str, Any]
738
+ ) -> Mapping[str, Any]:
739
+ data = copy.deepcopy(original_data)
740
+ if not isinstance(data, dict):
741
+ return data
742
+ if not (env_frame := data.get("environment_frame", {})):
743
+ return data
744
+ state = env_frame.get("state", {}).get("state", {})
745
+ data["answer"] = state.get("answer")
746
+ data["notebook"] = state.get("nb_state")
747
+ return data
748
+
749
+
693
750
  class PQATaskResponse(TaskResponse):
751
+ """
752
+ Response scheme for tasks executed with PQA.
753
+
754
+ Additional fields:
755
+ answer: Final answer from PQA
756
+ formatted_answer: Formatted answer from PQA
757
+ answer_reasoning: Reasoning used to generate the final answer, if available
758
+ has_successful_answer: Whether the answer is successful
759
+ """
760
+
694
761
  model_config = ConfigDict(extra="ignore")
695
762
 
696
763
  answer: str | None = None
@@ -702,7 +769,8 @@ class PQATaskResponse(TaskResponse):
702
769
 
703
770
  @model_validator(mode="before")
704
771
  @classmethod
705
- def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
772
+ def validate_pqa_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
773
+ data = copy.deepcopy(original_data) # Avoid mutating the original data
706
774
  if not isinstance(data, dict):
707
775
  return data
708
776
  if not (env_frame := data.get("environment_frame", {})):
@@ -1,107 +1,92 @@
1
- import asyncio
2
1
  import logging
3
- from collections.abc import Callable, Coroutine
4
- from functools import wraps
5
- from typing import Any, Final, Optional, ParamSpec, TypeVar, overload
2
+ from collections.abc import Collection, Generator
3
+ from typing import ClassVar, Final
6
4
 
7
5
  import httpx
8
- from httpx import HTTPStatusError
9
6
 
10
- logger = logging.getLogger(__name__)
11
-
12
- T = TypeVar("T")
13
- P = ParamSpec("P")
14
-
15
- AUTH_ERRORS_TO_RETRY_ON: Final[set[int]] = {
16
- httpx.codes.UNAUTHORIZED,
17
- httpx.codes.FORBIDDEN,
18
- }
19
-
20
-
21
- class AuthError(Exception):
22
- """Raised when authentication fails with 401/403 status."""
23
-
24
- def __init__(self, status_code: int, message: str, request=None, response=None):
25
- self.status_code = status_code
26
- self.request = request
27
- self.response = response
28
- super().__init__(message)
29
-
30
-
31
- def is_auth_error(e: Exception) -> bool:
32
- if isinstance(e, AuthError):
33
- return True
34
- if isinstance(e, HTTPStatusError):
35
- return e.response.status_code in AUTH_ERRORS_TO_RETRY_ON
36
- return False
37
-
38
-
39
- def get_status_code(e: Exception) -> Optional[int]:
40
- if isinstance(e, AuthError):
41
- return e.status_code
42
- if isinstance(e, HTTPStatusError):
43
- return e.response.status_code
44
- return None
7
+ from futurehouse_client.models.app import APIKeyPayload, AuthType
45
8
 
9
+ logger = logging.getLogger(__name__)
46
10
 
47
- @overload
48
- def refresh_token_on_auth_error(
49
- func: Callable[P, Coroutine[Any, Any, T]],
50
- ) -> Callable[P, Coroutine[Any, Any, T]]: ...
51
-
52
-
53
- @overload
54
- def refresh_token_on_auth_error(
55
- func: None = None, *, max_retries: int = ...
56
- ) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
57
-
58
-
59
- def refresh_token_on_auth_error(func=None, max_retries=1):
60
- """Decorator that refreshes JWT token on 401/403 auth errors."""
61
-
62
- def decorator(fn):
63
- @wraps(fn)
64
- def sync_wrapper(self, *args, **kwargs):
65
- retries = 0
66
- while True:
67
- try:
68
- return fn(self, *args, **kwargs)
69
- except Exception as e:
70
- if is_auth_error(e) and retries < max_retries:
71
- retries += 1
72
- status = get_status_code(e) or "Unknown"
73
- logger.info(
74
- f"Received auth error {status}, "
75
- f"refreshing token and retrying (attempt {retries}/{max_retries})..."
76
- )
77
- self.auth_jwt = self._run_auth()
78
- self._clients = {}
79
- continue
80
- raise
81
-
82
- @wraps(fn)
83
- async def async_wrapper(self, *args, **kwargs):
84
- retries = 0
85
- while True:
86
- try:
87
- return await fn(self, *args, **kwargs)
88
- except Exception as e:
89
- if is_auth_error(e) and retries < max_retries:
90
- retries += 1
91
- status = get_status_code(e) or "Unknown"
92
- logger.info(
93
- f"Received auth error {status}, "
94
- f"refreshing token and retrying (attempt {retries}/{max_retries})..."
95
- )
96
- self.auth_jwt = self._run_auth()
97
- self._clients = {}
98
- continue
99
- raise
100
-
101
- if asyncio.iscoroutinefunction(fn):
102
- return async_wrapper
103
- return sync_wrapper
104
-
105
- if callable(func):
106
- return decorator(func)
107
- return decorator
11
+ INVALID_REFRESH_TYPE_MSG: Final[str] = (
12
+ "API key auth is required to refresh auth tokens."
13
+ )
14
+ JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
15
+
16
+
17
+ def _run_auth(
18
+ client: httpx.Client,
19
+ auth_type: AuthType = AuthType.API_KEY,
20
+ api_key: str | None = None,
21
+ jwt: str | None = None,
22
+ ) -> str:
23
+ auth_payload: APIKeyPayload | None
24
+ if auth_type == AuthType.API_KEY:
25
+ auth_payload = APIKeyPayload(api_key=api_key)
26
+ elif auth_type == AuthType.JWT:
27
+ auth_payload = None
28
+ try:
29
+ if auth_payload:
30
+ response = client.post("/auth/login", json=auth_payload.model_dump())
31
+ response.raise_for_status()
32
+ token_data = response.json()
33
+ elif jwt:
34
+ token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
35
+ else:
36
+ raise ValueError("JWT token required for JWT authentication.")
37
+
38
+ return token_data["access_token"]
39
+ except Exception as e:
40
+ raise Exception("Failed to authenticate") from e # noqa: TRY002
41
+
42
+
43
+ class RefreshingJWT(httpx.Auth):
44
+ """Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
45
+
46
+ RETRY_STATUSES: ClassVar[Collection[httpx.codes]] = {
47
+ httpx.codes.UNAUTHORIZED,
48
+ httpx.codes.FORBIDDEN,
49
+ }
50
+
51
+ def __init__(
52
+ self,
53
+ auth_client: httpx.Client,
54
+ auth_type: AuthType = AuthType.API_KEY,
55
+ api_key: str | None = None,
56
+ jwt: str | None = None,
57
+ ):
58
+ self.auth_type = auth_type
59
+ self.auth_client = auth_client
60
+ self.api_key = api_key
61
+ self._jwt = _run_auth(
62
+ client=auth_client,
63
+ jwt=jwt,
64
+ auth_type=auth_type,
65
+ api_key=api_key,
66
+ )
67
+
68
+ def refresh_token(self) -> None:
69
+ if self.auth_type == AuthType.JWT:
70
+ logger.error(INVALID_REFRESH_TYPE_MSG)
71
+ raise ValueError(INVALID_REFRESH_TYPE_MSG)
72
+ self._jwt = _run_auth(
73
+ client=self.auth_client,
74
+ auth_type=self.auth_type,
75
+ api_key=self.api_key,
76
+ )
77
+
78
+ def auth_flow(
79
+ self, request: httpx.Request
80
+ ) -> Generator[httpx.Request, httpx.Response, None]:
81
+ request.headers["Authorization"] = f"Bearer {self._jwt}"
82
+ response = yield request
83
+
84
+ # If it failed, refresh once and replay the request
85
+ if response.status_code in self.RETRY_STATUSES:
86
+ logger.info(
87
+ "Received %s, refreshing token and retrying …",
88
+ response.status_code,
89
+ )
90
+ self.refresh_token()
91
+ request.headers["Authorization"] = f"Bearer {self._jwt}"
92
+ yield request # second (and final) attempt, again or use a while loop
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.18.dev186
3
+ Version: 0.3.19.dev111
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  Classifier: Operating System :: OS Independent
@@ -0,0 +1,17 @@
1
+ futurehouse_client/__init__.py,sha256=OzGDkVm5UTUzd4n8yOmRjMF73YrK0FaIQX5gS3Dk8Zo,304
2
+ futurehouse_client/clients/__init__.py,sha256=-HXNj-XJ3LRO5XM6MZ709iPs29YpApss0Q2YYg1qMZw,280
3
+ futurehouse_client/clients/job_client.py,sha256=JgB5IUAyCmnhGRsYc3bgKldA-lkM1JLwHRwwUeOCdus,11944
4
+ futurehouse_client/clients/rest_client.py,sha256=_XgkzA9OhUKjL9vpkU6ixh2lUW9StgqfGgLk2qHjGgI,55518
5
+ futurehouse_client/models/__init__.py,sha256=5x-f9AoM1hGzJBEHcHAXSt7tPeImST5oZLuMdwp0mXc,554
6
+ futurehouse_client/models/app.py,sha256=VCtg0ygd-TSrR6DtfljTBt9jnl1eBNal8UXHFdkDg88,28587
7
+ futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
8
+ futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
9
+ futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ futurehouse_client/utils/auth.py,sha256=tgWELjKfg8eWme_qdcRmc8TjQN9DVZuHHaVXZNHLchk,2960
11
+ futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
12
+ futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
13
+ futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
14
+ futurehouse_client-0.3.19.dev111.dist-info/METADATA,sha256=N4Msi8W4mMBXFs_-Pl8Ii12RcLRm2eBl9NiIFCy5--E,12767
15
+ futurehouse_client-0.3.19.dev111.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ futurehouse_client-0.3.19.dev111.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
17
+ futurehouse_client-0.3.19.dev111.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,17 +0,0 @@
1
- futurehouse_client/__init__.py,sha256=ddxO7JE97c6bt7LjNglZZ2Ql8bYCGI9laSFeh9MP6VU,344
2
- futurehouse_client/clients/__init__.py,sha256=tFWqwIAY5PvwfOVsCje4imjTpf6xXNRMh_UHIKVI1_0,320
3
- futurehouse_client/clients/job_client.py,sha256=uNkqQbeZw7wbA0qDWcIOwOykrosza-jev58paJZ_mbA,11150
4
- futurehouse_client/clients/rest_client.py,sha256=W9ASP1ZKYS7UL5J9b-Km77YXEiDQ9hCf4X_9PqaZZZc,47914
5
- futurehouse_client/models/__init__.py,sha256=5x-f9AoM1hGzJBEHcHAXSt7tPeImST5oZLuMdwp0mXc,554
6
- futurehouse_client/models/app.py,sha256=w_1e4F0IiC-BKeOLqYkABYo4U-Nka1S-F64S_eHB2KM,26421
7
- futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
8
- futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
9
- futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- futurehouse_client/utils/auth.py,sha256=Lq9mjSGc7iuRP6fmLICCS6KjzLHN6-tJUuhYp0XXrkE,3342
11
- futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
12
- futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
13
- futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
14
- futurehouse_client-0.3.18.dev186.dist-info/METADATA,sha256=PvjehEQZu2ihl7kG1uDvWJVUxyYbV7J-VmAe42Ml3zo,12767
15
- futurehouse_client-0.3.18.dev186.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
16
- futurehouse_client-0.3.18.dev186.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
17
- futurehouse_client-0.3.18.dev186.dist-info/RECORD,,