futurehouse-client 0.3.18.dev110__tar.gz → 0.3.18.dev185__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/PKG-INFO +1 -1
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/clients/rest_client.py +115 -104
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/models/__init__.py +10 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/models/app.py +93 -0
- futurehouse_client-0.3.18.dev185/futurehouse_client/utils/__init__.py +0 -0
- futurehouse_client-0.3.18.dev185/futurehouse_client/utils/auth.py +107 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client.egg-info/PKG-INFO +1 -1
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client.egg-info/SOURCES.txt +1 -1
- futurehouse_client-0.3.18.dev185/tests/test_rest.py +260 -0
- futurehouse_client-0.3.18.dev110/futurehouse_client/utils/__init__.py +0 -7
- futurehouse_client-0.3.18.dev110/futurehouse_client/utils/context.py +0 -45
- futurehouse_client-0.3.18.dev110/tests/test_rest.py +0 -214
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/LICENSE +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/README.md +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/docs/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/docs/client_notebook.ipynb +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/clients/__init__.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/clients/job_client.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/models/client.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/models/rest.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/utils/general.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/utils/module_utils.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/utils/monitoring.py +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client.egg-info/dependency_links.txt +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client.egg-info/requires.txt +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client.egg-info/top_level.txt +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/pyproject.toml +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/setup.cfg +0 -0
- {futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/uv.lock +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.18.
|
3
|
+
Version: 0.3.18.dev185
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -12,8 +12,7 @@ import sys
|
|
12
12
|
import tempfile
|
13
13
|
import time
|
14
14
|
import uuid
|
15
|
-
from collections.abc import Collection
|
16
|
-
from datetime import datetime
|
15
|
+
from collections.abc import Collection
|
17
16
|
from pathlib import Path
|
18
17
|
from types import ModuleType
|
19
18
|
from typing import Any, ClassVar, assert_never, cast
|
@@ -34,7 +33,6 @@ from httpx import (
|
|
34
33
|
RemoteProtocolError,
|
35
34
|
)
|
36
35
|
from ldp.agent import AgentConfig
|
37
|
-
from pydantic import BaseModel, ConfigDict, model_validator
|
38
36
|
from requests.exceptions import RequestException, Timeout
|
39
37
|
from tenacity import (
|
40
38
|
retry,
|
@@ -50,10 +48,18 @@ from futurehouse_client.models.app import (
|
|
50
48
|
APIKeyPayload,
|
51
49
|
AuthType,
|
52
50
|
JobDeploymentConfig,
|
51
|
+
PQATaskResponse,
|
53
52
|
Stage,
|
54
53
|
TaskRequest,
|
54
|
+
TaskResponse,
|
55
|
+
TaskResponseVerbose,
|
55
56
|
)
|
56
57
|
from futurehouse_client.models.rest import ExecutionStatus
|
58
|
+
from futurehouse_client.utils.auth import (
|
59
|
+
AUTH_ERRORS_TO_RETRY_ON,
|
60
|
+
AuthError,
|
61
|
+
refresh_token_on_auth_error,
|
62
|
+
)
|
57
63
|
from futurehouse_client.utils.general import gather_with_concurrency
|
58
64
|
from futurehouse_client.utils.module_utils import (
|
59
65
|
OrganizationSelector,
|
@@ -65,7 +71,7 @@ from futurehouse_client.utils.monitoring import (
|
|
65
71
|
|
66
72
|
logger = logging.getLogger(__name__)
|
67
73
|
logging.basicConfig(
|
68
|
-
level=logging.
|
74
|
+
level=logging.WARNING,
|
69
75
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
70
76
|
stream=sys.stdout,
|
71
77
|
)
|
@@ -122,103 +128,11 @@ retry_if_connection_error = retry_if_exception_type((
|
|
122
128
|
FileUploadError,
|
123
129
|
))
|
124
130
|
|
125
|
-
|
126
|
-
class SimpleOrganization(BaseModel):
|
127
|
-
id: int
|
128
|
-
name: str
|
129
|
-
display_name: str
|
130
|
-
|
131
|
-
|
132
131
|
# 5 minute default for JWTs
|
133
132
|
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
134
133
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
135
134
|
|
136
135
|
|
137
|
-
class TaskResponse(BaseModel):
|
138
|
-
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
139
|
-
|
140
|
-
model_config = ConfigDict(extra="ignore")
|
141
|
-
|
142
|
-
status: str
|
143
|
-
query: str
|
144
|
-
user: str | None = None
|
145
|
-
created_at: datetime
|
146
|
-
job_name: str
|
147
|
-
public: bool
|
148
|
-
shared_with: list[SimpleOrganization] | None = None
|
149
|
-
build_owner: str | None = None
|
150
|
-
environment_name: str | None = None
|
151
|
-
agent_name: str | None = None
|
152
|
-
task_id: UUID | None = None
|
153
|
-
|
154
|
-
@model_validator(mode="before")
|
155
|
-
@classmethod
|
156
|
-
def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
157
|
-
# Extract fields from environment frame state
|
158
|
-
if not isinstance(data, dict):
|
159
|
-
return data
|
160
|
-
# TODO: We probably want to remove these two once we define the final names.
|
161
|
-
data["job_name"] = data.get("crow")
|
162
|
-
data["query"] = data.get("task")
|
163
|
-
data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
|
164
|
-
if not (metadata := data.get("metadata", {})):
|
165
|
-
return data
|
166
|
-
data["environment_name"] = metadata.get("environment_name")
|
167
|
-
data["agent_name"] = metadata.get("agent_name")
|
168
|
-
return data
|
169
|
-
|
170
|
-
|
171
|
-
class PQATaskResponse(TaskResponse):
|
172
|
-
model_config = ConfigDict(extra="ignore")
|
173
|
-
|
174
|
-
answer: str | None = None
|
175
|
-
formatted_answer: str | None = None
|
176
|
-
answer_reasoning: str | None = None
|
177
|
-
has_successful_answer: bool | None = None
|
178
|
-
total_cost: float | None = None
|
179
|
-
total_queries: int | None = None
|
180
|
-
|
181
|
-
@model_validator(mode="before")
|
182
|
-
@classmethod
|
183
|
-
def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
184
|
-
if not isinstance(data, dict):
|
185
|
-
return data
|
186
|
-
if not (env_frame := data.get("environment_frame", {})):
|
187
|
-
return data
|
188
|
-
state = env_frame.get("state", {}).get("state", {})
|
189
|
-
response = state.get("response", {})
|
190
|
-
answer = response.get("answer", {})
|
191
|
-
usage = state.get("info", {}).get("usage", {})
|
192
|
-
|
193
|
-
# Add additional PQA specific fields to data so that pydantic can validate the model
|
194
|
-
data["answer"] = answer.get("answer")
|
195
|
-
data["formatted_answer"] = answer.get("formatted_answer")
|
196
|
-
data["answer_reasoning"] = answer.get("answer_reasoning")
|
197
|
-
data["has_successful_answer"] = answer.get("has_successful_answer")
|
198
|
-
data["total_cost"] = cast(float, usage.get("total_cost"))
|
199
|
-
data["total_queries"] = cast(int, usage.get("total_queries"))
|
200
|
-
|
201
|
-
return data
|
202
|
-
|
203
|
-
def clean_verbose(self) -> "TaskResponse":
|
204
|
-
"""Clean the verbose response from the server."""
|
205
|
-
self.request = None
|
206
|
-
self.response = None
|
207
|
-
return self
|
208
|
-
|
209
|
-
|
210
|
-
class TaskResponseVerbose(TaskResponse):
|
211
|
-
"""Class for responses to include all the fields of a task response."""
|
212
|
-
|
213
|
-
model_config = ConfigDict(extra="allow")
|
214
|
-
|
215
|
-
public: bool
|
216
|
-
agent_state: list[dict[str, Any]] | None = None
|
217
|
-
environment_frame: dict[str, Any] | None = None
|
218
|
-
metadata: dict[str, Any] | None = None
|
219
|
-
shared_with: list[SimpleOrganization] | None = None
|
220
|
-
|
221
|
-
|
222
136
|
class RestClient:
|
223
137
|
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
|
224
138
|
MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
|
@@ -236,7 +150,13 @@ class RestClient:
|
|
236
150
|
api_key: str | None = None,
|
237
151
|
jwt: str | None = None,
|
238
152
|
headers: dict[str, str] | None = None,
|
153
|
+
verbose_logging: bool = False,
|
239
154
|
):
|
155
|
+
if verbose_logging:
|
156
|
+
logger.setLevel(logging.INFO)
|
157
|
+
else:
|
158
|
+
logger.setLevel(logging.WARNING)
|
159
|
+
|
240
160
|
self.base_url = service_uri or stage.value
|
241
161
|
self.stage = stage
|
242
162
|
self.auth_type = auth_type
|
@@ -360,6 +280,7 @@ class RestClient:
|
|
360
280
|
except Exception as e:
|
361
281
|
raise RestClientError(f"Error authenticating: {e!s}") from e
|
362
282
|
|
283
|
+
@refresh_token_on_auth_error()
|
363
284
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
364
285
|
try:
|
365
286
|
response = self.client.get(
|
@@ -367,9 +288,19 @@ class RestClient:
|
|
367
288
|
)
|
368
289
|
response.raise_for_status()
|
369
290
|
return response.json()
|
291
|
+
except HTTPStatusError as e:
|
292
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
293
|
+
raise AuthError(
|
294
|
+
e.response.status_code,
|
295
|
+
f"Authentication failed: {e}",
|
296
|
+
request=e.request,
|
297
|
+
response=e.response,
|
298
|
+
) from e
|
299
|
+
raise
|
370
300
|
except Exception as e:
|
371
301
|
raise JobFetchError(f"Error checking job: {e!s}") from e
|
372
302
|
|
303
|
+
@refresh_token_on_auth_error()
|
373
304
|
def _fetch_my_orgs(self) -> list[str]:
|
374
305
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
375
306
|
response.raise_for_status()
|
@@ -427,6 +358,7 @@ class RestClient:
|
|
427
358
|
if not files:
|
428
359
|
raise TaskFetchError(f"No files found in {path}")
|
429
360
|
|
361
|
+
@refresh_token_on_auth_error()
|
430
362
|
@retry(
|
431
363
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
432
364
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -467,11 +399,19 @@ class RestClient:
|
|
467
399
|
):
|
468
400
|
return PQATaskResponse(**data)
|
469
401
|
return TaskResponse(**data)
|
470
|
-
except
|
471
|
-
|
402
|
+
except HTTPStatusError as e:
|
403
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
404
|
+
raise AuthError(
|
405
|
+
e.response.status_code,
|
406
|
+
f"Authentication failed: {e}",
|
407
|
+
request=e.request,
|
408
|
+
response=e.response,
|
409
|
+
) from e
|
410
|
+
raise
|
472
411
|
except Exception as e:
|
473
412
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
474
413
|
|
414
|
+
@refresh_token_on_auth_error()
|
475
415
|
@retry(
|
476
416
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
477
417
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -515,11 +455,19 @@ class RestClient:
|
|
515
455
|
):
|
516
456
|
return PQATaskResponse(**data)
|
517
457
|
return TaskResponse(**data)
|
518
|
-
except
|
519
|
-
|
458
|
+
except HTTPStatusError as e:
|
459
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
460
|
+
raise AuthError(
|
461
|
+
e.response.status_code,
|
462
|
+
f"Authentication failed: {e}",
|
463
|
+
request=e.request,
|
464
|
+
response=e.response,
|
465
|
+
) from e
|
466
|
+
raise
|
520
467
|
except Exception as e:
|
521
468
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
522
469
|
|
470
|
+
@refresh_token_on_auth_error()
|
523
471
|
@retry(
|
524
472
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
525
473
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -543,10 +491,20 @@ class RestClient:
|
|
543
491
|
response.raise_for_status()
|
544
492
|
trajectory_id = response.json()["trajectory_id"]
|
545
493
|
self.trajectory_id = trajectory_id
|
494
|
+
except HTTPStatusError as e:
|
495
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
496
|
+
raise AuthError(
|
497
|
+
e.response.status_code,
|
498
|
+
f"Authentication failed: {e}",
|
499
|
+
request=e.request,
|
500
|
+
response=e.response,
|
501
|
+
) from e
|
502
|
+
raise
|
546
503
|
except Exception as e:
|
547
504
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
548
505
|
return trajectory_id
|
549
506
|
|
507
|
+
@refresh_token_on_auth_error()
|
550
508
|
@retry(
|
551
509
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
552
510
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -570,6 +528,15 @@ class RestClient:
|
|
570
528
|
response.raise_for_status()
|
571
529
|
trajectory_id = response.json()["trajectory_id"]
|
572
530
|
self.trajectory_id = trajectory_id
|
531
|
+
except HTTPStatusError as e:
|
532
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
533
|
+
raise AuthError(
|
534
|
+
e.response.status_code,
|
535
|
+
f"Authentication failed: {e}",
|
536
|
+
request=e.request,
|
537
|
+
response=e.response,
|
538
|
+
) from e
|
539
|
+
raise
|
573
540
|
except Exception as e:
|
574
541
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
575
542
|
return trajectory_id
|
@@ -715,6 +682,7 @@ class RestClient:
|
|
715
682
|
for task_id in trajectory_ids
|
716
683
|
]
|
717
684
|
|
685
|
+
@refresh_token_on_auth_error()
|
718
686
|
@retry(
|
719
687
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
720
688
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -722,12 +690,23 @@ class RestClient:
|
|
722
690
|
)
|
723
691
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
724
692
|
"""Get the status of a build."""
|
725
|
-
|
726
|
-
|
727
|
-
|
693
|
+
try:
|
694
|
+
build_id = build_id or self.build_id
|
695
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
696
|
+
response.raise_for_status()
|
697
|
+
except HTTPStatusError as e:
|
698
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
699
|
+
raise AuthError(
|
700
|
+
e.response.status_code,
|
701
|
+
f"Authentication failed: {e}",
|
702
|
+
request=e.request,
|
703
|
+
response=e.response,
|
704
|
+
) from e
|
705
|
+
raise
|
728
706
|
return response.json()
|
729
707
|
|
730
708
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
709
|
+
@refresh_token_on_auth_error()
|
731
710
|
@retry(
|
732
711
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
733
712
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -907,6 +886,13 @@ class RestClient:
|
|
907
886
|
build_context = response.json()
|
908
887
|
self.build_id = build_context["build_id"]
|
909
888
|
except HTTPStatusError as e:
|
889
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
890
|
+
raise AuthError(
|
891
|
+
e.response.status_code,
|
892
|
+
f"Authentication failed: {e}",
|
893
|
+
request=e.request,
|
894
|
+
response=e.response,
|
895
|
+
) from e
|
910
896
|
error_detail = response.json()
|
911
897
|
error_message = error_detail.get("detail", str(e))
|
912
898
|
raise JobCreationError(
|
@@ -987,6 +973,7 @@ class RestClient:
|
|
987
973
|
except Exception as e:
|
988
974
|
raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
|
989
975
|
|
976
|
+
@refresh_token_on_auth_error()
|
990
977
|
def _upload_single_file(
|
991
978
|
self,
|
992
979
|
job_name: str,
|
@@ -1060,11 +1047,20 @@ class RestClient:
|
|
1060
1047
|
)
|
1061
1048
|
|
1062
1049
|
logger.info(f"Successfully uploaded {file_name}")
|
1063
|
-
|
1050
|
+
except HTTPStatusError as e:
|
1051
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1052
|
+
raise AuthError(
|
1053
|
+
e.response.status_code,
|
1054
|
+
f"Authentication failed: {e}",
|
1055
|
+
request=e.request,
|
1056
|
+
response=e.response,
|
1057
|
+
) from e
|
1058
|
+
raise
|
1064
1059
|
except Exception as e:
|
1065
1060
|
logger.exception(f"Error uploading file {file_path}")
|
1066
1061
|
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
1067
1062
|
|
1063
|
+
@refresh_token_on_auth_error()
|
1068
1064
|
@retry(
|
1069
1065
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1070
1066
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1101,6 +1097,13 @@ class RestClient:
|
|
1101
1097
|
response.raise_for_status()
|
1102
1098
|
return response.json()
|
1103
1099
|
except HTTPStatusError as e:
|
1100
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1101
|
+
raise AuthError(
|
1102
|
+
e.response.status_code,
|
1103
|
+
f"Authentication failed: {e}",
|
1104
|
+
request=e.request,
|
1105
|
+
response=e.response,
|
1106
|
+
) from e
|
1104
1107
|
logger.exception(
|
1105
1108
|
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
1106
1109
|
)
|
@@ -1113,6 +1116,7 @@ class RestClient:
|
|
1113
1116
|
)
|
1114
1117
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
1115
1118
|
|
1119
|
+
@refresh_token_on_auth_error()
|
1116
1120
|
@retry(
|
1117
1121
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1118
1122
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1160,6 +1164,13 @@ class RestClient:
|
|
1160
1164
|
|
1161
1165
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
1162
1166
|
except HTTPStatusError as e:
|
1167
|
+
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1168
|
+
raise AuthError(
|
1169
|
+
e.response.status_code,
|
1170
|
+
f"Authentication failed: {e}",
|
1171
|
+
request=e.request,
|
1172
|
+
response=e.response,
|
1173
|
+
) from e
|
1163
1174
|
logger.exception(
|
1164
1175
|
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
1165
1176
|
)
|
@@ -3,10 +3,15 @@ from .app import (
|
|
3
3
|
DockerContainerConfiguration,
|
4
4
|
FramePath,
|
5
5
|
JobDeploymentConfig,
|
6
|
+
PQATaskResponse,
|
6
7
|
RuntimeConfig,
|
7
8
|
Stage,
|
8
9
|
Step,
|
10
|
+
TaskQueue,
|
11
|
+
TaskQueuesConfig,
|
9
12
|
TaskRequest,
|
13
|
+
TaskResponse,
|
14
|
+
TaskResponseVerbose,
|
10
15
|
)
|
11
16
|
|
12
17
|
__all__ = [
|
@@ -14,8 +19,13 @@ __all__ = [
|
|
14
19
|
"DockerContainerConfiguration",
|
15
20
|
"FramePath",
|
16
21
|
"JobDeploymentConfig",
|
22
|
+
"PQATaskResponse",
|
17
23
|
"RuntimeConfig",
|
18
24
|
"Stage",
|
19
25
|
"Step",
|
26
|
+
"TaskQueue",
|
27
|
+
"TaskQueuesConfig",
|
20
28
|
"TaskRequest",
|
29
|
+
"TaskResponse",
|
30
|
+
"TaskResponseVerbose",
|
21
31
|
]
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import re
|
4
|
+
from collections.abc import Mapping
|
5
|
+
from datetime import datetime
|
4
6
|
from enum import StrEnum, auto
|
5
7
|
from pathlib import Path
|
6
8
|
from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
|
@@ -646,3 +648,94 @@ class TaskRequest(BaseModel):
|
|
646
648
|
runtime_config: RuntimeConfig | None = Field(
|
647
649
|
default=None, description="All optional runtime parameters for the job"
|
648
650
|
)
|
651
|
+
|
652
|
+
|
653
|
+
class SimpleOrganization(BaseModel):
|
654
|
+
id: int
|
655
|
+
name: str
|
656
|
+
display_name: str
|
657
|
+
|
658
|
+
|
659
|
+
class TaskResponse(BaseModel):
|
660
|
+
"""Base class for task responses. This holds attributes shared over all futurehouse jobs."""
|
661
|
+
|
662
|
+
model_config = ConfigDict(extra="ignore")
|
663
|
+
|
664
|
+
status: str
|
665
|
+
query: str
|
666
|
+
user: str | None = None
|
667
|
+
created_at: datetime
|
668
|
+
job_name: str
|
669
|
+
public: bool
|
670
|
+
shared_with: list[SimpleOrganization] | None = None
|
671
|
+
build_owner: str | None = None
|
672
|
+
environment_name: str | None = None
|
673
|
+
agent_name: str | None = None
|
674
|
+
task_id: UUID | None = None
|
675
|
+
|
676
|
+
@model_validator(mode="before")
|
677
|
+
@classmethod
|
678
|
+
def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
679
|
+
# Extract fields from environment frame state
|
680
|
+
if not isinstance(data, dict):
|
681
|
+
return data
|
682
|
+
# TODO: We probably want to remove these two once we define the final names.
|
683
|
+
data["job_name"] = data.get("crow")
|
684
|
+
data["query"] = data.get("task")
|
685
|
+
data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
|
686
|
+
if not (metadata := data.get("metadata", {})):
|
687
|
+
return data
|
688
|
+
data["environment_name"] = metadata.get("environment_name")
|
689
|
+
data["agent_name"] = metadata.get("agent_name")
|
690
|
+
return data
|
691
|
+
|
692
|
+
|
693
|
+
class PQATaskResponse(TaskResponse):
|
694
|
+
model_config = ConfigDict(extra="ignore")
|
695
|
+
|
696
|
+
answer: str | None = None
|
697
|
+
formatted_answer: str | None = None
|
698
|
+
answer_reasoning: str | None = None
|
699
|
+
has_successful_answer: bool | None = None
|
700
|
+
total_cost: float | None = None
|
701
|
+
total_queries: int | None = None
|
702
|
+
|
703
|
+
@model_validator(mode="before")
|
704
|
+
@classmethod
|
705
|
+
def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
706
|
+
if not isinstance(data, dict):
|
707
|
+
return data
|
708
|
+
if not (env_frame := data.get("environment_frame", {})):
|
709
|
+
return data
|
710
|
+
state = env_frame.get("state", {}).get("state", {})
|
711
|
+
response = state.get("response", {})
|
712
|
+
answer = response.get("answer", {})
|
713
|
+
usage = state.get("info", {}).get("usage", {})
|
714
|
+
|
715
|
+
# Add additional PQA specific fields to data so that pydantic can validate the model
|
716
|
+
data["answer"] = answer.get("answer")
|
717
|
+
data["formatted_answer"] = answer.get("formatted_answer")
|
718
|
+
data["answer_reasoning"] = answer.get("answer_reasoning")
|
719
|
+
data["has_successful_answer"] = answer.get("has_successful_answer")
|
720
|
+
data["total_cost"] = cast(float, usage.get("total_cost"))
|
721
|
+
data["total_queries"] = cast(int, usage.get("total_queries"))
|
722
|
+
|
723
|
+
return data
|
724
|
+
|
725
|
+
def clean_verbose(self) -> "TaskResponse":
|
726
|
+
"""Clean the verbose response from the server."""
|
727
|
+
self.request = None
|
728
|
+
self.response = None
|
729
|
+
return self
|
730
|
+
|
731
|
+
|
732
|
+
class TaskResponseVerbose(TaskResponse):
|
733
|
+
"""Class for responses to include all the fields of a task response."""
|
734
|
+
|
735
|
+
model_config = ConfigDict(extra="allow")
|
736
|
+
|
737
|
+
public: bool
|
738
|
+
agent_state: list[dict[str, Any]] | None = None
|
739
|
+
environment_frame: dict[str, Any] | None = None
|
740
|
+
metadata: dict[str, Any] | None = None
|
741
|
+
shared_with: list[SimpleOrganization] | None = None
|
File without changes
|
@@ -0,0 +1,107 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from collections.abc import Callable, Coroutine
|
4
|
+
from functools import wraps
|
5
|
+
from typing import Any, Final, Optional, ParamSpec, TypeVar, overload
|
6
|
+
|
7
|
+
import httpx
|
8
|
+
from httpx import HTTPStatusError
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
T = TypeVar("T")
|
13
|
+
P = ParamSpec("P")
|
14
|
+
|
15
|
+
AUTH_ERRORS_TO_RETRY_ON: Final[set[int]] = {
|
16
|
+
httpx.codes.UNAUTHORIZED,
|
17
|
+
httpx.codes.FORBIDDEN,
|
18
|
+
}
|
19
|
+
|
20
|
+
|
21
|
+
class AuthError(Exception):
|
22
|
+
"""Raised when authentication fails with 401/403 status."""
|
23
|
+
|
24
|
+
def __init__(self, status_code: int, message: str, request=None, response=None):
|
25
|
+
self.status_code = status_code
|
26
|
+
self.request = request
|
27
|
+
self.response = response
|
28
|
+
super().__init__(message)
|
29
|
+
|
30
|
+
|
31
|
+
def is_auth_error(e: Exception) -> bool:
|
32
|
+
if isinstance(e, AuthError):
|
33
|
+
return True
|
34
|
+
if isinstance(e, HTTPStatusError):
|
35
|
+
return e.response.status_code in AUTH_ERRORS_TO_RETRY_ON
|
36
|
+
return False
|
37
|
+
|
38
|
+
|
39
|
+
def get_status_code(e: Exception) -> Optional[int]:
|
40
|
+
if isinstance(e, AuthError):
|
41
|
+
return e.status_code
|
42
|
+
if isinstance(e, HTTPStatusError):
|
43
|
+
return e.response.status_code
|
44
|
+
return None
|
45
|
+
|
46
|
+
|
47
|
+
@overload
|
48
|
+
def refresh_token_on_auth_error(
|
49
|
+
func: Callable[P, Coroutine[Any, Any, T]],
|
50
|
+
) -> Callable[P, Coroutine[Any, Any, T]]: ...
|
51
|
+
|
52
|
+
|
53
|
+
@overload
|
54
|
+
def refresh_token_on_auth_error(
|
55
|
+
func: None = None, *, max_retries: int = ...
|
56
|
+
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...
|
57
|
+
|
58
|
+
|
59
|
+
def refresh_token_on_auth_error(func=None, max_retries=1):
|
60
|
+
"""Decorator that refreshes JWT token on 401/403 auth errors."""
|
61
|
+
|
62
|
+
def decorator(fn):
|
63
|
+
@wraps(fn)
|
64
|
+
def sync_wrapper(self, *args, **kwargs):
|
65
|
+
retries = 0
|
66
|
+
while True:
|
67
|
+
try:
|
68
|
+
return fn(self, *args, **kwargs)
|
69
|
+
except Exception as e:
|
70
|
+
if is_auth_error(e) and retries < max_retries:
|
71
|
+
retries += 1
|
72
|
+
status = get_status_code(e) or "Unknown"
|
73
|
+
logger.info(
|
74
|
+
f"Received auth error {status}, "
|
75
|
+
f"refreshing token and retrying (attempt {retries}/{max_retries})..."
|
76
|
+
)
|
77
|
+
self.auth_jwt = self._run_auth()
|
78
|
+
self._clients = {}
|
79
|
+
continue
|
80
|
+
raise
|
81
|
+
|
82
|
+
@wraps(fn)
|
83
|
+
async def async_wrapper(self, *args, **kwargs):
|
84
|
+
retries = 0
|
85
|
+
while True:
|
86
|
+
try:
|
87
|
+
return await fn(self, *args, **kwargs)
|
88
|
+
except Exception as e:
|
89
|
+
if is_auth_error(e) and retries < max_retries:
|
90
|
+
retries += 1
|
91
|
+
status = get_status_code(e) or "Unknown"
|
92
|
+
logger.info(
|
93
|
+
f"Received auth error {status}, "
|
94
|
+
f"refreshing token and retrying (attempt {retries}/{max_retries})..."
|
95
|
+
)
|
96
|
+
self.auth_jwt = self._run_auth()
|
97
|
+
self._clients = {}
|
98
|
+
continue
|
99
|
+
raise
|
100
|
+
|
101
|
+
if asyncio.iscoroutinefunction(fn):
|
102
|
+
return async_wrapper
|
103
|
+
return sync_wrapper
|
104
|
+
|
105
|
+
if callable(func):
|
106
|
+
return decorator(func)
|
107
|
+
return decorator
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.18.
|
3
|
+
Version: 0.3.18.dev185
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -18,7 +18,7 @@ futurehouse_client/models/app.py
|
|
18
18
|
futurehouse_client/models/client.py
|
19
19
|
futurehouse_client/models/rest.py
|
20
20
|
futurehouse_client/utils/__init__.py
|
21
|
-
futurehouse_client/utils/
|
21
|
+
futurehouse_client/utils/auth.py
|
22
22
|
futurehouse_client/utils/general.py
|
23
23
|
futurehouse_client/utils/module_utils.py
|
24
24
|
futurehouse_client/utils/monitoring.py
|
@@ -0,0 +1,260 @@
|
|
1
|
+
# ruff: noqa: ARG001
|
2
|
+
import asyncio
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from unittest.mock import patch
|
6
|
+
|
7
|
+
import pytest
|
8
|
+
from futurehouse_client.clients import (
|
9
|
+
JobNames,
|
10
|
+
PQATaskResponse,
|
11
|
+
TaskResponseVerbose,
|
12
|
+
)
|
13
|
+
from futurehouse_client.clients.rest_client import RestClient
|
14
|
+
from futurehouse_client.models.app import Stage, TaskRequest
|
15
|
+
from futurehouse_client.models.rest import ExecutionStatus
|
16
|
+
from futurehouse_client.utils.auth import AuthError, refresh_token_on_auth_error
|
17
|
+
from pytest_subtests import SubTests
|
18
|
+
|
19
|
+
ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
|
20
|
+
PUBLIC_API_KEY = os.environ["PLAYWRIGHT_PUBLIC_API_KEY"]
|
21
|
+
TEST_MAX_POLLS = 100
|
22
|
+
|
23
|
+
|
24
|
+
@pytest.fixture
|
25
|
+
def admin_client():
|
26
|
+
"""Create a RestClient for testing."""
|
27
|
+
return RestClient(
|
28
|
+
stage=Stage.DEV,
|
29
|
+
api_key=ADMIN_API_KEY,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
@pytest.fixture
|
34
|
+
def pub_client():
|
35
|
+
"""Create a RestClient for testing."""
|
36
|
+
return RestClient(
|
37
|
+
stage=Stage.DEV,
|
38
|
+
api_key=PUBLIC_API_KEY,
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
@pytest.fixture
|
43
|
+
def task_data():
|
44
|
+
"""Create a sample task request."""
|
45
|
+
return TaskRequest(
|
46
|
+
name=JobNames.from_string("dummy"),
|
47
|
+
query="How many moons does earth have?",
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
@pytest.fixture
|
52
|
+
def pqa_task_data():
|
53
|
+
return TaskRequest(
|
54
|
+
name=JobNames.from_string("crow"),
|
55
|
+
query="How many moons does earth have?",
|
56
|
+
)
|
57
|
+
|
58
|
+
|
59
|
+
@pytest.mark.timeout(300)
|
60
|
+
@pytest.mark.flaky(reruns=3)
|
61
|
+
def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_data: TaskRequest):
|
62
|
+
admin_client.create_task(task_data)
|
63
|
+
while (task_status := admin_client.get_task().status) in {"queued", "in progress"}:
|
64
|
+
time.sleep(5)
|
65
|
+
assert task_status == "success"
|
66
|
+
|
67
|
+
|
68
|
+
def test_insufficient_permissions_request(
|
69
|
+
pub_client: RestClient, task_data: TaskRequest
|
70
|
+
):
|
71
|
+
# Create a new instance so that cached credentials aren't reused
|
72
|
+
with pytest.raises(AuthError) as exc_info:
|
73
|
+
pub_client.create_task(task_data)
|
74
|
+
|
75
|
+
assert "403 Forbidden" in str(exc_info.value)
|
76
|
+
|
77
|
+
|
78
|
+
@pytest.mark.timeout(300)
|
79
|
+
@pytest.mark.asyncio
|
80
|
+
async def test_job_response( # noqa: PLR0915
|
81
|
+
subtests: SubTests, admin_client: RestClient, pqa_task_data: TaskRequest
|
82
|
+
):
|
83
|
+
task_id = admin_client.create_task(pqa_task_data)
|
84
|
+
atask_id = await admin_client.acreate_task(pqa_task_data)
|
85
|
+
|
86
|
+
with subtests.test("Test TaskResponse with queued task"):
|
87
|
+
task_response = admin_client.get_task(task_id)
|
88
|
+
assert task_response.status in {"queued", "in progress"}
|
89
|
+
assert task_response.job_name == pqa_task_data.name
|
90
|
+
assert task_response.query == pqa_task_data.query
|
91
|
+
task_response = await admin_client.aget_task(atask_id)
|
92
|
+
assert task_response.status in {"queued", "in progress"}
|
93
|
+
assert task_response.job_name == pqa_task_data.name
|
94
|
+
assert task_response.query == pqa_task_data.query
|
95
|
+
|
96
|
+
for _ in range(TEST_MAX_POLLS):
|
97
|
+
task_response = admin_client.get_task(task_id)
|
98
|
+
if task_response.status in ExecutionStatus.terminal_states():
|
99
|
+
break
|
100
|
+
await asyncio.sleep(5)
|
101
|
+
|
102
|
+
for _ in range(TEST_MAX_POLLS):
|
103
|
+
task_response = await admin_client.aget_task(atask_id)
|
104
|
+
if task_response.status in ExecutionStatus.terminal_states():
|
105
|
+
break
|
106
|
+
await asyncio.sleep(5)
|
107
|
+
|
108
|
+
with subtests.test("Test PQA job response"):
|
109
|
+
task_response = admin_client.get_task(task_id)
|
110
|
+
assert isinstance(task_response, PQATaskResponse)
|
111
|
+
# assert it has general fields
|
112
|
+
assert task_response.status == "success"
|
113
|
+
assert task_response.task_id is not None
|
114
|
+
assert pqa_task_data.name in task_response.job_name
|
115
|
+
assert pqa_task_data.query in task_response.query
|
116
|
+
# assert it has PQA specific fields
|
117
|
+
assert task_response.answer is not None
|
118
|
+
# assert it's not verbose
|
119
|
+
assert not hasattr(task_response, "environment_frame")
|
120
|
+
assert not hasattr(task_response, "agent_state")
|
121
|
+
|
122
|
+
with subtests.test("Test async PQA job response"):
|
123
|
+
task_response = await admin_client.aget_task(atask_id)
|
124
|
+
assert isinstance(task_response, PQATaskResponse)
|
125
|
+
# assert it has general fields
|
126
|
+
assert task_response.status == "success"
|
127
|
+
assert task_response.task_id is not None
|
128
|
+
assert pqa_task_data.name in task_response.job_name
|
129
|
+
assert pqa_task_data.query in task_response.query
|
130
|
+
# assert it has PQA specific fields
|
131
|
+
assert task_response.answer is not None
|
132
|
+
# assert it's not verbose
|
133
|
+
assert not hasattr(task_response, "environment_frame")
|
134
|
+
assert not hasattr(task_response, "agent_state")
|
135
|
+
|
136
|
+
with subtests.test("Test task response with verbose"):
|
137
|
+
task_response = admin_client.get_task(task_id, verbose=True)
|
138
|
+
assert isinstance(task_response, TaskResponseVerbose)
|
139
|
+
assert task_response.status == "success"
|
140
|
+
assert task_response.environment_frame is not None
|
141
|
+
assert task_response.agent_state is not None
|
142
|
+
|
143
|
+
with subtests.test("Test task async response with verbose"):
|
144
|
+
task_response = await admin_client.aget_task(atask_id, verbose=True)
|
145
|
+
assert isinstance(task_response, TaskResponseVerbose)
|
146
|
+
assert task_response.status == "success"
|
147
|
+
assert task_response.environment_frame is not None
|
148
|
+
assert task_response.agent_state is not None
|
149
|
+
|
150
|
+
|
151
|
+
@pytest.mark.timeout(300)
|
152
|
+
@pytest.mark.flaky(reruns=3)
|
153
|
+
def test_run_until_done_futurehouse_dummy_env_crow(
|
154
|
+
admin_client: RestClient, task_data: TaskRequest
|
155
|
+
):
|
156
|
+
tasks_to_do = [task_data, task_data]
|
157
|
+
|
158
|
+
results = admin_client.run_tasks_until_done(tasks_to_do)
|
159
|
+
|
160
|
+
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
161
|
+
assert all(task.status == "success" for task in results)
|
162
|
+
|
163
|
+
|
164
|
+
@pytest.mark.timeout(300)
|
165
|
+
@pytest.mark.flaky(reruns=3)
|
166
|
+
@pytest.mark.asyncio
|
167
|
+
async def test_arun_until_done_futurehouse_dummy_env_crow(
|
168
|
+
admin_client: RestClient, task_data: TaskRequest
|
169
|
+
):
|
170
|
+
tasks_to_do = [task_data, task_data]
|
171
|
+
|
172
|
+
results = await admin_client.arun_tasks_until_done(tasks_to_do)
|
173
|
+
|
174
|
+
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
175
|
+
assert all(task.status == "success" for task in results)
|
176
|
+
|
177
|
+
|
178
|
+
@pytest.mark.timeout(300)
|
179
|
+
@pytest.mark.flaky(reruns=3)
|
180
|
+
@pytest.mark.asyncio
|
181
|
+
async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
|
182
|
+
admin_client: RestClient, task_data: TaskRequest
|
183
|
+
):
|
184
|
+
tasks_to_do = [task_data, task_data]
|
185
|
+
|
186
|
+
results = await admin_client.arun_tasks_until_done(
|
187
|
+
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
188
|
+
)
|
189
|
+
|
190
|
+
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
191
|
+
assert all(task.status != "success" for task in results), "Should not be success."
|
192
|
+
assert all(not isinstance(task, PQATaskResponse) for task in results), (
|
193
|
+
"Should be verbose."
|
194
|
+
)
|
195
|
+
|
196
|
+
results = admin_client.run_tasks_until_done(
|
197
|
+
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
198
|
+
)
|
199
|
+
|
200
|
+
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
201
|
+
assert all(task.status != "success" for task in results), "Should not be success."
|
202
|
+
assert all(not isinstance(task, PQATaskResponse) for task in results), (
|
203
|
+
"Should be verbose."
|
204
|
+
)
|
205
|
+
|
206
|
+
|
207
|
+
def test_auth_refresh_flow(admin_client: RestClient):
|
208
|
+
refresh_calls = 0
|
209
|
+
func_calls = 0
|
210
|
+
|
211
|
+
def mock_run_auth(*args, **kwargs):
|
212
|
+
nonlocal refresh_calls
|
213
|
+
refresh_calls += 1
|
214
|
+
return f"fresh-token-{refresh_calls}"
|
215
|
+
|
216
|
+
@refresh_token_on_auth_error()
|
217
|
+
def test_func(self, *args):
|
218
|
+
nonlocal func_calls
|
219
|
+
func_calls += 1
|
220
|
+
|
221
|
+
if func_calls == 1:
|
222
|
+
raise AuthError(401, "Auth failed", None, None)
|
223
|
+
return "success"
|
224
|
+
|
225
|
+
with patch.object(admin_client, "_run_auth", mock_run_auth):
|
226
|
+
result = test_func(admin_client)
|
227
|
+
|
228
|
+
assert result == "success"
|
229
|
+
assert func_calls == 2, "Function should be called twice"
|
230
|
+
assert refresh_calls == 1, "Auth should be refreshed once"
|
231
|
+
assert admin_client.auth_jwt == "fresh-token-1"
|
232
|
+
|
233
|
+
|
234
|
+
@pytest.mark.asyncio
|
235
|
+
async def test_async_auth_refresh_flow(admin_client: RestClient):
|
236
|
+
refresh_calls = 0
|
237
|
+
func_calls = 0
|
238
|
+
|
239
|
+
def mock_run_auth(*args, **kwargs):
|
240
|
+
nonlocal refresh_calls
|
241
|
+
refresh_calls += 1
|
242
|
+
return f"fresh-token-{refresh_calls}"
|
243
|
+
|
244
|
+
@refresh_token_on_auth_error()
|
245
|
+
async def test_async_func(self, *args):
|
246
|
+
nonlocal func_calls
|
247
|
+
func_calls += 1
|
248
|
+
|
249
|
+
if func_calls == 1:
|
250
|
+
raise AuthError(401, "Auth failed", None, None)
|
251
|
+
await asyncio.sleep(1)
|
252
|
+
return "success"
|
253
|
+
|
254
|
+
with patch.object(admin_client, "_run_auth", mock_run_auth):
|
255
|
+
result = await test_async_func(admin_client)
|
256
|
+
|
257
|
+
assert result == "success"
|
258
|
+
assert func_calls == 2, "Function should be called twice"
|
259
|
+
assert refresh_calls == 1, "Auth should be refreshed once"
|
260
|
+
assert admin_client.auth_jwt == "fresh-token-1"
|
@@ -1,45 +0,0 @@
|
|
1
|
-
from typing import Any, ClassVar
|
2
|
-
|
3
|
-
USER_JWT_CONTEXT_KEY = "user_jwt"
|
4
|
-
JOB_ID_CONTEXT_KEY = "job_id"
|
5
|
-
|
6
|
-
|
7
|
-
class RequestContext:
|
8
|
-
"""A context manager for storing information from the initial request."""
|
9
|
-
|
10
|
-
_context: ClassVar[dict[str, Any]] = {}
|
11
|
-
|
12
|
-
@classmethod
|
13
|
-
def set(cls, key: str, value: Any) -> None:
|
14
|
-
"""Set a context variable.
|
15
|
-
|
16
|
-
Args:
|
17
|
-
key: The context variable name
|
18
|
-
value: The value to store
|
19
|
-
"""
|
20
|
-
cls._context[key] = value
|
21
|
-
|
22
|
-
@classmethod
|
23
|
-
def get(cls, key: str) -> Any:
|
24
|
-
"""Get a context variable.
|
25
|
-
|
26
|
-
Args:
|
27
|
-
key: The context variable name
|
28
|
-
default: Default value if key doesn't exist
|
29
|
-
|
30
|
-
Returns:
|
31
|
-
The stored value or default if not found
|
32
|
-
"""
|
33
|
-
return cls._context.get(key, None)
|
34
|
-
|
35
|
-
@classmethod
|
36
|
-
def clear(cls, key: str | None = None) -> None:
|
37
|
-
"""Clear a specific context variable or all variables.
|
38
|
-
|
39
|
-
Args:
|
40
|
-
key: Specific key to clear, or None to clear all
|
41
|
-
"""
|
42
|
-
if key is None:
|
43
|
-
cls._context.clear()
|
44
|
-
elif key in cls._context:
|
45
|
-
del cls._context[key]
|
@@ -1,214 +0,0 @@
|
|
1
|
-
import asyncio
|
2
|
-
import os
|
3
|
-
import time
|
4
|
-
|
5
|
-
import pytest
|
6
|
-
from futurehouse_client.clients import (
|
7
|
-
JobNames,
|
8
|
-
PQATaskResponse,
|
9
|
-
TaskResponseVerbose,
|
10
|
-
)
|
11
|
-
from futurehouse_client.clients.rest_client import RestClient, TaskFetchError
|
12
|
-
from futurehouse_client.models.app import Stage, TaskRequest
|
13
|
-
from futurehouse_client.models.rest import ExecutionStatus
|
14
|
-
from pytest_subtests import SubTests
|
15
|
-
|
16
|
-
ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
|
17
|
-
PUBLIC_API_KEY = os.environ["PLAYWRIGHT_PUBLIC_API_KEY"]
|
18
|
-
TEST_MAX_POLLS = 100
|
19
|
-
|
20
|
-
|
21
|
-
@pytest.mark.timeout(300)
|
22
|
-
@pytest.mark.flaky(reruns=3)
|
23
|
-
def test_futurehouse_dummy_env_crow():
|
24
|
-
client = RestClient(
|
25
|
-
stage=Stage.DEV,
|
26
|
-
api_key=ADMIN_API_KEY,
|
27
|
-
)
|
28
|
-
|
29
|
-
task_data = TaskRequest(
|
30
|
-
name=JobNames.from_string("dummy"),
|
31
|
-
query="How many moons does earth have?",
|
32
|
-
)
|
33
|
-
client.create_task(task_data)
|
34
|
-
|
35
|
-
while (task_status := client.get_task().status) in {"queued", "in progress"}:
|
36
|
-
time.sleep(5)
|
37
|
-
|
38
|
-
assert task_status == "success"
|
39
|
-
|
40
|
-
|
41
|
-
def test_insufficient_permissions_request():
|
42
|
-
# Create a new instance so that cached credentials aren't reused
|
43
|
-
client = RestClient(
|
44
|
-
stage=Stage.DEV,
|
45
|
-
api_key=PUBLIC_API_KEY,
|
46
|
-
)
|
47
|
-
task_data = TaskRequest(
|
48
|
-
name=JobNames.from_string("dummy"),
|
49
|
-
query="How many moons does earth have?",
|
50
|
-
)
|
51
|
-
|
52
|
-
with pytest.raises(TaskFetchError) as exc_info:
|
53
|
-
client.create_task(task_data)
|
54
|
-
|
55
|
-
assert "Error creating task" in str(exc_info.value)
|
56
|
-
|
57
|
-
|
58
|
-
@pytest.mark.timeout(300)
|
59
|
-
@pytest.mark.asyncio
|
60
|
-
async def test_job_response(subtests: SubTests): # noqa: PLR0915
|
61
|
-
client = RestClient(
|
62
|
-
stage=Stage.DEV,
|
63
|
-
api_key=ADMIN_API_KEY,
|
64
|
-
)
|
65
|
-
task_data = TaskRequest(
|
66
|
-
name=JobNames.from_string("crow"),
|
67
|
-
query="How many moons does earth have?",
|
68
|
-
)
|
69
|
-
task_id = client.create_task(task_data)
|
70
|
-
atask_id = await client.acreate_task(task_data)
|
71
|
-
|
72
|
-
with subtests.test("Test TaskResponse with queued task"):
|
73
|
-
task_response = client.get_task(task_id)
|
74
|
-
assert task_response.status in {"queued", "in progress"}
|
75
|
-
assert task_response.job_name == task_data.name
|
76
|
-
assert task_response.query == task_data.query
|
77
|
-
task_response = await client.aget_task(atask_id)
|
78
|
-
assert task_response.status in {"queued", "in progress"}
|
79
|
-
assert task_response.job_name == task_data.name
|
80
|
-
assert task_response.query == task_data.query
|
81
|
-
|
82
|
-
for _ in range(TEST_MAX_POLLS):
|
83
|
-
task_response = client.get_task(task_id)
|
84
|
-
if task_response.status in ExecutionStatus.terminal_states():
|
85
|
-
break
|
86
|
-
await asyncio.sleep(5)
|
87
|
-
|
88
|
-
for _ in range(TEST_MAX_POLLS):
|
89
|
-
task_response = await client.aget_task(atask_id)
|
90
|
-
if task_response.status in ExecutionStatus.terminal_states():
|
91
|
-
break
|
92
|
-
await asyncio.sleep(5)
|
93
|
-
|
94
|
-
with subtests.test("Test PQA job response"):
|
95
|
-
task_response = client.get_task(task_id)
|
96
|
-
assert isinstance(task_response, PQATaskResponse)
|
97
|
-
# assert it has general fields
|
98
|
-
assert task_response.status == "success"
|
99
|
-
assert task_response.task_id is not None
|
100
|
-
assert task_data.name in task_response.job_name
|
101
|
-
assert task_data.query in task_response.query
|
102
|
-
# assert it has PQA specific fields
|
103
|
-
assert task_response.answer is not None
|
104
|
-
# assert it's not verbose
|
105
|
-
assert not hasattr(task_response, "environment_frame")
|
106
|
-
assert not hasattr(task_response, "agent_state")
|
107
|
-
|
108
|
-
with subtests.test("Test async PQA job response"):
|
109
|
-
task_response = await client.aget_task(atask_id)
|
110
|
-
assert isinstance(task_response, PQATaskResponse)
|
111
|
-
# assert it has general fields
|
112
|
-
assert task_response.status == "success"
|
113
|
-
assert task_response.task_id is not None
|
114
|
-
assert task_data.name in task_response.job_name
|
115
|
-
assert task_data.query in task_response.query
|
116
|
-
# assert it has PQA specific fields
|
117
|
-
assert task_response.answer is not None
|
118
|
-
# assert it's not verbose
|
119
|
-
assert not hasattr(task_response, "environment_frame")
|
120
|
-
assert not hasattr(task_response, "agent_state")
|
121
|
-
|
122
|
-
with subtests.test("Test task response with verbose"):
|
123
|
-
task_response = client.get_task(task_id, verbose=True)
|
124
|
-
assert isinstance(task_response, TaskResponseVerbose)
|
125
|
-
assert task_response.status == "success"
|
126
|
-
assert task_response.environment_frame is not None
|
127
|
-
assert task_response.agent_state is not None
|
128
|
-
|
129
|
-
with subtests.test("Test task async response with verbose"):
|
130
|
-
task_response = await client.aget_task(atask_id, verbose=True)
|
131
|
-
assert isinstance(task_response, TaskResponseVerbose)
|
132
|
-
assert task_response.status == "success"
|
133
|
-
assert task_response.environment_frame is not None
|
134
|
-
assert task_response.agent_state is not None
|
135
|
-
|
136
|
-
|
137
|
-
@pytest.mark.timeout(300)
|
138
|
-
@pytest.mark.flaky(reruns=3)
|
139
|
-
def test_run_until_done_futurehouse_dummy_env_crow():
|
140
|
-
client = RestClient(
|
141
|
-
stage=Stage.DEV,
|
142
|
-
api_key=ADMIN_API_KEY,
|
143
|
-
)
|
144
|
-
|
145
|
-
task_data = TaskRequest(
|
146
|
-
name=JobNames.from_string("dummy"),
|
147
|
-
query="How many moons does earth have?",
|
148
|
-
)
|
149
|
-
|
150
|
-
tasks_to_do = [task_data, task_data]
|
151
|
-
|
152
|
-
results = client.run_tasks_until_done(tasks_to_do)
|
153
|
-
|
154
|
-
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
155
|
-
assert all(task.status == "success" for task in results)
|
156
|
-
|
157
|
-
|
158
|
-
@pytest.mark.timeout(300)
|
159
|
-
@pytest.mark.flaky(reruns=3)
|
160
|
-
@pytest.mark.asyncio
|
161
|
-
async def test_arun_until_done_futurehouse_dummy_env_crow():
|
162
|
-
client = RestClient(
|
163
|
-
stage=Stage.DEV,
|
164
|
-
api_key=ADMIN_API_KEY,
|
165
|
-
)
|
166
|
-
|
167
|
-
task_data = TaskRequest(
|
168
|
-
name=JobNames.from_string("dummy"),
|
169
|
-
query="How many moons does earth have?",
|
170
|
-
)
|
171
|
-
|
172
|
-
tasks_to_do = [task_data, task_data]
|
173
|
-
|
174
|
-
results = await client.arun_tasks_until_done(tasks_to_do)
|
175
|
-
|
176
|
-
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
177
|
-
assert all(task.status == "success" for task in results)
|
178
|
-
|
179
|
-
|
180
|
-
@pytest.mark.timeout(300)
|
181
|
-
@pytest.mark.flaky(reruns=3)
|
182
|
-
@pytest.mark.asyncio
|
183
|
-
async def test_timeout_run_until_done_futurehouse_dummy_env_crow():
|
184
|
-
client = RestClient(
|
185
|
-
stage=Stage.DEV,
|
186
|
-
api_key=ADMIN_API_KEY,
|
187
|
-
)
|
188
|
-
|
189
|
-
task_data = TaskRequest(
|
190
|
-
name=JobNames.from_string("dummy"),
|
191
|
-
query="How many moons does earth have?",
|
192
|
-
)
|
193
|
-
|
194
|
-
tasks_to_do = [task_data, task_data]
|
195
|
-
|
196
|
-
results = await client.arun_tasks_until_done(
|
197
|
-
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
198
|
-
)
|
199
|
-
|
200
|
-
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
201
|
-
assert all(task.status != "success" for task in results), "Should not be success."
|
202
|
-
assert all(not isinstance(task, PQATaskResponse) for task in results), (
|
203
|
-
"Should be verbose."
|
204
|
-
)
|
205
|
-
|
206
|
-
results = client.run_tasks_until_done(
|
207
|
-
tasks_to_do, verbose=True, timeout=5, progress_bar=True
|
208
|
-
)
|
209
|
-
|
210
|
-
assert len(results) == len(tasks_to_do), "Should return 2 tasks."
|
211
|
-
assert all(task.status != "success" for task in results), "Should not be success."
|
212
|
-
assert all(not isinstance(task, PQATaskResponse) for task in results), (
|
213
|
-
"Should be verbose."
|
214
|
-
)
|
File without changes
|
File without changes
|
File without changes
|
{futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/docs/client_notebook.ipynb
RENAMED
File without changes
|
{futurehouse_client-0.3.18.dev110 → futurehouse_client-0.3.18.dev185}/futurehouse_client/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|