futurehouse-client 0.3.18.dev186__py3-none-any.whl → 0.3.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/__init__.py +11 -1
- futurehouse_client/clients/__init__.py +1 -2
- futurehouse_client/clients/job_client.py +27 -1
- futurehouse_client/clients/rest_client.py +439 -255
- futurehouse_client/models/app.py +70 -2
- futurehouse_client/py.typed +0 -0
- futurehouse_client/utils/auth.py +86 -101
- {futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dist-info}/METADATA +2 -3
- futurehouse_client-0.3.19.dist-info/RECORD +18 -0
- {futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dist-info}/WHEEL +1 -1
- futurehouse_client-0.3.18.dev186.dist-info/RECORD +0 -17
- {futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dist-info}/top_level.txt +0 -0
@@ -13,9 +13,10 @@ import tempfile
|
|
13
13
|
import time
|
14
14
|
import uuid
|
15
15
|
from collections.abc import Collection
|
16
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16
17
|
from pathlib import Path
|
17
18
|
from types import ModuleType
|
18
|
-
from typing import Any, ClassVar,
|
19
|
+
from typing import Any, ClassVar, cast
|
19
20
|
from uuid import UUID
|
20
21
|
|
21
22
|
import cloudpickle
|
@@ -31,6 +32,7 @@ from httpx import (
|
|
31
32
|
ReadError,
|
32
33
|
ReadTimeout,
|
33
34
|
RemoteProtocolError,
|
35
|
+
codes,
|
34
36
|
)
|
35
37
|
from ldp.agent import AgentConfig
|
36
38
|
from requests.exceptions import RequestException, Timeout
|
@@ -45,21 +47,15 @@ from tqdm.asyncio import tqdm
|
|
45
47
|
|
46
48
|
from futurehouse_client.clients import JobNames
|
47
49
|
from futurehouse_client.models.app import (
|
48
|
-
APIKeyPayload,
|
49
50
|
AuthType,
|
50
51
|
JobDeploymentConfig,
|
51
|
-
PQATaskResponse,
|
52
52
|
Stage,
|
53
53
|
TaskRequest,
|
54
54
|
TaskResponse,
|
55
55
|
TaskResponseVerbose,
|
56
56
|
)
|
57
57
|
from futurehouse_client.models.rest import ExecutionStatus
|
58
|
-
from futurehouse_client.utils.auth import
|
59
|
-
AUTH_ERRORS_TO_RETRY_ON,
|
60
|
-
AuthError,
|
61
|
-
refresh_token_on_auth_error,
|
62
|
-
)
|
58
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
63
59
|
from futurehouse_client.utils.general import gather_with_concurrency
|
64
60
|
from futurehouse_client.utils.module_utils import (
|
65
61
|
OrganizationSelector,
|
@@ -128,8 +124,6 @@ retry_if_connection_error = retry_if_exception_type((
|
|
128
124
|
FileUploadError,
|
129
125
|
))
|
130
126
|
|
131
|
-
# 5 minute default for JWTs
|
132
|
-
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
133
127
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
134
128
|
|
135
129
|
|
@@ -140,6 +134,9 @@ class RestClient:
|
|
140
134
|
MAX_RETRY_WAIT: ClassVar[int] = 10
|
141
135
|
DEFAULT_POLLING_TIME: ClassVar[int] = 5 # seconds
|
142
136
|
CHUNK_SIZE: ClassVar[int] = 16 * 1024 * 1024 # 16MB chunks
|
137
|
+
ASSEMBLY_POLLING_INTERVAL: ClassVar[int] = 10 # seconds
|
138
|
+
MAX_ASSEMBLY_WAIT_TIME: ClassVar[int] = 1800 # 30 minutes
|
139
|
+
MAX_CONCURRENT_CHUNKS: ClassVar[int] = 12 # Maximum concurrent chunk uploads
|
143
140
|
|
144
141
|
def __init__(
|
145
142
|
self,
|
@@ -163,69 +160,87 @@ class RestClient:
|
|
163
160
|
self.api_key = api_key
|
164
161
|
self._clients: dict[str, Client | AsyncClient] = {}
|
165
162
|
self.headers = headers or {}
|
166
|
-
self.
|
163
|
+
self.jwt = jwt
|
167
164
|
self.organizations: list[str] = self._filter_orgs(organization)
|
168
165
|
|
169
166
|
@property
|
170
167
|
def client(self) -> Client:
|
171
|
-
"""
|
172
|
-
return cast(Client, self.get_client("application/json",
|
168
|
+
"""Authenticated HTTP client for regular API calls."""
|
169
|
+
return cast(Client, self.get_client("application/json", authenticated=True))
|
173
170
|
|
174
171
|
@property
|
175
172
|
def async_client(self) -> AsyncClient:
|
176
|
-
"""
|
173
|
+
"""Authenticated async HTTP client for regular API calls."""
|
177
174
|
return cast(
|
178
175
|
AsyncClient,
|
179
|
-
self.get_client("application/json",
|
176
|
+
self.get_client("application/json", authenticated=True, async_client=True),
|
180
177
|
)
|
181
178
|
|
182
179
|
@property
|
183
|
-
def
|
184
|
-
"""
|
185
|
-
return cast(Client, self.get_client("application/json",
|
180
|
+
def unauthenticated_client(self) -> Client:
|
181
|
+
"""Unauthenticated HTTP client for auth operations."""
|
182
|
+
return cast(Client, self.get_client("application/json", authenticated=False))
|
186
183
|
|
187
184
|
@property
|
188
185
|
def multipart_client(self) -> Client:
|
189
|
-
"""
|
190
|
-
return cast(Client, self.get_client(None,
|
186
|
+
"""Authenticated HTTP client for multipart uploads."""
|
187
|
+
return cast(Client, self.get_client(None, authenticated=True))
|
191
188
|
|
192
189
|
def get_client(
|
193
190
|
self,
|
194
191
|
content_type: str | None = "application/json",
|
195
|
-
|
196
|
-
|
192
|
+
authenticated: bool = True,
|
193
|
+
async_client: bool = False,
|
197
194
|
) -> Client | AsyncClient:
|
198
195
|
"""Return a cached HTTP client or create one if needed.
|
199
196
|
|
200
197
|
Args:
|
201
198
|
content_type: The desired content type header. Use None for multipart uploads.
|
202
|
-
|
203
|
-
|
199
|
+
authenticated: Whether the client should include authentication.
|
200
|
+
async_client: Whether to use an async client.
|
204
201
|
|
205
202
|
Returns:
|
206
203
|
An HTTP client configured with the appropriate headers.
|
207
204
|
"""
|
208
|
-
# Create a composite key based on content type and auth flag
|
209
|
-
key = f"{content_type or 'multipart'}_{
|
205
|
+
# Create a composite key based on content type and auth flag
|
206
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
207
|
+
|
210
208
|
if key not in self._clients:
|
211
209
|
headers = copy.deepcopy(self.headers)
|
212
|
-
|
213
|
-
|
210
|
+
auth = None
|
211
|
+
|
212
|
+
if authenticated:
|
213
|
+
auth = RefreshingJWT(
|
214
|
+
# authenticated=False will always return a synchronous client
|
215
|
+
auth_client=cast(
|
216
|
+
Client, self.get_client("application/json", authenticated=False)
|
217
|
+
),
|
218
|
+
auth_type=self.auth_type,
|
219
|
+
api_key=self.api_key,
|
220
|
+
jwt=self.jwt,
|
221
|
+
)
|
222
|
+
|
214
223
|
if content_type:
|
215
224
|
headers["Content-Type"] = content_type
|
225
|
+
|
226
|
+
headers["x-client"] = "sdk"
|
227
|
+
|
216
228
|
self._clients[key] = (
|
217
229
|
AsyncClient(
|
218
230
|
base_url=self.base_url,
|
219
231
|
headers=headers,
|
220
232
|
timeout=self.REQUEST_TIMEOUT,
|
233
|
+
auth=auth,
|
221
234
|
)
|
222
|
-
if
|
235
|
+
if async_client
|
223
236
|
else Client(
|
224
237
|
base_url=self.base_url,
|
225
238
|
headers=headers,
|
226
239
|
timeout=self.REQUEST_TIMEOUT,
|
240
|
+
auth=auth,
|
227
241
|
)
|
228
242
|
)
|
243
|
+
|
229
244
|
return self._clients[key]
|
230
245
|
|
231
246
|
def close(self):
|
@@ -255,32 +270,6 @@ class RestClient:
|
|
255
270
|
raise ValueError(f"Organization '{organization}' not found.")
|
256
271
|
return filtered_orgs
|
257
272
|
|
258
|
-
def _run_auth(self, jwt: str | None = None) -> str:
|
259
|
-
auth_payload: APIKeyPayload | None
|
260
|
-
if self.auth_type == AuthType.API_KEY:
|
261
|
-
auth_payload = APIKeyPayload(api_key=self.api_key)
|
262
|
-
elif self.auth_type == AuthType.JWT:
|
263
|
-
auth_payload = None
|
264
|
-
else:
|
265
|
-
assert_never(self.auth_type)
|
266
|
-
try:
|
267
|
-
# Use the unauthenticated client for login
|
268
|
-
if auth_payload:
|
269
|
-
response = self.auth_client.post(
|
270
|
-
"/auth/login", json=auth_payload.model_dump()
|
271
|
-
)
|
272
|
-
response.raise_for_status()
|
273
|
-
token_data = response.json()
|
274
|
-
elif jwt:
|
275
|
-
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
276
|
-
else:
|
277
|
-
raise ValueError("JWT token required for JWT authentication.")
|
278
|
-
|
279
|
-
return token_data["access_token"]
|
280
|
-
except Exception as e:
|
281
|
-
raise RestClientError(f"Error authenticating: {e!s}") from e
|
282
|
-
|
283
|
-
@refresh_token_on_auth_error()
|
284
273
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
285
274
|
try:
|
286
275
|
response = self.client.get(
|
@@ -288,25 +277,113 @@ class RestClient:
|
|
288
277
|
)
|
289
278
|
response.raise_for_status()
|
290
279
|
return response.json()
|
291
|
-
except HTTPStatusError as e:
|
292
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
293
|
-
raise AuthError(
|
294
|
-
e.response.status_code,
|
295
|
-
f"Authentication failed: {e}",
|
296
|
-
request=e.request,
|
297
|
-
response=e.response,
|
298
|
-
) from e
|
299
|
-
raise
|
300
280
|
except Exception as e:
|
301
281
|
raise JobFetchError(f"Error checking job: {e!s}") from e
|
302
282
|
|
303
|
-
@refresh_token_on_auth_error()
|
304
283
|
def _fetch_my_orgs(self) -> list[str]:
|
305
284
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
306
285
|
response.raise_for_status()
|
307
286
|
orgs = response.json()
|
308
287
|
return [org["name"] for org in orgs]
|
309
288
|
|
289
|
+
def _check_assembly_status(
|
290
|
+
self, job_name: str, upload_id: str, file_name: str
|
291
|
+
) -> dict[str, Any]:
|
292
|
+
"""Check the assembly status of an uploaded file.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
job_name: The name of the futurehouse job
|
296
|
+
upload_id: The upload ID
|
297
|
+
file_name: The name of the file
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
Dict containing status information
|
301
|
+
|
302
|
+
Raises:
|
303
|
+
RestClientError: If there's an error checking status
|
304
|
+
"""
|
305
|
+
try:
|
306
|
+
url = f"/v0.1/crows/{job_name}/assembly-status/{upload_id}/{file_name}"
|
307
|
+
response = self.client.get(url)
|
308
|
+
response.raise_for_status()
|
309
|
+
return response.json()
|
310
|
+
except Exception as e:
|
311
|
+
raise RestClientError(f"Error checking assembly status: {e}") from e
|
312
|
+
|
313
|
+
def _wait_for_all_assemblies_completion(
|
314
|
+
self,
|
315
|
+
job_name: str,
|
316
|
+
upload_id: str,
|
317
|
+
file_names: list[str],
|
318
|
+
timeout: int = MAX_ASSEMBLY_WAIT_TIME,
|
319
|
+
) -> bool:
|
320
|
+
"""Wait for all file assemblies to complete.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
job_name: The name of the futurehouse job
|
324
|
+
upload_id: The upload ID
|
325
|
+
file_names: List of file names to wait for
|
326
|
+
timeout: Maximum time to wait in seconds
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
True if all assemblies succeeded, False if any failed or timed out
|
330
|
+
|
331
|
+
Raises:
|
332
|
+
RestClientError: If any assembly fails
|
333
|
+
"""
|
334
|
+
if not file_names:
|
335
|
+
return True
|
336
|
+
|
337
|
+
start_time = time.time()
|
338
|
+
logger.info(f"Waiting for assembly of {len(file_names)} file(s) to complete...")
|
339
|
+
|
340
|
+
completed_files: set[str] = set()
|
341
|
+
|
342
|
+
while (time.time() - start_time) < timeout and len(completed_files) < len(
|
343
|
+
file_names
|
344
|
+
):
|
345
|
+
for file_name in file_names:
|
346
|
+
if file_name in completed_files:
|
347
|
+
continue
|
348
|
+
|
349
|
+
try:
|
350
|
+
status_data = self._check_assembly_status(
|
351
|
+
job_name, upload_id, file_name
|
352
|
+
)
|
353
|
+
status = status_data.get("status")
|
354
|
+
|
355
|
+
if status == ExecutionStatus.SUCCESS.value:
|
356
|
+
logger.info(f"Assembly completed for {file_name}")
|
357
|
+
completed_files.add(file_name)
|
358
|
+
elif status == ExecutionStatus.FAIL.value:
|
359
|
+
error_msg = status_data.get("error", "Unknown assembly error")
|
360
|
+
raise RestClientError(
|
361
|
+
f"Assembly failed for {file_name}: {error_msg}"
|
362
|
+
)
|
363
|
+
elif status == ExecutionStatus.IN_PROGRESS.value:
|
364
|
+
logger.debug(f"Assembly in progress for {file_name}...")
|
365
|
+
|
366
|
+
except RestClientError:
|
367
|
+
raise # Re-raise assembly errors
|
368
|
+
except Exception as e:
|
369
|
+
logger.warning(
|
370
|
+
f"Error checking assembly status for {file_name}: {e}"
|
371
|
+
)
|
372
|
+
|
373
|
+
# Don't sleep if all files are complete
|
374
|
+
if len(completed_files) < len(file_names):
|
375
|
+
time.sleep(self.ASSEMBLY_POLLING_INTERVAL)
|
376
|
+
|
377
|
+
if len(completed_files) < len(file_names):
|
378
|
+
remaining_files = set(file_names) - completed_files
|
379
|
+
logger.warning(
|
380
|
+
f"Assembly timeout for files: {remaining_files} after {timeout} seconds"
|
381
|
+
)
|
382
|
+
return False
|
383
|
+
|
384
|
+
logger.info(f"All {len(file_names)} file assemblies completed successfully")
|
385
|
+
return True
|
386
|
+
|
310
387
|
@staticmethod
|
311
388
|
def _validate_module_path(path: Path) -> None:
|
312
389
|
"""Validates that the given path exists and is a directory.
|
@@ -358,7 +435,6 @@ class RestClient:
|
|
358
435
|
if not files:
|
359
436
|
raise TaskFetchError(f"No files found in {path}")
|
360
437
|
|
361
|
-
@refresh_token_on_auth_error()
|
362
438
|
@retry(
|
363
439
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
364
440
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -368,51 +444,37 @@ class RestClient:
|
|
368
444
|
self, task_id: str | None = None, history: bool = False, verbose: bool = False
|
369
445
|
) -> "TaskResponse":
|
370
446
|
"""Get details for a specific task."""
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
full_url = f"{self.base_url}{url}"
|
375
|
-
|
376
|
-
with (
|
377
|
-
external_trace(
|
378
|
-
url=full_url,
|
379
|
-
method="GET",
|
380
|
-
library="httpx",
|
381
|
-
custom_params={
|
382
|
-
"operation": "get_job",
|
383
|
-
"job_id": task_id,
|
384
|
-
},
|
385
|
-
),
|
386
|
-
self.client.stream("GET", url, params={"history": history}) as response,
|
387
|
-
):
|
388
|
-
response.raise_for_status()
|
389
|
-
json_data = "".join(response.iter_text(chunk_size=1024))
|
390
|
-
data = json.loads(json_data)
|
391
|
-
if "id" not in data:
|
392
|
-
data["id"] = task_id
|
393
|
-
verbose_response = TaskResponseVerbose(**data)
|
447
|
+
task_id = task_id or self.trajectory_id
|
448
|
+
url = f"/v0.1/trajectories/{task_id}"
|
449
|
+
full_url = f"{self.base_url}{url}"
|
394
450
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
)
|
411
|
-
|
412
|
-
|
413
|
-
|
451
|
+
with (
|
452
|
+
external_trace(
|
453
|
+
url=full_url,
|
454
|
+
method="GET",
|
455
|
+
library="httpx",
|
456
|
+
custom_params={
|
457
|
+
"operation": "get_job",
|
458
|
+
"job_id": task_id,
|
459
|
+
},
|
460
|
+
),
|
461
|
+
self.client.stream("GET", url, params={"history": history}) as response,
|
462
|
+
):
|
463
|
+
if response.status_code in {401, 403}:
|
464
|
+
raise PermissionError(
|
465
|
+
f"Error getting task: Permission denied for task {task_id}"
|
466
|
+
)
|
467
|
+
response.raise_for_status()
|
468
|
+
json_data = "".join(response.iter_text(chunk_size=1024))
|
469
|
+
data = json.loads(json_data)
|
470
|
+
if "id" not in data:
|
471
|
+
data["id"] = task_id
|
472
|
+
verbose_response = TaskResponseVerbose(**data)
|
473
|
+
|
474
|
+
if verbose:
|
475
|
+
return verbose_response
|
476
|
+
return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
|
414
477
|
|
415
|
-
@refresh_token_on_auth_error()
|
416
478
|
@retry(
|
417
479
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
418
480
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -422,53 +484,37 @@ class RestClient:
|
|
422
484
|
self, task_id: str | None = None, history: bool = False, verbose: bool = False
|
423
485
|
) -> "TaskResponse":
|
424
486
|
"""Get details for a specific task asynchronously."""
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
487
|
+
task_id = task_id or self.trajectory_id
|
488
|
+
url = f"/v0.1/trajectories/{task_id}"
|
489
|
+
full_url = f"{self.base_url}{url}"
|
490
|
+
|
491
|
+
with external_trace(
|
492
|
+
url=full_url,
|
493
|
+
method="GET",
|
494
|
+
library="httpx",
|
495
|
+
custom_params={
|
496
|
+
"operation": "get_job",
|
497
|
+
"job_id": task_id,
|
498
|
+
},
|
499
|
+
):
|
500
|
+
async with self.async_client.stream(
|
501
|
+
"GET", url, params={"history": history}
|
502
|
+
) as response:
|
503
|
+
if response.status_code in {401, 403}:
|
504
|
+
raise PermissionError(
|
505
|
+
f"Error getting task: Permission denied for task {task_id}"
|
506
|
+
)
|
507
|
+
response.raise_for_status()
|
508
|
+
json_data = "".join([chunk async for chunk in response.aiter_text()])
|
509
|
+
data = json.loads(json_data)
|
510
|
+
if "id" not in data:
|
511
|
+
data["id"] = task_id
|
512
|
+
verbose_response = TaskResponseVerbose(**data)
|
429
513
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
library="httpx",
|
434
|
-
custom_params={
|
435
|
-
"operation": "get_job",
|
436
|
-
"job_id": task_id,
|
437
|
-
},
|
438
|
-
):
|
439
|
-
async with self.async_client.stream(
|
440
|
-
"GET", url, params={"history": history}
|
441
|
-
) as response:
|
442
|
-
response.raise_for_status()
|
443
|
-
json_data = "".join([
|
444
|
-
chunk async for chunk in response.aiter_text()
|
445
|
-
])
|
446
|
-
data = json.loads(json_data)
|
447
|
-
if "id" not in data:
|
448
|
-
data["id"] = task_id
|
449
|
-
verbose_response = TaskResponseVerbose(**data)
|
450
|
-
|
451
|
-
if verbose:
|
452
|
-
return verbose_response
|
453
|
-
if any(
|
454
|
-
JobNames.from_string(job_name) in verbose_response.job_name
|
455
|
-
for job_name in ["crow", "falcon", "owl", "dummy"]
|
456
|
-
):
|
457
|
-
return PQATaskResponse(**data)
|
458
|
-
return TaskResponse(**data)
|
459
|
-
except HTTPStatusError as e:
|
460
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
461
|
-
raise AuthError(
|
462
|
-
e.response.status_code,
|
463
|
-
f"Authentication failed: {e}",
|
464
|
-
request=e.request,
|
465
|
-
response=e.response,
|
466
|
-
) from e
|
467
|
-
raise
|
468
|
-
except Exception as e:
|
469
|
-
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
514
|
+
if verbose:
|
515
|
+
return verbose_response
|
516
|
+
return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
|
470
517
|
|
471
|
-
@refresh_token_on_auth_error()
|
472
518
|
@retry(
|
473
519
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
474
520
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -485,27 +531,18 @@ class RestClient:
|
|
485
531
|
self.stage,
|
486
532
|
)
|
487
533
|
|
488
|
-
|
489
|
-
|
490
|
-
|
534
|
+
response = self.client.post(
|
535
|
+
"/v0.1/crows", json=task_data.model_dump(mode="json")
|
536
|
+
)
|
537
|
+
if response.status_code in {401, 403}:
|
538
|
+
raise PermissionError(
|
539
|
+
f"Error creating task: Permission denied for task {task_data.name}"
|
491
540
|
)
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
except HTTPStatusError as e:
|
496
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
497
|
-
raise AuthError(
|
498
|
-
e.response.status_code,
|
499
|
-
f"Authentication failed: {e}",
|
500
|
-
request=e.request,
|
501
|
-
response=e.response,
|
502
|
-
) from e
|
503
|
-
raise
|
504
|
-
except Exception as e:
|
505
|
-
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
541
|
+
response.raise_for_status()
|
542
|
+
trajectory_id = response.json()["trajectory_id"]
|
543
|
+
self.trajectory_id = trajectory_id
|
506
544
|
return trajectory_id
|
507
545
|
|
508
|
-
@refresh_token_on_auth_error()
|
509
546
|
@retry(
|
510
547
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
511
548
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -521,25 +558,16 @@ class RestClient:
|
|
521
558
|
task_data.name.name,
|
522
559
|
self.stage,
|
523
560
|
)
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
561
|
+
response = await self.async_client.post(
|
562
|
+
"/v0.1/crows", json=task_data.model_dump(mode="json")
|
563
|
+
)
|
564
|
+
if response.status_code in {401, 403}:
|
565
|
+
raise PermissionError(
|
566
|
+
f"Error creating task: Permission denied for task {task_data.name}"
|
528
567
|
)
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
except HTTPStatusError as e:
|
533
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
534
|
-
raise AuthError(
|
535
|
-
e.response.status_code,
|
536
|
-
f"Authentication failed: {e}",
|
537
|
-
request=e.request,
|
538
|
-
response=e.response,
|
539
|
-
) from e
|
540
|
-
raise
|
541
|
-
except Exception as e:
|
542
|
-
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
568
|
+
response.raise_for_status()
|
569
|
+
trajectory_id = response.json()["trajectory_id"]
|
570
|
+
self.trajectory_id = trajectory_id
|
543
571
|
return trajectory_id
|
544
572
|
|
545
573
|
async def arun_tasks_until_done(
|
@@ -683,7 +711,6 @@ class RestClient:
|
|
683
711
|
for task_id in trajectory_ids
|
684
712
|
]
|
685
713
|
|
686
|
-
@refresh_token_on_auth_error()
|
687
714
|
@retry(
|
688
715
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
689
716
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -695,19 +722,11 @@ class RestClient:
|
|
695
722
|
build_id = build_id or self.build_id
|
696
723
|
response = self.client.get(f"/v0.1/builds/{build_id}")
|
697
724
|
response.raise_for_status()
|
698
|
-
except
|
699
|
-
|
700
|
-
raise AuthError(
|
701
|
-
e.response.status_code,
|
702
|
-
f"Authentication failed: {e}",
|
703
|
-
request=e.request,
|
704
|
-
response=e.response,
|
705
|
-
) from e
|
706
|
-
raise
|
725
|
+
except Exception as e:
|
726
|
+
raise JobFetchError(f"Error getting build status: {e!s}") from e
|
707
727
|
return response.json()
|
708
728
|
|
709
729
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
710
|
-
@refresh_token_on_auth_error()
|
711
730
|
@retry(
|
712
731
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
713
732
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -887,13 +906,6 @@ class RestClient:
|
|
887
906
|
build_context = response.json()
|
888
907
|
self.build_id = build_context["build_id"]
|
889
908
|
except HTTPStatusError as e:
|
890
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
891
|
-
raise AuthError(
|
892
|
-
e.response.status_code,
|
893
|
-
f"Authentication failed: {e}",
|
894
|
-
request=e.request,
|
895
|
-
response=e.response,
|
896
|
-
) from e
|
897
909
|
error_detail = response.json()
|
898
910
|
error_message = error_detail.get("detail", str(e))
|
899
911
|
raise JobCreationError(
|
@@ -903,6 +915,8 @@ class RestClient:
|
|
903
915
|
raise JobCreationError(f"Error generating docker image: {e!s}") from e
|
904
916
|
return build_context
|
905
917
|
|
918
|
+
# TODO: we should have have an async upload_file, check_assembly_status,
|
919
|
+
# wait_for_assembly_completion, upload_directory, upload_single_file
|
906
920
|
@retry(
|
907
921
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
908
922
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -913,6 +927,8 @@ class RestClient:
|
|
913
927
|
job_name: str,
|
914
928
|
file_path: str | os.PathLike,
|
915
929
|
upload_id: str | None = None,
|
930
|
+
wait_for_assembly: bool = True,
|
931
|
+
assembly_timeout: int = MAX_ASSEMBLY_WAIT_TIME,
|
916
932
|
) -> str:
|
917
933
|
"""Upload a file or directory to a futurehouse job bucket.
|
918
934
|
|
@@ -920,29 +936,47 @@ class RestClient:
|
|
920
936
|
job_name: The name of the futurehouse job to upload to.
|
921
937
|
file_path: The local path to the file or directory to upload.
|
922
938
|
upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
|
939
|
+
wait_for_assembly: After file chunking, wait for the assembly to be processed.
|
940
|
+
assembly_timeout: Maximum time to wait for assembly in seconds.
|
923
941
|
|
924
942
|
Returns:
|
925
943
|
The upload ID used for the upload.
|
926
944
|
|
927
945
|
Raises:
|
928
946
|
FileUploadError: If there's an error uploading the file.
|
947
|
+
RestClientError: If assembly fails or times out.
|
929
948
|
"""
|
930
949
|
file_path = Path(file_path)
|
931
950
|
if not file_path.exists():
|
932
951
|
raise FileNotFoundError(f"File or directory not found: {file_path}")
|
933
952
|
|
934
953
|
upload_id = upload_id or str(uuid.uuid4())
|
954
|
+
uploaded_files: list[str] = []
|
935
955
|
|
936
956
|
if file_path.is_dir():
|
937
957
|
# Process directory recursively
|
938
|
-
self._upload_directory(job_name, file_path, upload_id)
|
958
|
+
uploaded_files = self._upload_directory(job_name, file_path, upload_id)
|
939
959
|
else:
|
940
960
|
# Process single file
|
941
961
|
self._upload_single_file(job_name, file_path, upload_id)
|
962
|
+
uploaded_files = [file_path.name]
|
963
|
+
|
964
|
+
# Wait for all assemblies if requested and we have files
|
965
|
+
if wait_for_assembly and uploaded_files:
|
966
|
+
success = self._wait_for_all_assemblies_completion(
|
967
|
+
job_name, upload_id, uploaded_files, assembly_timeout
|
968
|
+
)
|
969
|
+
if not success:
|
970
|
+
raise RestClientError(
|
971
|
+
f"Assembly failed or timed out for one or more files: {uploaded_files}"
|
972
|
+
)
|
973
|
+
|
942
974
|
logger.info(f"Successfully uploaded {file_path} to {upload_id}")
|
943
975
|
return upload_id
|
944
976
|
|
945
|
-
def _upload_directory(
|
977
|
+
def _upload_directory(
|
978
|
+
self, job_name: str, dir_path: Path, upload_id: str
|
979
|
+
) -> list[str]:
|
946
980
|
"""Upload all files in a directory recursively.
|
947
981
|
|
948
982
|
Args:
|
@@ -950,12 +984,17 @@ class RestClient:
|
|
950
984
|
dir_path: The path to the directory to upload.
|
951
985
|
upload_id: The upload ID to use.
|
952
986
|
|
987
|
+
Returns:
|
988
|
+
List of uploaded file names.
|
989
|
+
|
953
990
|
Raises:
|
954
991
|
FileUploadError: If there's an error uploading any file.
|
955
992
|
"""
|
956
993
|
# Skip common directories that shouldn't be uploaded
|
957
994
|
if any(ignore in dir_path.parts for ignore in FILE_UPLOAD_IGNORE_PARTS):
|
958
|
-
return
|
995
|
+
return []
|
996
|
+
|
997
|
+
uploaded_files: list[str] = []
|
959
998
|
|
960
999
|
try:
|
961
1000
|
# Upload all files in the directory recursively
|
@@ -965,24 +1004,27 @@ class RestClient:
|
|
965
1004
|
):
|
966
1005
|
# Use path relative to the original directory as file name
|
967
1006
|
rel_path = path.relative_to(dir_path)
|
1007
|
+
file_name = str(rel_path)
|
968
1008
|
self._upload_single_file(
|
969
1009
|
job_name,
|
970
1010
|
path,
|
971
1011
|
upload_id,
|
972
|
-
file_name=
|
1012
|
+
file_name=file_name,
|
973
1013
|
)
|
1014
|
+
uploaded_files.append(file_name)
|
974
1015
|
except Exception as e:
|
975
1016
|
raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
|
976
1017
|
|
977
|
-
|
1018
|
+
return uploaded_files
|
1019
|
+
|
978
1020
|
def _upload_single_file(
|
979
1021
|
self,
|
980
1022
|
job_name: str,
|
981
1023
|
file_path: Path,
|
982
1024
|
upload_id: str,
|
983
1025
|
file_name: str | None = None,
|
984
|
-
) -> None:
|
985
|
-
"""Upload a single file in chunks.
|
1026
|
+
) -> str | None:
|
1027
|
+
"""Upload a single file in chunks using parallel uploads.
|
986
1028
|
|
987
1029
|
Args:
|
988
1030
|
job_name: The key of the crow to upload to.
|
@@ -990,6 +1032,9 @@ class RestClient:
|
|
990
1032
|
upload_id: The upload ID to use.
|
991
1033
|
file_name: Optional name to use for the file. If not provided, the file's name will be used.
|
992
1034
|
|
1035
|
+
Returns:
|
1036
|
+
The status URL if this was the last chunk, None otherwise.
|
1037
|
+
|
993
1038
|
Raises:
|
994
1039
|
FileUploadError: If there's an error uploading the file.
|
995
1040
|
"""
|
@@ -999,17 +1044,103 @@ class RestClient:
|
|
999
1044
|
# Skip empty files
|
1000
1045
|
if file_size == 0:
|
1001
1046
|
logger.warning(f"Skipping upload of empty file: {file_path}")
|
1002
|
-
return
|
1047
|
+
return None
|
1003
1048
|
|
1004
1049
|
total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
|
1005
1050
|
|
1006
1051
|
logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
|
1007
1052
|
|
1053
|
+
status_url = None
|
1054
|
+
|
1008
1055
|
try:
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1056
|
+
status_url = self._upload_chunks_parallel(
|
1057
|
+
job_name,
|
1058
|
+
file_path,
|
1059
|
+
file_name,
|
1060
|
+
upload_id,
|
1061
|
+
total_chunks,
|
1062
|
+
)
|
1063
|
+
|
1064
|
+
logger.info(f"Successfully uploaded {file_name}")
|
1065
|
+
except Exception as e:
|
1066
|
+
logger.exception(f"Error uploading file {file_path}")
|
1067
|
+
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
1068
|
+
return status_url
|
1069
|
+
|
1070
|
+
def _upload_chunks_parallel(
|
1071
|
+
self,
|
1072
|
+
job_name: str,
|
1073
|
+
file_path: Path,
|
1074
|
+
file_name: str,
|
1075
|
+
upload_id: str,
|
1076
|
+
total_chunks: int,
|
1077
|
+
) -> str | None:
|
1078
|
+
"""Upload all chunks in parallel batches, including the final chunk.
|
1079
|
+
|
1080
|
+
Args:
|
1081
|
+
job_name: The key of the crow to upload to.
|
1082
|
+
file_path: The path to the file to upload.
|
1083
|
+
file_name: The name to use for the file.
|
1084
|
+
upload_id: The upload ID to use.
|
1085
|
+
total_chunks: Total number of chunks.
|
1086
|
+
|
1087
|
+
Returns:
|
1088
|
+
The status URL from the final chunk response, or None if no chunks.
|
1089
|
+
|
1090
|
+
Raises:
|
1091
|
+
FileUploadError: If there's an error uploading any chunk.
|
1092
|
+
"""
|
1093
|
+
if total_chunks <= 0:
|
1094
|
+
return None
|
1095
|
+
|
1096
|
+
if total_chunks > 1:
|
1097
|
+
num_regular_chunks = total_chunks - 1
|
1098
|
+
for batch_start in range(0, num_regular_chunks, self.MAX_CONCURRENT_CHUNKS):
|
1099
|
+
batch_end = min(
|
1100
|
+
batch_start + self.MAX_CONCURRENT_CHUNKS, num_regular_chunks
|
1101
|
+
)
|
1102
|
+
|
1103
|
+
# Upload chunks in this batch concurrently
|
1104
|
+
with ThreadPoolExecutor(
|
1105
|
+
max_workers=self.MAX_CONCURRENT_CHUNKS
|
1106
|
+
) as executor:
|
1107
|
+
futures = {
|
1108
|
+
executor.submit(
|
1109
|
+
self._upload_single_chunk,
|
1110
|
+
job_name,
|
1111
|
+
file_path,
|
1112
|
+
file_name,
|
1113
|
+
upload_id,
|
1114
|
+
chunk_index,
|
1115
|
+
total_chunks,
|
1116
|
+
): chunk_index
|
1117
|
+
for chunk_index in range(batch_start, batch_end)
|
1118
|
+
}
|
1119
|
+
|
1120
|
+
for future in as_completed(futures):
|
1121
|
+
chunk_index = futures[future]
|
1122
|
+
try:
|
1123
|
+
future.result()
|
1124
|
+
logger.debug(
|
1125
|
+
f"Uploaded chunk {chunk_index + 1}/{total_chunks} of {file_name}"
|
1126
|
+
)
|
1127
|
+
except Exception as e:
|
1128
|
+
logger.error(f"Error uploading chunk {chunk_index}: {e}")
|
1129
|
+
raise FileUploadError(
|
1130
|
+
f"Error uploading chunk {chunk_index} of {file_name}: {e}"
|
1131
|
+
) from e
|
1132
|
+
|
1133
|
+
# Upload the final chunk with retry logic
|
1134
|
+
final_chunk_index = total_chunks - 1
|
1135
|
+
retries = 0
|
1136
|
+
max_retries = 3
|
1137
|
+
retry_delay = 2.0
|
1138
|
+
|
1139
|
+
while retries < max_retries:
|
1140
|
+
try:
|
1141
|
+
with open(file_path, "rb") as f:
|
1142
|
+
# Read the final chunk from the file
|
1143
|
+
f.seek(final_chunk_index * self.CHUNK_SIZE)
|
1013
1144
|
chunk_data = f.read(self.CHUNK_SIZE)
|
1014
1145
|
|
1015
1146
|
# Prepare and send the chunk
|
@@ -1028,40 +1159,108 @@ class RestClient:
|
|
1028
1159
|
}
|
1029
1160
|
data = {
|
1030
1161
|
"file_name": file_name,
|
1031
|
-
"chunk_index":
|
1162
|
+
"chunk_index": final_chunk_index,
|
1032
1163
|
"total_chunks": total_chunks,
|
1033
1164
|
"upload_id": upload_id,
|
1034
1165
|
}
|
1035
1166
|
|
1036
|
-
# Send the chunk
|
1167
|
+
# Send the final chunk
|
1037
1168
|
response = self.multipart_client.post(
|
1038
1169
|
f"/v0.1/crows/{job_name}/upload-chunk",
|
1039
1170
|
files=files,
|
1040
1171
|
data=data,
|
1041
1172
|
)
|
1173
|
+
|
1174
|
+
# Handle missing chunks (status 409)
|
1175
|
+
if response.status_code == codes.CONFLICT:
|
1176
|
+
retries += 1
|
1177
|
+
if retries < max_retries:
|
1178
|
+
logger.warning(
|
1179
|
+
f"Missing chunks detected for {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries})"
|
1180
|
+
)
|
1181
|
+
time.sleep(retry_delay)
|
1182
|
+
continue
|
1183
|
+
|
1042
1184
|
response.raise_for_status()
|
1185
|
+
response_data = response.json()
|
1186
|
+
status_url = response_data.get("status_url")
|
1043
1187
|
|
1044
|
-
|
1188
|
+
logger.debug(
|
1189
|
+
f"Uploaded final chunk {final_chunk_index + 1}/{total_chunks} of {file_name}"
|
1190
|
+
)
|
1191
|
+
return status_url
|
1045
1192
|
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1193
|
+
except Exception as e:
|
1194
|
+
if retries >= max_retries - 1:
|
1195
|
+
raise FileUploadError(
|
1196
|
+
f"Error uploading final chunk of {file_name}: {e}"
|
1197
|
+
) from e
|
1198
|
+
retries += 1
|
1199
|
+
logger.warning(
|
1200
|
+
f"Error uploading final chunk of {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries}): {e}"
|
1201
|
+
)
|
1202
|
+
time.sleep(retry_delay)
|
1049
1203
|
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
|
1055
|
-
|
1056
|
-
|
1057
|
-
|
1058
|
-
|
1059
|
-
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1204
|
+
raise FileUploadError(
|
1205
|
+
f"Failed to upload final chunk of {file_name} after {max_retries} retries"
|
1206
|
+
)
|
1207
|
+
|
1208
|
+
def _upload_single_chunk(
|
1209
|
+
self,
|
1210
|
+
job_name: str,
|
1211
|
+
file_path: Path,
|
1212
|
+
file_name: str,
|
1213
|
+
upload_id: str,
|
1214
|
+
chunk_index: int,
|
1215
|
+
total_chunks: int,
|
1216
|
+
) -> None:
|
1217
|
+
"""Upload a single chunk.
|
1218
|
+
|
1219
|
+
Args:
|
1220
|
+
job_name: The key of the crow to upload to.
|
1221
|
+
file_path: The path to the file to upload.
|
1222
|
+
file_name: The name to use for the file.
|
1223
|
+
upload_id: The upload ID to use.
|
1224
|
+
chunk_index: The index of this chunk.
|
1225
|
+
total_chunks: Total number of chunks.
|
1226
|
+
|
1227
|
+
Raises:
|
1228
|
+
Exception: If there's an error uploading the chunk.
|
1229
|
+
"""
|
1230
|
+
with open(file_path, "rb") as f:
|
1231
|
+
# Read the chunk from the file
|
1232
|
+
f.seek(chunk_index * self.CHUNK_SIZE)
|
1233
|
+
chunk_data = f.read(self.CHUNK_SIZE)
|
1234
|
+
|
1235
|
+
# Prepare and send the chunk
|
1236
|
+
with tempfile.NamedTemporaryFile() as temp_file:
|
1237
|
+
temp_file.write(chunk_data)
|
1238
|
+
temp_file.flush()
|
1239
|
+
|
1240
|
+
# Create form data
|
1241
|
+
with open(temp_file.name, "rb") as chunk_file_obj:
|
1242
|
+
files = {
|
1243
|
+
"chunk": (
|
1244
|
+
file_name,
|
1245
|
+
chunk_file_obj,
|
1246
|
+
"application/octet-stream",
|
1247
|
+
)
|
1248
|
+
}
|
1249
|
+
data = {
|
1250
|
+
"file_name": file_name,
|
1251
|
+
"chunk_index": chunk_index,
|
1252
|
+
"total_chunks": total_chunks,
|
1253
|
+
"upload_id": upload_id,
|
1254
|
+
}
|
1255
|
+
|
1256
|
+
# Send the chunk
|
1257
|
+
response = self.multipart_client.post(
|
1258
|
+
f"/v0.1/crows/{job_name}/upload-chunk",
|
1259
|
+
files=files,
|
1260
|
+
data=data,
|
1261
|
+
)
|
1262
|
+
response.raise_for_status()
|
1063
1263
|
|
1064
|
-
@refresh_token_on_auth_error()
|
1065
1264
|
@retry(
|
1066
1265
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1067
1266
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1098,13 +1297,6 @@ class RestClient:
|
|
1098
1297
|
response.raise_for_status()
|
1099
1298
|
return response.json()
|
1100
1299
|
except HTTPStatusError as e:
|
1101
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1102
|
-
raise AuthError(
|
1103
|
-
e.response.status_code,
|
1104
|
-
f"Authentication failed: {e}",
|
1105
|
-
request=e.request,
|
1106
|
-
response=e.response,
|
1107
|
-
) from e
|
1108
1300
|
logger.exception(
|
1109
1301
|
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
1110
1302
|
)
|
@@ -1117,7 +1309,6 @@ class RestClient:
|
|
1117
1309
|
)
|
1118
1310
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
1119
1311
|
|
1120
|
-
@refresh_token_on_auth_error()
|
1121
1312
|
@retry(
|
1122
1313
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1123
1314
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1165,13 +1356,6 @@ class RestClient:
|
|
1165
1356
|
|
1166
1357
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
1167
1358
|
except HTTPStatusError as e:
|
1168
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1169
|
-
raise AuthError(
|
1170
|
-
e.response.status_code,
|
1171
|
-
f"Authentication failed: {e}",
|
1172
|
-
request=e.request,
|
1173
|
-
response=e.response,
|
1174
|
-
) from e
|
1175
1359
|
logger.exception(
|
1176
1360
|
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
1177
1361
|
)
|