futurehouse-client 0.3.18.dev186__py3-none-any.whl → 0.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,9 +13,10 @@ import tempfile
13
13
  import time
14
14
  import uuid
15
15
  from collections.abc import Collection
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
17
  from pathlib import Path
17
18
  from types import ModuleType
18
- from typing import Any, ClassVar, assert_never, cast
19
+ from typing import Any, ClassVar, cast
19
20
  from uuid import UUID
20
21
 
21
22
  import cloudpickle
@@ -31,6 +32,7 @@ from httpx import (
31
32
  ReadError,
32
33
  ReadTimeout,
33
34
  RemoteProtocolError,
35
+ codes,
34
36
  )
35
37
  from ldp.agent import AgentConfig
36
38
  from requests.exceptions import RequestException, Timeout
@@ -45,21 +47,15 @@ from tqdm.asyncio import tqdm
45
47
 
46
48
  from futurehouse_client.clients import JobNames
47
49
  from futurehouse_client.models.app import (
48
- APIKeyPayload,
49
50
  AuthType,
50
51
  JobDeploymentConfig,
51
- PQATaskResponse,
52
52
  Stage,
53
53
  TaskRequest,
54
54
  TaskResponse,
55
55
  TaskResponseVerbose,
56
56
  )
57
57
  from futurehouse_client.models.rest import ExecutionStatus
58
- from futurehouse_client.utils.auth import (
59
- AUTH_ERRORS_TO_RETRY_ON,
60
- AuthError,
61
- refresh_token_on_auth_error,
62
- )
58
+ from futurehouse_client.utils.auth import RefreshingJWT
63
59
  from futurehouse_client.utils.general import gather_with_concurrency
64
60
  from futurehouse_client.utils.module_utils import (
65
61
  OrganizationSelector,
@@ -128,8 +124,6 @@ retry_if_connection_error = retry_if_exception_type((
128
124
  FileUploadError,
129
125
  ))
130
126
 
131
- # 5 minute default for JWTs
132
- JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
133
127
  DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
134
128
 
135
129
 
@@ -140,6 +134,9 @@ class RestClient:
140
134
  MAX_RETRY_WAIT: ClassVar[int] = 10
141
135
  DEFAULT_POLLING_TIME: ClassVar[int] = 5 # seconds
142
136
  CHUNK_SIZE: ClassVar[int] = 16 * 1024 * 1024 # 16MB chunks
137
+ ASSEMBLY_POLLING_INTERVAL: ClassVar[int] = 10 # seconds
138
+ MAX_ASSEMBLY_WAIT_TIME: ClassVar[int] = 1800 # 30 minutes
139
+ MAX_CONCURRENT_CHUNKS: ClassVar[int] = 12 # Maximum concurrent chunk uploads
143
140
 
144
141
  def __init__(
145
142
  self,
@@ -163,69 +160,87 @@ class RestClient:
163
160
  self.api_key = api_key
164
161
  self._clients: dict[str, Client | AsyncClient] = {}
165
162
  self.headers = headers or {}
166
- self.auth_jwt = self._run_auth(jwt=jwt)
163
+ self.jwt = jwt
167
164
  self.organizations: list[str] = self._filter_orgs(organization)
168
165
 
169
166
  @property
170
167
  def client(self) -> Client:
171
- """Lazily initialized and cached HTTP client with authentication."""
172
- return cast(Client, self.get_client("application/json", with_auth=True))
168
+ """Authenticated HTTP client for regular API calls."""
169
+ return cast(Client, self.get_client("application/json", authenticated=True))
173
170
 
174
171
  @property
175
172
  def async_client(self) -> AsyncClient:
176
- """Lazily initialized and cached HTTP client with authentication."""
173
+ """Authenticated async HTTP client for regular API calls."""
177
174
  return cast(
178
175
  AsyncClient,
179
- self.get_client("application/json", with_auth=True, with_async=True),
176
+ self.get_client("application/json", authenticated=True, async_client=True),
180
177
  )
181
178
 
182
179
  @property
183
- def auth_client(self) -> Client:
184
- """Lazily initialized and cached HTTP client without authentication."""
185
- return cast(Client, self.get_client("application/json", with_auth=False))
180
+ def unauthenticated_client(self) -> Client:
181
+ """Unauthenticated HTTP client for auth operations."""
182
+ return cast(Client, self.get_client("application/json", authenticated=False))
186
183
 
187
184
  @property
188
185
  def multipart_client(self) -> Client:
189
- """Lazily initialized and cached HTTP client for multipart uploads."""
190
- return cast(Client, self.get_client(None, with_auth=True))
186
+ """Authenticated HTTP client for multipart uploads."""
187
+ return cast(Client, self.get_client(None, authenticated=True))
191
188
 
192
189
  def get_client(
193
190
  self,
194
191
  content_type: str | None = "application/json",
195
- with_auth: bool = True,
196
- with_async: bool = False,
192
+ authenticated: bool = True,
193
+ async_client: bool = False,
197
194
  ) -> Client | AsyncClient:
198
195
  """Return a cached HTTP client or create one if needed.
199
196
 
200
197
  Args:
201
198
  content_type: The desired content type header. Use None for multipart uploads.
202
- with_auth: Whether the client should include an Authorization header.
203
- with_async: Whether to use an async client.
199
+ authenticated: Whether the client should include authentication.
200
+ async_client: Whether to use an async client.
204
201
 
205
202
  Returns:
206
203
  An HTTP client configured with the appropriate headers.
207
204
  """
208
- # Create a composite key based on content type and auth flag.
209
- key = f"{content_type or 'multipart'}_{with_auth}_{with_async}"
205
+ # Create a composite key based on content type and auth flag
206
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
207
+
210
208
  if key not in self._clients:
211
209
  headers = copy.deepcopy(self.headers)
212
- if with_auth:
213
- headers["Authorization"] = f"Bearer {self.auth_jwt}"
210
+ auth = None
211
+
212
+ if authenticated:
213
+ auth = RefreshingJWT(
214
+ # authenticated=False will always return a synchronous client
215
+ auth_client=cast(
216
+ Client, self.get_client("application/json", authenticated=False)
217
+ ),
218
+ auth_type=self.auth_type,
219
+ api_key=self.api_key,
220
+ jwt=self.jwt,
221
+ )
222
+
214
223
  if content_type:
215
224
  headers["Content-Type"] = content_type
225
+
226
+ headers["x-client"] = "sdk"
227
+
216
228
  self._clients[key] = (
217
229
  AsyncClient(
218
230
  base_url=self.base_url,
219
231
  headers=headers,
220
232
  timeout=self.REQUEST_TIMEOUT,
233
+ auth=auth,
221
234
  )
222
- if with_async
235
+ if async_client
223
236
  else Client(
224
237
  base_url=self.base_url,
225
238
  headers=headers,
226
239
  timeout=self.REQUEST_TIMEOUT,
240
+ auth=auth,
227
241
  )
228
242
  )
243
+
229
244
  return self._clients[key]
230
245
 
231
246
  def close(self):
@@ -255,32 +270,6 @@ class RestClient:
255
270
  raise ValueError(f"Organization '{organization}' not found.")
256
271
  return filtered_orgs
257
272
 
258
- def _run_auth(self, jwt: str | None = None) -> str:
259
- auth_payload: APIKeyPayload | None
260
- if self.auth_type == AuthType.API_KEY:
261
- auth_payload = APIKeyPayload(api_key=self.api_key)
262
- elif self.auth_type == AuthType.JWT:
263
- auth_payload = None
264
- else:
265
- assert_never(self.auth_type)
266
- try:
267
- # Use the unauthenticated client for login
268
- if auth_payload:
269
- response = self.auth_client.post(
270
- "/auth/login", json=auth_payload.model_dump()
271
- )
272
- response.raise_for_status()
273
- token_data = response.json()
274
- elif jwt:
275
- token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
276
- else:
277
- raise ValueError("JWT token required for JWT authentication.")
278
-
279
- return token_data["access_token"]
280
- except Exception as e:
281
- raise RestClientError(f"Error authenticating: {e!s}") from e
282
-
283
- @refresh_token_on_auth_error()
284
273
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
285
274
  try:
286
275
  response = self.client.get(
@@ -288,25 +277,113 @@ class RestClient:
288
277
  )
289
278
  response.raise_for_status()
290
279
  return response.json()
291
- except HTTPStatusError as e:
292
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
293
- raise AuthError(
294
- e.response.status_code,
295
- f"Authentication failed: {e}",
296
- request=e.request,
297
- response=e.response,
298
- ) from e
299
- raise
300
280
  except Exception as e:
301
281
  raise JobFetchError(f"Error checking job: {e!s}") from e
302
282
 
303
- @refresh_token_on_auth_error()
304
283
  def _fetch_my_orgs(self) -> list[str]:
305
284
  response = self.client.get(f"/v0.1/organizations?filter={True}")
306
285
  response.raise_for_status()
307
286
  orgs = response.json()
308
287
  return [org["name"] for org in orgs]
309
288
 
289
+ def _check_assembly_status(
290
+ self, job_name: str, upload_id: str, file_name: str
291
+ ) -> dict[str, Any]:
292
+ """Check the assembly status of an uploaded file.
293
+
294
+ Args:
295
+ job_name: The name of the futurehouse job
296
+ upload_id: The upload ID
297
+ file_name: The name of the file
298
+
299
+ Returns:
300
+ Dict containing status information
301
+
302
+ Raises:
303
+ RestClientError: If there's an error checking status
304
+ """
305
+ try:
306
+ url = f"/v0.1/crows/{job_name}/assembly-status/{upload_id}/{file_name}"
307
+ response = self.client.get(url)
308
+ response.raise_for_status()
309
+ return response.json()
310
+ except Exception as e:
311
+ raise RestClientError(f"Error checking assembly status: {e}") from e
312
+
313
+ def _wait_for_all_assemblies_completion(
314
+ self,
315
+ job_name: str,
316
+ upload_id: str,
317
+ file_names: list[str],
318
+ timeout: int = MAX_ASSEMBLY_WAIT_TIME,
319
+ ) -> bool:
320
+ """Wait for all file assemblies to complete.
321
+
322
+ Args:
323
+ job_name: The name of the futurehouse job
324
+ upload_id: The upload ID
325
+ file_names: List of file names to wait for
326
+ timeout: Maximum time to wait in seconds
327
+
328
+ Returns:
329
+ True if all assemblies succeeded, False if any failed or timed out
330
+
331
+ Raises:
332
+ RestClientError: If any assembly fails
333
+ """
334
+ if not file_names:
335
+ return True
336
+
337
+ start_time = time.time()
338
+ logger.info(f"Waiting for assembly of {len(file_names)} file(s) to complete...")
339
+
340
+ completed_files: set[str] = set()
341
+
342
+ while (time.time() - start_time) < timeout and len(completed_files) < len(
343
+ file_names
344
+ ):
345
+ for file_name in file_names:
346
+ if file_name in completed_files:
347
+ continue
348
+
349
+ try:
350
+ status_data = self._check_assembly_status(
351
+ job_name, upload_id, file_name
352
+ )
353
+ status = status_data.get("status")
354
+
355
+ if status == ExecutionStatus.SUCCESS.value:
356
+ logger.info(f"Assembly completed for {file_name}")
357
+ completed_files.add(file_name)
358
+ elif status == ExecutionStatus.FAIL.value:
359
+ error_msg = status_data.get("error", "Unknown assembly error")
360
+ raise RestClientError(
361
+ f"Assembly failed for {file_name}: {error_msg}"
362
+ )
363
+ elif status == ExecutionStatus.IN_PROGRESS.value:
364
+ logger.debug(f"Assembly in progress for {file_name}...")
365
+
366
+ except RestClientError:
367
+ raise # Re-raise assembly errors
368
+ except Exception as e:
369
+ logger.warning(
370
+ f"Error checking assembly status for {file_name}: {e}"
371
+ )
372
+
373
+ # Don't sleep if all files are complete
374
+ if len(completed_files) < len(file_names):
375
+ time.sleep(self.ASSEMBLY_POLLING_INTERVAL)
376
+
377
+ if len(completed_files) < len(file_names):
378
+ remaining_files = set(file_names) - completed_files
379
+ logger.warning(
380
+ f"Assembly timeout for files: {remaining_files} after {timeout} seconds"
381
+ )
382
+ return False
383
+
384
+ logger.info(f"All {len(file_names)} file assemblies completed successfully")
385
+ return True
386
+
310
387
  @staticmethod
311
388
  def _validate_module_path(path: Path) -> None:
312
389
  """Validates that the given path exists and is a directory.
@@ -358,7 +435,6 @@ class RestClient:
358
435
  if not files:
359
436
  raise TaskFetchError(f"No files found in {path}")
360
437
 
361
- @refresh_token_on_auth_error()
362
438
  @retry(
363
439
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
364
440
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -368,51 +444,37 @@ class RestClient:
368
444
  self, task_id: str | None = None, history: bool = False, verbose: bool = False
369
445
  ) -> "TaskResponse":
370
446
  """Get details for a specific task."""
371
- try:
372
- task_id = task_id or self.trajectory_id
373
- url = f"/v0.1/trajectories/{task_id}"
374
- full_url = f"{self.base_url}{url}"
375
-
376
- with (
377
- external_trace(
378
- url=full_url,
379
- method="GET",
380
- library="httpx",
381
- custom_params={
382
- "operation": "get_job",
383
- "job_id": task_id,
384
- },
385
- ),
386
- self.client.stream("GET", url, params={"history": history}) as response,
387
- ):
388
- response.raise_for_status()
389
- json_data = "".join(response.iter_text(chunk_size=1024))
390
- data = json.loads(json_data)
391
- if "id" not in data:
392
- data["id"] = task_id
393
- verbose_response = TaskResponseVerbose(**data)
447
+ task_id = task_id or self.trajectory_id
448
+ url = f"/v0.1/trajectories/{task_id}"
449
+ full_url = f"{self.base_url}{url}"
394
450
 
395
- if verbose:
396
- return verbose_response
397
- if any(
398
- JobNames.from_string(job_name) in verbose_response.job_name
399
- for job_name in ["crow", "falcon", "owl", "dummy"]
400
- ):
401
- return PQATaskResponse(**data)
402
- return TaskResponse(**data)
403
- except HTTPStatusError as e:
404
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
405
- raise AuthError(
406
- e.response.status_code,
407
- f"Authentication failed: {e}",
408
- request=e.request,
409
- response=e.response,
410
- ) from e
411
- raise
412
- except Exception as e:
413
- raise TaskFetchError(f"Error getting task: {e!s}") from e
451
+ with (
452
+ external_trace(
453
+ url=full_url,
454
+ method="GET",
455
+ library="httpx",
456
+ custom_params={
457
+ "operation": "get_job",
458
+ "job_id": task_id,
459
+ },
460
+ ),
461
+ self.client.stream("GET", url, params={"history": history}) as response,
462
+ ):
463
+ if response.status_code in {401, 403}:
464
+ raise PermissionError(
465
+ f"Error getting task: Permission denied for task {task_id}"
466
+ )
467
+ response.raise_for_status()
468
+ json_data = "".join(response.iter_text(chunk_size=1024))
469
+ data = json.loads(json_data)
470
+ if "id" not in data:
471
+ data["id"] = task_id
472
+ verbose_response = TaskResponseVerbose(**data)
473
+
474
+ if verbose:
475
+ return verbose_response
476
+ return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
414
477
 
415
- @refresh_token_on_auth_error()
416
478
  @retry(
417
479
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
418
480
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -422,53 +484,37 @@ class RestClient:
422
484
  self, task_id: str | None = None, history: bool = False, verbose: bool = False
423
485
  ) -> "TaskResponse":
424
486
  """Get details for a specific task asynchronously."""
425
- try:
426
- task_id = task_id or self.trajectory_id
427
- url = f"/v0.1/trajectories/{task_id}"
428
- full_url = f"{self.base_url}{url}"
487
+ task_id = task_id or self.trajectory_id
488
+ url = f"/v0.1/trajectories/{task_id}"
489
+ full_url = f"{self.base_url}{url}"
490
+
491
+ with external_trace(
492
+ url=full_url,
493
+ method="GET",
494
+ library="httpx",
495
+ custom_params={
496
+ "operation": "get_job",
497
+ "job_id": task_id,
498
+ },
499
+ ):
500
+ async with self.async_client.stream(
501
+ "GET", url, params={"history": history}
502
+ ) as response:
503
+ if response.status_code in {401, 403}:
504
+ raise PermissionError(
505
+ f"Error getting task: Permission denied for task {task_id}"
506
+ )
507
+ response.raise_for_status()
508
+ json_data = "".join([chunk async for chunk in response.aiter_text()])
509
+ data = json.loads(json_data)
510
+ if "id" not in data:
511
+ data["id"] = task_id
512
+ verbose_response = TaskResponseVerbose(**data)
429
513
 
430
- with external_trace(
431
- url=full_url,
432
- method="GET",
433
- library="httpx",
434
- custom_params={
435
- "operation": "get_job",
436
- "job_id": task_id,
437
- },
438
- ):
439
- async with self.async_client.stream(
440
- "GET", url, params={"history": history}
441
- ) as response:
442
- response.raise_for_status()
443
- json_data = "".join([
444
- chunk async for chunk in response.aiter_text()
445
- ])
446
- data = json.loads(json_data)
447
- if "id" not in data:
448
- data["id"] = task_id
449
- verbose_response = TaskResponseVerbose(**data)
450
-
451
- if verbose:
452
- return verbose_response
453
- if any(
454
- JobNames.from_string(job_name) in verbose_response.job_name
455
- for job_name in ["crow", "falcon", "owl", "dummy"]
456
- ):
457
- return PQATaskResponse(**data)
458
- return TaskResponse(**data)
459
- except HTTPStatusError as e:
460
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
461
- raise AuthError(
462
- e.response.status_code,
463
- f"Authentication failed: {e}",
464
- request=e.request,
465
- response=e.response,
466
- ) from e
467
- raise
468
- except Exception as e:
469
- raise TaskFetchError(f"Error getting task: {e!s}") from e
514
+ if verbose:
515
+ return verbose_response
516
+ return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
470
517
 
471
- @refresh_token_on_auth_error()
472
518
  @retry(
473
519
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
474
520
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -485,27 +531,18 @@ class RestClient:
485
531
  self.stage,
486
532
  )
487
533
 
488
- try:
489
- response = self.client.post(
490
- "/v0.1/crows", json=task_data.model_dump(mode="json")
534
+ response = self.client.post(
535
+ "/v0.1/crows", json=task_data.model_dump(mode="json")
536
+ )
537
+ if response.status_code in {401, 403}:
538
+ raise PermissionError(
539
+ f"Error creating task: Permission denied for task {task_data.name}"
491
540
  )
492
- response.raise_for_status()
493
- trajectory_id = response.json()["trajectory_id"]
494
- self.trajectory_id = trajectory_id
495
- except HTTPStatusError as e:
496
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
497
- raise AuthError(
498
- e.response.status_code,
499
- f"Authentication failed: {e}",
500
- request=e.request,
501
- response=e.response,
502
- ) from e
503
- raise
504
- except Exception as e:
505
- raise TaskFetchError(f"Error creating task: {e!s}") from e
541
+ response.raise_for_status()
542
+ trajectory_id = response.json()["trajectory_id"]
543
+ self.trajectory_id = trajectory_id
506
544
  return trajectory_id
507
545
 
508
- @refresh_token_on_auth_error()
509
546
  @retry(
510
547
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
511
548
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -521,25 +558,16 @@ class RestClient:
521
558
  task_data.name.name,
522
559
  self.stage,
523
560
  )
524
-
525
- try:
526
- response = await self.async_client.post(
527
- "/v0.1/crows", json=task_data.model_dump(mode="json")
561
+ response = await self.async_client.post(
562
+ "/v0.1/crows", json=task_data.model_dump(mode="json")
563
+ )
564
+ if response.status_code in {401, 403}:
565
+ raise PermissionError(
566
+ f"Error creating task: Permission denied for task {task_data.name}"
528
567
  )
529
- response.raise_for_status()
530
- trajectory_id = response.json()["trajectory_id"]
531
- self.trajectory_id = trajectory_id
532
- except HTTPStatusError as e:
533
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
534
- raise AuthError(
535
- e.response.status_code,
536
- f"Authentication failed: {e}",
537
- request=e.request,
538
- response=e.response,
539
- ) from e
540
- raise
541
- except Exception as e:
542
- raise TaskFetchError(f"Error creating task: {e!s}") from e
568
+ response.raise_for_status()
569
+ trajectory_id = response.json()["trajectory_id"]
570
+ self.trajectory_id = trajectory_id
543
571
  return trajectory_id
544
572
 
545
573
  async def arun_tasks_until_done(
@@ -683,7 +711,6 @@ class RestClient:
683
711
  for task_id in trajectory_ids
684
712
  ]
685
713
 
686
- @refresh_token_on_auth_error()
687
714
  @retry(
688
715
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
689
716
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -695,19 +722,11 @@ class RestClient:
695
722
  build_id = build_id or self.build_id
696
723
  response = self.client.get(f"/v0.1/builds/{build_id}")
697
724
  response.raise_for_status()
698
- except HTTPStatusError as e:
699
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
700
- raise AuthError(
701
- e.response.status_code,
702
- f"Authentication failed: {e}",
703
- request=e.request,
704
- response=e.response,
705
- ) from e
706
- raise
725
+ except Exception as e:
726
+ raise JobFetchError(f"Error getting build status: {e!s}") from e
707
727
  return response.json()
708
728
 
709
729
  # TODO: Refactor later so we don't have to ignore PLR0915
710
- @refresh_token_on_auth_error()
711
730
  @retry(
712
731
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
713
732
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -887,13 +906,6 @@ class RestClient:
887
906
  build_context = response.json()
888
907
  self.build_id = build_context["build_id"]
889
908
  except HTTPStatusError as e:
890
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
891
- raise AuthError(
892
- e.response.status_code,
893
- f"Authentication failed: {e}",
894
- request=e.request,
895
- response=e.response,
896
- ) from e
897
909
  error_detail = response.json()
898
910
  error_message = error_detail.get("detail", str(e))
899
911
  raise JobCreationError(
@@ -903,6 +915,8 @@ class RestClient:
903
915
  raise JobCreationError(f"Error generating docker image: {e!s}") from e
904
916
  return build_context
905
917
 
918
+ # TODO: we should have have an async upload_file, check_assembly_status,
919
+ # wait_for_assembly_completion, upload_directory, upload_single_file
906
920
  @retry(
907
921
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
908
922
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -913,6 +927,8 @@ class RestClient:
913
927
  job_name: str,
914
928
  file_path: str | os.PathLike,
915
929
  upload_id: str | None = None,
930
+ wait_for_assembly: bool = True,
931
+ assembly_timeout: int = MAX_ASSEMBLY_WAIT_TIME,
916
932
  ) -> str:
917
933
  """Upload a file or directory to a futurehouse job bucket.
918
934
 
@@ -920,29 +936,47 @@ class RestClient:
920
936
  job_name: The name of the futurehouse job to upload to.
921
937
  file_path: The local path to the file or directory to upload.
922
938
  upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
939
+ wait_for_assembly: After file chunking, wait for the assembly to be processed.
940
+ assembly_timeout: Maximum time to wait for assembly in seconds.
923
941
 
924
942
  Returns:
925
943
  The upload ID used for the upload.
926
944
 
927
945
  Raises:
928
946
  FileUploadError: If there's an error uploading the file.
947
+ RestClientError: If assembly fails or times out.
929
948
  """
930
949
  file_path = Path(file_path)
931
950
  if not file_path.exists():
932
951
  raise FileNotFoundError(f"File or directory not found: {file_path}")
933
952
 
934
953
  upload_id = upload_id or str(uuid.uuid4())
954
+ uploaded_files: list[str] = []
935
955
 
936
956
  if file_path.is_dir():
937
957
  # Process directory recursively
938
- self._upload_directory(job_name, file_path, upload_id)
958
+ uploaded_files = self._upload_directory(job_name, file_path, upload_id)
939
959
  else:
940
960
  # Process single file
941
961
  self._upload_single_file(job_name, file_path, upload_id)
962
+ uploaded_files = [file_path.name]
963
+
964
+ # Wait for all assemblies if requested and we have files
965
+ if wait_for_assembly and uploaded_files:
966
+ success = self._wait_for_all_assemblies_completion(
967
+ job_name, upload_id, uploaded_files, assembly_timeout
968
+ )
969
+ if not success:
970
+ raise RestClientError(
971
+ f"Assembly failed or timed out for one or more files: {uploaded_files}"
972
+ )
973
+
942
974
  logger.info(f"Successfully uploaded {file_path} to {upload_id}")
943
975
  return upload_id
944
976
 
945
- def _upload_directory(self, job_name: str, dir_path: Path, upload_id: str) -> None:
977
+ def _upload_directory(
978
+ self, job_name: str, dir_path: Path, upload_id: str
979
+ ) -> list[str]:
946
980
  """Upload all files in a directory recursively.
947
981
 
948
982
  Args:
@@ -950,12 +984,17 @@ class RestClient:
950
984
  dir_path: The path to the directory to upload.
951
985
  upload_id: The upload ID to use.
952
986
 
987
+ Returns:
988
+ List of uploaded file names.
989
+
953
990
  Raises:
954
991
  FileUploadError: If there's an error uploading any file.
955
992
  """
956
993
  # Skip common directories that shouldn't be uploaded
957
994
  if any(ignore in dir_path.parts for ignore in FILE_UPLOAD_IGNORE_PARTS):
958
- return
995
+ return []
996
+
997
+ uploaded_files: list[str] = []
959
998
 
960
999
  try:
961
1000
  # Upload all files in the directory recursively
@@ -965,24 +1004,27 @@ class RestClient:
965
1004
  ):
966
1005
  # Use path relative to the original directory as file name
967
1006
  rel_path = path.relative_to(dir_path)
1007
+ file_name = str(rel_path)
968
1008
  self._upload_single_file(
969
1009
  job_name,
970
1010
  path,
971
1011
  upload_id,
972
- file_name=str(rel_path),
1012
+ file_name=file_name,
973
1013
  )
1014
+ uploaded_files.append(file_name)
974
1015
  except Exception as e:
975
1016
  raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
976
1017
 
977
- @refresh_token_on_auth_error()
1018
+ return uploaded_files
1019
+
978
1020
  def _upload_single_file(
979
1021
  self,
980
1022
  job_name: str,
981
1023
  file_path: Path,
982
1024
  upload_id: str,
983
1025
  file_name: str | None = None,
984
- ) -> None:
985
- """Upload a single file in chunks.
1026
+ ) -> str | None:
1027
+ """Upload a single file in chunks using parallel uploads.
986
1028
 
987
1029
  Args:
988
1030
  job_name: The key of the crow to upload to.
@@ -990,6 +1032,9 @@ class RestClient:
990
1032
  upload_id: The upload ID to use.
991
1033
  file_name: Optional name to use for the file. If not provided, the file's name will be used.
992
1034
 
1035
+ Returns:
1036
+ The status URL if this was the last chunk, None otherwise.
1037
+
993
1038
  Raises:
994
1039
  FileUploadError: If there's an error uploading the file.
995
1040
  """
@@ -999,17 +1044,103 @@ class RestClient:
999
1044
  # Skip empty files
1000
1045
  if file_size == 0:
1001
1046
  logger.warning(f"Skipping upload of empty file: {file_path}")
1002
- return
1047
+ return None
1003
1048
 
1004
1049
  total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
1005
1050
 
1006
1051
  logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
1007
1052
 
1053
+ status_url = None
1054
+
1008
1055
  try:
1009
- with open(file_path, "rb") as f:
1010
- for chunk_index in range(total_chunks):
1011
- # Read the chunk from the file
1012
- f.seek(chunk_index * self.CHUNK_SIZE)
1056
+ status_url = self._upload_chunks_parallel(
1057
+ job_name,
1058
+ file_path,
1059
+ file_name,
1060
+ upload_id,
1061
+ total_chunks,
1062
+ )
1063
+
1064
+ logger.info(f"Successfully uploaded {file_name}")
1065
+ except Exception as e:
1066
+ logger.exception(f"Error uploading file {file_path}")
1067
+ raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
1068
+ return status_url
1069
+
1070
+ def _upload_chunks_parallel(
1071
+ self,
1072
+ job_name: str,
1073
+ file_path: Path,
1074
+ file_name: str,
1075
+ upload_id: str,
1076
+ total_chunks: int,
1077
+ ) -> str | None:
1078
+ """Upload all chunks in parallel batches, including the final chunk.
1079
+
1080
+ Args:
1081
+ job_name: The key of the crow to upload to.
1082
+ file_path: The path to the file to upload.
1083
+ file_name: The name to use for the file.
1084
+ upload_id: The upload ID to use.
1085
+ total_chunks: Total number of chunks.
1086
+
1087
+ Returns:
1088
+ The status URL from the final chunk response, or None if no chunks.
1089
+
1090
+ Raises:
1091
+ FileUploadError: If there's an error uploading any chunk.
1092
+ """
1093
+ if total_chunks <= 0:
1094
+ return None
1095
+
1096
+ if total_chunks > 1:
1097
+ num_regular_chunks = total_chunks - 1
1098
+ for batch_start in range(0, num_regular_chunks, self.MAX_CONCURRENT_CHUNKS):
1099
+ batch_end = min(
1100
+ batch_start + self.MAX_CONCURRENT_CHUNKS, num_regular_chunks
1101
+ )
1102
+
1103
+ # Upload chunks in this batch concurrently
1104
+ with ThreadPoolExecutor(
1105
+ max_workers=self.MAX_CONCURRENT_CHUNKS
1106
+ ) as executor:
1107
+ futures = {
1108
+ executor.submit(
1109
+ self._upload_single_chunk,
1110
+ job_name,
1111
+ file_path,
1112
+ file_name,
1113
+ upload_id,
1114
+ chunk_index,
1115
+ total_chunks,
1116
+ ): chunk_index
1117
+ for chunk_index in range(batch_start, batch_end)
1118
+ }
1119
+
1120
+ for future in as_completed(futures):
1121
+ chunk_index = futures[future]
1122
+ try:
1123
+ future.result()
1124
+ logger.debug(
1125
+ f"Uploaded chunk {chunk_index + 1}/{total_chunks} of {file_name}"
1126
+ )
1127
+ except Exception as e:
1128
+ logger.error(f"Error uploading chunk {chunk_index}: {e}")
1129
+ raise FileUploadError(
1130
+ f"Error uploading chunk {chunk_index} of {file_name}: {e}"
1131
+ ) from e
1132
+
1133
+ # Upload the final chunk with retry logic
1134
+ final_chunk_index = total_chunks - 1
1135
+ retries = 0
1136
+ max_retries = 3
1137
+ retry_delay = 2.0
1138
+
1139
+ while retries < max_retries:
1140
+ try:
1141
+ with open(file_path, "rb") as f:
1142
+ # Read the final chunk from the file
1143
+ f.seek(final_chunk_index * self.CHUNK_SIZE)
1013
1144
  chunk_data = f.read(self.CHUNK_SIZE)
1014
1145
 
1015
1146
  # Prepare and send the chunk
@@ -1028,40 +1159,108 @@ class RestClient:
1028
1159
  }
1029
1160
  data = {
1030
1161
  "file_name": file_name,
1031
- "chunk_index": chunk_index,
1162
+ "chunk_index": final_chunk_index,
1032
1163
  "total_chunks": total_chunks,
1033
1164
  "upload_id": upload_id,
1034
1165
  }
1035
1166
 
1036
- # Send the chunk
1167
+ # Send the final chunk
1037
1168
  response = self.multipart_client.post(
1038
1169
  f"/v0.1/crows/{job_name}/upload-chunk",
1039
1170
  files=files,
1040
1171
  data=data,
1041
1172
  )
1173
+
1174
+ # Handle missing chunks (status 409)
1175
+ if response.status_code == codes.CONFLICT:
1176
+ retries += 1
1177
+ if retries < max_retries:
1178
+ logger.warning(
1179
+ f"Missing chunks detected for {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries})"
1180
+ )
1181
+ time.sleep(retry_delay)
1182
+ continue
1183
+
1042
1184
  response.raise_for_status()
1185
+ response_data = response.json()
1186
+ status_url = response_data.get("status_url")
1043
1187
 
1044
- # Call progress callback if provided
1188
+ logger.debug(
1189
+ f"Uploaded final chunk {final_chunk_index + 1}/{total_chunks} of {file_name}"
1190
+ )
1191
+ return status_url
1045
1192
 
1046
- logger.debug(
1047
- f"Uploaded chunk {chunk_index + 1}/{total_chunks} of {file_name}"
1048
- )
1193
+ except Exception as e:
1194
+ if retries >= max_retries - 1:
1195
+ raise FileUploadError(
1196
+ f"Error uploading final chunk of {file_name}: {e}"
1197
+ ) from e
1198
+ retries += 1
1199
+ logger.warning(
1200
+ f"Error uploading final chunk of {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries}): {e}"
1201
+ )
1202
+ time.sleep(retry_delay)
1049
1203
 
1050
- logger.info(f"Successfully uploaded {file_name}")
1051
- except HTTPStatusError as e:
1052
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1053
- raise AuthError(
1054
- e.response.status_code,
1055
- f"Authentication failed: {e}",
1056
- request=e.request,
1057
- response=e.response,
1058
- ) from e
1059
- raise
1060
- except Exception as e:
1061
- logger.exception(f"Error uploading file {file_path}")
1062
- raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
1204
+ raise FileUploadError(
1205
+ f"Failed to upload final chunk of {file_name} after {max_retries} retries"
1206
+ )
1207
+
1208
+ def _upload_single_chunk(
1209
+ self,
1210
+ job_name: str,
1211
+ file_path: Path,
1212
+ file_name: str,
1213
+ upload_id: str,
1214
+ chunk_index: int,
1215
+ total_chunks: int,
1216
+ ) -> None:
1217
+ """Upload a single chunk.
1218
+
1219
+ Args:
1220
+ job_name: The key of the crow to upload to.
1221
+ file_path: The path to the file to upload.
1222
+ file_name: The name to use for the file.
1223
+ upload_id: The upload ID to use.
1224
+ chunk_index: The index of this chunk.
1225
+ total_chunks: Total number of chunks.
1226
+
1227
+ Raises:
1228
+ Exception: If there's an error uploading the chunk.
1229
+ """
1230
+ with open(file_path, "rb") as f:
1231
+ # Read the chunk from the file
1232
+ f.seek(chunk_index * self.CHUNK_SIZE)
1233
+ chunk_data = f.read(self.CHUNK_SIZE)
1234
+
1235
+ # Prepare and send the chunk
1236
+ with tempfile.NamedTemporaryFile() as temp_file:
1237
+ temp_file.write(chunk_data)
1238
+ temp_file.flush()
1239
+
1240
+ # Create form data
1241
+ with open(temp_file.name, "rb") as chunk_file_obj:
1242
+ files = {
1243
+ "chunk": (
1244
+ file_name,
1245
+ chunk_file_obj,
1246
+ "application/octet-stream",
1247
+ )
1248
+ }
1249
+ data = {
1250
+ "file_name": file_name,
1251
+ "chunk_index": chunk_index,
1252
+ "total_chunks": total_chunks,
1253
+ "upload_id": upload_id,
1254
+ }
1255
+
1256
+ # Send the chunk
1257
+ response = self.multipart_client.post(
1258
+ f"/v0.1/crows/{job_name}/upload-chunk",
1259
+ files=files,
1260
+ data=data,
1261
+ )
1262
+ response.raise_for_status()
1063
1263
 
1064
- @refresh_token_on_auth_error()
1065
1264
  @retry(
1066
1265
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1067
1266
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1098,13 +1297,6 @@ class RestClient:
1098
1297
  response.raise_for_status()
1099
1298
  return response.json()
1100
1299
  except HTTPStatusError as e:
1101
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1102
- raise AuthError(
1103
- e.response.status_code,
1104
- f"Authentication failed: {e}",
1105
- request=e.request,
1106
- response=e.response,
1107
- ) from e
1108
1300
  logger.exception(
1109
1301
  f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
1110
1302
  )
@@ -1117,7 +1309,6 @@ class RestClient:
1117
1309
  )
1118
1310
  raise RestClientError(f"Error listing files: {e!s}") from e
1119
1311
 
1120
- @refresh_token_on_auth_error()
1121
1312
  @retry(
1122
1313
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1123
1314
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1165,13 +1356,6 @@ class RestClient:
1165
1356
 
1166
1357
  logger.info(f"File {file_path} downloaded to {destination_path}")
1167
1358
  except HTTPStatusError as e:
1168
- if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
1169
- raise AuthError(
1170
- e.response.status_code,
1171
- f"Authentication failed: {e}",
1172
- request=e.request,
1173
- response=e.response,
1174
- ) from e
1175
1359
  logger.exception(
1176
1360
  f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
1177
1361
  )