futurehouse-client 0.4.2.dev11__py3-none-any.whl → 0.4.3.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/clients/data_storage_methods.py +2585 -0
- futurehouse_client/clients/job_client.py +50 -0
- futurehouse_client/clients/rest_client.py +148 -69
- futurehouse_client/models/__init__.py +2 -1
- futurehouse_client/models/client.py +5 -1
- futurehouse_client/models/data_storage_methods.py +355 -0
- futurehouse_client/models/rest.py +48 -7
- futurehouse_client/utils/general.py +64 -1
- futurehouse_client/utils/world_model_tools.py +21 -2
- futurehouse_client/version.py +2 -2
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dev3.dist-info}/METADATA +8 -1
- futurehouse_client-0.4.3.dev3.dist-info/RECORD +23 -0
- futurehouse_client-0.4.2.dev11.dist-info/RECORD +0 -21
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dev3.dist-info}/WHEEL +0 -0
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dev3.dist-info}/licenses/LICENSE +0 -0
- {futurehouse_client-0.4.2.dev11.dist-info → futurehouse_client-0.4.3.dev3.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from futurehouse_client.models.rest import (
|
|
19
19
|
FinalEnvironmentRequest,
|
20
20
|
StoreAgentStatePostRequest,
|
21
21
|
StoreEnvironmentFrameRequest,
|
22
|
+
TrajectoryPatchRequest,
|
22
23
|
)
|
23
24
|
from futurehouse_client.utils.monitoring import (
|
24
25
|
external_trace,
|
@@ -318,3 +319,52 @@ class JobClient:
|
|
318
319
|
f"Unexpected error storing environment frame for state {state_identifier}",
|
319
320
|
)
|
320
321
|
raise
|
322
|
+
|
323
|
+
async def patch_trajectory(
|
324
|
+
self,
|
325
|
+
public: bool | None = None,
|
326
|
+
shared_with: list[int] | None = None,
|
327
|
+
notification_enabled: bool | None = None,
|
328
|
+
notification_type: str | None = None,
|
329
|
+
min_estimated_time: float | None = None,
|
330
|
+
max_estimated_time: float | None = None,
|
331
|
+
) -> None:
|
332
|
+
data = TrajectoryPatchRequest(
|
333
|
+
public=public,
|
334
|
+
shared_with=shared_with,
|
335
|
+
notification_enabled=notification_enabled,
|
336
|
+
notification_type=notification_type,
|
337
|
+
min_estimated_time=min_estimated_time,
|
338
|
+
max_estimated_time=max_estimated_time,
|
339
|
+
)
|
340
|
+
try:
|
341
|
+
async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
|
342
|
+
url = f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}"
|
343
|
+
headers = {
|
344
|
+
"Authorization": f"Bearer {self.oauth_jwt}",
|
345
|
+
"x-trajectory-id": self.trajectory_id,
|
346
|
+
}
|
347
|
+
response = await client.patch(
|
348
|
+
url=url,
|
349
|
+
json=data.model_dump(mode="json", exclude_none=True),
|
350
|
+
headers=headers,
|
351
|
+
)
|
352
|
+
response.raise_for_status()
|
353
|
+
logger.debug("Trajectory updated successfully")
|
354
|
+
except httpx.HTTPStatusError as e:
|
355
|
+
logger.exception(
|
356
|
+
"HTTP error while patching trajectory. "
|
357
|
+
f"Status code: {e.response.status_code}, "
|
358
|
+
f"Response: {e.response.text}",
|
359
|
+
)
|
360
|
+
except httpx.TimeoutException:
|
361
|
+
logger.exception(
|
362
|
+
f"Timeout while patching trajectory after {self.REQUEST_TIMEOUT}s",
|
363
|
+
)
|
364
|
+
raise
|
365
|
+
except httpx.NetworkError:
|
366
|
+
logger.exception("Network error while patching trajectory")
|
367
|
+
raise
|
368
|
+
except Exception:
|
369
|
+
logger.exception("Unexpected error while patching trajectory")
|
370
|
+
raise
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# ruff: noqa: PLR0915
|
1
2
|
import ast
|
2
3
|
import asyncio
|
3
4
|
import base64
|
@@ -26,21 +27,14 @@ from httpx import (
|
|
26
27
|
AsyncClient,
|
27
28
|
Client,
|
28
29
|
CloseError,
|
29
|
-
ConnectError,
|
30
|
-
ConnectTimeout,
|
31
30
|
HTTPStatusError,
|
32
|
-
NetworkError,
|
33
|
-
ReadError,
|
34
|
-
ReadTimeout,
|
35
31
|
RemoteProtocolError,
|
36
32
|
codes,
|
37
33
|
)
|
38
34
|
from ldp.agent import AgentConfig
|
39
|
-
from requests.exceptions import RequestException, Timeout
|
40
35
|
from tenacity import (
|
41
36
|
before_sleep_log,
|
42
37
|
retry,
|
43
|
-
retry_if_exception_type,
|
44
38
|
stop_after_attempt,
|
45
39
|
wait_exponential,
|
46
40
|
)
|
@@ -48,6 +42,7 @@ from tqdm import tqdm as sync_tqdm
|
|
48
42
|
from tqdm.asyncio import tqdm
|
49
43
|
|
50
44
|
from futurehouse_client.clients import JobNames
|
45
|
+
from futurehouse_client.clients.data_storage_methods import DataStorageMethods
|
51
46
|
from futurehouse_client.models.app import (
|
52
47
|
AuthType,
|
53
48
|
JobDeploymentConfig,
|
@@ -60,15 +55,20 @@ from futurehouse_client.models.app import (
|
|
60
55
|
from futurehouse_client.models.rest import (
|
61
56
|
DiscoveryResponse,
|
62
57
|
ExecutionStatus,
|
58
|
+
SearchCriterion,
|
63
59
|
UserAgentRequest,
|
64
60
|
UserAgentRequestPostPayload,
|
65
61
|
UserAgentRequestStatus,
|
66
62
|
UserAgentResponsePayload,
|
67
63
|
WorldModel,
|
68
64
|
WorldModelResponse,
|
65
|
+
WorldModelSearchPayload,
|
69
66
|
)
|
70
67
|
from futurehouse_client.utils.auth import RefreshingJWT
|
71
|
-
from futurehouse_client.utils.general import
|
68
|
+
from futurehouse_client.utils.general import (
|
69
|
+
create_retry_if_connection_error,
|
70
|
+
gather_with_concurrency,
|
71
|
+
)
|
72
72
|
from futurehouse_client.utils.module_utils import (
|
73
73
|
OrganizationSelector,
|
74
74
|
fetch_environment_function_docstring,
|
@@ -160,28 +160,15 @@ class FileUploadError(RestClientError):
|
|
160
160
|
"""Raised when there's an error uploading a file."""
|
161
161
|
|
162
162
|
|
163
|
-
retry_if_connection_error =
|
164
|
-
# From requests
|
165
|
-
Timeout,
|
166
|
-
ConnectionError,
|
167
|
-
RequestException,
|
168
|
-
# From httpx
|
169
|
-
ConnectError,
|
170
|
-
ConnectTimeout,
|
171
|
-
ReadTimeout,
|
172
|
-
ReadError,
|
173
|
-
NetworkError,
|
174
|
-
RemoteProtocolError,
|
175
|
-
CloseError,
|
176
|
-
FileUploadError,
|
177
|
-
))
|
163
|
+
retry_if_connection_error = create_retry_if_connection_error(FileUploadError)
|
178
164
|
|
179
165
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
180
166
|
|
181
167
|
|
182
168
|
# pylint: disable=too-many-public-methods
|
183
|
-
class RestClient:
|
184
|
-
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
|
169
|
+
class RestClient(DataStorageMethods):
|
170
|
+
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec - for general API calls
|
171
|
+
FILE_UPLOAD_TIMEOUT: ClassVar[float] = 600.0 # 10 minutes - for file uploads
|
185
172
|
MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
|
186
173
|
RETRY_MULTIPLIER: ClassVar[int] = 1
|
187
174
|
MAX_RETRY_WAIT: ClassVar[int] = 10
|
@@ -239,11 +226,35 @@ class RestClient:
|
|
239
226
|
"""Authenticated HTTP client for multipart uploads."""
|
240
227
|
return cast(Client, self.get_client(None, authenticated=True))
|
241
228
|
|
229
|
+
@property
|
230
|
+
def file_upload_client(self) -> Client:
|
231
|
+
"""Authenticated HTTP client with extended timeout for file uploads."""
|
232
|
+
return cast(
|
233
|
+
Client,
|
234
|
+
self.get_client(
|
235
|
+
"application/json", authenticated=True, timeout=self.FILE_UPLOAD_TIMEOUT
|
236
|
+
),
|
237
|
+
)
|
238
|
+
|
239
|
+
@property
|
240
|
+
def async_file_upload_client(self) -> AsyncClient:
|
241
|
+
"""Authenticated async HTTP client with extended timeout for file uploads."""
|
242
|
+
return cast(
|
243
|
+
AsyncClient,
|
244
|
+
self.get_client(
|
245
|
+
"application/json",
|
246
|
+
authenticated=True,
|
247
|
+
async_client=True,
|
248
|
+
timeout=self.FILE_UPLOAD_TIMEOUT,
|
249
|
+
),
|
250
|
+
)
|
251
|
+
|
242
252
|
def get_client(
|
243
253
|
self,
|
244
254
|
content_type: str | None = "application/json",
|
245
255
|
authenticated: bool = True,
|
246
256
|
async_client: bool = False,
|
257
|
+
timeout: float | None = None,
|
247
258
|
) -> Client | AsyncClient:
|
248
259
|
"""Return a cached HTTP client or create one if needed.
|
249
260
|
|
@@ -251,12 +262,13 @@ class RestClient:
|
|
251
262
|
content_type: The desired content type header. Use None for multipart uploads.
|
252
263
|
authenticated: Whether the client should include authentication.
|
253
264
|
async_client: Whether to use an async client.
|
265
|
+
timeout: Custom timeout in seconds. Uses REQUEST_TIMEOUT if not provided.
|
254
266
|
|
255
267
|
Returns:
|
256
268
|
An HTTP client configured with the appropriate headers.
|
257
269
|
"""
|
258
|
-
|
259
|
-
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
270
|
+
client_timeout = timeout or self.REQUEST_TIMEOUT
|
271
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}_{client_timeout}"
|
260
272
|
|
261
273
|
if key not in self._clients:
|
262
274
|
headers = copy.deepcopy(self.headers)
|
@@ -282,14 +294,14 @@ class RestClient:
|
|
282
294
|
AsyncClient(
|
283
295
|
base_url=self.base_url,
|
284
296
|
headers=headers,
|
285
|
-
timeout=
|
297
|
+
timeout=client_timeout,
|
286
298
|
auth=auth,
|
287
299
|
)
|
288
300
|
if async_client
|
289
301
|
else Client(
|
290
302
|
base_url=self.base_url,
|
291
303
|
headers=headers,
|
292
|
-
timeout=
|
304
|
+
timeout=client_timeout,
|
293
305
|
auth=auth,
|
294
306
|
)
|
295
307
|
)
|
@@ -323,16 +335,23 @@ class RestClient:
|
|
323
335
|
raise ValueError(f"Organization '{organization}' not found.")
|
324
336
|
return filtered_orgs
|
325
337
|
|
338
|
+
@retry(
|
339
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
340
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
341
|
+
retry=retry_if_connection_error,
|
342
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
343
|
+
)
|
326
344
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
)
|
331
|
-
response.raise_for_status()
|
332
|
-
return response.json()
|
333
|
-
except Exception as e:
|
334
|
-
raise JobFetchError(f"Error checking job: {e!r}.") from e
|
345
|
+
response = self.client.get(f"/v0.1/crows/{name}/organizations/{organization}")
|
346
|
+
response.raise_for_status()
|
347
|
+
return response.json()
|
335
348
|
|
349
|
+
@retry(
|
350
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
351
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
352
|
+
retry=retry_if_connection_error,
|
353
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
354
|
+
)
|
336
355
|
def _fetch_my_orgs(self) -> list[str]:
|
337
356
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
338
357
|
response.raise_for_status()
|
@@ -690,10 +709,12 @@ class RestClient:
|
|
690
709
|
|
691
710
|
async def arun_tasks_until_done(
|
692
711
|
self,
|
693
|
-
task_data:
|
694
|
-
|
695
|
-
|
696
|
-
|
712
|
+
task_data: (
|
713
|
+
TaskRequest
|
714
|
+
| dict[str, Any]
|
715
|
+
| Collection[TaskRequest]
|
716
|
+
| Collection[dict[str, Any]]
|
717
|
+
),
|
697
718
|
verbose: bool = False,
|
698
719
|
progress_bar: bool = False,
|
699
720
|
concurrency: int = 10,
|
@@ -761,10 +782,12 @@ class RestClient:
|
|
761
782
|
|
762
783
|
def run_tasks_until_done(
|
763
784
|
self,
|
764
|
-
task_data:
|
765
|
-
|
766
|
-
|
767
|
-
|
785
|
+
task_data: (
|
786
|
+
TaskRequest
|
787
|
+
| dict[str, Any]
|
788
|
+
| Collection[TaskRequest]
|
789
|
+
| Collection[dict[str, Any]]
|
790
|
+
),
|
768
791
|
verbose: bool = False,
|
769
792
|
progress_bar: bool = False,
|
770
793
|
timeout: int = DEFAULT_AGENT_TIMEOUT,
|
@@ -837,12 +860,9 @@ class RestClient:
|
|
837
860
|
)
|
838
861
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
839
862
|
"""Get the status of a build."""
|
840
|
-
|
841
|
-
|
842
|
-
|
843
|
-
response.raise_for_status()
|
844
|
-
except Exception as e:
|
845
|
-
raise JobFetchError(f"Error getting build status: {e!r}.") from e
|
863
|
+
build_id = build_id or self.build_id
|
864
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
865
|
+
response.raise_for_status()
|
846
866
|
return response.json()
|
847
867
|
|
848
868
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
@@ -852,7 +872,7 @@ class RestClient:
|
|
852
872
|
retry=retry_if_connection_error,
|
853
873
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
854
874
|
)
|
855
|
-
def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
|
875
|
+
def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
|
856
876
|
"""Creates a futurehouse job deployment from the environment and environment files.
|
857
877
|
|
858
878
|
Args:
|
@@ -1597,6 +1617,56 @@ class RestClient:
|
|
1597
1617
|
except Exception as e:
|
1598
1618
|
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1599
1619
|
|
1620
|
+
@retry(
|
1621
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1622
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1623
|
+
retry=retry_if_connection_error,
|
1624
|
+
)
|
1625
|
+
def list_world_models(
|
1626
|
+
self,
|
1627
|
+
name: str | None = None,
|
1628
|
+
project_id: UUID | str | None = None,
|
1629
|
+
limit: int = 150,
|
1630
|
+
offset: int = 0,
|
1631
|
+
sort_order: str = "asc",
|
1632
|
+
) -> list[WorldModelResponse]:
|
1633
|
+
"""List world models with different behavior based on filters.
|
1634
|
+
|
1635
|
+
When filtering by name: returns only the latest version for that name.
|
1636
|
+
When filtering by project_id (without name): returns all versions for that project.
|
1637
|
+
When no filters: returns latest version of each world model.
|
1638
|
+
|
1639
|
+
Args:
|
1640
|
+
name: Filter by world model name.
|
1641
|
+
project_id: Filter by project ID.
|
1642
|
+
limit: The maximum number of models to return.
|
1643
|
+
offset: Number of results to skip for pagination.
|
1644
|
+
sort_order: Sort order 'asc' or 'desc'.
|
1645
|
+
|
1646
|
+
Returns:
|
1647
|
+
A list of world model dictionaries.
|
1648
|
+
"""
|
1649
|
+
try:
|
1650
|
+
params: dict[str, str | int] = {
|
1651
|
+
"limit": limit,
|
1652
|
+
"offset": offset,
|
1653
|
+
"sort_order": sort_order,
|
1654
|
+
}
|
1655
|
+
if name:
|
1656
|
+
params["name"] = name
|
1657
|
+
if project_id:
|
1658
|
+
params["project_id"] = str(project_id)
|
1659
|
+
|
1660
|
+
response = self.client.get("/v0.1/world-models", params=params)
|
1661
|
+
response.raise_for_status()
|
1662
|
+
return response.json()
|
1663
|
+
except HTTPStatusError as e:
|
1664
|
+
raise WorldModelFetchError(
|
1665
|
+
f"Error listing world models: {e.response.status_code} - {e.response.text}"
|
1666
|
+
) from e
|
1667
|
+
except Exception as e:
|
1668
|
+
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1669
|
+
|
1600
1670
|
@retry(
|
1601
1671
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1602
1672
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1604,34 +1674,43 @@ class RestClient:
|
|
1604
1674
|
)
|
1605
1675
|
def search_world_models(
|
1606
1676
|
self,
|
1607
|
-
|
1677
|
+
criteria: list[SearchCriterion] | None = None,
|
1608
1678
|
size: int = 10,
|
1609
|
-
|
1679
|
+
project_id: UUID | str | None = None,
|
1610
1680
|
search_all_versions: bool = False,
|
1611
|
-
) -> list[
|
1612
|
-
"""Search
|
1681
|
+
) -> list[WorldModelResponse]:
|
1682
|
+
"""Search world models using structured criteria.
|
1613
1683
|
|
1614
1684
|
Args:
|
1615
|
-
|
1685
|
+
criteria: List of SearchCriterion objects with field, operator, and value.
|
1616
1686
|
size: The number of results to return.
|
1617
|
-
|
1618
|
-
search_all_versions: Whether to search all versions
|
1687
|
+
project_id: Optional filter by project ID.
|
1688
|
+
search_all_versions: Whether to search all versions or just latest.
|
1619
1689
|
|
1620
1690
|
Returns:
|
1621
|
-
A list of world model
|
1691
|
+
A list of world model responses.
|
1692
|
+
|
1693
|
+
Example:
|
1694
|
+
from futurehouse_client.models.rest import SearchCriterion, SearchOperator
|
1695
|
+
criteria = [
|
1696
|
+
SearchCriterion(field="name", operator=SearchOperator.CONTAINS, value="chemistry"),
|
1697
|
+
SearchCriterion(field="email", operator=SearchOperator.CONTAINS, value="tyler"),
|
1698
|
+
]
|
1699
|
+
results = client.search_world_models(criteria=criteria, size=20)
|
1622
1700
|
"""
|
1623
1701
|
try:
|
1624
|
-
|
1625
|
-
|
1626
|
-
|
1627
|
-
|
1628
|
-
|
1629
|
-
|
1630
|
-
|
1631
|
-
|
1702
|
+
payload = WorldModelSearchPayload(
|
1703
|
+
criteria=criteria or [],
|
1704
|
+
size=size,
|
1705
|
+
project_id=project_id,
|
1706
|
+
search_all_versions=search_all_versions,
|
1707
|
+
)
|
1708
|
+
|
1709
|
+
response = self.client.post(
|
1710
|
+
"/v0.1/world-models/search",
|
1711
|
+
json=payload.model_dump(mode="json"),
|
1632
1712
|
)
|
1633
1713
|
response.raise_for_status()
|
1634
|
-
# The new endpoint returns a list of models directly
|
1635
1714
|
return response.json()
|
1636
1715
|
except HTTPStatusError as e:
|
1637
1716
|
raise WorldModelFetchError(
|
@@ -13,7 +13,7 @@ from .app import (
|
|
13
13
|
TaskResponse,
|
14
14
|
TaskResponseVerbose,
|
15
15
|
)
|
16
|
-
from .rest import WorldModel, WorldModelResponse
|
16
|
+
from .rest import TrajectoryPatchRequest, WorldModel, WorldModelResponse
|
17
17
|
|
18
18
|
__all__ = [
|
19
19
|
"AuthType",
|
@@ -29,6 +29,7 @@ __all__ = [
|
|
29
29
|
"TaskRequest",
|
30
30
|
"TaskResponse",
|
31
31
|
"TaskResponseVerbose",
|
32
|
+
"TrajectoryPatchRequest",
|
32
33
|
"WorldModel",
|
33
34
|
"WorldModelResponse",
|
34
35
|
]
|
@@ -27,13 +27,17 @@ class InitialState(BaseState):
|
|
27
27
|
|
28
28
|
class ASVState(BaseState, Generic[T]):
|
29
29
|
action: OpResult[T] = Field()
|
30
|
-
|
30
|
+
next_state: Any = Field()
|
31
31
|
value: float = Field()
|
32
32
|
|
33
33
|
@field_serializer("action")
|
34
34
|
def serialize_action(self, action: OpResult[T]) -> dict:
|
35
35
|
return action.to_dict()
|
36
36
|
|
37
|
+
@field_serializer("next_state")
|
38
|
+
def serialize_next_state(self, state: Any) -> str:
|
39
|
+
return str(state)
|
40
|
+
|
37
41
|
|
38
42
|
class EnvResetState(BaseState):
|
39
43
|
observations: list[Message] = Field()
|