futurehouse-client 0.3.17.dev94__tar.gz → 0.3.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/PKG-INFO +1 -1
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/clients/job_client.py +1 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/clients/rest_client.py +128 -182
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/__init__.py +10 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/app.py +96 -0
- futurehouse_client-0.3.18/futurehouse_client/utils/auth.py +92 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client.egg-info/PKG-INFO +1 -1
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client.egg-info/SOURCES.txt +2 -0
- futurehouse_client-0.3.18/tests/test_client.py +161 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/tests/test_rest.py +74 -86
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/LICENSE +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/README.md +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/docs/__init__.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/docs/client_notebook.ipynb +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/__init__.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/clients/__init__.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/client.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/rest.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/__init__.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/general.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/module_utils.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/monitoring.py +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client.egg-info/dependency_links.txt +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client.egg-info/requires.txt +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client.egg-info/top_level.txt +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/pyproject.toml +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/setup.cfg +0 -0
- {futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.18
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -30,6 +30,7 @@ class JobNames(StrEnum):
|
|
30
30
|
OWL = "job-futurehouse-hasanyone"
|
31
31
|
DUMMY = "job-futurehouse-dummy-env"
|
32
32
|
PHOENIX = "job-futurehouse-phoenix"
|
33
|
+
FINCH = "job-futurehouse-data-analysis-crow-high"
|
33
34
|
|
34
35
|
@classmethod
|
35
36
|
def from_stage(cls, job_name: str, stage: Stage | None = None) -> str:
|
@@ -8,14 +8,14 @@ import inspect
|
|
8
8
|
import json
|
9
9
|
import logging
|
10
10
|
import os
|
11
|
+
import sys
|
11
12
|
import tempfile
|
12
13
|
import time
|
13
14
|
import uuid
|
14
|
-
from collections.abc import Collection
|
15
|
-
from datetime import datetime
|
15
|
+
from collections.abc import Collection
|
16
16
|
from pathlib import Path
|
17
17
|
from types import ModuleType
|
18
|
-
from typing import Any, ClassVar,
|
18
|
+
from typing import Any, ClassVar, cast
|
19
19
|
from uuid import UUID
|
20
20
|
|
21
21
|
import cloudpickle
|
@@ -33,7 +33,6 @@ from httpx import (
|
|
33
33
|
RemoteProtocolError,
|
34
34
|
)
|
35
35
|
from ldp.agent import AgentConfig
|
36
|
-
from pydantic import BaseModel, ConfigDict, model_validator
|
37
36
|
from requests.exceptions import RequestException, Timeout
|
38
37
|
from tenacity import (
|
39
38
|
retry,
|
@@ -46,13 +45,16 @@ from tqdm.asyncio import tqdm
|
|
46
45
|
|
47
46
|
from futurehouse_client.clients import JobNames
|
48
47
|
from futurehouse_client.models.app import (
|
49
|
-
APIKeyPayload,
|
50
48
|
AuthType,
|
51
49
|
JobDeploymentConfig,
|
50
|
+
PQATaskResponse,
|
52
51
|
Stage,
|
53
52
|
TaskRequest,
|
53
|
+
TaskResponse,
|
54
|
+
TaskResponseVerbose,
|
54
55
|
)
|
55
56
|
from futurehouse_client.models.rest import ExecutionStatus
|
57
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
56
58
|
from futurehouse_client.utils.general import gather_with_concurrency
|
57
59
|
from futurehouse_client.utils.module_utils import (
|
58
60
|
OrganizationSelector,
|
@@ -63,24 +65,14 @@ from futurehouse_client.utils.monitoring import (
|
|
63
65
|
)
|
64
66
|
|
65
67
|
logger = logging.getLogger(__name__)
|
66
|
-
|
68
|
+
logging.basicConfig(
|
69
|
+
level=logging.WARNING,
|
70
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
71
|
+
stream=sys.stdout,
|
72
|
+
)
|
73
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
67
74
|
TaskRequest.model_rebuild()
|
68
75
|
|
69
|
-
retry_if_connection_error = retry_if_exception_type((
|
70
|
-
# From requests
|
71
|
-
Timeout,
|
72
|
-
ConnectionError,
|
73
|
-
RequestException,
|
74
|
-
# From httpx
|
75
|
-
ConnectError,
|
76
|
-
ConnectTimeout,
|
77
|
-
ReadTimeout,
|
78
|
-
ReadError,
|
79
|
-
NetworkError,
|
80
|
-
RemoteProtocolError,
|
81
|
-
CloseError,
|
82
|
-
))
|
83
|
-
|
84
76
|
FILE_UPLOAD_IGNORE_PARTS = {
|
85
77
|
".ruff_cache",
|
86
78
|
"__pycache__",
|
@@ -111,104 +103,27 @@ class InvalidTaskDescriptionError(Exception):
|
|
111
103
|
"""Raised when the task description is invalid or empty."""
|
112
104
|
|
113
105
|
|
114
|
-
class
|
115
|
-
|
116
|
-
name: str
|
117
|
-
display_name: str
|
118
|
-
|
119
|
-
|
120
|
-
# 5 minute default for JWTs
|
121
|
-
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
122
|
-
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
123
|
-
|
106
|
+
class FileUploadError(RestClientError):
|
107
|
+
"""Raised when there's an error uploading a file."""
|
124
108
|
|
125
|
-
class TaskResponse(BaseModel):
|
126
|
-
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
127
|
-
|
128
|
-
model_config = ConfigDict(extra="ignore")
|
129
|
-
|
130
|
-
status: str
|
131
|
-
query: str
|
132
|
-
user: str | None = None
|
133
|
-
created_at: datetime
|
134
|
-
job_name: str
|
135
|
-
public: bool
|
136
|
-
shared_with: list[SimpleOrganization] | None = None
|
137
|
-
build_owner: str | None = None
|
138
|
-
environment_name: str | None = None
|
139
|
-
agent_name: str | None = None
|
140
|
-
task_id: UUID | None = None
|
141
|
-
|
142
|
-
@model_validator(mode="before")
|
143
|
-
@classmethod
|
144
|
-
def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
145
|
-
# Extract fields from environment frame state
|
146
|
-
if not isinstance(data, dict):
|
147
|
-
return data
|
148
|
-
# TODO: We probably want to remove these two once we define the final names.
|
149
|
-
data["job_name"] = data.get("crow")
|
150
|
-
data["query"] = data.get("task")
|
151
|
-
data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
|
152
|
-
if not (metadata := data.get("metadata", {})):
|
153
|
-
return data
|
154
|
-
data["environment_name"] = metadata.get("environment_name")
|
155
|
-
data["agent_name"] = metadata.get("agent_name")
|
156
|
-
return data
|
157
|
-
|
158
|
-
|
159
|
-
class PQATaskResponse(TaskResponse):
|
160
|
-
model_config = ConfigDict(extra="ignore")
|
161
|
-
|
162
|
-
answer: str | None = None
|
163
|
-
formatted_answer: str | None = None
|
164
|
-
answer_reasoning: str | None = None
|
165
|
-
has_successful_answer: bool | None = None
|
166
|
-
total_cost: float | None = None
|
167
|
-
total_queries: int | None = None
|
168
|
-
|
169
|
-
@model_validator(mode="before")
|
170
|
-
@classmethod
|
171
|
-
def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
172
|
-
if not isinstance(data, dict):
|
173
|
-
return data
|
174
|
-
if not (env_frame := data.get("environment_frame", {})):
|
175
|
-
return data
|
176
|
-
state = env_frame.get("state", {}).get("state", {})
|
177
|
-
response = state.get("response", {})
|
178
|
-
answer = response.get("answer", {})
|
179
|
-
usage = state.get("info", {}).get("usage", {})
|
180
|
-
|
181
|
-
# Add additional PQA specific fields to data so that pydantic can validate the model
|
182
|
-
data["answer"] = answer.get("answer")
|
183
|
-
data["formatted_answer"] = answer.get("formatted_answer")
|
184
|
-
data["answer_reasoning"] = answer.get("answer_reasoning")
|
185
|
-
data["has_successful_answer"] = answer.get("has_successful_answer")
|
186
|
-
data["total_cost"] = cast(float, usage.get("total_cost"))
|
187
|
-
data["total_queries"] = cast(int, usage.get("total_queries"))
|
188
|
-
|
189
|
-
return data
|
190
|
-
|
191
|
-
def clean_verbose(self) -> "TaskResponse":
|
192
|
-
"""Clean the verbose response from the server."""
|
193
|
-
self.request = None
|
194
|
-
self.response = None
|
195
|
-
return self
|
196
|
-
|
197
|
-
|
198
|
-
class TaskResponseVerbose(TaskResponse):
|
199
|
-
"""Class for responses to include all the fields of a task response."""
|
200
|
-
|
201
|
-
model_config = ConfigDict(extra="allow")
|
202
|
-
|
203
|
-
public: bool
|
204
|
-
agent_state: list[dict[str, Any]] | None = None
|
205
|
-
environment_frame: dict[str, Any] | None = None
|
206
|
-
metadata: dict[str, Any] | None = None
|
207
|
-
shared_with: list[SimpleOrganization] | None = None
|
208
109
|
|
110
|
+
retry_if_connection_error = retry_if_exception_type((
|
111
|
+
# From requests
|
112
|
+
Timeout,
|
113
|
+
ConnectionError,
|
114
|
+
RequestException,
|
115
|
+
# From httpx
|
116
|
+
ConnectError,
|
117
|
+
ConnectTimeout,
|
118
|
+
ReadTimeout,
|
119
|
+
ReadError,
|
120
|
+
NetworkError,
|
121
|
+
RemoteProtocolError,
|
122
|
+
CloseError,
|
123
|
+
FileUploadError,
|
124
|
+
))
|
209
125
|
|
210
|
-
|
211
|
-
"""Raised when there's an error uploading a file."""
|
126
|
+
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
212
127
|
|
213
128
|
|
214
129
|
class RestClient:
|
@@ -228,76 +143,98 @@ class RestClient:
|
|
228
143
|
api_key: str | None = None,
|
229
144
|
jwt: str | None = None,
|
230
145
|
headers: dict[str, str] | None = None,
|
146
|
+
verbose_logging: bool = False,
|
231
147
|
):
|
148
|
+
if verbose_logging:
|
149
|
+
logger.setLevel(logging.INFO)
|
150
|
+
else:
|
151
|
+
logger.setLevel(logging.WARNING)
|
152
|
+
|
232
153
|
self.base_url = service_uri or stage.value
|
233
154
|
self.stage = stage
|
234
155
|
self.auth_type = auth_type
|
235
156
|
self.api_key = api_key
|
236
157
|
self._clients: dict[str, Client | AsyncClient] = {}
|
237
158
|
self.headers = headers or {}
|
238
|
-
self.
|
159
|
+
self.jwt = jwt
|
239
160
|
self.organizations: list[str] = self._filter_orgs(organization)
|
240
161
|
|
241
162
|
@property
|
242
163
|
def client(self) -> Client:
|
243
|
-
"""
|
244
|
-
return cast(Client, self.get_client("application/json",
|
164
|
+
"""Authenticated HTTP client for regular API calls."""
|
165
|
+
return cast(Client, self.get_client("application/json", authenticated=True))
|
245
166
|
|
246
167
|
@property
|
247
168
|
def async_client(self) -> AsyncClient:
|
248
|
-
"""
|
169
|
+
"""Authenticated async HTTP client for regular API calls."""
|
249
170
|
return cast(
|
250
171
|
AsyncClient,
|
251
|
-
self.get_client("application/json",
|
172
|
+
self.get_client("application/json", authenticated=True, async_client=True),
|
252
173
|
)
|
253
174
|
|
254
175
|
@property
|
255
|
-
def
|
256
|
-
"""
|
257
|
-
return cast(Client, self.get_client("application/json",
|
176
|
+
def unauthenticated_client(self) -> Client:
|
177
|
+
"""Unauthenticated HTTP client for auth operations."""
|
178
|
+
return cast(Client, self.get_client("application/json", authenticated=False))
|
258
179
|
|
259
180
|
@property
|
260
181
|
def multipart_client(self) -> Client:
|
261
|
-
"""
|
262
|
-
return cast(Client, self.get_client(None,
|
182
|
+
"""Authenticated HTTP client for multipart uploads."""
|
183
|
+
return cast(Client, self.get_client(None, authenticated=True))
|
263
184
|
|
264
185
|
def get_client(
|
265
186
|
self,
|
266
187
|
content_type: str | None = "application/json",
|
267
|
-
|
268
|
-
|
188
|
+
authenticated: bool = True,
|
189
|
+
async_client: bool = False,
|
269
190
|
) -> Client | AsyncClient:
|
270
191
|
"""Return a cached HTTP client or create one if needed.
|
271
192
|
|
272
193
|
Args:
|
273
194
|
content_type: The desired content type header. Use None for multipart uploads.
|
274
|
-
|
275
|
-
|
195
|
+
authenticated: Whether the client should include authentication.
|
196
|
+
async_client: Whether to use an async client.
|
276
197
|
|
277
198
|
Returns:
|
278
199
|
An HTTP client configured with the appropriate headers.
|
279
200
|
"""
|
280
|
-
# Create a composite key based on content type and auth flag
|
281
|
-
key = f"{content_type or 'multipart'}_{
|
201
|
+
# Create a composite key based on content type and auth flag
|
202
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
203
|
+
|
282
204
|
if key not in self._clients:
|
283
205
|
headers = copy.deepcopy(self.headers)
|
284
|
-
|
285
|
-
|
206
|
+
auth = None
|
207
|
+
|
208
|
+
if authenticated:
|
209
|
+
auth = RefreshingJWT(
|
210
|
+
# authenticated=False will always return a synchronous client
|
211
|
+
auth_client=cast(
|
212
|
+
Client, self.get_client("application/json", authenticated=False)
|
213
|
+
),
|
214
|
+
auth_type=self.auth_type,
|
215
|
+
api_key=self.api_key,
|
216
|
+
jwt=self.jwt,
|
217
|
+
)
|
218
|
+
|
286
219
|
if content_type:
|
287
220
|
headers["Content-Type"] = content_type
|
221
|
+
|
288
222
|
self._clients[key] = (
|
289
223
|
AsyncClient(
|
290
224
|
base_url=self.base_url,
|
291
225
|
headers=headers,
|
292
226
|
timeout=self.REQUEST_TIMEOUT,
|
227
|
+
auth=auth,
|
293
228
|
)
|
294
|
-
if
|
229
|
+
if async_client
|
295
230
|
else Client(
|
296
231
|
base_url=self.base_url,
|
297
232
|
headers=headers,
|
298
233
|
timeout=self.REQUEST_TIMEOUT,
|
234
|
+
auth=auth,
|
299
235
|
)
|
300
236
|
)
|
237
|
+
|
301
238
|
return self._clients[key]
|
302
239
|
|
303
240
|
def close(self):
|
@@ -327,31 +264,6 @@ class RestClient:
|
|
327
264
|
raise ValueError(f"Organization '{organization}' not found.")
|
328
265
|
return filtered_orgs
|
329
266
|
|
330
|
-
def _run_auth(self, jwt: str | None = None) -> str:
|
331
|
-
auth_payload: APIKeyPayload | None
|
332
|
-
if self.auth_type == AuthType.API_KEY:
|
333
|
-
auth_payload = APIKeyPayload(api_key=self.api_key)
|
334
|
-
elif self.auth_type == AuthType.JWT:
|
335
|
-
auth_payload = None
|
336
|
-
else:
|
337
|
-
assert_never(self.auth_type)
|
338
|
-
try:
|
339
|
-
# Use the unauthenticated client for login
|
340
|
-
if auth_payload:
|
341
|
-
response = self.auth_client.post(
|
342
|
-
"/auth/login", json=auth_payload.model_dump()
|
343
|
-
)
|
344
|
-
response.raise_for_status()
|
345
|
-
token_data = response.json()
|
346
|
-
elif jwt:
|
347
|
-
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
348
|
-
else:
|
349
|
-
raise ValueError("JWT token required for JWT authentication.")
|
350
|
-
|
351
|
-
return token_data["access_token"]
|
352
|
-
except Exception as e:
|
353
|
-
raise RestClientError(f"Error authenticating: {e!s}") from e
|
354
|
-
|
355
267
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
356
268
|
try:
|
357
269
|
response = self.client.get(
|
@@ -445,6 +357,7 @@ class RestClient:
|
|
445
357
|
),
|
446
358
|
self.client.stream("GET", url, params={"history": history}) as response,
|
447
359
|
):
|
360
|
+
response.raise_for_status()
|
448
361
|
json_data = "".join(response.iter_text(chunk_size=1024))
|
449
362
|
data = json.loads(json_data)
|
450
363
|
if "id" not in data:
|
@@ -459,8 +372,6 @@ class RestClient:
|
|
459
372
|
):
|
460
373
|
return PQATaskResponse(**data)
|
461
374
|
return TaskResponse(**data)
|
462
|
-
except ValueError as e:
|
463
|
-
raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
|
464
375
|
except Exception as e:
|
465
376
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
466
377
|
|
@@ -507,8 +418,6 @@ class RestClient:
|
|
507
418
|
):
|
508
419
|
return PQATaskResponse(**data)
|
509
420
|
return TaskResponse(**data)
|
510
|
-
except ValueError as e:
|
511
|
-
raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
|
512
421
|
except Exception as e:
|
513
422
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
514
423
|
|
@@ -714,9 +623,12 @@ class RestClient:
|
|
714
623
|
)
|
715
624
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
716
625
|
"""Get the status of a build."""
|
717
|
-
|
718
|
-
|
719
|
-
|
626
|
+
try:
|
627
|
+
build_id = build_id or self.build_id
|
628
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
629
|
+
response.raise_for_status()
|
630
|
+
except Exception as e:
|
631
|
+
raise JobFetchError(f"Error getting build status: {e!s}") from e
|
720
632
|
return response.json()
|
721
633
|
|
722
634
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
@@ -917,14 +829,14 @@ class RestClient:
|
|
917
829
|
self,
|
918
830
|
job_name: str,
|
919
831
|
file_path: str | os.PathLike,
|
920
|
-
|
832
|
+
upload_id: str | None = None,
|
921
833
|
) -> str:
|
922
834
|
"""Upload a file or directory to a futurehouse job bucket.
|
923
835
|
|
924
836
|
Args:
|
925
837
|
job_name: The name of the futurehouse job to upload to.
|
926
838
|
file_path: The local path to the file or directory to upload.
|
927
|
-
|
839
|
+
upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
|
928
840
|
|
929
841
|
Returns:
|
930
842
|
The upload ID used for the upload.
|
@@ -936,7 +848,7 @@ class RestClient:
|
|
936
848
|
if not file_path.exists():
|
937
849
|
raise FileNotFoundError(f"File or directory not found: {file_path}")
|
938
850
|
|
939
|
-
upload_id =
|
851
|
+
upload_id = upload_id or str(uuid.uuid4())
|
940
852
|
|
941
853
|
if file_path.is_dir():
|
942
854
|
# Process directory recursively
|
@@ -999,6 +911,12 @@ class RestClient:
|
|
999
911
|
"""
|
1000
912
|
file_name = file_name or file_path.name
|
1001
913
|
file_size = file_path.stat().st_size
|
914
|
+
|
915
|
+
# Skip empty files
|
916
|
+
if file_size == 0:
|
917
|
+
logger.warning(f"Skipping upload of empty file: {file_path}")
|
918
|
+
return
|
919
|
+
|
1002
920
|
total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
|
1003
921
|
|
1004
922
|
logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
|
@@ -1046,7 +964,6 @@ class RestClient:
|
|
1046
964
|
)
|
1047
965
|
|
1048
966
|
logger.info(f"Successfully uploaded {file_name}")
|
1049
|
-
|
1050
967
|
except Exception as e:
|
1051
968
|
logger.exception(f"Error uploading file {file_path}")
|
1052
969
|
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
@@ -1056,12 +973,18 @@ class RestClient:
|
|
1056
973
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1057
974
|
retry=retry_if_connection_error,
|
1058
975
|
)
|
1059
|
-
def list_files(
|
976
|
+
def list_files(
|
977
|
+
self,
|
978
|
+
job_name: str,
|
979
|
+
trajectory_id: str | None = None,
|
980
|
+
upload_id: str | None = None,
|
981
|
+
) -> dict[str, list[str]]:
|
1060
982
|
"""List files and directories in a GCS location for a given job_name and upload_id.
|
1061
983
|
|
1062
984
|
Args:
|
1063
985
|
job_name: The name of the futurehouse job.
|
1064
|
-
|
986
|
+
trajectory_id: The specific trajectory id to list files from.
|
987
|
+
upload_id: The specific upload id to list files from.
|
1065
988
|
|
1066
989
|
Returns:
|
1067
990
|
A list of files in the GCS folder.
|
@@ -1069,22 +992,27 @@ class RestClient:
|
|
1069
992
|
Raises:
|
1070
993
|
RestClientError: If there is an error listing the files.
|
1071
994
|
"""
|
995
|
+
if not bool(trajectory_id) ^ bool(upload_id):
|
996
|
+
raise RestClientError(
|
997
|
+
"Must at least specify one of trajectory_id or upload_id, but not both"
|
998
|
+
)
|
1072
999
|
try:
|
1073
1000
|
url = f"/v0.1/crows/{job_name}/list-files"
|
1074
|
-
params = {"upload_id":
|
1001
|
+
params = {"trajectory_id": trajectory_id, "upload_id": upload_id}
|
1002
|
+
params = {k: v for k, v in params.items() if v is not None}
|
1075
1003
|
response = self.client.get(url, params=params)
|
1076
1004
|
response.raise_for_status()
|
1077
1005
|
return response.json()
|
1078
1006
|
except HTTPStatusError as e:
|
1079
1007
|
logger.exception(
|
1080
|
-
f"Error listing files for job {job_name},
|
1008
|
+
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
1081
1009
|
)
|
1082
1010
|
raise RestClientError(
|
1083
1011
|
f"Error listing files: {e.response.status_code} - {e.response.text}"
|
1084
1012
|
) from e
|
1085
1013
|
except Exception as e:
|
1086
1014
|
logger.exception(
|
1087
|
-
f"Error listing files for job {job_name},
|
1015
|
+
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}"
|
1088
1016
|
)
|
1089
1017
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
1090
1018
|
|
@@ -1096,7 +1024,7 @@ class RestClient:
|
|
1096
1024
|
def download_file(
|
1097
1025
|
self,
|
1098
1026
|
job_name: str,
|
1099
|
-
|
1027
|
+
trajectory_id: str,
|
1100
1028
|
file_path: str,
|
1101
1029
|
destination_path: str | os.PathLike,
|
1102
1030
|
) -> None:
|
@@ -1104,14 +1032,14 @@ class RestClient:
|
|
1104
1032
|
|
1105
1033
|
Args:
|
1106
1034
|
job_name: The name of the futurehouse job.
|
1107
|
-
|
1035
|
+
trajectory_id: The specific trajectory id the file belongs to.
|
1108
1036
|
file_path: The relative path of the file to download
|
1109
1037
|
(e.g., 'data/my_file.csv' or 'my_image.png').
|
1110
1038
|
destination_path: The local path where the file should be saved.
|
1111
1039
|
|
1112
1040
|
Raises:
|
1113
1041
|
RestClientError: If there is an error downloading the file.
|
1114
|
-
FileNotFoundError: If the destination directory does not exist.
|
1042
|
+
FileNotFoundError: If the destination directory does not exist or if the file is not found.
|
1115
1043
|
"""
|
1116
1044
|
destination_path = Path(destination_path)
|
1117
1045
|
# Ensure the destination directory exists
|
@@ -1119,17 +1047,24 @@ class RestClient:
|
|
1119
1047
|
|
1120
1048
|
try:
|
1121
1049
|
url = f"/v0.1/crows/{job_name}/download-file"
|
1122
|
-
params = {"
|
1050
|
+
params = {"trajectory_id": trajectory_id, "file_path": file_path}
|
1123
1051
|
|
1124
1052
|
with self.client.stream("GET", url, params=params) as response:
|
1125
1053
|
response.raise_for_status() # Check for HTTP errors before streaming
|
1126
1054
|
with open(destination_path, "wb") as f:
|
1127
1055
|
for chunk in response.iter_bytes(chunk_size=8192):
|
1128
1056
|
f.write(chunk)
|
1057
|
+
|
1058
|
+
# Check if the downloaded file is empty
|
1059
|
+
if destination_path.stat().st_size == 0:
|
1060
|
+
# Remove the empty file
|
1061
|
+
destination_path.unlink()
|
1062
|
+
raise FileNotFoundError(f"File not found or is empty: {file_path}")
|
1063
|
+
|
1129
1064
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
1130
1065
|
except HTTPStatusError as e:
|
1131
1066
|
logger.exception(
|
1132
|
-
f"Error downloading file {file_path} for job {job_name},
|
1067
|
+
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
1133
1068
|
)
|
1134
1069
|
# Clean up partially downloaded file if an error occurs
|
1135
1070
|
if destination_path.exists():
|
@@ -1137,9 +1072,20 @@ class RestClient:
|
|
1137
1072
|
raise RestClientError(
|
1138
1073
|
f"Error downloading file: {e.response.status_code} - {e.response.text}"
|
1139
1074
|
) from e
|
1075
|
+
except RemoteProtocolError as e:
|
1076
|
+
logger.error(
|
1077
|
+
f"Connection error while downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
|
1078
|
+
)
|
1079
|
+
# Clean up partially downloaded file
|
1080
|
+
if destination_path.exists():
|
1081
|
+
destination_path.unlink()
|
1082
|
+
|
1083
|
+
# Often RemoteProtocolError during download means the file wasn't found
|
1084
|
+
# or was empty/corrupted on the server side
|
1085
|
+
raise FileNotFoundError(f"File not found or corrupted: {file_path}") from e
|
1140
1086
|
except Exception as e:
|
1141
1087
|
logger.exception(
|
1142
|
-
f"Error downloading file {file_path} for job {job_name},
|
1088
|
+
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
|
1143
1089
|
)
|
1144
1090
|
if destination_path.exists():
|
1145
1091
|
destination_path.unlink() # Clean up partial file
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/__init__.py
RENAMED
@@ -3,10 +3,15 @@ from .app import (
|
|
3
3
|
DockerContainerConfiguration,
|
4
4
|
FramePath,
|
5
5
|
JobDeploymentConfig,
|
6
|
+
PQATaskResponse,
|
6
7
|
RuntimeConfig,
|
7
8
|
Stage,
|
8
9
|
Step,
|
10
|
+
TaskQueue,
|
11
|
+
TaskQueuesConfig,
|
9
12
|
TaskRequest,
|
13
|
+
TaskResponse,
|
14
|
+
TaskResponseVerbose,
|
10
15
|
)
|
11
16
|
|
12
17
|
__all__ = [
|
@@ -14,8 +19,13 @@ __all__ = [
|
|
14
19
|
"DockerContainerConfiguration",
|
15
20
|
"FramePath",
|
16
21
|
"JobDeploymentConfig",
|
22
|
+
"PQATaskResponse",
|
17
23
|
"RuntimeConfig",
|
18
24
|
"Stage",
|
19
25
|
"Step",
|
26
|
+
"TaskQueue",
|
27
|
+
"TaskQueuesConfig",
|
20
28
|
"TaskRequest",
|
29
|
+
"TaskResponse",
|
30
|
+
"TaskResponseVerbose",
|
21
31
|
]
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/app.py
RENAMED
@@ -1,6 +1,9 @@
|
|
1
|
+
import copy
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import re
|
5
|
+
from collections.abc import Mapping
|
6
|
+
from datetime import datetime
|
4
7
|
from enum import StrEnum, auto
|
5
8
|
from pathlib import Path
|
6
9
|
from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
|
@@ -646,3 +649,96 @@ class TaskRequest(BaseModel):
|
|
646
649
|
runtime_config: RuntimeConfig | None = Field(
|
647
650
|
default=None, description="All optional runtime parameters for the job"
|
648
651
|
)
|
652
|
+
|
653
|
+
|
654
|
+
class SimpleOrganization(BaseModel):
|
655
|
+
id: int
|
656
|
+
name: str
|
657
|
+
display_name: str
|
658
|
+
|
659
|
+
|
660
|
+
class TaskResponse(BaseModel):
|
661
|
+
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
662
|
+
|
663
|
+
model_config = ConfigDict(extra="ignore")
|
664
|
+
|
665
|
+
status: str
|
666
|
+
query: str
|
667
|
+
user: str | None = None
|
668
|
+
created_at: datetime
|
669
|
+
job_name: str
|
670
|
+
public: bool
|
671
|
+
shared_with: list[SimpleOrganization] | None = None
|
672
|
+
build_owner: str | None = None
|
673
|
+
environment_name: str | None = None
|
674
|
+
agent_name: str | None = None
|
675
|
+
task_id: UUID | None = None
|
676
|
+
|
677
|
+
@model_validator(mode="before")
|
678
|
+
@classmethod
|
679
|
+
def validate_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
|
680
|
+
data = copy.deepcopy(original_data) # Avoid mutating the original data
|
681
|
+
# Extract fields from environment frame state
|
682
|
+
if not isinstance(data, dict):
|
683
|
+
return data
|
684
|
+
# TODO: We probably want to remove these two once we define the final names.
|
685
|
+
data["job_name"] = data.get("crow")
|
686
|
+
data["query"] = data.get("task")
|
687
|
+
data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
|
688
|
+
if not (metadata := data.get("metadata", {})):
|
689
|
+
return data
|
690
|
+
data["environment_name"] = metadata.get("environment_name")
|
691
|
+
data["agent_name"] = metadata.get("agent_name")
|
692
|
+
return data
|
693
|
+
|
694
|
+
|
695
|
+
class PQATaskResponse(TaskResponse):
|
696
|
+
model_config = ConfigDict(extra="ignore")
|
697
|
+
|
698
|
+
answer: str | None = None
|
699
|
+
formatted_answer: str | None = None
|
700
|
+
answer_reasoning: str | None = None
|
701
|
+
has_successful_answer: bool | None = None
|
702
|
+
total_cost: float | None = None
|
703
|
+
total_queries: int | None = None
|
704
|
+
|
705
|
+
@model_validator(mode="before")
|
706
|
+
@classmethod
|
707
|
+
def validate_pqa_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
|
708
|
+
data = copy.deepcopy(original_data) # Avoid mutating the original data
|
709
|
+
if not isinstance(data, dict):
|
710
|
+
return data
|
711
|
+
if not (env_frame := data.get("environment_frame", {})):
|
712
|
+
return data
|
713
|
+
state = env_frame.get("state", {}).get("state", {})
|
714
|
+
response = state.get("response", {})
|
715
|
+
answer = response.get("answer", {})
|
716
|
+
usage = state.get("info", {}).get("usage", {})
|
717
|
+
|
718
|
+
# Add additional PQA specific fields to data so that pydantic can validate the model
|
719
|
+
data["answer"] = answer.get("answer")
|
720
|
+
data["formatted_answer"] = answer.get("formatted_answer")
|
721
|
+
data["answer_reasoning"] = answer.get("answer_reasoning")
|
722
|
+
data["has_successful_answer"] = answer.get("has_successful_answer")
|
723
|
+
data["total_cost"] = cast(float, usage.get("total_cost"))
|
724
|
+
data["total_queries"] = cast(int, usage.get("total_queries"))
|
725
|
+
|
726
|
+
return data
|
727
|
+
|
728
|
+
def clean_verbose(self) -> "TaskResponse":
|
729
|
+
"""Clean the verbose response from the server."""
|
730
|
+
self.request = None
|
731
|
+
self.response = None
|
732
|
+
return self
|
733
|
+
|
734
|
+
|
735
|
+
class TaskResponseVerbose(TaskResponse):
|
736
|
+
"""Class for responses to include all the fields of a task response."""
|
737
|
+
|
738
|
+
model_config = ConfigDict(extra="allow")
|
739
|
+
|
740
|
+
public: bool
|
741
|
+
agent_state: list[dict[str, Any]] | None = None
|
742
|
+
environment_frame: dict[str, Any] | None = None
|
743
|
+
metadata: dict[str, Any] | None = None
|
744
|
+
shared_with: list[SimpleOrganization] | None = None
|
@@ -0,0 +1,92 @@
|
|
1
|
+
import logging
|
2
|
+
from collections.abc import Collection, Generator
|
3
|
+
from typing import ClassVar, Final
|
4
|
+
|
5
|
+
import httpx
|
6
|
+
|
7
|
+
from futurehouse_client.models.app import APIKeyPayload, AuthType
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
INVALID_REFRESH_TYPE_MSG: Final[str] = (
|
12
|
+
"API key auth is required to refresh auth tokens."
|
13
|
+
)
|
14
|
+
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
15
|
+
|
16
|
+
|
17
|
+
def _run_auth(
|
18
|
+
client: httpx.Client,
|
19
|
+
auth_type: AuthType = AuthType.API_KEY,
|
20
|
+
api_key: str | None = None,
|
21
|
+
jwt: str | None = None,
|
22
|
+
) -> str:
|
23
|
+
auth_payload: APIKeyPayload | None
|
24
|
+
if auth_type == AuthType.API_KEY:
|
25
|
+
auth_payload = APIKeyPayload(api_key=api_key)
|
26
|
+
elif auth_type == AuthType.JWT:
|
27
|
+
auth_payload = None
|
28
|
+
try:
|
29
|
+
if auth_payload:
|
30
|
+
response = client.post("/auth/login", json=auth_payload.model_dump())
|
31
|
+
response.raise_for_status()
|
32
|
+
token_data = response.json()
|
33
|
+
elif jwt:
|
34
|
+
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
35
|
+
else:
|
36
|
+
raise ValueError("JWT token required for JWT authentication.")
|
37
|
+
|
38
|
+
return token_data["access_token"]
|
39
|
+
except Exception as e:
|
40
|
+
raise Exception("Failed to authenticate") from e # noqa: TRY002
|
41
|
+
|
42
|
+
|
43
|
+
class RefreshingJWT(httpx.Auth):
|
44
|
+
"""Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
|
45
|
+
|
46
|
+
RETRY_STATUSES: ClassVar[Collection[httpx.codes]] = {
|
47
|
+
httpx.codes.UNAUTHORIZED,
|
48
|
+
httpx.codes.FORBIDDEN,
|
49
|
+
}
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
auth_client: httpx.Client,
|
54
|
+
auth_type: AuthType = AuthType.API_KEY,
|
55
|
+
api_key: str | None = None,
|
56
|
+
jwt: str | None = None,
|
57
|
+
):
|
58
|
+
self.auth_type = auth_type
|
59
|
+
self.auth_client = auth_client
|
60
|
+
self.api_key = api_key
|
61
|
+
self._jwt = _run_auth(
|
62
|
+
client=auth_client,
|
63
|
+
jwt=jwt,
|
64
|
+
auth_type=auth_type,
|
65
|
+
api_key=api_key,
|
66
|
+
)
|
67
|
+
|
68
|
+
def refresh_token(self) -> None:
|
69
|
+
if self.auth_type == AuthType.JWT:
|
70
|
+
logger.error(INVALID_REFRESH_TYPE_MSG)
|
71
|
+
raise ValueError(INVALID_REFRESH_TYPE_MSG)
|
72
|
+
self._jwt = _run_auth(
|
73
|
+
client=self.auth_client,
|
74
|
+
auth_type=self.auth_type,
|
75
|
+
api_key=self.api_key,
|
76
|
+
)
|
77
|
+
|
78
|
+
def auth_flow(
|
79
|
+
self, request: httpx.Request
|
80
|
+
) -> Generator[httpx.Request, httpx.Response, None]:
|
81
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
82
|
+
response = yield request
|
83
|
+
|
84
|
+
# If it failed, refresh once and replay the request
|
85
|
+
if response.status_code in self.RETRY_STATUSES:
|
86
|
+
logger.info(
|
87
|
+
"Received %s, refreshing token and retrying …",
|
88
|
+
response.status_code,
|
89
|
+
)
|
90
|
+
self.refresh_token()
|
91
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
92
|
+
yield request # second (and final) attempt, again or use a while loop
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client.egg-info/PKG-INFO
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.18
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -18,7 +18,9 @@ futurehouse_client/models/app.py
|
|
18
18
|
futurehouse_client/models/client.py
|
19
19
|
futurehouse_client/models/rest.py
|
20
20
|
futurehouse_client/utils/__init__.py
|
21
|
+
futurehouse_client/utils/auth.py
|
21
22
|
futurehouse_client/utils/general.py
|
22
23
|
futurehouse_client/utils/module_utils.py
|
23
24
|
futurehouse_client/utils/monitoring.py
|
25
|
+
tests/test_client.py
|
24
26
|
tests/test_rest.py
|
@@ -0,0 +1,161 @@
|
|
1
|
+
import copy
|
2
|
+
from datetime import datetime
|
3
|
+
from typing import Any
|
4
|
+
from unittest.mock import MagicMock
|
5
|
+
|
6
|
+
import httpx
|
7
|
+
import pytest
|
8
|
+
from futurehouse_client.models.app import AuthType, TaskResponse
|
9
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
10
|
+
|
11
|
+
|
12
|
+
@pytest.fixture
|
13
|
+
def mock_client():
|
14
|
+
"""Create a mock synchronous HTTP client that returns success on first auth attempt."""
|
15
|
+
client = MagicMock(spec=httpx.Client)
|
16
|
+
response = MagicMock()
|
17
|
+
response.raise_for_status.return_value = None
|
18
|
+
response.json.return_value = {
|
19
|
+
"access_token": "test_token_from_api",
|
20
|
+
"expires_in": 300,
|
21
|
+
}
|
22
|
+
client.post.return_value = response
|
23
|
+
return client
|
24
|
+
|
25
|
+
|
26
|
+
@pytest.fixture
|
27
|
+
def failing_then_success_client():
|
28
|
+
"""Create a client that fails with 401 on first call, then succeeds on retry."""
|
29
|
+
client = MagicMock(spec=httpx.Client)
|
30
|
+
|
31
|
+
first_response = MagicMock(status_code=401)
|
32
|
+
success_response = MagicMock()
|
33
|
+
success_response.raise_for_status.return_value = None
|
34
|
+
success_response.json.return_value = {
|
35
|
+
"access_token": "refreshed_token",
|
36
|
+
"expires_in": 300,
|
37
|
+
}
|
38
|
+
|
39
|
+
client.post.return_value = success_response
|
40
|
+
|
41
|
+
return client, first_response
|
42
|
+
|
43
|
+
|
44
|
+
def test_refreshing_jwt_with_api_key(mock_client):
|
45
|
+
"""Test that RefreshingJWT works with API key authentication."""
|
46
|
+
api_key = "mock_api_key_12345"
|
47
|
+
|
48
|
+
auth = RefreshingJWT(
|
49
|
+
auth_client=mock_client, auth_type=AuthType.API_KEY, api_key=api_key
|
50
|
+
)
|
51
|
+
|
52
|
+
assert auth._jwt == "test_token_from_api"
|
53
|
+
|
54
|
+
mock_client.post.assert_called_once()
|
55
|
+
args, kwargs = mock_client.post.call_args
|
56
|
+
assert args[0] == "/auth/login"
|
57
|
+
assert "json" in kwargs
|
58
|
+
assert kwargs["json"] == {"api_key": api_key}
|
59
|
+
|
60
|
+
|
61
|
+
def test_refreshing_jwt_with_jwt_token():
|
62
|
+
"""Test that RefreshingJWT works with JWT authentication."""
|
63
|
+
jwt_token = "mock.jwt.token"
|
64
|
+
|
65
|
+
auth = RefreshingJWT(auth_client=MagicMock(), auth_type=AuthType.JWT, jwt=jwt_token)
|
66
|
+
|
67
|
+
assert auth._jwt == jwt_token
|
68
|
+
|
69
|
+
|
70
|
+
def test_refreshing_jwt_refresh_token(mock_client):
|
71
|
+
"""Test that refresh_token method correctly gets a new token."""
|
72
|
+
api_key = "mock_api_key_12345"
|
73
|
+
|
74
|
+
auth = RefreshingJWT(
|
75
|
+
auth_client=mock_client, auth_type=AuthType.API_KEY, api_key=api_key
|
76
|
+
)
|
77
|
+
|
78
|
+
original_token = auth._jwt
|
79
|
+
|
80
|
+
new_response = MagicMock()
|
81
|
+
new_response.raise_for_status.return_value = None
|
82
|
+
new_response.json.return_value = {
|
83
|
+
"access_token": "new_refreshed_token",
|
84
|
+
"expires_in": 300,
|
85
|
+
}
|
86
|
+
mock_client.post.return_value = new_response
|
87
|
+
|
88
|
+
auth.refresh_token()
|
89
|
+
|
90
|
+
assert auth._jwt == "new_refreshed_token"
|
91
|
+
assert auth._jwt != original_token
|
92
|
+
|
93
|
+
assert mock_client.post.call_count == 2 # Initial auth + refresh
|
94
|
+
|
95
|
+
|
96
|
+
def test_refreshing_jwt_refresh_token_jwt_auth_fails():
|
97
|
+
"""Test that refresh_token raises an error with JWT auth type."""
|
98
|
+
jwt_token = "mock.jwt.token"
|
99
|
+
|
100
|
+
auth = RefreshingJWT(auth_client=MagicMock(), auth_type=AuthType.JWT, jwt=jwt_token)
|
101
|
+
|
102
|
+
with pytest.raises(ValueError) as excinfo: # noqa: PT011
|
103
|
+
auth.refresh_token()
|
104
|
+
|
105
|
+
assert "API key auth is required to refresh auth tokens" in str(excinfo.value)
|
106
|
+
|
107
|
+
|
108
|
+
def test_auth_flow_with_retry(failing_then_success_client):
|
109
|
+
"""Test that auth_flow retries with new token after receiving a 401."""
|
110
|
+
client, first_response = failing_then_success_client
|
111
|
+
api_key = "mock_api_key_12345"
|
112
|
+
auth = RefreshingJWT(
|
113
|
+
auth_client=client, auth_type=AuthType.API_KEY, api_key=api_key
|
114
|
+
)
|
115
|
+
request = httpx.Request("GET", "https://fh.org")
|
116
|
+
|
117
|
+
flow = auth.auth_flow(request)
|
118
|
+
first_request = next(flow)
|
119
|
+
assert first_request.headers["Authorization"] == f"Bearer {auth._jwt}"
|
120
|
+
|
121
|
+
second_request = flow.send(first_response)
|
122
|
+
assert auth._jwt == "refreshed_token"
|
123
|
+
assert second_request.headers["Authorization"] == "Bearer refreshed_token"
|
124
|
+
success_response = httpx.Response(200)
|
125
|
+
|
126
|
+
try:
|
127
|
+
flow.send(success_response)
|
128
|
+
pytest.fail("Generator should have exited after processing the response")
|
129
|
+
except StopIteration:
|
130
|
+
pass
|
131
|
+
|
132
|
+
client.post.assert_called_with("/auth/login", json={"api_key": api_key})
|
133
|
+
|
134
|
+
|
135
|
+
def test_task_response_does_not_mutate_original_data():
|
136
|
+
"""Test that TaskResponse doesn't mutate the original data when creating an instance."""
|
137
|
+
original_data: dict[str, Any] = {
|
138
|
+
"crow": "test-crow",
|
139
|
+
"task": "test task",
|
140
|
+
"metadata": {
|
141
|
+
"environment_name": "test-env",
|
142
|
+
"agent_name": "test-agent",
|
143
|
+
"some_other_field": "should not be modified",
|
144
|
+
},
|
145
|
+
"status": "success",
|
146
|
+
"created_at": datetime.now(),
|
147
|
+
"public": True,
|
148
|
+
}
|
149
|
+
|
150
|
+
original_data_copy = copy.deepcopy(original_data)
|
151
|
+
|
152
|
+
task_response = TaskResponse(**original_data)
|
153
|
+
|
154
|
+
assert original_data == original_data_copy, "Original data was mutated"
|
155
|
+
|
156
|
+
# Assert the fields are set correctly
|
157
|
+
assert task_response.job_name == original_data["crow"]
|
158
|
+
assert task_response.query == original_data["task"]
|
159
|
+
metadata = original_data.get("metadata", {})
|
160
|
+
assert task_response.environment_name == metadata.get("environment_name")
|
161
|
+
assert task_response.agent_name == metadata.get("agent_name")
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# ruff: noqa: ARG001
|
1
2
|
import asyncio
|
2
3
|
import os
|
3
4
|
import time
|
@@ -18,87 +19,98 @@ PUBLIC_API_KEY = os.environ["PLAYWRIGHT_PUBLIC_API_KEY"]
|
|
18
19
|
TEST_MAX_POLLS = 100
|
19
20
|
|
20
21
|
|
21
|
-
@pytest.
|
22
|
-
|
23
|
-
|
24
|
-
|
22
|
+
@pytest.fixture
|
23
|
+
def admin_client():
|
24
|
+
"""Create a RestClient for testing; using an admin key."""
|
25
|
+
return RestClient(
|
25
26
|
stage=Stage.DEV,
|
26
27
|
api_key=ADMIN_API_KEY,
|
27
28
|
)
|
28
29
|
|
29
|
-
|
30
|
+
|
31
|
+
@pytest.fixture
|
32
|
+
def pub_client():
|
33
|
+
"""Create a RestClient for testing; using a public user key with limited access."""
|
34
|
+
return RestClient(
|
35
|
+
stage=Stage.DEV,
|
36
|
+
api_key=PUBLIC_API_KEY,
|
37
|
+
)
|
38
|
+
|
39
|
+
|
40
|
+
@pytest.fixture
|
41
|
+
def task_req():
|
42
|
+
"""Create a sample task request."""
|
43
|
+
return TaskRequest(
|
30
44
|
name=JobNames.from_string("dummy"),
|
31
45
|
query="How many moons does earth have?",
|
32
46
|
)
|
33
|
-
client.create_task(task_data)
|
34
47
|
|
35
|
-
while (task_status := client.get_task().status) in {"queued", "in progress"}:
|
36
|
-
time.sleep(5)
|
37
48
|
|
49
|
+
@pytest.fixture
|
50
|
+
def pqa_task_req():
|
51
|
+
return TaskRequest(
|
52
|
+
name=JobNames.from_string("crow"),
|
53
|
+
query="How many moons does earth have?",
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
@pytest.mark.timeout(300)
|
58
|
+
@pytest.mark.flaky(reruns=3)
|
59
|
+
def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_req: TaskRequest):
|
60
|
+
admin_client.create_task(task_req)
|
61
|
+
while (task_status := admin_client.get_task().status) in {"queued", "in progress"}:
|
62
|
+
time.sleep(5)
|
38
63
|
assert task_status == "success"
|
39
64
|
|
40
65
|
|
41
|
-
def test_insufficient_permissions_request(
|
66
|
+
def test_insufficient_permissions_request(
|
67
|
+
pub_client: RestClient, task_req: TaskRequest
|
68
|
+
):
|
42
69
|
# Create a new instance so that cached credentials aren't reused
|
43
|
-
client = RestClient(
|
44
|
-
stage=Stage.DEV,
|
45
|
-
api_key=PUBLIC_API_KEY,
|
46
|
-
)
|
47
|
-
task_data = TaskRequest(
|
48
|
-
name=JobNames.from_string("dummy"),
|
49
|
-
query="How many moons does earth have?",
|
50
|
-
)
|
51
|
-
|
52
70
|
with pytest.raises(TaskFetchError) as exc_info:
|
53
|
-
|
71
|
+
pub_client.create_task(task_req)
|
54
72
|
|
55
73
|
assert "Error creating task" in str(exc_info.value)
|
56
74
|
|
57
75
|
|
58
76
|
@pytest.mark.timeout(300)
|
59
77
|
@pytest.mark.asyncio
|
60
|
-
async def test_job_response(
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
)
|
65
|
-
task_data = TaskRequest(
|
66
|
-
name=JobNames.from_string("crow"),
|
67
|
-
query="How many moons does earth have?",
|
68
|
-
)
|
69
|
-
task_id = client.create_task(task_data)
|
70
|
-
atask_id = await client.acreate_task(task_data)
|
78
|
+
async def test_job_response( # noqa: PLR0915
|
79
|
+
subtests: SubTests, admin_client: RestClient, pqa_task_req: TaskRequest
|
80
|
+
):
|
81
|
+
task_id = admin_client.create_task(pqa_task_req)
|
82
|
+
atask_id = await admin_client.acreate_task(pqa_task_req)
|
71
83
|
|
72
84
|
with subtests.test("Test TaskResponse with queued task"):
|
73
|
-
task_response =
|
85
|
+
task_response = admin_client.get_task(task_id)
|
74
86
|
assert task_response.status in {"queued", "in progress"}
|
75
|
-
assert task_response.job_name ==
|
76
|
-
assert task_response.query ==
|
77
|
-
task_response = await
|
87
|
+
assert task_response.job_name == pqa_task_req.name
|
88
|
+
assert task_response.query == pqa_task_req.query
|
89
|
+
task_response = await admin_client.aget_task(atask_id)
|
78
90
|
assert task_response.status in {"queued", "in progress"}
|
79
|
-
assert task_response.job_name ==
|
80
|
-
assert task_response.query ==
|
91
|
+
assert task_response.job_name == pqa_task_req.name
|
92
|
+
assert task_response.query == pqa_task_req.query
|
81
93
|
|
82
94
|
for _ in range(TEST_MAX_POLLS):
|
83
|
-
task_response =
|
95
|
+
task_response = admin_client.get_task(task_id)
|
84
96
|
if task_response.status in ExecutionStatus.terminal_states():
|
85
97
|
break
|
86
98
|
await asyncio.sleep(5)
|
87
99
|
|
88
100
|
for _ in range(TEST_MAX_POLLS):
|
89
|
-
task_response = await
|
101
|
+
task_response = await admin_client.aget_task(atask_id)
|
90
102
|
if task_response.status in ExecutionStatus.terminal_states():
|
91
103
|
break
|
92
104
|
await asyncio.sleep(5)
|
93
105
|
|
94
106
|
with subtests.test("Test PQA job response"):
|
95
|
-
task_response =
|
107
|
+
task_response = admin_client.get_task(task_id)
|
96
108
|
assert isinstance(task_response, PQATaskResponse)
|
97
109
|
# assert it has general fields
|
98
110
|
assert task_response.status == "success"
|
99
111
|
assert task_response.task_id is not None
|
100
|
-
assert
|
101
|
-
assert
|
112
|
+
assert pqa_task_req.name in task_response.job_name
|
113
|
+
assert pqa_task_req.query in task_response.query
|
102
114
|
# assert it has PQA specific fields
|
103
115
|
assert task_response.answer is not None
|
104
116
|
# assert it's not verbose
|
@@ -106,13 +118,13 @@ async def test_job_response(subtests: SubTests): # noqa: PLR0915
|
|
106
118
|
assert not hasattr(task_response, "agent_state")
|
107
119
|
|
108
120
|
with subtests.test("Test async PQA job response"):
|
109
|
-
task_response = await
|
121
|
+
task_response = await admin_client.aget_task(atask_id)
|
110
122
|
assert isinstance(task_response, PQATaskResponse)
|
111
123
|
# assert it has general fields
|
112
124
|
assert task_response.status == "success"
|
113
125
|
assert task_response.task_id is not None
|
114
|
-
assert
|
115
|
-
assert
|
126
|
+
assert pqa_task_req.name in task_response.job_name
|
127
|
+
assert pqa_task_req.query in task_response.query
|
116
128
|
# assert it has PQA specific fields
|
117
129
|
assert task_response.answer is not None
|
118
130
|
# assert it's not verbose
|
@@ -120,14 +132,14 @@ async def test_job_response(subtests: SubTests): # noqa: PLR0915
|
|
120
132
|
assert not hasattr(task_response, "agent_state")
|
121
133
|
|
122
134
|
with subtests.test("Test task response with verbose"):
|
123
|
-
task_response =
|
135
|
+
task_response = admin_client.get_task(task_id, verbose=True)
|
124
136
|
assert isinstance(task_response, TaskResponseVerbose)
|
125
137
|
assert task_response.status == "success"
|
126
138
|
assert task_response.environment_frame is not None
|
127
139
|
assert task_response.agent_state is not None
|
128
140
|
|
129
141
|
with subtests.test("Test task async response with verbose"):
|
130
|
-
task_response = await
|
142
|
+
task_response = await admin_client.aget_task(atask_id, verbose=True)
|
131
143
|
assert isinstance(task_response, TaskResponseVerbose)
|
132
144
|
assert task_response.status == "success"
|
133
145
|
assert task_response.environment_frame is not None
|
@@ -136,20 +148,12 @@ async def test_job_response(subtests: SubTests): # noqa: PLR0915
|
|
136
148
|
|
137
149
|
@pytest.mark.timeout(300)
|
138
150
|
@pytest.mark.flaky(reruns=3)
|
139
|
-
def test_run_until_done_futurehouse_dummy_env_crow(
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
)
|
144
|
-
|
145
|
-
task_data = TaskRequest(
|
146
|
-
name=JobNames.from_string("dummy"),
|
147
|
-
query="How many moons does earth have?",
|
148
|
-
)
|
149
|
-
|
150
|
-
tasks_to_do = [task_data, task_data]
|
151
|
+
def test_run_until_done_futurehouse_dummy_env_crow(
|
152
|
+
admin_client: RestClient, task_req: TaskRequest
|
153
|
+
):
|
154
|
+
tasks_to_do = [task_req, task_req]
|
151
155
|
|
152
|
-
results =
|
156
|
+
results = admin_client.run_tasks_until_done(tasks_to_do)
|
153
157
|
|
154
158
|
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
155
159
|
assert all(task.status == "success" for task in results)
|
@@ -158,20 +162,12 @@ def test_run_until_done_futurehouse_dummy_env_crow():
|
|
158
162
|
@pytest.mark.timeout(300)
|
159
163
|
@pytest.mark.flaky(reruns=3)
|
160
164
|
@pytest.mark.asyncio
|
161
|
-
async def test_arun_until_done_futurehouse_dummy_env_crow(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
)
|
165
|
+
async def test_arun_until_done_futurehouse_dummy_env_crow(
|
166
|
+
admin_client: RestClient, task_req: TaskRequest
|
167
|
+
):
|
168
|
+
tasks_to_do = [task_req, task_req]
|
166
169
|
|
167
|
-
|
168
|
-
name=JobNames.from_string("dummy"),
|
169
|
-
query="How many moons does earth have?",
|
170
|
-
)
|
171
|
-
|
172
|
-
tasks_to_do = [task_data, task_data]
|
173
|
-
|
174
|
-
results = await client.arun_tasks_until_done(tasks_to_do)
|
170
|
+
results = await admin_client.arun_tasks_until_done(tasks_to_do)
|
175
171
|
|
176
172
|
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
177
173
|
assert all(task.status == "success" for task in results)
|
@@ -180,20 +176,12 @@ async def test_arun_until_done_futurehouse_dummy_env_crow():
|
|
180
176
|
@pytest.mark.timeout(300)
|
181
177
|
@pytest.mark.flaky(reruns=3)
|
182
178
|
@pytest.mark.asyncio
|
183
|
-
async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
)
|
188
|
-
|
189
|
-
task_data = TaskRequest(
|
190
|
-
name=JobNames.from_string("dummy"),
|
191
|
-
query="How many moons does earth have?",
|
192
|
-
)
|
193
|
-
|
194
|
-
tasks_to_do = [task_data, task_data]
|
179
|
+
async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
|
180
|
+
admin_client: RestClient, task_req: TaskRequest
|
181
|
+
):
|
182
|
+
tasks_to_do = [task_req, task_req]
|
195
183
|
|
196
|
-
results = await
|
184
|
+
results = await admin_client.arun_tasks_until_done(
|
197
185
|
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
198
186
|
)
|
199
187
|
|
@@ -203,7 +191,7 @@ async def test_timeout_run_until_done_futurehouse_dummy_env_crow():
|
|
203
191
|
"Should be verbose."
|
204
192
|
)
|
205
193
|
|
206
|
-
results =
|
194
|
+
results = admin_client.run_tasks_until_done(
|
207
195
|
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
208
196
|
)
|
209
197
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/__init__.py
RENAMED
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/clients/__init__.py
RENAMED
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/client.py
RENAMED
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/models/rest.py
RENAMED
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/__init__.py
RENAMED
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/general.py
RENAMED
File without changes
|
File without changes
|
{futurehouse_client-0.3.17.dev94 → futurehouse_client-0.3.18}/futurehouse_client/utils/monitoring.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|