futurehouse-client 0.3.17.dev56__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/clients/job_client.py +1 -0
- futurehouse_client/clients/rest_client.py +393 -190
- futurehouse_client/models/__init__.py +10 -0
- futurehouse_client/models/app.py +96 -0
- futurehouse_client/models/rest.py +17 -0
- futurehouse_client/utils/auth.py +92 -0
- futurehouse_client/utils/general.py +29 -0
- {futurehouse_client-0.3.17.dev56.dist-info → futurehouse_client-0.3.18.dist-info}/METADATA +143 -22
- futurehouse_client-0.3.18.dist-info/RECORD +17 -0
- {futurehouse_client-0.3.17.dev56.dist-info → futurehouse_client-0.3.18.dist-info}/WHEEL +1 -1
- futurehouse_client-0.3.17.dev56.dist-info/RECORD +0 -15
- {futurehouse_client-0.3.17.dev56.dist-info → futurehouse_client-0.3.18.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,27 @@
|
|
1
1
|
import ast
|
2
|
+
import asyncio
|
2
3
|
import base64
|
4
|
+
import contextlib
|
3
5
|
import copy
|
4
6
|
import importlib.metadata
|
5
7
|
import inspect
|
6
8
|
import json
|
7
9
|
import logging
|
8
10
|
import os
|
11
|
+
import sys
|
9
12
|
import tempfile
|
13
|
+
import time
|
10
14
|
import uuid
|
11
|
-
from collections.abc import
|
12
|
-
from datetime import datetime
|
15
|
+
from collections.abc import Collection
|
13
16
|
from pathlib import Path
|
14
17
|
from types import ModuleType
|
15
|
-
from typing import Any, ClassVar,
|
18
|
+
from typing import Any, ClassVar, cast
|
16
19
|
from uuid import UUID
|
17
20
|
|
18
21
|
import cloudpickle
|
19
22
|
from aviary.functional import EnvironmentBuilder
|
20
23
|
from httpx import (
|
24
|
+
AsyncClient,
|
21
25
|
Client,
|
22
26
|
CloseError,
|
23
27
|
ConnectError,
|
@@ -29,7 +33,6 @@ from httpx import (
|
|
29
33
|
RemoteProtocolError,
|
30
34
|
)
|
31
35
|
from ldp.agent import AgentConfig
|
32
|
-
from pydantic import BaseModel, ConfigDict, model_validator
|
33
36
|
from requests.exceptions import RequestException, Timeout
|
34
37
|
from tenacity import (
|
35
38
|
retry,
|
@@ -37,15 +40,22 @@ from tenacity import (
|
|
37
40
|
stop_after_attempt,
|
38
41
|
wait_exponential,
|
39
42
|
)
|
43
|
+
from tqdm import tqdm as sync_tqdm
|
44
|
+
from tqdm.asyncio import tqdm
|
40
45
|
|
41
46
|
from futurehouse_client.clients import JobNames
|
42
47
|
from futurehouse_client.models.app import (
|
43
|
-
APIKeyPayload,
|
44
48
|
AuthType,
|
45
49
|
JobDeploymentConfig,
|
50
|
+
PQATaskResponse,
|
46
51
|
Stage,
|
47
52
|
TaskRequest,
|
53
|
+
TaskResponse,
|
54
|
+
TaskResponseVerbose,
|
48
55
|
)
|
56
|
+
from futurehouse_client.models.rest import ExecutionStatus
|
57
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
58
|
+
from futurehouse_client.utils.general import gather_with_concurrency
|
49
59
|
from futurehouse_client.utils.module_utils import (
|
50
60
|
OrganizationSelector,
|
51
61
|
fetch_environment_function_docstring,
|
@@ -55,24 +65,14 @@ from futurehouse_client.utils.monitoring import (
|
|
55
65
|
)
|
56
66
|
|
57
67
|
logger = logging.getLogger(__name__)
|
58
|
-
|
68
|
+
logging.basicConfig(
|
69
|
+
level=logging.WARNING,
|
70
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
71
|
+
stream=sys.stdout,
|
72
|
+
)
|
73
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
59
74
|
TaskRequest.model_rebuild()
|
60
75
|
|
61
|
-
retry_if_connection_error = retry_if_exception_type((
|
62
|
-
# From requests
|
63
|
-
Timeout,
|
64
|
-
ConnectionError,
|
65
|
-
RequestException,
|
66
|
-
# From httpx
|
67
|
-
ConnectError,
|
68
|
-
ConnectTimeout,
|
69
|
-
ReadTimeout,
|
70
|
-
ReadError,
|
71
|
-
NetworkError,
|
72
|
-
RemoteProtocolError,
|
73
|
-
CloseError,
|
74
|
-
))
|
75
|
-
|
76
76
|
FILE_UPLOAD_IGNORE_PARTS = {
|
77
77
|
".ruff_cache",
|
78
78
|
"__pycache__",
|
@@ -103,114 +103,35 @@ class InvalidTaskDescriptionError(Exception):
|
|
103
103
|
"""Raised when the task description is invalid or empty."""
|
104
104
|
|
105
105
|
|
106
|
-
class SimpleOrganization(BaseModel):
|
107
|
-
id: int
|
108
|
-
name: str
|
109
|
-
display_name: str
|
110
|
-
|
111
|
-
|
112
|
-
# 5 minute default for JWTs
|
113
|
-
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
114
|
-
|
115
|
-
|
116
|
-
class TaskResponse(BaseModel):
|
117
|
-
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
118
|
-
|
119
|
-
model_config = ConfigDict(extra="ignore")
|
120
|
-
|
121
|
-
status: str
|
122
|
-
query: str
|
123
|
-
user: str | None = None
|
124
|
-
created_at: datetime
|
125
|
-
job_name: str
|
126
|
-
public: bool
|
127
|
-
shared_with: list[SimpleOrganization] | None = None
|
128
|
-
build_owner: str | None = None
|
129
|
-
environment_name: str | None = None
|
130
|
-
agent_name: str | None = None
|
131
|
-
task_id: UUID | None = None
|
132
|
-
|
133
|
-
@model_validator(mode="before")
|
134
|
-
@classmethod
|
135
|
-
def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
136
|
-
# Extract fields from environment frame state
|
137
|
-
if not isinstance(data, dict):
|
138
|
-
return data
|
139
|
-
# TODO: We probably want to remove these two once we define the final names.
|
140
|
-
data["job_name"] = data.get("crow")
|
141
|
-
data["query"] = data.get("task")
|
142
|
-
if not (env_frame := data.get("environment_frame", {})):
|
143
|
-
return data
|
144
|
-
state = env_frame.get("state", {}).get("state", {})
|
145
|
-
data["task_id"] = cast(UUID, state.get("id")) if state.get("id") else None
|
146
|
-
if not (metadata := data.get("metadata", {})):
|
147
|
-
return data
|
148
|
-
data["environment_name"] = metadata.get("environment_name")
|
149
|
-
data["agent_name"] = metadata.get("agent_name")
|
150
|
-
return data
|
151
|
-
|
152
|
-
|
153
|
-
class PQATaskResponse(TaskResponse):
|
154
|
-
model_config = ConfigDict(extra="ignore")
|
155
|
-
|
156
|
-
answer: str | None = None
|
157
|
-
formatted_answer: str | None = None
|
158
|
-
answer_reasoning: str | None = None
|
159
|
-
has_successful_answer: bool | None = None
|
160
|
-
total_cost: float | None = None
|
161
|
-
total_queries: int | None = None
|
162
|
-
|
163
|
-
@model_validator(mode="before")
|
164
|
-
@classmethod
|
165
|
-
def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
166
|
-
# Extract fields from environment frame state
|
167
|
-
if not isinstance(data, dict):
|
168
|
-
return data
|
169
|
-
if not (env_frame := data.get("environment_frame", {})):
|
170
|
-
return data
|
171
|
-
state = env_frame.get("state", {}).get("state", {})
|
172
|
-
response = state.get("response", {})
|
173
|
-
answer = response.get("answer", {})
|
174
|
-
usage = state.get("info", {}).get("usage", {})
|
175
|
-
|
176
|
-
# Add additional PQA specific fields to data so that pydantic can validate the model
|
177
|
-
data["answer"] = answer.get("answer")
|
178
|
-
data["formatted_answer"] = answer.get("formatted_answer")
|
179
|
-
data["answer_reasoning"] = answer.get("answer_reasoning")
|
180
|
-
data["has_successful_answer"] = answer.get("has_successful_answer")
|
181
|
-
data["total_cost"] = cast(float, usage.get("total_cost"))
|
182
|
-
data["total_queries"] = cast(int, usage.get("total_queries"))
|
183
|
-
|
184
|
-
return data
|
185
|
-
|
186
|
-
def clean_verbose(self) -> "TaskResponse":
|
187
|
-
"""Clean the verbose response from the server."""
|
188
|
-
self.request = None
|
189
|
-
self.response = None
|
190
|
-
return self
|
191
|
-
|
192
|
-
|
193
|
-
class TaskResponseVerbose(TaskResponse):
|
194
|
-
"""Class for responses to include all the fields of a task response."""
|
195
|
-
|
196
|
-
model_config = ConfigDict(extra="allow")
|
197
|
-
|
198
|
-
public: bool
|
199
|
-
agent_state: list[dict[str, Any]] | None = None
|
200
|
-
environment_frame: dict[str, Any] | None = None
|
201
|
-
metadata: dict[str, Any] | None = None
|
202
|
-
shared_with: list[SimpleOrganization] | None = None
|
203
|
-
|
204
|
-
|
205
106
|
class FileUploadError(RestClientError):
|
206
107
|
"""Raised when there's an error uploading a file."""
|
207
108
|
|
208
109
|
|
110
|
+
retry_if_connection_error = retry_if_exception_type((
|
111
|
+
# From requests
|
112
|
+
Timeout,
|
113
|
+
ConnectionError,
|
114
|
+
RequestException,
|
115
|
+
# From httpx
|
116
|
+
ConnectError,
|
117
|
+
ConnectTimeout,
|
118
|
+
ReadTimeout,
|
119
|
+
ReadError,
|
120
|
+
NetworkError,
|
121
|
+
RemoteProtocolError,
|
122
|
+
CloseError,
|
123
|
+
FileUploadError,
|
124
|
+
))
|
125
|
+
|
126
|
+
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
127
|
+
|
128
|
+
|
209
129
|
class RestClient:
|
210
130
|
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
|
211
131
|
MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
|
212
132
|
RETRY_MULTIPLIER: ClassVar[int] = 1
|
213
133
|
MAX_RETRY_WAIT: ClassVar[int] = 10
|
134
|
+
DEFAULT_POLLING_TIME: ClassVar[int] = 5 # seconds
|
214
135
|
CHUNK_SIZE: ClassVar[int] = 16 * 1024 * 1024 # 16MB chunks
|
215
136
|
|
216
137
|
def __init__(
|
@@ -222,62 +143,116 @@ class RestClient:
|
|
222
143
|
api_key: str | None = None,
|
223
144
|
jwt: str | None = None,
|
224
145
|
headers: dict[str, str] | None = None,
|
146
|
+
verbose_logging: bool = False,
|
225
147
|
):
|
148
|
+
if verbose_logging:
|
149
|
+
logger.setLevel(logging.INFO)
|
150
|
+
else:
|
151
|
+
logger.setLevel(logging.WARNING)
|
152
|
+
|
226
153
|
self.base_url = service_uri or stage.value
|
227
154
|
self.stage = stage
|
228
155
|
self.auth_type = auth_type
|
229
156
|
self.api_key = api_key
|
230
|
-
self._clients: dict[str, Client] = {}
|
157
|
+
self._clients: dict[str, Client | AsyncClient] = {}
|
231
158
|
self.headers = headers or {}
|
232
|
-
self.
|
159
|
+
self.jwt = jwt
|
233
160
|
self.organizations: list[str] = self._filter_orgs(organization)
|
234
161
|
|
235
162
|
@property
|
236
163
|
def client(self) -> Client:
|
237
|
-
"""
|
238
|
-
return self.get_client("application/json",
|
164
|
+
"""Authenticated HTTP client for regular API calls."""
|
165
|
+
return cast(Client, self.get_client("application/json", authenticated=True))
|
239
166
|
|
240
167
|
@property
|
241
|
-
def
|
242
|
-
"""
|
243
|
-
return
|
168
|
+
def async_client(self) -> AsyncClient:
|
169
|
+
"""Authenticated async HTTP client for regular API calls."""
|
170
|
+
return cast(
|
171
|
+
AsyncClient,
|
172
|
+
self.get_client("application/json", authenticated=True, async_client=True),
|
173
|
+
)
|
174
|
+
|
175
|
+
@property
|
176
|
+
def unauthenticated_client(self) -> Client:
|
177
|
+
"""Unauthenticated HTTP client for auth operations."""
|
178
|
+
return cast(Client, self.get_client("application/json", authenticated=False))
|
244
179
|
|
245
180
|
@property
|
246
181
|
def multipart_client(self) -> Client:
|
247
|
-
"""
|
248
|
-
return self.get_client(None,
|
182
|
+
"""Authenticated HTTP client for multipart uploads."""
|
183
|
+
return cast(Client, self.get_client(None, authenticated=True))
|
249
184
|
|
250
185
|
def get_client(
|
251
|
-
self,
|
252
|
-
|
186
|
+
self,
|
187
|
+
content_type: str | None = "application/json",
|
188
|
+
authenticated: bool = True,
|
189
|
+
async_client: bool = False,
|
190
|
+
) -> Client | AsyncClient:
|
253
191
|
"""Return a cached HTTP client or create one if needed.
|
254
192
|
|
255
193
|
Args:
|
256
194
|
content_type: The desired content type header. Use None for multipart uploads.
|
257
|
-
|
195
|
+
authenticated: Whether the client should include authentication.
|
196
|
+
async_client: Whether to use an async client.
|
258
197
|
|
259
198
|
Returns:
|
260
199
|
An HTTP client configured with the appropriate headers.
|
261
200
|
"""
|
262
|
-
# Create a composite key based on content type and auth flag
|
263
|
-
key = f"{content_type or 'multipart'}_{
|
201
|
+
# Create a composite key based on content type and auth flag
|
202
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
203
|
+
|
264
204
|
if key not in self._clients:
|
265
205
|
headers = copy.deepcopy(self.headers)
|
266
|
-
|
267
|
-
|
206
|
+
auth = None
|
207
|
+
|
208
|
+
if authenticated:
|
209
|
+
auth = RefreshingJWT(
|
210
|
+
# authenticated=False will always return a synchronous client
|
211
|
+
auth_client=cast(
|
212
|
+
Client, self.get_client("application/json", authenticated=False)
|
213
|
+
),
|
214
|
+
auth_type=self.auth_type,
|
215
|
+
api_key=self.api_key,
|
216
|
+
jwt=self.jwt,
|
217
|
+
)
|
218
|
+
|
268
219
|
if content_type:
|
269
220
|
headers["Content-Type"] = content_type
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
221
|
+
|
222
|
+
self._clients[key] = (
|
223
|
+
AsyncClient(
|
224
|
+
base_url=self.base_url,
|
225
|
+
headers=headers,
|
226
|
+
timeout=self.REQUEST_TIMEOUT,
|
227
|
+
auth=auth,
|
228
|
+
)
|
229
|
+
if async_client
|
230
|
+
else Client(
|
231
|
+
base_url=self.base_url,
|
232
|
+
headers=headers,
|
233
|
+
timeout=self.REQUEST_TIMEOUT,
|
234
|
+
auth=auth,
|
235
|
+
)
|
274
236
|
)
|
237
|
+
|
275
238
|
return self._clients[key]
|
276
239
|
|
277
|
-
def
|
278
|
-
"""
|
240
|
+
def close(self):
|
241
|
+
"""Explicitly close all cached clients."""
|
279
242
|
for client in self._clients.values():
|
280
|
-
client
|
243
|
+
if isinstance(client, Client):
|
244
|
+
with contextlib.suppress(RuntimeError, CloseError):
|
245
|
+
client.close()
|
246
|
+
|
247
|
+
async def aclose(self):
|
248
|
+
"""Asynchronously close all cached clients."""
|
249
|
+
for client in self._clients.values():
|
250
|
+
if isinstance(client, AsyncClient):
|
251
|
+
with contextlib.suppress(RuntimeError, CloseError):
|
252
|
+
await client.aclose()
|
253
|
+
|
254
|
+
def __del__(self):
|
255
|
+
self.close()
|
281
256
|
|
282
257
|
def _filter_orgs(self, organization: str | None = None) -> list[str]:
|
283
258
|
filtered_orgs = [
|
@@ -289,31 +264,6 @@ class RestClient:
|
|
289
264
|
raise ValueError(f"Organization '{organization}' not found.")
|
290
265
|
return filtered_orgs
|
291
266
|
|
292
|
-
def _run_auth(self, jwt: str | None = None) -> str:
|
293
|
-
auth_payload: APIKeyPayload | None
|
294
|
-
if self.auth_type == AuthType.API_KEY:
|
295
|
-
auth_payload = APIKeyPayload(api_key=self.api_key)
|
296
|
-
elif self.auth_type == AuthType.JWT:
|
297
|
-
auth_payload = None
|
298
|
-
else:
|
299
|
-
assert_never(self.auth_type)
|
300
|
-
try:
|
301
|
-
# Use the unauthenticated client for login
|
302
|
-
if auth_payload:
|
303
|
-
response = self.auth_client.post(
|
304
|
-
"/auth/login", json=auth_payload.model_dump()
|
305
|
-
)
|
306
|
-
response.raise_for_status()
|
307
|
-
token_data = response.json()
|
308
|
-
elif jwt:
|
309
|
-
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
310
|
-
else:
|
311
|
-
raise ValueError("JWT token required for JWT authentication.")
|
312
|
-
|
313
|
-
return token_data["access_token"]
|
314
|
-
except Exception as e:
|
315
|
-
raise RestClientError(f"Error authenticating: {e!s}") from e
|
316
|
-
|
317
267
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
318
268
|
try:
|
319
269
|
response = self.client.get(
|
@@ -407,8 +357,11 @@ class RestClient:
|
|
407
357
|
),
|
408
358
|
self.client.stream("GET", url, params={"history": history}) as response,
|
409
359
|
):
|
360
|
+
response.raise_for_status()
|
410
361
|
json_data = "".join(response.iter_text(chunk_size=1024))
|
411
362
|
data = json.loads(json_data)
|
363
|
+
if "id" not in data:
|
364
|
+
data["id"] = task_id
|
412
365
|
verbose_response = TaskResponseVerbose(**data)
|
413
366
|
|
414
367
|
if verbose:
|
@@ -419,8 +372,52 @@ class RestClient:
|
|
419
372
|
):
|
420
373
|
return PQATaskResponse(**data)
|
421
374
|
return TaskResponse(**data)
|
422
|
-
except
|
423
|
-
raise
|
375
|
+
except Exception as e:
|
376
|
+
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
377
|
+
|
378
|
+
@retry(
|
379
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
380
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
381
|
+
retry=retry_if_connection_error,
|
382
|
+
)
|
383
|
+
async def aget_task(
|
384
|
+
self, task_id: str | None = None, history: bool = False, verbose: bool = False
|
385
|
+
) -> "TaskResponse":
|
386
|
+
"""Get details for a specific task asynchronously."""
|
387
|
+
try:
|
388
|
+
task_id = task_id or self.trajectory_id
|
389
|
+
url = f"/v0.1/trajectories/{task_id}"
|
390
|
+
full_url = f"{self.base_url}{url}"
|
391
|
+
|
392
|
+
with external_trace(
|
393
|
+
url=full_url,
|
394
|
+
method="GET",
|
395
|
+
library="httpx",
|
396
|
+
custom_params={
|
397
|
+
"operation": "get_job",
|
398
|
+
"job_id": task_id,
|
399
|
+
},
|
400
|
+
):
|
401
|
+
async with self.async_client.stream(
|
402
|
+
"GET", url, params={"history": history}
|
403
|
+
) as response:
|
404
|
+
response.raise_for_status()
|
405
|
+
json_data = "".join([
|
406
|
+
chunk async for chunk in response.aiter_text()
|
407
|
+
])
|
408
|
+
data = json.loads(json_data)
|
409
|
+
if "id" not in data:
|
410
|
+
data["id"] = task_id
|
411
|
+
verbose_response = TaskResponseVerbose(**data)
|
412
|
+
|
413
|
+
if verbose:
|
414
|
+
return verbose_response
|
415
|
+
if any(
|
416
|
+
JobNames.from_string(job_name) in verbose_response.job_name
|
417
|
+
for job_name in ["crow", "falcon", "owl", "dummy"]
|
418
|
+
):
|
419
|
+
return PQATaskResponse(**data)
|
420
|
+
return TaskResponse(**data)
|
424
421
|
except Exception as e:
|
425
422
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
426
423
|
|
@@ -445,10 +442,179 @@ class RestClient:
|
|
445
442
|
"/v0.1/crows", json=task_data.model_dump(mode="json")
|
446
443
|
)
|
447
444
|
response.raise_for_status()
|
448
|
-
|
445
|
+
trajectory_id = response.json()["trajectory_id"]
|
446
|
+
self.trajectory_id = trajectory_id
|
449
447
|
except Exception as e:
|
450
448
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
451
|
-
return
|
449
|
+
return trajectory_id
|
450
|
+
|
451
|
+
@retry(
|
452
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
453
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
454
|
+
retry=retry_if_connection_error,
|
455
|
+
)
|
456
|
+
async def acreate_task(self, task_data: TaskRequest | dict[str, Any]):
|
457
|
+
"""Create a new futurehouse task."""
|
458
|
+
if isinstance(task_data, dict):
|
459
|
+
task_data = TaskRequest.model_validate(task_data)
|
460
|
+
|
461
|
+
if isinstance(task_data.name, JobNames):
|
462
|
+
task_data.name = task_data.name.from_stage(
|
463
|
+
task_data.name.name,
|
464
|
+
self.stage,
|
465
|
+
)
|
466
|
+
|
467
|
+
try:
|
468
|
+
response = await self.async_client.post(
|
469
|
+
"/v0.1/crows", json=task_data.model_dump(mode="json")
|
470
|
+
)
|
471
|
+
response.raise_for_status()
|
472
|
+
trajectory_id = response.json()["trajectory_id"]
|
473
|
+
self.trajectory_id = trajectory_id
|
474
|
+
except Exception as e:
|
475
|
+
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
476
|
+
return trajectory_id
|
477
|
+
|
478
|
+
async def arun_tasks_until_done(
|
479
|
+
self,
|
480
|
+
task_data: TaskRequest
|
481
|
+
| dict[str, Any]
|
482
|
+
| Collection[TaskRequest]
|
483
|
+
| Collection[dict[str, Any]],
|
484
|
+
verbose: bool = False,
|
485
|
+
progress_bar: bool = False,
|
486
|
+
concurrency: int = 10,
|
487
|
+
timeout: int = DEFAULT_AGENT_TIMEOUT,
|
488
|
+
) -> list[TaskResponse]:
|
489
|
+
all_tasks: Collection[TaskRequest | dict[str, Any]] = (
|
490
|
+
cast(Collection[TaskRequest | dict[str, Any]], [task_data])
|
491
|
+
if (isinstance(task_data, dict) or not isinstance(task_data, Collection))
|
492
|
+
else cast(Collection[TaskRequest | dict[str, Any]], task_data)
|
493
|
+
)
|
494
|
+
|
495
|
+
trajectory_ids = await gather_with_concurrency(
|
496
|
+
concurrency,
|
497
|
+
[self.acreate_task(task) for task in all_tasks],
|
498
|
+
progress=progress_bar,
|
499
|
+
)
|
500
|
+
|
501
|
+
start_time = time.monotonic()
|
502
|
+
completed_tasks: dict[str, TaskResponse] = {}
|
503
|
+
|
504
|
+
if progress_bar:
|
505
|
+
progress = tqdm(
|
506
|
+
total=len(trajectory_ids), desc="Waiting for tasks to finish", ncols=0
|
507
|
+
)
|
508
|
+
|
509
|
+
while (time.monotonic() - start_time) < timeout:
|
510
|
+
task_results = await gather_with_concurrency(
|
511
|
+
concurrency,
|
512
|
+
[
|
513
|
+
self.aget_task(task_id, verbose=verbose)
|
514
|
+
for task_id in trajectory_ids
|
515
|
+
if task_id not in completed_tasks
|
516
|
+
],
|
517
|
+
)
|
518
|
+
|
519
|
+
for task in task_results:
|
520
|
+
task_id = str(task.task_id)
|
521
|
+
if (
|
522
|
+
task_id not in completed_tasks
|
523
|
+
and ExecutionStatus(task.status).is_terminal_state()
|
524
|
+
):
|
525
|
+
completed_tasks[task_id] = task
|
526
|
+
if progress_bar:
|
527
|
+
progress.update(1)
|
528
|
+
|
529
|
+
all_done = len(completed_tasks) == len(trajectory_ids)
|
530
|
+
|
531
|
+
if all_done:
|
532
|
+
break
|
533
|
+
await asyncio.sleep(self.DEFAULT_POLLING_TIME)
|
534
|
+
|
535
|
+
else:
|
536
|
+
logger.warning(
|
537
|
+
f"Timed out waiting for tasks to finish after {timeout} seconds. Returning with {len(completed_tasks)} completed tasks and {len(trajectory_ids)} total tasks."
|
538
|
+
)
|
539
|
+
|
540
|
+
if progress_bar:
|
541
|
+
progress.close()
|
542
|
+
|
543
|
+
return [
|
544
|
+
completed_tasks.get(task_id)
|
545
|
+
or (await self.aget_task(task_id, verbose=verbose))
|
546
|
+
for task_id in trajectory_ids
|
547
|
+
]
|
548
|
+
|
549
|
+
def run_tasks_until_done(
|
550
|
+
self,
|
551
|
+
task_data: TaskRequest
|
552
|
+
| dict[str, Any]
|
553
|
+
| Collection[TaskRequest]
|
554
|
+
| Collection[dict[str, Any]],
|
555
|
+
verbose: bool = False,
|
556
|
+
progress_bar: bool = False,
|
557
|
+
timeout: int = DEFAULT_AGENT_TIMEOUT,
|
558
|
+
) -> list[TaskResponse]:
|
559
|
+
"""Run multiple tasks and wait for them to complete.
|
560
|
+
|
561
|
+
Args:
|
562
|
+
task_data: A single task or collection of tasks to run
|
563
|
+
verbose: Whether to return verbose task responses
|
564
|
+
progress_bar: Whether to display a progress bar
|
565
|
+
timeout: Maximum time to wait for task completion in seconds
|
566
|
+
|
567
|
+
Returns:
|
568
|
+
A list of completed task responses
|
569
|
+
"""
|
570
|
+
all_tasks: Collection[TaskRequest | dict[str, Any]] = (
|
571
|
+
cast(Collection[TaskRequest | dict[str, Any]], [task_data])
|
572
|
+
if (isinstance(task_data, dict) or not isinstance(task_data, Collection))
|
573
|
+
else cast(Collection[TaskRequest | dict[str, Any]], task_data)
|
574
|
+
)
|
575
|
+
|
576
|
+
trajectory_ids = [self.create_task(task) for task in all_tasks]
|
577
|
+
|
578
|
+
start_time = time.monotonic()
|
579
|
+
completed_tasks: dict[str, TaskResponse] = {}
|
580
|
+
|
581
|
+
if progress_bar:
|
582
|
+
progress = sync_tqdm(
|
583
|
+
total=len(trajectory_ids), desc="Waiting for tasks to finish", ncols=0
|
584
|
+
)
|
585
|
+
|
586
|
+
while (time.monotonic() - start_time) < timeout:
|
587
|
+
all_done = True
|
588
|
+
|
589
|
+
for task_id in trajectory_ids:
|
590
|
+
if task_id in completed_tasks:
|
591
|
+
continue
|
592
|
+
|
593
|
+
task = self.get_task(task_id, verbose=verbose)
|
594
|
+
|
595
|
+
if not ExecutionStatus(task.status).is_terminal_state():
|
596
|
+
all_done = False
|
597
|
+
elif task_id not in completed_tasks:
|
598
|
+
completed_tasks[task_id] = task
|
599
|
+
if progress_bar:
|
600
|
+
progress.update(1)
|
601
|
+
|
602
|
+
if all_done:
|
603
|
+
break
|
604
|
+
time.sleep(self.DEFAULT_POLLING_TIME)
|
605
|
+
|
606
|
+
else:
|
607
|
+
logger.warning(
|
608
|
+
f"Timed out waiting for tasks to finish after {timeout} seconds. Returning with {len(completed_tasks)} completed tasks and {len(trajectory_ids)} total tasks."
|
609
|
+
)
|
610
|
+
|
611
|
+
if progress_bar:
|
612
|
+
progress.close()
|
613
|
+
|
614
|
+
return [
|
615
|
+
completed_tasks.get(task_id) or self.get_task(task_id, verbose=verbose)
|
616
|
+
for task_id in trajectory_ids
|
617
|
+
]
|
452
618
|
|
453
619
|
@retry(
|
454
620
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
@@ -457,9 +623,12 @@ class RestClient:
|
|
457
623
|
)
|
458
624
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
459
625
|
"""Get the status of a build."""
|
460
|
-
|
461
|
-
|
462
|
-
|
626
|
+
try:
|
627
|
+
build_id = build_id or self.build_id
|
628
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
629
|
+
response.raise_for_status()
|
630
|
+
except Exception as e:
|
631
|
+
raise JobFetchError(f"Error getting build status: {e!s}") from e
|
463
632
|
return response.json()
|
464
633
|
|
465
634
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
@@ -660,14 +829,14 @@ class RestClient:
|
|
660
829
|
self,
|
661
830
|
job_name: str,
|
662
831
|
file_path: str | os.PathLike,
|
663
|
-
|
832
|
+
upload_id: str | None = None,
|
664
833
|
) -> str:
|
665
834
|
"""Upload a file or directory to a futurehouse job bucket.
|
666
835
|
|
667
836
|
Args:
|
668
837
|
job_name: The name of the futurehouse job to upload to.
|
669
838
|
file_path: The local path to the file or directory to upload.
|
670
|
-
|
839
|
+
upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
|
671
840
|
|
672
841
|
Returns:
|
673
842
|
The upload ID used for the upload.
|
@@ -679,7 +848,7 @@ class RestClient:
|
|
679
848
|
if not file_path.exists():
|
680
849
|
raise FileNotFoundError(f"File or directory not found: {file_path}")
|
681
850
|
|
682
|
-
upload_id =
|
851
|
+
upload_id = upload_id or str(uuid.uuid4())
|
683
852
|
|
684
853
|
if file_path.is_dir():
|
685
854
|
# Process directory recursively
|
@@ -742,6 +911,12 @@ class RestClient:
|
|
742
911
|
"""
|
743
912
|
file_name = file_name or file_path.name
|
744
913
|
file_size = file_path.stat().st_size
|
914
|
+
|
915
|
+
# Skip empty files
|
916
|
+
if file_size == 0:
|
917
|
+
logger.warning(f"Skipping upload of empty file: {file_path}")
|
918
|
+
return
|
919
|
+
|
745
920
|
total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
|
746
921
|
|
747
922
|
logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
|
@@ -789,7 +964,6 @@ class RestClient:
|
|
789
964
|
)
|
790
965
|
|
791
966
|
logger.info(f"Successfully uploaded {file_name}")
|
792
|
-
|
793
967
|
except Exception as e:
|
794
968
|
logger.exception(f"Error uploading file {file_path}")
|
795
969
|
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
@@ -799,12 +973,18 @@ class RestClient:
|
|
799
973
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
800
974
|
retry=retry_if_connection_error,
|
801
975
|
)
|
802
|
-
def list_files(
|
976
|
+
def list_files(
|
977
|
+
self,
|
978
|
+
job_name: str,
|
979
|
+
trajectory_id: str | None = None,
|
980
|
+
upload_id: str | None = None,
|
981
|
+
) -> dict[str, list[str]]:
|
803
982
|
"""List files and directories in a GCS location for a given job_name and upload_id.
|
804
983
|
|
805
984
|
Args:
|
806
985
|
job_name: The name of the futurehouse job.
|
807
|
-
|
986
|
+
trajectory_id: The specific trajectory id to list files from.
|
987
|
+
upload_id: The specific upload id to list files from.
|
808
988
|
|
809
989
|
Returns:
|
810
990
|
A list of files in the GCS folder.
|
@@ -812,22 +992,27 @@ class RestClient:
|
|
812
992
|
Raises:
|
813
993
|
RestClientError: If there is an error listing the files.
|
814
994
|
"""
|
995
|
+
if not bool(trajectory_id) ^ bool(upload_id):
|
996
|
+
raise RestClientError(
|
997
|
+
"Must at least specify one of trajectory_id or upload_id, but not both"
|
998
|
+
)
|
815
999
|
try:
|
816
1000
|
url = f"/v0.1/crows/{job_name}/list-files"
|
817
|
-
params = {"upload_id":
|
1001
|
+
params = {"trajectory_id": trajectory_id, "upload_id": upload_id}
|
1002
|
+
params = {k: v for k, v in params.items() if v is not None}
|
818
1003
|
response = self.client.get(url, params=params)
|
819
1004
|
response.raise_for_status()
|
820
1005
|
return response.json()
|
821
1006
|
except HTTPStatusError as e:
|
822
1007
|
logger.exception(
|
823
|
-
f"Error listing files for job {job_name},
|
1008
|
+
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
824
1009
|
)
|
825
1010
|
raise RestClientError(
|
826
1011
|
f"Error listing files: {e.response.status_code} - {e.response.text}"
|
827
1012
|
) from e
|
828
1013
|
except Exception as e:
|
829
1014
|
logger.exception(
|
830
|
-
f"Error listing files for job {job_name},
|
1015
|
+
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}"
|
831
1016
|
)
|
832
1017
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
833
1018
|
|
@@ -839,7 +1024,7 @@ class RestClient:
|
|
839
1024
|
def download_file(
|
840
1025
|
self,
|
841
1026
|
job_name: str,
|
842
|
-
|
1027
|
+
trajectory_id: str,
|
843
1028
|
file_path: str,
|
844
1029
|
destination_path: str | os.PathLike,
|
845
1030
|
) -> None:
|
@@ -847,14 +1032,14 @@ class RestClient:
|
|
847
1032
|
|
848
1033
|
Args:
|
849
1034
|
job_name: The name of the futurehouse job.
|
850
|
-
|
1035
|
+
trajectory_id: The specific trajectory id the file belongs to.
|
851
1036
|
file_path: The relative path of the file to download
|
852
1037
|
(e.g., 'data/my_file.csv' or 'my_image.png').
|
853
1038
|
destination_path: The local path where the file should be saved.
|
854
1039
|
|
855
1040
|
Raises:
|
856
1041
|
RestClientError: If there is an error downloading the file.
|
857
|
-
FileNotFoundError: If the destination directory does not exist.
|
1042
|
+
FileNotFoundError: If the destination directory does not exist or if the file is not found.
|
858
1043
|
"""
|
859
1044
|
destination_path = Path(destination_path)
|
860
1045
|
# Ensure the destination directory exists
|
@@ -862,17 +1047,24 @@ class RestClient:
|
|
862
1047
|
|
863
1048
|
try:
|
864
1049
|
url = f"/v0.1/crows/{job_name}/download-file"
|
865
|
-
params = {"
|
1050
|
+
params = {"trajectory_id": trajectory_id, "file_path": file_path}
|
866
1051
|
|
867
1052
|
with self.client.stream("GET", url, params=params) as response:
|
868
1053
|
response.raise_for_status() # Check for HTTP errors before streaming
|
869
1054
|
with open(destination_path, "wb") as f:
|
870
1055
|
for chunk in response.iter_bytes(chunk_size=8192):
|
871
1056
|
f.write(chunk)
|
1057
|
+
|
1058
|
+
# Check if the downloaded file is empty
|
1059
|
+
if destination_path.stat().st_size == 0:
|
1060
|
+
# Remove the empty file
|
1061
|
+
destination_path.unlink()
|
1062
|
+
raise FileNotFoundError(f"File not found or is empty: {file_path}")
|
1063
|
+
|
872
1064
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
873
1065
|
except HTTPStatusError as e:
|
874
1066
|
logger.exception(
|
875
|
-
f"Error downloading file {file_path} for job {job_name},
|
1067
|
+
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
876
1068
|
)
|
877
1069
|
# Clean up partially downloaded file if an error occurs
|
878
1070
|
if destination_path.exists():
|
@@ -880,9 +1072,20 @@ class RestClient:
|
|
880
1072
|
raise RestClientError(
|
881
1073
|
f"Error downloading file: {e.response.status_code} - {e.response.text}"
|
882
1074
|
) from e
|
1075
|
+
except RemoteProtocolError as e:
|
1076
|
+
logger.error(
|
1077
|
+
f"Connection error while downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
|
1078
|
+
)
|
1079
|
+
# Clean up partially downloaded file
|
1080
|
+
if destination_path.exists():
|
1081
|
+
destination_path.unlink()
|
1082
|
+
|
1083
|
+
# Often RemoteProtocolError during download means the file wasn't found
|
1084
|
+
# or was empty/corrupted on the server side
|
1085
|
+
raise FileNotFoundError(f"File not found or corrupted: {file_path}") from e
|
883
1086
|
except Exception as e:
|
884
1087
|
logger.exception(
|
885
|
-
f"Error downloading file {file_path} for job {job_name},
|
1088
|
+
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
|
886
1089
|
)
|
887
1090
|
if destination_path.exists():
|
888
1091
|
destination_path.unlink() # Clean up partial file
|