futurehouse-client 0.4.2.dev11__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/clients/data_storage_methods.py +2649 -0
- futurehouse_client/clients/job_client.py +50 -0
- futurehouse_client/clients/rest_client.py +148 -70
- futurehouse_client/models/__init__.py +2 -1
- futurehouse_client/models/client.py +5 -1
- futurehouse_client/models/data_storage_methods.py +355 -0
- futurehouse_client/models/rest.py +48 -7
- futurehouse_client/utils/general.py +64 -1
- futurehouse_client/utils/world_model_tools.py +21 -2
- futurehouse_client/version.py +2 -2
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dist-info}/METADATA +8 -1
- futurehouse_client-0.4.3.dist-info/RECORD +23 -0
- futurehouse_client-0.4.2.dev11.dist-info/RECORD +0 -21
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dist-info}/WHEEL +0 -0
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from futurehouse_client.models.rest import (
|
|
19
19
|
FinalEnvironmentRequest,
|
20
20
|
StoreAgentStatePostRequest,
|
21
21
|
StoreEnvironmentFrameRequest,
|
22
|
+
TrajectoryPatchRequest,
|
22
23
|
)
|
23
24
|
from futurehouse_client.utils.monitoring import (
|
24
25
|
external_trace,
|
@@ -318,3 +319,52 @@ class JobClient:
|
|
318
319
|
f"Unexpected error storing environment frame for state {state_identifier}",
|
319
320
|
)
|
320
321
|
raise
|
322
|
+
|
323
|
+
async def patch_trajectory(
|
324
|
+
self,
|
325
|
+
public: bool | None = None,
|
326
|
+
shared_with: list[int] | None = None,
|
327
|
+
notification_enabled: bool | None = None,
|
328
|
+
notification_type: str | None = None,
|
329
|
+
min_estimated_time: float | None = None,
|
330
|
+
max_estimated_time: float | None = None,
|
331
|
+
) -> None:
|
332
|
+
data = TrajectoryPatchRequest(
|
333
|
+
public=public,
|
334
|
+
shared_with=shared_with,
|
335
|
+
notification_enabled=notification_enabled,
|
336
|
+
notification_type=notification_type,
|
337
|
+
min_estimated_time=min_estimated_time,
|
338
|
+
max_estimated_time=max_estimated_time,
|
339
|
+
)
|
340
|
+
try:
|
341
|
+
async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
|
342
|
+
url = f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}"
|
343
|
+
headers = {
|
344
|
+
"Authorization": f"Bearer {self.oauth_jwt}",
|
345
|
+
"x-trajectory-id": self.trajectory_id,
|
346
|
+
}
|
347
|
+
response = await client.patch(
|
348
|
+
url=url,
|
349
|
+
json=data.model_dump(mode="json", exclude_none=True),
|
350
|
+
headers=headers,
|
351
|
+
)
|
352
|
+
response.raise_for_status()
|
353
|
+
logger.debug("Trajectory updated successfully")
|
354
|
+
except httpx.HTTPStatusError as e:
|
355
|
+
logger.exception(
|
356
|
+
"HTTP error while patching trajectory. "
|
357
|
+
f"Status code: {e.response.status_code}, "
|
358
|
+
f"Response: {e.response.text}",
|
359
|
+
)
|
360
|
+
except httpx.TimeoutException:
|
361
|
+
logger.exception(
|
362
|
+
f"Timeout while patching trajectory after {self.REQUEST_TIMEOUT}s",
|
363
|
+
)
|
364
|
+
raise
|
365
|
+
except httpx.NetworkError:
|
366
|
+
logger.exception("Network error while patching trajectory")
|
367
|
+
raise
|
368
|
+
except Exception:
|
369
|
+
logger.exception("Unexpected error while patching trajectory")
|
370
|
+
raise
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# ruff: noqa: PLR0915
|
1
2
|
import ast
|
2
3
|
import asyncio
|
3
4
|
import base64
|
@@ -26,21 +27,14 @@ from httpx import (
|
|
26
27
|
AsyncClient,
|
27
28
|
Client,
|
28
29
|
CloseError,
|
29
|
-
ConnectError,
|
30
|
-
ConnectTimeout,
|
31
30
|
HTTPStatusError,
|
32
|
-
NetworkError,
|
33
|
-
ReadError,
|
34
|
-
ReadTimeout,
|
35
31
|
RemoteProtocolError,
|
36
32
|
codes,
|
37
33
|
)
|
38
34
|
from ldp.agent import AgentConfig
|
39
|
-
from requests.exceptions import RequestException, Timeout
|
40
35
|
from tenacity import (
|
41
36
|
before_sleep_log,
|
42
37
|
retry,
|
43
|
-
retry_if_exception_type,
|
44
38
|
stop_after_attempt,
|
45
39
|
wait_exponential,
|
46
40
|
)
|
@@ -48,6 +42,7 @@ from tqdm import tqdm as sync_tqdm
|
|
48
42
|
from tqdm.asyncio import tqdm
|
49
43
|
|
50
44
|
from futurehouse_client.clients import JobNames
|
45
|
+
from futurehouse_client.clients.data_storage_methods import DataStorageMethods
|
51
46
|
from futurehouse_client.models.app import (
|
52
47
|
AuthType,
|
53
48
|
JobDeploymentConfig,
|
@@ -60,15 +55,20 @@ from futurehouse_client.models.app import (
|
|
60
55
|
from futurehouse_client.models.rest import (
|
61
56
|
DiscoveryResponse,
|
62
57
|
ExecutionStatus,
|
58
|
+
SearchCriterion,
|
63
59
|
UserAgentRequest,
|
64
60
|
UserAgentRequestPostPayload,
|
65
61
|
UserAgentRequestStatus,
|
66
62
|
UserAgentResponsePayload,
|
67
63
|
WorldModel,
|
68
64
|
WorldModelResponse,
|
65
|
+
WorldModelSearchPayload,
|
69
66
|
)
|
70
67
|
from futurehouse_client.utils.auth import RefreshingJWT
|
71
|
-
from futurehouse_client.utils.general import
|
68
|
+
from futurehouse_client.utils.general import (
|
69
|
+
create_retry_if_connection_error,
|
70
|
+
gather_with_concurrency,
|
71
|
+
)
|
72
72
|
from futurehouse_client.utils.module_utils import (
|
73
73
|
OrganizationSelector,
|
74
74
|
fetch_environment_function_docstring,
|
@@ -160,28 +160,14 @@ class FileUploadError(RestClientError):
|
|
160
160
|
"""Raised when there's an error uploading a file."""
|
161
161
|
|
162
162
|
|
163
|
-
retry_if_connection_error =
|
164
|
-
# From requests
|
165
|
-
Timeout,
|
166
|
-
ConnectionError,
|
167
|
-
RequestException,
|
168
|
-
# From httpx
|
169
|
-
ConnectError,
|
170
|
-
ConnectTimeout,
|
171
|
-
ReadTimeout,
|
172
|
-
ReadError,
|
173
|
-
NetworkError,
|
174
|
-
RemoteProtocolError,
|
175
|
-
CloseError,
|
176
|
-
FileUploadError,
|
177
|
-
))
|
163
|
+
retry_if_connection_error = create_retry_if_connection_error(FileUploadError)
|
178
164
|
|
179
165
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
180
166
|
|
181
167
|
|
182
|
-
|
183
|
-
|
184
|
-
|
168
|
+
class RestClient(DataStorageMethods):
|
169
|
+
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec - for general API calls
|
170
|
+
FILE_UPLOAD_TIMEOUT: ClassVar[float] = 600.0 # 10 minutes - for file uploads
|
185
171
|
MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
|
186
172
|
RETRY_MULTIPLIER: ClassVar[int] = 1
|
187
173
|
MAX_RETRY_WAIT: ClassVar[int] = 10
|
@@ -239,11 +225,35 @@ class RestClient:
|
|
239
225
|
"""Authenticated HTTP client for multipart uploads."""
|
240
226
|
return cast(Client, self.get_client(None, authenticated=True))
|
241
227
|
|
228
|
+
@property
|
229
|
+
def file_upload_client(self) -> Client:
|
230
|
+
"""Authenticated HTTP client with extended timeout for file uploads."""
|
231
|
+
return cast(
|
232
|
+
Client,
|
233
|
+
self.get_client(
|
234
|
+
"application/json", authenticated=True, timeout=self.FILE_UPLOAD_TIMEOUT
|
235
|
+
),
|
236
|
+
)
|
237
|
+
|
238
|
+
@property
|
239
|
+
def async_file_upload_client(self) -> AsyncClient:
|
240
|
+
"""Authenticated async HTTP client with extended timeout for file uploads."""
|
241
|
+
return cast(
|
242
|
+
AsyncClient,
|
243
|
+
self.get_client(
|
244
|
+
"application/json",
|
245
|
+
authenticated=True,
|
246
|
+
async_client=True,
|
247
|
+
timeout=self.FILE_UPLOAD_TIMEOUT,
|
248
|
+
),
|
249
|
+
)
|
250
|
+
|
242
251
|
def get_client(
|
243
252
|
self,
|
244
253
|
content_type: str | None = "application/json",
|
245
254
|
authenticated: bool = True,
|
246
255
|
async_client: bool = False,
|
256
|
+
timeout: float | None = None,
|
247
257
|
) -> Client | AsyncClient:
|
248
258
|
"""Return a cached HTTP client or create one if needed.
|
249
259
|
|
@@ -251,12 +261,13 @@ class RestClient:
|
|
251
261
|
content_type: The desired content type header. Use None for multipart uploads.
|
252
262
|
authenticated: Whether the client should include authentication.
|
253
263
|
async_client: Whether to use an async client.
|
264
|
+
timeout: Custom timeout in seconds. Uses REQUEST_TIMEOUT if not provided.
|
254
265
|
|
255
266
|
Returns:
|
256
267
|
An HTTP client configured with the appropriate headers.
|
257
268
|
"""
|
258
|
-
|
259
|
-
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
269
|
+
client_timeout = timeout or self.REQUEST_TIMEOUT
|
270
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}_{client_timeout}"
|
260
271
|
|
261
272
|
if key not in self._clients:
|
262
273
|
headers = copy.deepcopy(self.headers)
|
@@ -282,14 +293,14 @@ class RestClient:
|
|
282
293
|
AsyncClient(
|
283
294
|
base_url=self.base_url,
|
284
295
|
headers=headers,
|
285
|
-
timeout=
|
296
|
+
timeout=client_timeout,
|
286
297
|
auth=auth,
|
287
298
|
)
|
288
299
|
if async_client
|
289
300
|
else Client(
|
290
301
|
base_url=self.base_url,
|
291
302
|
headers=headers,
|
292
|
-
timeout=
|
303
|
+
timeout=client_timeout,
|
293
304
|
auth=auth,
|
294
305
|
)
|
295
306
|
)
|
@@ -323,16 +334,23 @@ class RestClient:
|
|
323
334
|
raise ValueError(f"Organization '{organization}' not found.")
|
324
335
|
return filtered_orgs
|
325
336
|
|
337
|
+
@retry(
|
338
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
339
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
340
|
+
retry=retry_if_connection_error,
|
341
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
342
|
+
)
|
326
343
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
)
|
331
|
-
response.raise_for_status()
|
332
|
-
return response.json()
|
333
|
-
except Exception as e:
|
334
|
-
raise JobFetchError(f"Error checking job: {e!r}.") from e
|
344
|
+
response = self.client.get(f"/v0.1/crows/{name}/organizations/{organization}")
|
345
|
+
response.raise_for_status()
|
346
|
+
return response.json()
|
335
347
|
|
348
|
+
@retry(
|
349
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
350
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
351
|
+
retry=retry_if_connection_error,
|
352
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
353
|
+
)
|
336
354
|
def _fetch_my_orgs(self) -> list[str]:
|
337
355
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
338
356
|
response.raise_for_status()
|
@@ -690,10 +708,12 @@ class RestClient:
|
|
690
708
|
|
691
709
|
async def arun_tasks_until_done(
|
692
710
|
self,
|
693
|
-
task_data:
|
694
|
-
|
695
|
-
|
696
|
-
|
711
|
+
task_data: (
|
712
|
+
TaskRequest
|
713
|
+
| dict[str, Any]
|
714
|
+
| Collection[TaskRequest]
|
715
|
+
| Collection[dict[str, Any]]
|
716
|
+
),
|
697
717
|
verbose: bool = False,
|
698
718
|
progress_bar: bool = False,
|
699
719
|
concurrency: int = 10,
|
@@ -761,10 +781,12 @@ class RestClient:
|
|
761
781
|
|
762
782
|
def run_tasks_until_done(
|
763
783
|
self,
|
764
|
-
task_data:
|
765
|
-
|
766
|
-
|
767
|
-
|
784
|
+
task_data: (
|
785
|
+
TaskRequest
|
786
|
+
| dict[str, Any]
|
787
|
+
| Collection[TaskRequest]
|
788
|
+
| Collection[dict[str, Any]]
|
789
|
+
),
|
768
790
|
verbose: bool = False,
|
769
791
|
progress_bar: bool = False,
|
770
792
|
timeout: int = DEFAULT_AGENT_TIMEOUT,
|
@@ -837,12 +859,9 @@ class RestClient:
|
|
837
859
|
)
|
838
860
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
839
861
|
"""Get the status of a build."""
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
response.raise_for_status()
|
844
|
-
except Exception as e:
|
845
|
-
raise JobFetchError(f"Error getting build status: {e!r}.") from e
|
862
|
+
build_id = build_id or self.build_id
|
863
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
864
|
+
response.raise_for_status()
|
846
865
|
return response.json()
|
847
866
|
|
848
867
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
@@ -852,7 +871,7 @@ class RestClient:
|
|
852
871
|
retry=retry_if_connection_error,
|
853
872
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
854
873
|
)
|
855
|
-
def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
|
874
|
+
def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
|
856
875
|
"""Creates a futurehouse job deployment from the environment and environment files.
|
857
876
|
|
858
877
|
Args:
|
@@ -1597,6 +1616,56 @@ class RestClient:
|
|
1597
1616
|
except Exception as e:
|
1598
1617
|
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1599
1618
|
|
1619
|
+
@retry(
|
1620
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1621
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1622
|
+
retry=retry_if_connection_error,
|
1623
|
+
)
|
1624
|
+
def list_world_models(
|
1625
|
+
self,
|
1626
|
+
name: str | None = None,
|
1627
|
+
project_id: UUID | str | None = None,
|
1628
|
+
limit: int = 150,
|
1629
|
+
offset: int = 0,
|
1630
|
+
sort_order: str = "asc",
|
1631
|
+
) -> list[WorldModelResponse]:
|
1632
|
+
"""List world models with different behavior based on filters.
|
1633
|
+
|
1634
|
+
When filtering by name: returns only the latest version for that name.
|
1635
|
+
When filtering by project_id (without name): returns all versions for that project.
|
1636
|
+
When no filters: returns latest version of each world model.
|
1637
|
+
|
1638
|
+
Args:
|
1639
|
+
name: Filter by world model name.
|
1640
|
+
project_id: Filter by project ID.
|
1641
|
+
limit: The maximum number of models to return.
|
1642
|
+
offset: Number of results to skip for pagination.
|
1643
|
+
sort_order: Sort order 'asc' or 'desc'.
|
1644
|
+
|
1645
|
+
Returns:
|
1646
|
+
A list of world model dictionaries.
|
1647
|
+
"""
|
1648
|
+
try:
|
1649
|
+
params: dict[str, str | int] = {
|
1650
|
+
"limit": limit,
|
1651
|
+
"offset": offset,
|
1652
|
+
"sort_order": sort_order,
|
1653
|
+
}
|
1654
|
+
if name:
|
1655
|
+
params["name"] = name
|
1656
|
+
if project_id:
|
1657
|
+
params["project_id"] = str(project_id)
|
1658
|
+
|
1659
|
+
response = self.client.get("/v0.1/world-models", params=params)
|
1660
|
+
response.raise_for_status()
|
1661
|
+
return response.json()
|
1662
|
+
except HTTPStatusError as e:
|
1663
|
+
raise WorldModelFetchError(
|
1664
|
+
f"Error listing world models: {e.response.status_code} - {e.response.text}"
|
1665
|
+
) from e
|
1666
|
+
except Exception as e:
|
1667
|
+
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1668
|
+
|
1600
1669
|
@retry(
|
1601
1670
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1602
1671
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1604,34 +1673,43 @@ class RestClient:
|
|
1604
1673
|
)
|
1605
1674
|
def search_world_models(
|
1606
1675
|
self,
|
1607
|
-
|
1676
|
+
criteria: list[SearchCriterion] | None = None,
|
1608
1677
|
size: int = 10,
|
1609
|
-
|
1678
|
+
project_id: UUID | str | None = None,
|
1610
1679
|
search_all_versions: bool = False,
|
1611
|
-
) -> list[
|
1612
|
-
"""Search
|
1680
|
+
) -> list[WorldModelResponse]:
|
1681
|
+
"""Search world models using structured criteria.
|
1613
1682
|
|
1614
1683
|
Args:
|
1615
|
-
|
1684
|
+
criteria: List of SearchCriterion objects with field, operator, and value.
|
1616
1685
|
size: The number of results to return.
|
1617
|
-
|
1618
|
-
search_all_versions: Whether to search all versions
|
1686
|
+
project_id: Optional filter by project ID.
|
1687
|
+
search_all_versions: Whether to search all versions or just latest.
|
1619
1688
|
|
1620
1689
|
Returns:
|
1621
|
-
A list of world model
|
1690
|
+
A list of world model responses.
|
1691
|
+
|
1692
|
+
Example:
|
1693
|
+
from futurehouse_client.models.rest import SearchCriterion, SearchOperator
|
1694
|
+
criteria = [
|
1695
|
+
SearchCriterion(field="name", operator=SearchOperator.CONTAINS, value="chemistry"),
|
1696
|
+
SearchCriterion(field="email", operator=SearchOperator.CONTAINS, value="tyler"),
|
1697
|
+
]
|
1698
|
+
results = client.search_world_models(criteria=criteria, size=20)
|
1622
1699
|
"""
|
1623
1700
|
try:
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1701
|
+
payload = WorldModelSearchPayload(
|
1702
|
+
criteria=criteria or [],
|
1703
|
+
size=size,
|
1704
|
+
project_id=project_id,
|
1705
|
+
search_all_versions=search_all_versions,
|
1706
|
+
)
|
1707
|
+
|
1708
|
+
response = self.client.post(
|
1709
|
+
"/v0.1/world-models/search",
|
1710
|
+
json=payload.model_dump(mode="json"),
|
1632
1711
|
)
|
1633
1712
|
response.raise_for_status()
|
1634
|
-
# The new endpoint returns a list of models directly
|
1635
1713
|
return response.json()
|
1636
1714
|
except HTTPStatusError as e:
|
1637
1715
|
raise WorldModelFetchError(
|
@@ -13,7 +13,7 @@ from .app import (
|
|
13
13
|
TaskResponse,
|
14
14
|
TaskResponseVerbose,
|
15
15
|
)
|
16
|
-
from .rest import WorldModel, WorldModelResponse
|
16
|
+
from .rest import TrajectoryPatchRequest, WorldModel, WorldModelResponse
|
17
17
|
|
18
18
|
__all__ = [
|
19
19
|
"AuthType",
|
@@ -29,6 +29,7 @@ __all__ = [
|
|
29
29
|
"TaskRequest",
|
30
30
|
"TaskResponse",
|
31
31
|
"TaskResponseVerbose",
|
32
|
+
"TrajectoryPatchRequest",
|
32
33
|
"WorldModel",
|
33
34
|
"WorldModelResponse",
|
34
35
|
]
|
@@ -27,13 +27,17 @@ class InitialState(BaseState):
|
|
27
27
|
|
28
28
|
class ASVState(BaseState, Generic[T]):
|
29
29
|
action: OpResult[T] = Field()
|
30
|
-
|
30
|
+
next_state: Any = Field()
|
31
31
|
value: float = Field()
|
32
32
|
|
33
33
|
@field_serializer("action")
|
34
34
|
def serialize_action(self, action: OpResult[T]) -> dict:
|
35
35
|
return action.to_dict()
|
36
36
|
|
37
|
+
@field_serializer("next_state")
|
38
|
+
def serialize_next_state(self, state: Any) -> str:
|
39
|
+
return str(state)
|
40
|
+
|
37
41
|
|
38
42
|
class EnvResetState(BaseState):
|
39
43
|
observations: list[Message] = Field()
|