futurehouse-client 0.3.17.dev94__py3-none-any.whl → 0.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/clients/job_client.py +1 -0
- futurehouse_client/clients/rest_client.py +128 -182
- futurehouse_client/models/__init__.py +10 -0
- futurehouse_client/models/app.py +96 -0
- futurehouse_client/utils/auth.py +92 -0
- {futurehouse_client-0.3.17.dev94.dist-info → futurehouse_client-0.3.18.dist-info}/METADATA +1 -1
- futurehouse_client-0.3.18.dist-info/RECORD +17 -0
- {futurehouse_client-0.3.17.dev94.dist-info → futurehouse_client-0.3.18.dist-info}/WHEEL +1 -1
- futurehouse_client-0.3.17.dev94.dist-info/RECORD +0 -16
- {futurehouse_client-0.3.17.dev94.dist-info → futurehouse_client-0.3.18.dist-info}/top_level.txt +0 -0
@@ -30,6 +30,7 @@ class JobNames(StrEnum):
|
|
30
30
|
OWL = "job-futurehouse-hasanyone"
|
31
31
|
DUMMY = "job-futurehouse-dummy-env"
|
32
32
|
PHOENIX = "job-futurehouse-phoenix"
|
33
|
+
FINCH = "job-futurehouse-data-analysis-crow-high"
|
33
34
|
|
34
35
|
@classmethod
|
35
36
|
def from_stage(cls, job_name: str, stage: Stage | None = None) -> str:
|
@@ -8,14 +8,14 @@ import inspect
|
|
8
8
|
import json
|
9
9
|
import logging
|
10
10
|
import os
|
11
|
+
import sys
|
11
12
|
import tempfile
|
12
13
|
import time
|
13
14
|
import uuid
|
14
|
-
from collections.abc import Collection
|
15
|
-
from datetime import datetime
|
15
|
+
from collections.abc import Collection
|
16
16
|
from pathlib import Path
|
17
17
|
from types import ModuleType
|
18
|
-
from typing import Any, ClassVar,
|
18
|
+
from typing import Any, ClassVar, cast
|
19
19
|
from uuid import UUID
|
20
20
|
|
21
21
|
import cloudpickle
|
@@ -33,7 +33,6 @@ from httpx import (
|
|
33
33
|
RemoteProtocolError,
|
34
34
|
)
|
35
35
|
from ldp.agent import AgentConfig
|
36
|
-
from pydantic import BaseModel, ConfigDict, model_validator
|
37
36
|
from requests.exceptions import RequestException, Timeout
|
38
37
|
from tenacity import (
|
39
38
|
retry,
|
@@ -46,13 +45,16 @@ from tqdm.asyncio import tqdm
|
|
46
45
|
|
47
46
|
from futurehouse_client.clients import JobNames
|
48
47
|
from futurehouse_client.models.app import (
|
49
|
-
APIKeyPayload,
|
50
48
|
AuthType,
|
51
49
|
JobDeploymentConfig,
|
50
|
+
PQATaskResponse,
|
52
51
|
Stage,
|
53
52
|
TaskRequest,
|
53
|
+
TaskResponse,
|
54
|
+
TaskResponseVerbose,
|
54
55
|
)
|
55
56
|
from futurehouse_client.models.rest import ExecutionStatus
|
57
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
56
58
|
from futurehouse_client.utils.general import gather_with_concurrency
|
57
59
|
from futurehouse_client.utils.module_utils import (
|
58
60
|
OrganizationSelector,
|
@@ -63,24 +65,14 @@ from futurehouse_client.utils.monitoring import (
|
|
63
65
|
)
|
64
66
|
|
65
67
|
logger = logging.getLogger(__name__)
|
66
|
-
|
68
|
+
logging.basicConfig(
|
69
|
+
level=logging.WARNING,
|
70
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
71
|
+
stream=sys.stdout,
|
72
|
+
)
|
73
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
67
74
|
TaskRequest.model_rebuild()
|
68
75
|
|
69
|
-
retry_if_connection_error = retry_if_exception_type((
|
70
|
-
# From requests
|
71
|
-
Timeout,
|
72
|
-
ConnectionError,
|
73
|
-
RequestException,
|
74
|
-
# From httpx
|
75
|
-
ConnectError,
|
76
|
-
ConnectTimeout,
|
77
|
-
ReadTimeout,
|
78
|
-
ReadError,
|
79
|
-
NetworkError,
|
80
|
-
RemoteProtocolError,
|
81
|
-
CloseError,
|
82
|
-
))
|
83
|
-
|
84
76
|
FILE_UPLOAD_IGNORE_PARTS = {
|
85
77
|
".ruff_cache",
|
86
78
|
"__pycache__",
|
@@ -111,104 +103,27 @@ class InvalidTaskDescriptionError(Exception):
|
|
111
103
|
"""Raised when the task description is invalid or empty."""
|
112
104
|
|
113
105
|
|
114
|
-
class
|
115
|
-
|
116
|
-
name: str
|
117
|
-
display_name: str
|
118
|
-
|
119
|
-
|
120
|
-
# 5 minute default for JWTs
|
121
|
-
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
122
|
-
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
123
|
-
|
106
|
+
class FileUploadError(RestClientError):
|
107
|
+
"""Raised when there's an error uploading a file."""
|
124
108
|
|
125
|
-
class TaskResponse(BaseModel):
|
126
|
-
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
127
|
-
|
128
|
-
model_config = ConfigDict(extra="ignore")
|
129
|
-
|
130
|
-
status: str
|
131
|
-
query: str
|
132
|
-
user: str | None = None
|
133
|
-
created_at: datetime
|
134
|
-
job_name: str
|
135
|
-
public: bool
|
136
|
-
shared_with: list[SimpleOrganization] | None = None
|
137
|
-
build_owner: str | None = None
|
138
|
-
environment_name: str | None = None
|
139
|
-
agent_name: str | None = None
|
140
|
-
task_id: UUID | None = None
|
141
|
-
|
142
|
-
@model_validator(mode="before")
|
143
|
-
@classmethod
|
144
|
-
def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
145
|
-
# Extract fields from environment frame state
|
146
|
-
if not isinstance(data, dict):
|
147
|
-
return data
|
148
|
-
# TODO: We probably want to remove these two once we define the final names.
|
149
|
-
data["job_name"] = data.get("crow")
|
150
|
-
data["query"] = data.get("task")
|
151
|
-
data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
|
152
|
-
if not (metadata := data.get("metadata", {})):
|
153
|
-
return data
|
154
|
-
data["environment_name"] = metadata.get("environment_name")
|
155
|
-
data["agent_name"] = metadata.get("agent_name")
|
156
|
-
return data
|
157
|
-
|
158
|
-
|
159
|
-
class PQATaskResponse(TaskResponse):
|
160
|
-
model_config = ConfigDict(extra="ignore")
|
161
|
-
|
162
|
-
answer: str | None = None
|
163
|
-
formatted_answer: str | None = None
|
164
|
-
answer_reasoning: str | None = None
|
165
|
-
has_successful_answer: bool | None = None
|
166
|
-
total_cost: float | None = None
|
167
|
-
total_queries: int | None = None
|
168
|
-
|
169
|
-
@model_validator(mode="before")
|
170
|
-
@classmethod
|
171
|
-
def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
172
|
-
if not isinstance(data, dict):
|
173
|
-
return data
|
174
|
-
if not (env_frame := data.get("environment_frame", {})):
|
175
|
-
return data
|
176
|
-
state = env_frame.get("state", {}).get("state", {})
|
177
|
-
response = state.get("response", {})
|
178
|
-
answer = response.get("answer", {})
|
179
|
-
usage = state.get("info", {}).get("usage", {})
|
180
|
-
|
181
|
-
# Add additional PQA specific fields to data so that pydantic can validate the model
|
182
|
-
data["answer"] = answer.get("answer")
|
183
|
-
data["formatted_answer"] = answer.get("formatted_answer")
|
184
|
-
data["answer_reasoning"] = answer.get("answer_reasoning")
|
185
|
-
data["has_successful_answer"] = answer.get("has_successful_answer")
|
186
|
-
data["total_cost"] = cast(float, usage.get("total_cost"))
|
187
|
-
data["total_queries"] = cast(int, usage.get("total_queries"))
|
188
|
-
|
189
|
-
return data
|
190
|
-
|
191
|
-
def clean_verbose(self) -> "TaskResponse":
|
192
|
-
"""Clean the verbose response from the server."""
|
193
|
-
self.request = None
|
194
|
-
self.response = None
|
195
|
-
return self
|
196
|
-
|
197
|
-
|
198
|
-
class TaskResponseVerbose(TaskResponse):
|
199
|
-
"""Class for responses to include all the fields of a task response."""
|
200
|
-
|
201
|
-
model_config = ConfigDict(extra="allow")
|
202
|
-
|
203
|
-
public: bool
|
204
|
-
agent_state: list[dict[str, Any]] | None = None
|
205
|
-
environment_frame: dict[str, Any] | None = None
|
206
|
-
metadata: dict[str, Any] | None = None
|
207
|
-
shared_with: list[SimpleOrganization] | None = None
|
208
109
|
|
110
|
+
retry_if_connection_error = retry_if_exception_type((
|
111
|
+
# From requests
|
112
|
+
Timeout,
|
113
|
+
ConnectionError,
|
114
|
+
RequestException,
|
115
|
+
# From httpx
|
116
|
+
ConnectError,
|
117
|
+
ConnectTimeout,
|
118
|
+
ReadTimeout,
|
119
|
+
ReadError,
|
120
|
+
NetworkError,
|
121
|
+
RemoteProtocolError,
|
122
|
+
CloseError,
|
123
|
+
FileUploadError,
|
124
|
+
))
|
209
125
|
|
210
|
-
|
211
|
-
"""Raised when there's an error uploading a file."""
|
126
|
+
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
212
127
|
|
213
128
|
|
214
129
|
class RestClient:
|
@@ -228,76 +143,98 @@ class RestClient:
|
|
228
143
|
api_key: str | None = None,
|
229
144
|
jwt: str | None = None,
|
230
145
|
headers: dict[str, str] | None = None,
|
146
|
+
verbose_logging: bool = False,
|
231
147
|
):
|
148
|
+
if verbose_logging:
|
149
|
+
logger.setLevel(logging.INFO)
|
150
|
+
else:
|
151
|
+
logger.setLevel(logging.WARNING)
|
152
|
+
|
232
153
|
self.base_url = service_uri or stage.value
|
233
154
|
self.stage = stage
|
234
155
|
self.auth_type = auth_type
|
235
156
|
self.api_key = api_key
|
236
157
|
self._clients: dict[str, Client | AsyncClient] = {}
|
237
158
|
self.headers = headers or {}
|
238
|
-
self.
|
159
|
+
self.jwt = jwt
|
239
160
|
self.organizations: list[str] = self._filter_orgs(organization)
|
240
161
|
|
241
162
|
@property
|
242
163
|
def client(self) -> Client:
|
243
|
-
"""
|
244
|
-
return cast(Client, self.get_client("application/json",
|
164
|
+
"""Authenticated HTTP client for regular API calls."""
|
165
|
+
return cast(Client, self.get_client("application/json", authenticated=True))
|
245
166
|
|
246
167
|
@property
|
247
168
|
def async_client(self) -> AsyncClient:
|
248
|
-
"""
|
169
|
+
"""Authenticated async HTTP client for regular API calls."""
|
249
170
|
return cast(
|
250
171
|
AsyncClient,
|
251
|
-
self.get_client("application/json",
|
172
|
+
self.get_client("application/json", authenticated=True, async_client=True),
|
252
173
|
)
|
253
174
|
|
254
175
|
@property
|
255
|
-
def
|
256
|
-
"""
|
257
|
-
return cast(Client, self.get_client("application/json",
|
176
|
+
def unauthenticated_client(self) -> Client:
|
177
|
+
"""Unauthenticated HTTP client for auth operations."""
|
178
|
+
return cast(Client, self.get_client("application/json", authenticated=False))
|
258
179
|
|
259
180
|
@property
|
260
181
|
def multipart_client(self) -> Client:
|
261
|
-
"""
|
262
|
-
return cast(Client, self.get_client(None,
|
182
|
+
"""Authenticated HTTP client for multipart uploads."""
|
183
|
+
return cast(Client, self.get_client(None, authenticated=True))
|
263
184
|
|
264
185
|
def get_client(
|
265
186
|
self,
|
266
187
|
content_type: str | None = "application/json",
|
267
|
-
|
268
|
-
|
188
|
+
authenticated: bool = True,
|
189
|
+
async_client: bool = False,
|
269
190
|
) -> Client | AsyncClient:
|
270
191
|
"""Return a cached HTTP client or create one if needed.
|
271
192
|
|
272
193
|
Args:
|
273
194
|
content_type: The desired content type header. Use None for multipart uploads.
|
274
|
-
|
275
|
-
|
195
|
+
authenticated: Whether the client should include authentication.
|
196
|
+
async_client: Whether to use an async client.
|
276
197
|
|
277
198
|
Returns:
|
278
199
|
An HTTP client configured with the appropriate headers.
|
279
200
|
"""
|
280
|
-
# Create a composite key based on content type and auth flag
|
281
|
-
key = f"{content_type or 'multipart'}_{
|
201
|
+
# Create a composite key based on content type and auth flag
|
202
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
203
|
+
|
282
204
|
if key not in self._clients:
|
283
205
|
headers = copy.deepcopy(self.headers)
|
284
|
-
|
285
|
-
|
206
|
+
auth = None
|
207
|
+
|
208
|
+
if authenticated:
|
209
|
+
auth = RefreshingJWT(
|
210
|
+
# authenticated=False will always return a synchronous client
|
211
|
+
auth_client=cast(
|
212
|
+
Client, self.get_client("application/json", authenticated=False)
|
213
|
+
),
|
214
|
+
auth_type=self.auth_type,
|
215
|
+
api_key=self.api_key,
|
216
|
+
jwt=self.jwt,
|
217
|
+
)
|
218
|
+
|
286
219
|
if content_type:
|
287
220
|
headers["Content-Type"] = content_type
|
221
|
+
|
288
222
|
self._clients[key] = (
|
289
223
|
AsyncClient(
|
290
224
|
base_url=self.base_url,
|
291
225
|
headers=headers,
|
292
226
|
timeout=self.REQUEST_TIMEOUT,
|
227
|
+
auth=auth,
|
293
228
|
)
|
294
|
-
if
|
229
|
+
if async_client
|
295
230
|
else Client(
|
296
231
|
base_url=self.base_url,
|
297
232
|
headers=headers,
|
298
233
|
timeout=self.REQUEST_TIMEOUT,
|
234
|
+
auth=auth,
|
299
235
|
)
|
300
236
|
)
|
237
|
+
|
301
238
|
return self._clients[key]
|
302
239
|
|
303
240
|
def close(self):
|
@@ -327,31 +264,6 @@ class RestClient:
|
|
327
264
|
raise ValueError(f"Organization '{organization}' not found.")
|
328
265
|
return filtered_orgs
|
329
266
|
|
330
|
-
def _run_auth(self, jwt: str | None = None) -> str:
|
331
|
-
auth_payload: APIKeyPayload | None
|
332
|
-
if self.auth_type == AuthType.API_KEY:
|
333
|
-
auth_payload = APIKeyPayload(api_key=self.api_key)
|
334
|
-
elif self.auth_type == AuthType.JWT:
|
335
|
-
auth_payload = None
|
336
|
-
else:
|
337
|
-
assert_never(self.auth_type)
|
338
|
-
try:
|
339
|
-
# Use the unauthenticated client for login
|
340
|
-
if auth_payload:
|
341
|
-
response = self.auth_client.post(
|
342
|
-
"/auth/login", json=auth_payload.model_dump()
|
343
|
-
)
|
344
|
-
response.raise_for_status()
|
345
|
-
token_data = response.json()
|
346
|
-
elif jwt:
|
347
|
-
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
348
|
-
else:
|
349
|
-
raise ValueError("JWT token required for JWT authentication.")
|
350
|
-
|
351
|
-
return token_data["access_token"]
|
352
|
-
except Exception as e:
|
353
|
-
raise RestClientError(f"Error authenticating: {e!s}") from e
|
354
|
-
|
355
267
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
356
268
|
try:
|
357
269
|
response = self.client.get(
|
@@ -445,6 +357,7 @@ class RestClient:
|
|
445
357
|
),
|
446
358
|
self.client.stream("GET", url, params={"history": history}) as response,
|
447
359
|
):
|
360
|
+
response.raise_for_status()
|
448
361
|
json_data = "".join(response.iter_text(chunk_size=1024))
|
449
362
|
data = json.loads(json_data)
|
450
363
|
if "id" not in data:
|
@@ -459,8 +372,6 @@ class RestClient:
|
|
459
372
|
):
|
460
373
|
return PQATaskResponse(**data)
|
461
374
|
return TaskResponse(**data)
|
462
|
-
except ValueError as e:
|
463
|
-
raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
|
464
375
|
except Exception as e:
|
465
376
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
466
377
|
|
@@ -507,8 +418,6 @@ class RestClient:
|
|
507
418
|
):
|
508
419
|
return PQATaskResponse(**data)
|
509
420
|
return TaskResponse(**data)
|
510
|
-
except ValueError as e:
|
511
|
-
raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
|
512
421
|
except Exception as e:
|
513
422
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
514
423
|
|
@@ -714,9 +623,12 @@ class RestClient:
|
|
714
623
|
)
|
715
624
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
716
625
|
"""Get the status of a build."""
|
717
|
-
|
718
|
-
|
719
|
-
|
626
|
+
try:
|
627
|
+
build_id = build_id or self.build_id
|
628
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
629
|
+
response.raise_for_status()
|
630
|
+
except Exception as e:
|
631
|
+
raise JobFetchError(f"Error getting build status: {e!s}") from e
|
720
632
|
return response.json()
|
721
633
|
|
722
634
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
@@ -917,14 +829,14 @@ class RestClient:
|
|
917
829
|
self,
|
918
830
|
job_name: str,
|
919
831
|
file_path: str | os.PathLike,
|
920
|
-
|
832
|
+
upload_id: str | None = None,
|
921
833
|
) -> str:
|
922
834
|
"""Upload a file or directory to a futurehouse job bucket.
|
923
835
|
|
924
836
|
Args:
|
925
837
|
job_name: The name of the futurehouse job to upload to.
|
926
838
|
file_path: The local path to the file or directory to upload.
|
927
|
-
|
839
|
+
upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
|
928
840
|
|
929
841
|
Returns:
|
930
842
|
The upload ID used for the upload.
|
@@ -936,7 +848,7 @@ class RestClient:
|
|
936
848
|
if not file_path.exists():
|
937
849
|
raise FileNotFoundError(f"File or directory not found: {file_path}")
|
938
850
|
|
939
|
-
upload_id =
|
851
|
+
upload_id = upload_id or str(uuid.uuid4())
|
940
852
|
|
941
853
|
if file_path.is_dir():
|
942
854
|
# Process directory recursively
|
@@ -999,6 +911,12 @@ class RestClient:
|
|
999
911
|
"""
|
1000
912
|
file_name = file_name or file_path.name
|
1001
913
|
file_size = file_path.stat().st_size
|
914
|
+
|
915
|
+
# Skip empty files
|
916
|
+
if file_size == 0:
|
917
|
+
logger.warning(f"Skipping upload of empty file: {file_path}")
|
918
|
+
return
|
919
|
+
|
1002
920
|
total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
|
1003
921
|
|
1004
922
|
logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
|
@@ -1046,7 +964,6 @@ class RestClient:
|
|
1046
964
|
)
|
1047
965
|
|
1048
966
|
logger.info(f"Successfully uploaded {file_name}")
|
1049
|
-
|
1050
967
|
except Exception as e:
|
1051
968
|
logger.exception(f"Error uploading file {file_path}")
|
1052
969
|
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
@@ -1056,12 +973,18 @@ class RestClient:
|
|
1056
973
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1057
974
|
retry=retry_if_connection_error,
|
1058
975
|
)
|
1059
|
-
def list_files(
|
976
|
+
def list_files(
|
977
|
+
self,
|
978
|
+
job_name: str,
|
979
|
+
trajectory_id: str | None = None,
|
980
|
+
upload_id: str | None = None,
|
981
|
+
) -> dict[str, list[str]]:
|
1060
982
|
"""List files and directories in a GCS location for a given job_name and upload_id.
|
1061
983
|
|
1062
984
|
Args:
|
1063
985
|
job_name: The name of the futurehouse job.
|
1064
|
-
|
986
|
+
trajectory_id: The specific trajectory id to list files from.
|
987
|
+
upload_id: The specific upload id to list files from.
|
1065
988
|
|
1066
989
|
Returns:
|
1067
990
|
A list of files in the GCS folder.
|
@@ -1069,22 +992,27 @@ class RestClient:
|
|
1069
992
|
Raises:
|
1070
993
|
RestClientError: If there is an error listing the files.
|
1071
994
|
"""
|
995
|
+
if not bool(trajectory_id) ^ bool(upload_id):
|
996
|
+
raise RestClientError(
|
997
|
+
"Must at least specify one of trajectory_id or upload_id, but not both"
|
998
|
+
)
|
1072
999
|
try:
|
1073
1000
|
url = f"/v0.1/crows/{job_name}/list-files"
|
1074
|
-
params = {"upload_id":
|
1001
|
+
params = {"trajectory_id": trajectory_id, "upload_id": upload_id}
|
1002
|
+
params = {k: v for k, v in params.items() if v is not None}
|
1075
1003
|
response = self.client.get(url, params=params)
|
1076
1004
|
response.raise_for_status()
|
1077
1005
|
return response.json()
|
1078
1006
|
except HTTPStatusError as e:
|
1079
1007
|
logger.exception(
|
1080
|
-
f"Error listing files for job {job_name},
|
1008
|
+
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
1081
1009
|
)
|
1082
1010
|
raise RestClientError(
|
1083
1011
|
f"Error listing files: {e.response.status_code} - {e.response.text}"
|
1084
1012
|
) from e
|
1085
1013
|
except Exception as e:
|
1086
1014
|
logger.exception(
|
1087
|
-
f"Error listing files for job {job_name},
|
1015
|
+
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}"
|
1088
1016
|
)
|
1089
1017
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
1090
1018
|
|
@@ -1096,7 +1024,7 @@ class RestClient:
|
|
1096
1024
|
def download_file(
|
1097
1025
|
self,
|
1098
1026
|
job_name: str,
|
1099
|
-
|
1027
|
+
trajectory_id: str,
|
1100
1028
|
file_path: str,
|
1101
1029
|
destination_path: str | os.PathLike,
|
1102
1030
|
) -> None:
|
@@ -1104,14 +1032,14 @@ class RestClient:
|
|
1104
1032
|
|
1105
1033
|
Args:
|
1106
1034
|
job_name: The name of the futurehouse job.
|
1107
|
-
|
1035
|
+
trajectory_id: The specific trajectory id the file belongs to.
|
1108
1036
|
file_path: The relative path of the file to download
|
1109
1037
|
(e.g., 'data/my_file.csv' or 'my_image.png').
|
1110
1038
|
destination_path: The local path where the file should be saved.
|
1111
1039
|
|
1112
1040
|
Raises:
|
1113
1041
|
RestClientError: If there is an error downloading the file.
|
1114
|
-
FileNotFoundError: If the destination directory does not exist.
|
1042
|
+
FileNotFoundError: If the destination directory does not exist or if the file is not found.
|
1115
1043
|
"""
|
1116
1044
|
destination_path = Path(destination_path)
|
1117
1045
|
# Ensure the destination directory exists
|
@@ -1119,17 +1047,24 @@ class RestClient:
|
|
1119
1047
|
|
1120
1048
|
try:
|
1121
1049
|
url = f"/v0.1/crows/{job_name}/download-file"
|
1122
|
-
params = {"
|
1050
|
+
params = {"trajectory_id": trajectory_id, "file_path": file_path}
|
1123
1051
|
|
1124
1052
|
with self.client.stream("GET", url, params=params) as response:
|
1125
1053
|
response.raise_for_status() # Check for HTTP errors before streaming
|
1126
1054
|
with open(destination_path, "wb") as f:
|
1127
1055
|
for chunk in response.iter_bytes(chunk_size=8192):
|
1128
1056
|
f.write(chunk)
|
1057
|
+
|
1058
|
+
# Check if the downloaded file is empty
|
1059
|
+
if destination_path.stat().st_size == 0:
|
1060
|
+
# Remove the empty file
|
1061
|
+
destination_path.unlink()
|
1062
|
+
raise FileNotFoundError(f"File not found or is empty: {file_path}")
|
1063
|
+
|
1129
1064
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
1130
1065
|
except HTTPStatusError as e:
|
1131
1066
|
logger.exception(
|
1132
|
-
f"Error downloading file {file_path} for job {job_name},
|
1067
|
+
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
1133
1068
|
)
|
1134
1069
|
# Clean up partially downloaded file if an error occurs
|
1135
1070
|
if destination_path.exists():
|
@@ -1137,9 +1072,20 @@ class RestClient:
|
|
1137
1072
|
raise RestClientError(
|
1138
1073
|
f"Error downloading file: {e.response.status_code} - {e.response.text}"
|
1139
1074
|
) from e
|
1075
|
+
except RemoteProtocolError as e:
|
1076
|
+
logger.error(
|
1077
|
+
f"Connection error while downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
|
1078
|
+
)
|
1079
|
+
# Clean up partially downloaded file
|
1080
|
+
if destination_path.exists():
|
1081
|
+
destination_path.unlink()
|
1082
|
+
|
1083
|
+
# Often RemoteProtocolError during download means the file wasn't found
|
1084
|
+
# or was empty/corrupted on the server side
|
1085
|
+
raise FileNotFoundError(f"File not found or corrupted: {file_path}") from e
|
1140
1086
|
except Exception as e:
|
1141
1087
|
logger.exception(
|
1142
|
-
f"Error downloading file {file_path} for job {job_name},
|
1088
|
+
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
|
1143
1089
|
)
|
1144
1090
|
if destination_path.exists():
|
1145
1091
|
destination_path.unlink() # Clean up partial file
|
@@ -3,10 +3,15 @@ from .app import (
|
|
3
3
|
DockerContainerConfiguration,
|
4
4
|
FramePath,
|
5
5
|
JobDeploymentConfig,
|
6
|
+
PQATaskResponse,
|
6
7
|
RuntimeConfig,
|
7
8
|
Stage,
|
8
9
|
Step,
|
10
|
+
TaskQueue,
|
11
|
+
TaskQueuesConfig,
|
9
12
|
TaskRequest,
|
13
|
+
TaskResponse,
|
14
|
+
TaskResponseVerbose,
|
10
15
|
)
|
11
16
|
|
12
17
|
__all__ = [
|
@@ -14,8 +19,13 @@ __all__ = [
|
|
14
19
|
"DockerContainerConfiguration",
|
15
20
|
"FramePath",
|
16
21
|
"JobDeploymentConfig",
|
22
|
+
"PQATaskResponse",
|
17
23
|
"RuntimeConfig",
|
18
24
|
"Stage",
|
19
25
|
"Step",
|
26
|
+
"TaskQueue",
|
27
|
+
"TaskQueuesConfig",
|
20
28
|
"TaskRequest",
|
29
|
+
"TaskResponse",
|
30
|
+
"TaskResponseVerbose",
|
21
31
|
]
|
futurehouse_client/models/app.py
CHANGED
@@ -1,6 +1,9 @@
|
|
1
|
+
import copy
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import re
|
5
|
+
from collections.abc import Mapping
|
6
|
+
from datetime import datetime
|
4
7
|
from enum import StrEnum, auto
|
5
8
|
from pathlib import Path
|
6
9
|
from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
|
@@ -646,3 +649,96 @@ class TaskRequest(BaseModel):
|
|
646
649
|
runtime_config: RuntimeConfig | None = Field(
|
647
650
|
default=None, description="All optional runtime parameters for the job"
|
648
651
|
)
|
652
|
+
|
653
|
+
|
654
|
+
class SimpleOrganization(BaseModel):
|
655
|
+
id: int
|
656
|
+
name: str
|
657
|
+
display_name: str
|
658
|
+
|
659
|
+
|
660
|
+
class TaskResponse(BaseModel):
|
661
|
+
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
662
|
+
|
663
|
+
model_config = ConfigDict(extra="ignore")
|
664
|
+
|
665
|
+
status: str
|
666
|
+
query: str
|
667
|
+
user: str | None = None
|
668
|
+
created_at: datetime
|
669
|
+
job_name: str
|
670
|
+
public: bool
|
671
|
+
shared_with: list[SimpleOrganization] | None = None
|
672
|
+
build_owner: str | None = None
|
673
|
+
environment_name: str | None = None
|
674
|
+
agent_name: str | None = None
|
675
|
+
task_id: UUID | None = None
|
676
|
+
|
677
|
+
@model_validator(mode="before")
|
678
|
+
@classmethod
|
679
|
+
def validate_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
|
680
|
+
data = copy.deepcopy(original_data) # Avoid mutating the original data
|
681
|
+
# Extract fields from environment frame state
|
682
|
+
if not isinstance(data, dict):
|
683
|
+
return data
|
684
|
+
# TODO: We probably want to remove these two once we define the final names.
|
685
|
+
data["job_name"] = data.get("crow")
|
686
|
+
data["query"] = data.get("task")
|
687
|
+
data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
|
688
|
+
if not (metadata := data.get("metadata", {})):
|
689
|
+
return data
|
690
|
+
data["environment_name"] = metadata.get("environment_name")
|
691
|
+
data["agent_name"] = metadata.get("agent_name")
|
692
|
+
return data
|
693
|
+
|
694
|
+
|
695
|
+
class PQATaskResponse(TaskResponse):
|
696
|
+
model_config = ConfigDict(extra="ignore")
|
697
|
+
|
698
|
+
answer: str | None = None
|
699
|
+
formatted_answer: str | None = None
|
700
|
+
answer_reasoning: str | None = None
|
701
|
+
has_successful_answer: bool | None = None
|
702
|
+
total_cost: float | None = None
|
703
|
+
total_queries: int | None = None
|
704
|
+
|
705
|
+
@model_validator(mode="before")
|
706
|
+
@classmethod
|
707
|
+
def validate_pqa_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
|
708
|
+
data = copy.deepcopy(original_data) # Avoid mutating the original data
|
709
|
+
if not isinstance(data, dict):
|
710
|
+
return data
|
711
|
+
if not (env_frame := data.get("environment_frame", {})):
|
712
|
+
return data
|
713
|
+
state = env_frame.get("state", {}).get("state", {})
|
714
|
+
response = state.get("response", {})
|
715
|
+
answer = response.get("answer", {})
|
716
|
+
usage = state.get("info", {}).get("usage", {})
|
717
|
+
|
718
|
+
# Add additional PQA specific fields to data so that pydantic can validate the model
|
719
|
+
data["answer"] = answer.get("answer")
|
720
|
+
data["formatted_answer"] = answer.get("formatted_answer")
|
721
|
+
data["answer_reasoning"] = answer.get("answer_reasoning")
|
722
|
+
data["has_successful_answer"] = answer.get("has_successful_answer")
|
723
|
+
data["total_cost"] = cast(float, usage.get("total_cost"))
|
724
|
+
data["total_queries"] = cast(int, usage.get("total_queries"))
|
725
|
+
|
726
|
+
return data
|
727
|
+
|
728
|
+
def clean_verbose(self) -> "TaskResponse":
|
729
|
+
"""Clean the verbose response from the server."""
|
730
|
+
self.request = None
|
731
|
+
self.response = None
|
732
|
+
return self
|
733
|
+
|
734
|
+
|
735
|
+
class TaskResponseVerbose(TaskResponse):
|
736
|
+
"""Class for responses to include all the fields of a task response."""
|
737
|
+
|
738
|
+
model_config = ConfigDict(extra="allow")
|
739
|
+
|
740
|
+
public: bool
|
741
|
+
agent_state: list[dict[str, Any]] | None = None
|
742
|
+
environment_frame: dict[str, Any] | None = None
|
743
|
+
metadata: dict[str, Any] | None = None
|
744
|
+
shared_with: list[SimpleOrganization] | None = None
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import Collection, Generator
|
3
|
+
from typing import ClassVar, Final
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
|
7
|
+
from futurehouse_client.models.app import APIKeyPayload, AuthType
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
INVALID_REFRESH_TYPE_MSG: Final[str] = (
|
12
|
+
"API key auth is required to refresh auth tokens."
|
13
|
+
)
|
14
|
+
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
15
|
+
|
16
|
+
|
17
|
+
def _run_auth(
|
18
|
+
client: httpx.Client,
|
19
|
+
auth_type: AuthType = AuthType.API_KEY,
|
20
|
+
api_key: str | None = None,
|
21
|
+
jwt: str | None = None,
|
22
|
+
) -> str:
|
23
|
+
auth_payload: APIKeyPayload | None
|
24
|
+
if auth_type == AuthType.API_KEY:
|
25
|
+
auth_payload = APIKeyPayload(api_key=api_key)
|
26
|
+
elif auth_type == AuthType.JWT:
|
27
|
+
auth_payload = None
|
28
|
+
try:
|
29
|
+
if auth_payload:
|
30
|
+
response = client.post("/auth/login", json=auth_payload.model_dump())
|
31
|
+
response.raise_for_status()
|
32
|
+
token_data = response.json()
|
33
|
+
elif jwt:
|
34
|
+
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
35
|
+
else:
|
36
|
+
raise ValueError("JWT token required for JWT authentication.")
|
37
|
+
|
38
|
+
return token_data["access_token"]
|
39
|
+
except Exception as e:
|
40
|
+
raise Exception("Failed to authenticate") from e # noqa: TRY002
|
41
|
+
|
42
|
+
|
43
|
+
class RefreshingJWT(httpx.Auth):
|
44
|
+
"""Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
|
45
|
+
|
46
|
+
RETRY_STATUSES: ClassVar[Collection[httpx.codes]] = {
|
47
|
+
httpx.codes.UNAUTHORIZED,
|
48
|
+
httpx.codes.FORBIDDEN,
|
49
|
+
}
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
auth_client: httpx.Client,
|
54
|
+
auth_type: AuthType = AuthType.API_KEY,
|
55
|
+
api_key: str | None = None,
|
56
|
+
jwt: str | None = None,
|
57
|
+
):
|
58
|
+
self.auth_type = auth_type
|
59
|
+
self.auth_client = auth_client
|
60
|
+
self.api_key = api_key
|
61
|
+
self._jwt = _run_auth(
|
62
|
+
client=auth_client,
|
63
|
+
jwt=jwt,
|
64
|
+
auth_type=auth_type,
|
65
|
+
api_key=api_key,
|
66
|
+
)
|
67
|
+
|
68
|
+
def refresh_token(self) -> None:
|
69
|
+
if self.auth_type == AuthType.JWT:
|
70
|
+
logger.error(INVALID_REFRESH_TYPE_MSG)
|
71
|
+
raise ValueError(INVALID_REFRESH_TYPE_MSG)
|
72
|
+
self._jwt = _run_auth(
|
73
|
+
client=self.auth_client,
|
74
|
+
auth_type=self.auth_type,
|
75
|
+
api_key=self.api_key,
|
76
|
+
)
|
77
|
+
|
78
|
+
def auth_flow(
|
79
|
+
self, request: httpx.Request
|
80
|
+
) -> Generator[httpx.Request, httpx.Response, None]:
|
81
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
82
|
+
response = yield request
|
83
|
+
|
84
|
+
# If it failed, refresh once and replay the request
|
85
|
+
if response.status_code in self.RETRY_STATUSES:
|
86
|
+
logger.info(
|
87
|
+
"Received %s, refreshing token and retrying …",
|
88
|
+
response.status_code,
|
89
|
+
)
|
90
|
+
self.refresh_token()
|
91
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
92
|
+
yield request # second (and final) attempt, again or use a while loop
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.18
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -0,0 +1,17 @@
|
|
1
|
+
futurehouse_client/__init__.py,sha256=ddxO7JE97c6bt7LjNglZZ2Ql8bYCGI9laSFeh9MP6VU,344
|
2
|
+
futurehouse_client/clients/__init__.py,sha256=tFWqwIAY5PvwfOVsCje4imjTpf6xXNRMh_UHIKVI1_0,320
|
3
|
+
futurehouse_client/clients/job_client.py,sha256=uNkqQbeZw7wbA0qDWcIOwOykrosza-jev58paJZ_mbA,11150
|
4
|
+
futurehouse_client/clients/rest_client.py,sha256=CwgyjYj-i6U0aVXb1GkxP6KSxR5tVrlXIDE9WIHYtds,43435
|
5
|
+
futurehouse_client/models/__init__.py,sha256=5x-f9AoM1hGzJBEHcHAXSt7tPeImST5oZLuMdwp0mXc,554
|
6
|
+
futurehouse_client/models/app.py,sha256=24lCnpOtxQNzvcc1HJyP62wfertvTuZ30sX_vrBCUnk,26611
|
7
|
+
futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
|
8
|
+
futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
|
9
|
+
futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
futurehouse_client/utils/auth.py,sha256=tgWELjKfg8eWme_qdcRmc8TjQN9DVZuHHaVXZNHLchk,2960
|
11
|
+
futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
|
12
|
+
futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
|
13
|
+
futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
|
14
|
+
futurehouse_client-0.3.18.dist-info/METADATA,sha256=AJVzzXq77PTe-Hg31YXy-KWFAb5IXiZhyC08YlbRec4,12760
|
15
|
+
futurehouse_client-0.3.18.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
16
|
+
futurehouse_client-0.3.18.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
|
17
|
+
futurehouse_client-0.3.18.dist-info/RECORD,,
|
@@ -1,16 +0,0 @@
|
|
1
|
-
futurehouse_client/__init__.py,sha256=ddxO7JE97c6bt7LjNglZZ2Ql8bYCGI9laSFeh9MP6VU,344
|
2
|
-
futurehouse_client/clients/__init__.py,sha256=tFWqwIAY5PvwfOVsCje4imjTpf6xXNRMh_UHIKVI1_0,320
|
3
|
-
futurehouse_client/clients/job_client.py,sha256=Fi3YvN4k82AuXCe8vlwxhkK8CXS164NQrs7paj9qIek,11096
|
4
|
-
futurehouse_client/clients/rest_client.py,sha256=dsUmpgV5sfyb4GDv6whWVwRN1z2LOfZsPF8vjoioNfY,45472
|
5
|
-
futurehouse_client/models/__init__.py,sha256=ta3jFLM_LsDz1rKDmx8rja8sT7WtSKoFvMgLF0yFpvA,342
|
6
|
-
futurehouse_client/models/app.py,sha256=yfZ9tyw4VATVAfYrU7aTdCNPSljLEho09_nIbh8oZDY,23174
|
7
|
-
futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
|
8
|
-
futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
|
9
|
-
futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
|
11
|
-
futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
|
12
|
-
futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
|
13
|
-
futurehouse_client-0.3.17.dev94.dist-info/METADATA,sha256=acLPon9oE1ecVZzz8JrpumcSLmhRkqGGG62gjGEW1IQ,12766
|
14
|
-
futurehouse_client-0.3.17.dev94.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
15
|
-
futurehouse_client-0.3.17.dev94.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
|
16
|
-
futurehouse_client-0.3.17.dev94.dist-info/RECORD,,
|
{futurehouse_client-0.3.17.dev94.dist-info → futurehouse_client-0.3.18.dist-info}/top_level.txt
RENAMED
File without changes
|