futurehouse-client 0.3.17.dev56__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,23 +1,27 @@
1
1
  import ast
2
+ import asyncio
2
3
  import base64
4
+ import contextlib
3
5
  import copy
4
6
  import importlib.metadata
5
7
  import inspect
6
8
  import json
7
9
  import logging
8
10
  import os
11
+ import sys
9
12
  import tempfile
13
+ import time
10
14
  import uuid
11
- from collections.abc import Mapping
12
- from datetime import datetime
15
+ from collections.abc import Collection
13
16
  from pathlib import Path
14
17
  from types import ModuleType
15
- from typing import Any, ClassVar, assert_never, cast
18
+ from typing import Any, ClassVar, cast
16
19
  from uuid import UUID
17
20
 
18
21
  import cloudpickle
19
22
  from aviary.functional import EnvironmentBuilder
20
23
  from httpx import (
24
+ AsyncClient,
21
25
  Client,
22
26
  CloseError,
23
27
  ConnectError,
@@ -29,7 +33,6 @@ from httpx import (
29
33
  RemoteProtocolError,
30
34
  )
31
35
  from ldp.agent import AgentConfig
32
- from pydantic import BaseModel, ConfigDict, model_validator
33
36
  from requests.exceptions import RequestException, Timeout
34
37
  from tenacity import (
35
38
  retry,
@@ -37,15 +40,22 @@ from tenacity import (
37
40
  stop_after_attempt,
38
41
  wait_exponential,
39
42
  )
43
+ from tqdm import tqdm as sync_tqdm
44
+ from tqdm.asyncio import tqdm
40
45
 
41
46
  from futurehouse_client.clients import JobNames
42
47
  from futurehouse_client.models.app import (
43
- APIKeyPayload,
44
48
  AuthType,
45
49
  JobDeploymentConfig,
50
+ PQATaskResponse,
46
51
  Stage,
47
52
  TaskRequest,
53
+ TaskResponse,
54
+ TaskResponseVerbose,
48
55
  )
56
+ from futurehouse_client.models.rest import ExecutionStatus
57
+ from futurehouse_client.utils.auth import RefreshingJWT
58
+ from futurehouse_client.utils.general import gather_with_concurrency
49
59
  from futurehouse_client.utils.module_utils import (
50
60
  OrganizationSelector,
51
61
  fetch_environment_function_docstring,
@@ -55,24 +65,14 @@ from futurehouse_client.utils.monitoring import (
55
65
  )
56
66
 
57
67
  logger = logging.getLogger(__name__)
58
-
68
+ logging.basicConfig(
69
+ level=logging.WARNING,
70
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
71
+ stream=sys.stdout,
72
+ )
73
+ logging.getLogger("httpx").setLevel(logging.WARNING)
59
74
  TaskRequest.model_rebuild()
60
75
 
61
- retry_if_connection_error = retry_if_exception_type((
62
- # From requests
63
- Timeout,
64
- ConnectionError,
65
- RequestException,
66
- # From httpx
67
- ConnectError,
68
- ConnectTimeout,
69
- ReadTimeout,
70
- ReadError,
71
- NetworkError,
72
- RemoteProtocolError,
73
- CloseError,
74
- ))
75
-
76
76
  FILE_UPLOAD_IGNORE_PARTS = {
77
77
  ".ruff_cache",
78
78
  "__pycache__",
@@ -103,114 +103,35 @@ class InvalidTaskDescriptionError(Exception):
103
103
  """Raised when the task description is invalid or empty."""
104
104
 
105
105
 
106
- class SimpleOrganization(BaseModel):
107
- id: int
108
- name: str
109
- display_name: str
110
-
111
-
112
- # 5 minute default for JWTs
113
- JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
114
-
115
-
116
- class TaskResponse(BaseModel):
117
- """Base class for task responses. This holds attributes shared over all futurehouse jobs."""
118
-
119
- model_config = ConfigDict(extra="ignore")
120
-
121
- status: str
122
- query: str
123
- user: str | None = None
124
- created_at: datetime
125
- job_name: str
126
- public: bool
127
- shared_with: list[SimpleOrganization] | None = None
128
- build_owner: str | None = None
129
- environment_name: str | None = None
130
- agent_name: str | None = None
131
- task_id: UUID | None = None
132
-
133
- @model_validator(mode="before")
134
- @classmethod
135
- def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
136
- # Extract fields from environment frame state
137
- if not isinstance(data, dict):
138
- return data
139
- # TODO: We probably want to remove these two once we define the final names.
140
- data["job_name"] = data.get("crow")
141
- data["query"] = data.get("task")
142
- if not (env_frame := data.get("environment_frame", {})):
143
- return data
144
- state = env_frame.get("state", {}).get("state", {})
145
- data["task_id"] = cast(UUID, state.get("id")) if state.get("id") else None
146
- if not (metadata := data.get("metadata", {})):
147
- return data
148
- data["environment_name"] = metadata.get("environment_name")
149
- data["agent_name"] = metadata.get("agent_name")
150
- return data
151
-
152
-
153
- class PQATaskResponse(TaskResponse):
154
- model_config = ConfigDict(extra="ignore")
155
-
156
- answer: str | None = None
157
- formatted_answer: str | None = None
158
- answer_reasoning: str | None = None
159
- has_successful_answer: bool | None = None
160
- total_cost: float | None = None
161
- total_queries: int | None = None
162
-
163
- @model_validator(mode="before")
164
- @classmethod
165
- def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
166
- # Extract fields from environment frame state
167
- if not isinstance(data, dict):
168
- return data
169
- if not (env_frame := data.get("environment_frame", {})):
170
- return data
171
- state = env_frame.get("state", {}).get("state", {})
172
- response = state.get("response", {})
173
- answer = response.get("answer", {})
174
- usage = state.get("info", {}).get("usage", {})
175
-
176
- # Add additional PQA specific fields to data so that pydantic can validate the model
177
- data["answer"] = answer.get("answer")
178
- data["formatted_answer"] = answer.get("formatted_answer")
179
- data["answer_reasoning"] = answer.get("answer_reasoning")
180
- data["has_successful_answer"] = answer.get("has_successful_answer")
181
- data["total_cost"] = cast(float, usage.get("total_cost"))
182
- data["total_queries"] = cast(int, usage.get("total_queries"))
183
-
184
- return data
185
-
186
- def clean_verbose(self) -> "TaskResponse":
187
- """Clean the verbose response from the server."""
188
- self.request = None
189
- self.response = None
190
- return self
191
-
192
-
193
- class TaskResponseVerbose(TaskResponse):
194
- """Class for responses to include all the fields of a task response."""
195
-
196
- model_config = ConfigDict(extra="allow")
197
-
198
- public: bool
199
- agent_state: list[dict[str, Any]] | None = None
200
- environment_frame: dict[str, Any] | None = None
201
- metadata: dict[str, Any] | None = None
202
- shared_with: list[SimpleOrganization] | None = None
203
-
204
-
205
106
  class FileUploadError(RestClientError):
206
107
  """Raised when there's an error uploading a file."""
207
108
 
208
109
 
110
+ retry_if_connection_error = retry_if_exception_type((
111
+ # From requests
112
+ Timeout,
113
+ ConnectionError,
114
+ RequestException,
115
+ # From httpx
116
+ ConnectError,
117
+ ConnectTimeout,
118
+ ReadTimeout,
119
+ ReadError,
120
+ NetworkError,
121
+ RemoteProtocolError,
122
+ CloseError,
123
+ FileUploadError,
124
+ ))
125
+
126
+ DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
127
+
128
+
209
129
  class RestClient:
210
130
  REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
211
131
  MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
212
132
  RETRY_MULTIPLIER: ClassVar[int] = 1
213
133
  MAX_RETRY_WAIT: ClassVar[int] = 10
134
+ DEFAULT_POLLING_TIME: ClassVar[int] = 5 # seconds
214
135
  CHUNK_SIZE: ClassVar[int] = 16 * 1024 * 1024 # 16MB chunks
215
136
 
216
137
  def __init__(
@@ -222,62 +143,116 @@ class RestClient:
222
143
  api_key: str | None = None,
223
144
  jwt: str | None = None,
224
145
  headers: dict[str, str] | None = None,
146
+ verbose_logging: bool = False,
225
147
  ):
148
+ if verbose_logging:
149
+ logger.setLevel(logging.INFO)
150
+ else:
151
+ logger.setLevel(logging.WARNING)
152
+
226
153
  self.base_url = service_uri or stage.value
227
154
  self.stage = stage
228
155
  self.auth_type = auth_type
229
156
  self.api_key = api_key
230
- self._clients: dict[str, Client] = {}
157
+ self._clients: dict[str, Client | AsyncClient] = {}
231
158
  self.headers = headers or {}
232
- self.auth_jwt = self._run_auth(jwt=jwt)
159
+ self.jwt = jwt
233
160
  self.organizations: list[str] = self._filter_orgs(organization)
234
161
 
235
162
  @property
236
163
  def client(self) -> Client:
237
- """Lazily initialized and cached HTTP client with authentication."""
238
- return self.get_client("application/json", with_auth=True)
164
+ """Authenticated HTTP client for regular API calls."""
165
+ return cast(Client, self.get_client("application/json", authenticated=True))
239
166
 
240
167
  @property
241
- def auth_client(self) -> Client:
242
- """Lazily initialized and cached HTTP client without authentication."""
243
- return self.get_client("application/json", with_auth=False)
168
+ def async_client(self) -> AsyncClient:
169
+ """Authenticated async HTTP client for regular API calls."""
170
+ return cast(
171
+ AsyncClient,
172
+ self.get_client("application/json", authenticated=True, async_client=True),
173
+ )
174
+
175
+ @property
176
+ def unauthenticated_client(self) -> Client:
177
+ """Unauthenticated HTTP client for auth operations."""
178
+ return cast(Client, self.get_client("application/json", authenticated=False))
244
179
 
245
180
  @property
246
181
  def multipart_client(self) -> Client:
247
- """Lazily initialized and cached HTTP client for multipart uploads."""
248
- return self.get_client(None, with_auth=True)
182
+ """Authenticated HTTP client for multipart uploads."""
183
+ return cast(Client, self.get_client(None, authenticated=True))
249
184
 
250
185
  def get_client(
251
- self, content_type: str | None = "application/json", with_auth: bool = True
252
- ) -> Client:
186
+ self,
187
+ content_type: str | None = "application/json",
188
+ authenticated: bool = True,
189
+ async_client: bool = False,
190
+ ) -> Client | AsyncClient:
253
191
  """Return a cached HTTP client or create one if needed.
254
192
 
255
193
  Args:
256
194
  content_type: The desired content type header. Use None for multipart uploads.
257
- with_auth: Whether the client should include an Authorization header.
195
+ authenticated: Whether the client should include authentication.
196
+ async_client: Whether to use an async client.
258
197
 
259
198
  Returns:
260
199
  An HTTP client configured with the appropriate headers.
261
200
  """
262
- # Create a composite key based on content type and auth flag.
263
- key = f"{content_type or 'multipart'}_{with_auth}"
201
+ # Create a composite key based on content type and auth flag
202
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
203
+
264
204
  if key not in self._clients:
265
205
  headers = copy.deepcopy(self.headers)
266
- if with_auth:
267
- headers["Authorization"] = f"Bearer {self.auth_jwt}"
206
+ auth = None
207
+
208
+ if authenticated:
209
+ auth = RefreshingJWT(
210
+ # authenticated=False will always return a synchronous client
211
+ auth_client=cast(
212
+ Client, self.get_client("application/json", authenticated=False)
213
+ ),
214
+ auth_type=self.auth_type,
215
+ api_key=self.api_key,
216
+ jwt=self.jwt,
217
+ )
218
+
268
219
  if content_type:
269
220
  headers["Content-Type"] = content_type
270
- self._clients[key] = Client(
271
- base_url=self.base_url,
272
- headers=headers,
273
- timeout=self.REQUEST_TIMEOUT,
221
+
222
+ self._clients[key] = (
223
+ AsyncClient(
224
+ base_url=self.base_url,
225
+ headers=headers,
226
+ timeout=self.REQUEST_TIMEOUT,
227
+ auth=auth,
228
+ )
229
+ if async_client
230
+ else Client(
231
+ base_url=self.base_url,
232
+ headers=headers,
233
+ timeout=self.REQUEST_TIMEOUT,
234
+ auth=auth,
235
+ )
274
236
  )
237
+
275
238
  return self._clients[key]
276
239
 
277
- def __del__(self):
278
- """Ensure all cached clients are properly closed when the instance is destroyed."""
240
+ def close(self):
241
+ """Explicitly close all cached clients."""
279
242
  for client in self._clients.values():
280
- client.close()
243
+ if isinstance(client, Client):
244
+ with contextlib.suppress(RuntimeError, CloseError):
245
+ client.close()
246
+
247
+ async def aclose(self):
248
+ """Asynchronously close all cached clients."""
249
+ for client in self._clients.values():
250
+ if isinstance(client, AsyncClient):
251
+ with contextlib.suppress(RuntimeError, CloseError):
252
+ await client.aclose()
253
+
254
+ def __del__(self):
255
+ self.close()
281
256
 
282
257
  def _filter_orgs(self, organization: str | None = None) -> list[str]:
283
258
  filtered_orgs = [
@@ -289,31 +264,6 @@ class RestClient:
289
264
  raise ValueError(f"Organization '{organization}' not found.")
290
265
  return filtered_orgs
291
266
 
292
- def _run_auth(self, jwt: str | None = None) -> str:
293
- auth_payload: APIKeyPayload | None
294
- if self.auth_type == AuthType.API_KEY:
295
- auth_payload = APIKeyPayload(api_key=self.api_key)
296
- elif self.auth_type == AuthType.JWT:
297
- auth_payload = None
298
- else:
299
- assert_never(self.auth_type)
300
- try:
301
- # Use the unauthenticated client for login
302
- if auth_payload:
303
- response = self.auth_client.post(
304
- "/auth/login", json=auth_payload.model_dump()
305
- )
306
- response.raise_for_status()
307
- token_data = response.json()
308
- elif jwt:
309
- token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
310
- else:
311
- raise ValueError("JWT token required for JWT authentication.")
312
-
313
- return token_data["access_token"]
314
- except Exception as e:
315
- raise RestClientError(f"Error authenticating: {e!s}") from e
316
-
317
267
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
318
268
  try:
319
269
  response = self.client.get(
@@ -407,8 +357,11 @@ class RestClient:
407
357
  ),
408
358
  self.client.stream("GET", url, params={"history": history}) as response,
409
359
  ):
360
+ response.raise_for_status()
410
361
  json_data = "".join(response.iter_text(chunk_size=1024))
411
362
  data = json.loads(json_data)
363
+ if "id" not in data:
364
+ data["id"] = task_id
412
365
  verbose_response = TaskResponseVerbose(**data)
413
366
 
414
367
  if verbose:
@@ -419,8 +372,52 @@ class RestClient:
419
372
  ):
420
373
  return PQATaskResponse(**data)
421
374
  return TaskResponse(**data)
422
- except ValueError as e:
423
- raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
375
+ except Exception as e:
376
+ raise TaskFetchError(f"Error getting task: {e!s}") from e
377
+
378
+ @retry(
379
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
380
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
381
+ retry=retry_if_connection_error,
382
+ )
383
+ async def aget_task(
384
+ self, task_id: str | None = None, history: bool = False, verbose: bool = False
385
+ ) -> "TaskResponse":
386
+ """Get details for a specific task asynchronously."""
387
+ try:
388
+ task_id = task_id or self.trajectory_id
389
+ url = f"/v0.1/trajectories/{task_id}"
390
+ full_url = f"{self.base_url}{url}"
391
+
392
+ with external_trace(
393
+ url=full_url,
394
+ method="GET",
395
+ library="httpx",
396
+ custom_params={
397
+ "operation": "get_job",
398
+ "job_id": task_id,
399
+ },
400
+ ):
401
+ async with self.async_client.stream(
402
+ "GET", url, params={"history": history}
403
+ ) as response:
404
+ response.raise_for_status()
405
+ json_data = "".join([
406
+ chunk async for chunk in response.aiter_text()
407
+ ])
408
+ data = json.loads(json_data)
409
+ if "id" not in data:
410
+ data["id"] = task_id
411
+ verbose_response = TaskResponseVerbose(**data)
412
+
413
+ if verbose:
414
+ return verbose_response
415
+ if any(
416
+ JobNames.from_string(job_name) in verbose_response.job_name
417
+ for job_name in ["crow", "falcon", "owl", "dummy"]
418
+ ):
419
+ return PQATaskResponse(**data)
420
+ return TaskResponse(**data)
424
421
  except Exception as e:
425
422
  raise TaskFetchError(f"Error getting task: {e!s}") from e
426
423
 
@@ -445,10 +442,179 @@ class RestClient:
445
442
  "/v0.1/crows", json=task_data.model_dump(mode="json")
446
443
  )
447
444
  response.raise_for_status()
448
- self.trajectory_id = response.json()["trajectory_id"]
445
+ trajectory_id = response.json()["trajectory_id"]
446
+ self.trajectory_id = trajectory_id
449
447
  except Exception as e:
450
448
  raise TaskFetchError(f"Error creating task: {e!s}") from e
451
- return self.trajectory_id
449
+ return trajectory_id
450
+
451
+ @retry(
452
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
453
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
454
+ retry=retry_if_connection_error,
455
+ )
456
+ async def acreate_task(self, task_data: TaskRequest | dict[str, Any]):
457
+ """Create a new futurehouse task."""
458
+ if isinstance(task_data, dict):
459
+ task_data = TaskRequest.model_validate(task_data)
460
+
461
+ if isinstance(task_data.name, JobNames):
462
+ task_data.name = task_data.name.from_stage(
463
+ task_data.name.name,
464
+ self.stage,
465
+ )
466
+
467
+ try:
468
+ response = await self.async_client.post(
469
+ "/v0.1/crows", json=task_data.model_dump(mode="json")
470
+ )
471
+ response.raise_for_status()
472
+ trajectory_id = response.json()["trajectory_id"]
473
+ self.trajectory_id = trajectory_id
474
+ except Exception as e:
475
+ raise TaskFetchError(f"Error creating task: {e!s}") from e
476
+ return trajectory_id
477
+
478
+ async def arun_tasks_until_done(
479
+ self,
480
+ task_data: TaskRequest
481
+ | dict[str, Any]
482
+ | Collection[TaskRequest]
483
+ | Collection[dict[str, Any]],
484
+ verbose: bool = False,
485
+ progress_bar: bool = False,
486
+ concurrency: int = 10,
487
+ timeout: int = DEFAULT_AGENT_TIMEOUT,
488
+ ) -> list[TaskResponse]:
489
+ all_tasks: Collection[TaskRequest | dict[str, Any]] = (
490
+ cast(Collection[TaskRequest | dict[str, Any]], [task_data])
491
+ if (isinstance(task_data, dict) or not isinstance(task_data, Collection))
492
+ else cast(Collection[TaskRequest | dict[str, Any]], task_data)
493
+ )
494
+
495
+ trajectory_ids = await gather_with_concurrency(
496
+ concurrency,
497
+ [self.acreate_task(task) for task in all_tasks],
498
+ progress=progress_bar,
499
+ )
500
+
501
+ start_time = time.monotonic()
502
+ completed_tasks: dict[str, TaskResponse] = {}
503
+
504
+ if progress_bar:
505
+ progress = tqdm(
506
+ total=len(trajectory_ids), desc="Waiting for tasks to finish", ncols=0
507
+ )
508
+
509
+ while (time.monotonic() - start_time) < timeout:
510
+ task_results = await gather_with_concurrency(
511
+ concurrency,
512
+ [
513
+ self.aget_task(task_id, verbose=verbose)
514
+ for task_id in trajectory_ids
515
+ if task_id not in completed_tasks
516
+ ],
517
+ )
518
+
519
+ for task in task_results:
520
+ task_id = str(task.task_id)
521
+ if (
522
+ task_id not in completed_tasks
523
+ and ExecutionStatus(task.status).is_terminal_state()
524
+ ):
525
+ completed_tasks[task_id] = task
526
+ if progress_bar:
527
+ progress.update(1)
528
+
529
+ all_done = len(completed_tasks) == len(trajectory_ids)
530
+
531
+ if all_done:
532
+ break
533
+ await asyncio.sleep(self.DEFAULT_POLLING_TIME)
534
+
535
+ else:
536
+ logger.warning(
537
+ f"Timed out waiting for tasks to finish after {timeout} seconds. Returning with {len(completed_tasks)} completed tasks and {len(trajectory_ids)} total tasks."
538
+ )
539
+
540
+ if progress_bar:
541
+ progress.close()
542
+
543
+ return [
544
+ completed_tasks.get(task_id)
545
+ or (await self.aget_task(task_id, verbose=verbose))
546
+ for task_id in trajectory_ids
547
+ ]
548
+
549
+ def run_tasks_until_done(
550
+ self,
551
+ task_data: TaskRequest
552
+ | dict[str, Any]
553
+ | Collection[TaskRequest]
554
+ | Collection[dict[str, Any]],
555
+ verbose: bool = False,
556
+ progress_bar: bool = False,
557
+ timeout: int = DEFAULT_AGENT_TIMEOUT,
558
+ ) -> list[TaskResponse]:
559
+ """Run multiple tasks and wait for them to complete.
560
+
561
+ Args:
562
+ task_data: A single task or collection of tasks to run
563
+ verbose: Whether to return verbose task responses
564
+ progress_bar: Whether to display a progress bar
565
+ timeout: Maximum time to wait for task completion in seconds
566
+
567
+ Returns:
568
+ A list of completed task responses
569
+ """
570
+ all_tasks: Collection[TaskRequest | dict[str, Any]] = (
571
+ cast(Collection[TaskRequest | dict[str, Any]], [task_data])
572
+ if (isinstance(task_data, dict) or not isinstance(task_data, Collection))
573
+ else cast(Collection[TaskRequest | dict[str, Any]], task_data)
574
+ )
575
+
576
+ trajectory_ids = [self.create_task(task) for task in all_tasks]
577
+
578
+ start_time = time.monotonic()
579
+ completed_tasks: dict[str, TaskResponse] = {}
580
+
581
+ if progress_bar:
582
+ progress = sync_tqdm(
583
+ total=len(trajectory_ids), desc="Waiting for tasks to finish", ncols=0
584
+ )
585
+
586
+ while (time.monotonic() - start_time) < timeout:
587
+ all_done = True
588
+
589
+ for task_id in trajectory_ids:
590
+ if task_id in completed_tasks:
591
+ continue
592
+
593
+ task = self.get_task(task_id, verbose=verbose)
594
+
595
+ if not ExecutionStatus(task.status).is_terminal_state():
596
+ all_done = False
597
+ elif task_id not in completed_tasks:
598
+ completed_tasks[task_id] = task
599
+ if progress_bar:
600
+ progress.update(1)
601
+
602
+ if all_done:
603
+ break
604
+ time.sleep(self.DEFAULT_POLLING_TIME)
605
+
606
+ else:
607
+ logger.warning(
608
+ f"Timed out waiting for tasks to finish after {timeout} seconds. Returning with {len(completed_tasks)} completed tasks and {len(trajectory_ids)} total tasks."
609
+ )
610
+
611
+ if progress_bar:
612
+ progress.close()
613
+
614
+ return [
615
+ completed_tasks.get(task_id) or self.get_task(task_id, verbose=verbose)
616
+ for task_id in trajectory_ids
617
+ ]
452
618
 
453
619
  @retry(
454
620
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
@@ -457,9 +623,12 @@ class RestClient:
457
623
  )
458
624
  def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
459
625
  """Get the status of a build."""
460
- build_id = build_id or self.build_id
461
- response = self.client.get(f"/v0.1/builds/{build_id}")
462
- response.raise_for_status()
626
+ try:
627
+ build_id = build_id or self.build_id
628
+ response = self.client.get(f"/v0.1/builds/{build_id}")
629
+ response.raise_for_status()
630
+ except Exception as e:
631
+ raise JobFetchError(f"Error getting build status: {e!s}") from e
463
632
  return response.json()
464
633
 
465
634
  # TODO: Refactor later so we don't have to ignore PLR0915
@@ -660,14 +829,14 @@ class RestClient:
660
829
  self,
661
830
  job_name: str,
662
831
  file_path: str | os.PathLike,
663
- folder_name: str | None = None,
832
+ upload_id: str | None = None,
664
833
  ) -> str:
665
834
  """Upload a file or directory to a futurehouse job bucket.
666
835
 
667
836
  Args:
668
837
  job_name: The name of the futurehouse job to upload to.
669
838
  file_path: The local path to the file or directory to upload.
670
- folder_name: Optional folder name to use for the upload. If not provided, a random UUID will be used.
839
+ upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
671
840
 
672
841
  Returns:
673
842
  The upload ID used for the upload.
@@ -679,7 +848,7 @@ class RestClient:
679
848
  if not file_path.exists():
680
849
  raise FileNotFoundError(f"File or directory not found: {file_path}")
681
850
 
682
- upload_id = folder_name or str(uuid.uuid4())
851
+ upload_id = upload_id or str(uuid.uuid4())
683
852
 
684
853
  if file_path.is_dir():
685
854
  # Process directory recursively
@@ -742,6 +911,12 @@ class RestClient:
742
911
  """
743
912
  file_name = file_name or file_path.name
744
913
  file_size = file_path.stat().st_size
914
+
915
+ # Skip empty files
916
+ if file_size == 0:
917
+ logger.warning(f"Skipping upload of empty file: {file_path}")
918
+ return
919
+
745
920
  total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
746
921
 
747
922
  logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
@@ -789,7 +964,6 @@ class RestClient:
789
964
  )
790
965
 
791
966
  logger.info(f"Successfully uploaded {file_name}")
792
-
793
967
  except Exception as e:
794
968
  logger.exception(f"Error uploading file {file_path}")
795
969
  raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
@@ -799,12 +973,18 @@ class RestClient:
799
973
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
800
974
  retry=retry_if_connection_error,
801
975
  )
802
- def list_files(self, job_name: str, folder_name: str) -> dict[str, list[str]]:
976
+ def list_files(
977
+ self,
978
+ job_name: str,
979
+ trajectory_id: str | None = None,
980
+ upload_id: str | None = None,
981
+ ) -> dict[str, list[str]]:
803
982
  """List files and directories in a GCS location for a given job_name and upload_id.
804
983
 
805
984
  Args:
806
985
  job_name: The name of the futurehouse job.
807
- folder_name: The specific folder name (upload_id) to list files from.
986
+ trajectory_id: The specific trajectory id to list files from.
987
+ upload_id: The specific upload id to list files from.
808
988
 
809
989
  Returns:
810
990
  A list of files in the GCS folder.
@@ -812,22 +992,27 @@ class RestClient:
812
992
  Raises:
813
993
  RestClientError: If there is an error listing the files.
814
994
  """
995
+ if not bool(trajectory_id) ^ bool(upload_id):
996
+ raise RestClientError(
997
+ "Must at least specify one of trajectory_id or upload_id, but not both"
998
+ )
815
999
  try:
816
1000
  url = f"/v0.1/crows/{job_name}/list-files"
817
- params = {"upload_id": folder_name}
1001
+ params = {"trajectory_id": trajectory_id, "upload_id": upload_id}
1002
+ params = {k: v for k, v in params.items() if v is not None}
818
1003
  response = self.client.get(url, params=params)
819
1004
  response.raise_for_status()
820
1005
  return response.json()
821
1006
  except HTTPStatusError as e:
822
1007
  logger.exception(
823
- f"Error listing files for job {job_name}, folder {folder_name}: {e.response.text}"
1008
+ f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
824
1009
  )
825
1010
  raise RestClientError(
826
1011
  f"Error listing files: {e.response.status_code} - {e.response.text}"
827
1012
  ) from e
828
1013
  except Exception as e:
829
1014
  logger.exception(
830
- f"Error listing files for job {job_name}, folder {folder_name}"
1015
+ f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}"
831
1016
  )
832
1017
  raise RestClientError(f"Error listing files: {e!s}") from e
833
1018
 
@@ -839,7 +1024,7 @@ class RestClient:
839
1024
  def download_file(
840
1025
  self,
841
1026
  job_name: str,
842
- folder_name: str,
1027
+ trajectory_id: str,
843
1028
  file_path: str,
844
1029
  destination_path: str | os.PathLike,
845
1030
  ) -> None:
@@ -847,14 +1032,14 @@ class RestClient:
847
1032
 
848
1033
  Args:
849
1034
  job_name: The name of the futurehouse job.
850
- folder_name: The specific folder name (upload_id) the file belongs to.
1035
+ trajectory_id: The specific trajectory id the file belongs to.
851
1036
  file_path: The relative path of the file to download
852
1037
  (e.g., 'data/my_file.csv' or 'my_image.png').
853
1038
  destination_path: The local path where the file should be saved.
854
1039
 
855
1040
  Raises:
856
1041
  RestClientError: If there is an error downloading the file.
857
- FileNotFoundError: If the destination directory does not exist.
1042
+ FileNotFoundError: If the destination directory does not exist or if the file is not found.
858
1043
  """
859
1044
  destination_path = Path(destination_path)
860
1045
  # Ensure the destination directory exists
@@ -862,17 +1047,24 @@ class RestClient:
862
1047
 
863
1048
  try:
864
1049
  url = f"/v0.1/crows/{job_name}/download-file"
865
- params = {"upload_id": folder_name, "file_path": file_path}
1050
+ params = {"trajectory_id": trajectory_id, "file_path": file_path}
866
1051
 
867
1052
  with self.client.stream("GET", url, params=params) as response:
868
1053
  response.raise_for_status() # Check for HTTP errors before streaming
869
1054
  with open(destination_path, "wb") as f:
870
1055
  for chunk in response.iter_bytes(chunk_size=8192):
871
1056
  f.write(chunk)
1057
+
1058
+ # Check if the downloaded file is empty
1059
+ if destination_path.stat().st_size == 0:
1060
+ # Remove the empty file
1061
+ destination_path.unlink()
1062
+ raise FileNotFoundError(f"File not found or is empty: {file_path}")
1063
+
872
1064
  logger.info(f"File {file_path} downloaded to {destination_path}")
873
1065
  except HTTPStatusError as e:
874
1066
  logger.exception(
875
- f"Error downloading file {file_path} for job {job_name}, folder {folder_name}: {e.response.text}"
1067
+ f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
876
1068
  )
877
1069
  # Clean up partially downloaded file if an error occurs
878
1070
  if destination_path.exists():
@@ -880,9 +1072,20 @@ class RestClient:
880
1072
  raise RestClientError(
881
1073
  f"Error downloading file: {e.response.status_code} - {e.response.text}"
882
1074
  ) from e
1075
+ except RemoteProtocolError as e:
1076
+ logger.error(
1077
+ f"Connection error while downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
1078
+ )
1079
+ # Clean up partially downloaded file
1080
+ if destination_path.exists():
1081
+ destination_path.unlink()
1082
+
1083
+ # Often RemoteProtocolError during download means the file wasn't found
1084
+ # or was empty/corrupted on the server side
1085
+ raise FileNotFoundError(f"File not found or corrupted: {file_path}") from e
883
1086
  except Exception as e:
884
1087
  logger.exception(
885
- f"Error downloading file {file_path} for job {job_name}, folder {folder_name}"
1088
+ f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
886
1089
  )
887
1090
  if destination_path.exists():
888
1091
  destination_path.unlink() # Clean up partial file