futurehouse-client 0.4.1.dev95__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/clients/data_storage_methods.py +725 -139
- futurehouse_client/clients/job_client.py +50 -0
- futurehouse_client/clients/rest_client.py +126 -56
- futurehouse_client/models/__init__.py +2 -1
- futurehouse_client/models/app.py +4 -7
- futurehouse_client/models/data_storage_methods.py +31 -10
- futurehouse_client/models/rest.py +48 -7
- futurehouse_client/utils/general.py +35 -6
- futurehouse_client/utils/world_model_tools.py +23 -3
- futurehouse_client/version.py +16 -3
- {futurehouse_client-0.4.1.dev95.dist-info → futurehouse_client-0.4.2.dist-info}/METADATA +2 -1
- futurehouse_client-0.4.2.dist-info/RECORD +23 -0
- futurehouse_client-0.4.1.dev95.dist-info/RECORD +0 -23
- {futurehouse_client-0.4.1.dev95.dist-info → futurehouse_client-0.4.2.dist-info}/WHEEL +0 -0
- {futurehouse_client-0.4.1.dev95.dist-info → futurehouse_client-0.4.2.dist-info}/licenses/LICENSE +0 -0
- {futurehouse_client-0.4.1.dev95.dist-info → futurehouse_client-0.4.2.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ from futurehouse_client.models.rest import (
|
|
19
19
|
FinalEnvironmentRequest,
|
20
20
|
StoreAgentStatePostRequest,
|
21
21
|
StoreEnvironmentFrameRequest,
|
22
|
+
TrajectoryPatchRequest,
|
22
23
|
)
|
23
24
|
from futurehouse_client.utils.monitoring import (
|
24
25
|
external_trace,
|
@@ -318,3 +319,52 @@ class JobClient:
|
|
318
319
|
f"Unexpected error storing environment frame for state {state_identifier}",
|
319
320
|
)
|
320
321
|
raise
|
322
|
+
|
323
|
+
async def patch_trajectory(
|
324
|
+
self,
|
325
|
+
public: bool | None = None,
|
326
|
+
shared_with: list[int] | None = None,
|
327
|
+
notification_enabled: bool | None = None,
|
328
|
+
notification_type: str | None = None,
|
329
|
+
min_estimated_time: float | None = None,
|
330
|
+
max_estimated_time: float | None = None,
|
331
|
+
) -> None:
|
332
|
+
data = TrajectoryPatchRequest(
|
333
|
+
public=public,
|
334
|
+
shared_with=shared_with,
|
335
|
+
notification_enabled=notification_enabled,
|
336
|
+
notification_type=notification_type,
|
337
|
+
min_estimated_time=min_estimated_time,
|
338
|
+
max_estimated_time=max_estimated_time,
|
339
|
+
)
|
340
|
+
try:
|
341
|
+
async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
|
342
|
+
url = f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}"
|
343
|
+
headers = {
|
344
|
+
"Authorization": f"Bearer {self.oauth_jwt}",
|
345
|
+
"x-trajectory-id": self.trajectory_id,
|
346
|
+
}
|
347
|
+
response = await client.patch(
|
348
|
+
url=url,
|
349
|
+
json=data.model_dump(mode="json", exclude_none=True),
|
350
|
+
headers=headers,
|
351
|
+
)
|
352
|
+
response.raise_for_status()
|
353
|
+
logger.debug("Trajectory updated successfully")
|
354
|
+
except httpx.HTTPStatusError as e:
|
355
|
+
logger.exception(
|
356
|
+
"HTTP error while patching trajectory. "
|
357
|
+
f"Status code: {e.response.status_code}, "
|
358
|
+
f"Response: {e.response.text}",
|
359
|
+
)
|
360
|
+
except httpx.TimeoutException:
|
361
|
+
logger.exception(
|
362
|
+
f"Timeout while patching trajectory after {self.REQUEST_TIMEOUT}s",
|
363
|
+
)
|
364
|
+
raise
|
365
|
+
except httpx.NetworkError:
|
366
|
+
logger.exception("Network error while patching trajectory")
|
367
|
+
raise
|
368
|
+
except Exception:
|
369
|
+
logger.exception("Unexpected error while patching trajectory")
|
370
|
+
raise
|
@@ -1,3 +1,4 @@
|
|
1
|
+
# ruff: noqa: PLR0915
|
1
2
|
import ast
|
2
3
|
import asyncio
|
3
4
|
import base64
|
@@ -54,12 +55,14 @@ from futurehouse_client.models.app import (
|
|
54
55
|
from futurehouse_client.models.rest import (
|
55
56
|
DiscoveryResponse,
|
56
57
|
ExecutionStatus,
|
58
|
+
SearchCriterion,
|
57
59
|
UserAgentRequest,
|
58
60
|
UserAgentRequestPostPayload,
|
59
61
|
UserAgentRequestStatus,
|
60
62
|
UserAgentResponsePayload,
|
61
63
|
WorldModel,
|
62
64
|
WorldModelResponse,
|
65
|
+
WorldModelSearchPayload,
|
63
66
|
)
|
64
67
|
from futurehouse_client.utils.auth import RefreshingJWT
|
65
68
|
from futurehouse_client.utils.general import (
|
@@ -332,16 +335,23 @@ class RestClient(DataStorageMethods):
|
|
332
335
|
raise ValueError(f"Organization '{organization}' not found.")
|
333
336
|
return filtered_orgs
|
334
337
|
|
338
|
+
@retry(
|
339
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
340
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
341
|
+
retry=retry_if_connection_error,
|
342
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
343
|
+
)
|
335
344
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
)
|
340
|
-
response.raise_for_status()
|
341
|
-
return response.json()
|
342
|
-
except Exception as e:
|
343
|
-
raise JobFetchError(f"Error checking job: {e!r}.") from e
|
345
|
+
response = self.client.get(f"/v0.1/crows/{name}/organizations/{organization}")
|
346
|
+
response.raise_for_status()
|
347
|
+
return response.json()
|
344
348
|
|
349
|
+
@retry(
|
350
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
351
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
352
|
+
retry=retry_if_connection_error,
|
353
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
354
|
+
)
|
345
355
|
def _fetch_my_orgs(self) -> list[str]:
|
346
356
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
347
357
|
response.raise_for_status()
|
@@ -699,10 +709,12 @@ class RestClient(DataStorageMethods):
|
|
699
709
|
|
700
710
|
async def arun_tasks_until_done(
|
701
711
|
self,
|
702
|
-
task_data:
|
703
|
-
|
704
|
-
|
705
|
-
|
712
|
+
task_data: (
|
713
|
+
TaskRequest
|
714
|
+
| dict[str, Any]
|
715
|
+
| Collection[TaskRequest]
|
716
|
+
| Collection[dict[str, Any]]
|
717
|
+
),
|
706
718
|
verbose: bool = False,
|
707
719
|
progress_bar: bool = False,
|
708
720
|
concurrency: int = 10,
|
@@ -770,10 +782,12 @@ class RestClient(DataStorageMethods):
|
|
770
782
|
|
771
783
|
def run_tasks_until_done(
|
772
784
|
self,
|
773
|
-
task_data:
|
774
|
-
|
775
|
-
|
776
|
-
|
785
|
+
task_data: (
|
786
|
+
TaskRequest
|
787
|
+
| dict[str, Any]
|
788
|
+
| Collection[TaskRequest]
|
789
|
+
| Collection[dict[str, Any]]
|
790
|
+
),
|
777
791
|
verbose: bool = False,
|
778
792
|
progress_bar: bool = False,
|
779
793
|
timeout: int = DEFAULT_AGENT_TIMEOUT,
|
@@ -846,12 +860,9 @@ class RestClient(DataStorageMethods):
|
|
846
860
|
)
|
847
861
|
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
848
862
|
"""Get the status of a build."""
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
response.raise_for_status()
|
853
|
-
except Exception as e:
|
854
|
-
raise JobFetchError(f"Error getting build status: {e!r}.") from e
|
863
|
+
build_id = build_id or self.build_id
|
864
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
865
|
+
response.raise_for_status()
|
855
866
|
return response.json()
|
856
867
|
|
857
868
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
@@ -861,7 +872,7 @@ class RestClient(DataStorageMethods):
|
|
861
872
|
retry=retry_if_connection_error,
|
862
873
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
863
874
|
)
|
864
|
-
def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
|
875
|
+
def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
|
865
876
|
"""Creates a futurehouse job deployment from the environment and environment files.
|
866
877
|
|
867
878
|
Args:
|
@@ -1606,6 +1617,56 @@ class RestClient(DataStorageMethods):
|
|
1606
1617
|
except Exception as e:
|
1607
1618
|
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1608
1619
|
|
1620
|
+
@retry(
|
1621
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1622
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1623
|
+
retry=retry_if_connection_error,
|
1624
|
+
)
|
1625
|
+
def list_world_models(
|
1626
|
+
self,
|
1627
|
+
name: str | None = None,
|
1628
|
+
project_id: UUID | str | None = None,
|
1629
|
+
limit: int = 150,
|
1630
|
+
offset: int = 0,
|
1631
|
+
sort_order: str = "asc",
|
1632
|
+
) -> list[WorldModelResponse]:
|
1633
|
+
"""List world models with different behavior based on filters.
|
1634
|
+
|
1635
|
+
When filtering by name: returns only the latest version for that name.
|
1636
|
+
When filtering by project_id (without name): returns all versions for that project.
|
1637
|
+
When no filters: returns latest version of each world model.
|
1638
|
+
|
1639
|
+
Args:
|
1640
|
+
name: Filter by world model name.
|
1641
|
+
project_id: Filter by project ID.
|
1642
|
+
limit: The maximum number of models to return.
|
1643
|
+
offset: Number of results to skip for pagination.
|
1644
|
+
sort_order: Sort order 'asc' or 'desc'.
|
1645
|
+
|
1646
|
+
Returns:
|
1647
|
+
A list of world model dictionaries.
|
1648
|
+
"""
|
1649
|
+
try:
|
1650
|
+
params: dict[str, str | int] = {
|
1651
|
+
"limit": limit,
|
1652
|
+
"offset": offset,
|
1653
|
+
"sort_order": sort_order,
|
1654
|
+
}
|
1655
|
+
if name:
|
1656
|
+
params["name"] = name
|
1657
|
+
if project_id:
|
1658
|
+
params["project_id"] = str(project_id)
|
1659
|
+
|
1660
|
+
response = self.client.get("/v0.1/world-models", params=params)
|
1661
|
+
response.raise_for_status()
|
1662
|
+
return response.json()
|
1663
|
+
except HTTPStatusError as e:
|
1664
|
+
raise WorldModelFetchError(
|
1665
|
+
f"Error listing world models: {e.response.status_code} - {e.response.text}"
|
1666
|
+
) from e
|
1667
|
+
except Exception as e:
|
1668
|
+
raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
|
1669
|
+
|
1609
1670
|
@retry(
|
1610
1671
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1611
1672
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1613,31 +1674,41 @@ class RestClient(DataStorageMethods):
|
|
1613
1674
|
)
|
1614
1675
|
def search_world_models(
|
1615
1676
|
self,
|
1616
|
-
|
1677
|
+
criteria: list[SearchCriterion] | None = None,
|
1617
1678
|
size: int = 10,
|
1618
|
-
|
1679
|
+
project_id: UUID | str | None = None,
|
1619
1680
|
search_all_versions: bool = False,
|
1620
|
-
) -> list[
|
1621
|
-
"""Search
|
1681
|
+
) -> list[WorldModelResponse]:
|
1682
|
+
"""Search world models using structured criteria.
|
1622
1683
|
|
1623
1684
|
Args:
|
1624
|
-
|
1685
|
+
criteria: List of SearchCriterion objects with field, operator, and value.
|
1625
1686
|
size: The number of results to return.
|
1626
|
-
|
1627
|
-
search_all_versions: Whether to search all versions
|
1687
|
+
project_id: Optional filter by project ID.
|
1688
|
+
search_all_versions: Whether to search all versions or just latest.
|
1628
1689
|
|
1629
1690
|
Returns:
|
1630
|
-
A list of world model
|
1691
|
+
A list of world model responses.
|
1692
|
+
|
1693
|
+
Example:
|
1694
|
+
from futurehouse_client.models.rest import SearchCriterion, SearchOperator
|
1695
|
+
criteria = [
|
1696
|
+
SearchCriterion(field="name", operator=SearchOperator.CONTAINS, value="chemistry"),
|
1697
|
+
SearchCriterion(field="email", operator=SearchOperator.CONTAINS, value="tyler"),
|
1698
|
+
]
|
1699
|
+
results = client.search_world_models(criteria=criteria, size=20)
|
1631
1700
|
"""
|
1632
1701
|
try:
|
1633
|
-
|
1634
|
-
|
1635
|
-
|
1636
|
-
|
1637
|
-
|
1638
|
-
|
1639
|
-
|
1640
|
-
|
1702
|
+
payload = WorldModelSearchPayload(
|
1703
|
+
criteria=criteria or [],
|
1704
|
+
size=size,
|
1705
|
+
project_id=project_id,
|
1706
|
+
search_all_versions=search_all_versions,
|
1707
|
+
)
|
1708
|
+
|
1709
|
+
response = self.client.post(
|
1710
|
+
"/v0.1/world-models/search",
|
1711
|
+
json=payload.model_dump(mode="json"),
|
1641
1712
|
)
|
1642
1713
|
response.raise_for_status()
|
1643
1714
|
return response.json()
|
@@ -1754,22 +1825,19 @@ class RestClient(DataStorageMethods):
|
|
1754
1825
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1755
1826
|
retry=retry_if_connection_error,
|
1756
1827
|
)
|
1757
|
-
def get_project_by_name(self, name: str) -> UUID:
|
1828
|
+
def get_project_by_name(self, name: str, limit: int = 2) -> UUID | list[UUID]:
|
1758
1829
|
"""Get a project UUID by name.
|
1759
1830
|
|
1760
1831
|
Args:
|
1761
1832
|
name: The name of the project to find
|
1833
|
+
limit: Maximum number of projects to return
|
1762
1834
|
|
1763
1835
|
Returns:
|
1764
|
-
UUID of the project as a string
|
1765
|
-
|
1766
|
-
Raises:
|
1767
|
-
ProjectError: If no project is found, multiple projects are found, or there's an error
|
1836
|
+
UUID of the project as a string or a list of UUIDs if multiple projects are found
|
1768
1837
|
"""
|
1769
1838
|
try:
|
1770
|
-
# Get projects filtered by name (backend now filters by name and owner)
|
1771
1839
|
response = self.client.get(
|
1772
|
-
"/v0.1/projects", params={"limit":
|
1840
|
+
"/v0.1/projects", params={"limit": limit, "name": name}
|
1773
1841
|
)
|
1774
1842
|
response.raise_for_status()
|
1775
1843
|
projects = response.json()
|
@@ -1782,32 +1850,33 @@ class RestClient(DataStorageMethods):
|
|
1782
1850
|
if len(projects) == 0:
|
1783
1851
|
raise ProjectError(f"No project found with name '{name}'")
|
1784
1852
|
if len(projects) > 1:
|
1785
|
-
|
1853
|
+
logger.warning(
|
1786
1854
|
f"Multiple projects found with name '{name}'. Found {len(projects)} projects."
|
1787
1855
|
)
|
1788
1856
|
|
1789
|
-
|
1857
|
+
ids = [UUID(project["id"]) for project in projects]
|
1858
|
+
return ids[0] if len(ids) == 1 else ids
|
1790
1859
|
|
1791
1860
|
@retry(
|
1792
1861
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1793
1862
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1794
1863
|
retry=retry_if_connection_error,
|
1795
1864
|
)
|
1796
|
-
async def aget_project_by_name(
|
1865
|
+
async def aget_project_by_name(
|
1866
|
+
self, name: str, limit: int = 2
|
1867
|
+
) -> UUID | list[UUID]:
|
1797
1868
|
"""Asynchronously get a project UUID by name.
|
1798
1869
|
|
1799
1870
|
Args:
|
1800
1871
|
name: The name of the project to find
|
1872
|
+
limit: Maximum number of projects to return
|
1801
1873
|
|
1802
1874
|
Returns:
|
1803
|
-
UUID of the project as a string
|
1804
|
-
|
1805
|
-
Raises:
|
1806
|
-
ProjectError: If no project is found, multiple projects are found, or there's an error
|
1875
|
+
UUID of the project as a string or a list of UUIDs if multiple projects are found
|
1807
1876
|
"""
|
1808
1877
|
try:
|
1809
1878
|
response = await self.async_client.get(
|
1810
|
-
"/v0.1/projects", params={"limit":
|
1879
|
+
"/v0.1/projects", params={"limit": limit, "name": name}
|
1811
1880
|
)
|
1812
1881
|
response.raise_for_status()
|
1813
1882
|
projects = response.json()
|
@@ -1816,11 +1885,12 @@ class RestClient(DataStorageMethods):
|
|
1816
1885
|
if len(projects) == 0:
|
1817
1886
|
raise ProjectError(f"No project found with name '{name}'")
|
1818
1887
|
if len(projects) > 1:
|
1819
|
-
|
1888
|
+
logger.warning(
|
1820
1889
|
f"Multiple projects found with name '{name}'. Found {len(projects)} projects."
|
1821
1890
|
)
|
1822
1891
|
|
1823
|
-
|
1892
|
+
ids = [UUID(project["id"]) for project in projects]
|
1893
|
+
return ids[0] if len(ids) == 1 else ids
|
1824
1894
|
|
1825
1895
|
@retry(
|
1826
1896
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
@@ -13,7 +13,7 @@ from .app import (
|
|
13
13
|
TaskResponse,
|
14
14
|
TaskResponseVerbose,
|
15
15
|
)
|
16
|
-
from .rest import WorldModel, WorldModelResponse
|
16
|
+
from .rest import TrajectoryPatchRequest, WorldModel, WorldModelResponse
|
17
17
|
|
18
18
|
__all__ = [
|
19
19
|
"AuthType",
|
@@ -29,6 +29,7 @@ __all__ = [
|
|
29
29
|
"TaskRequest",
|
30
30
|
"TaskResponse",
|
31
31
|
"TaskResponseVerbose",
|
32
|
+
"TrajectoryPatchRequest",
|
32
33
|
"WorldModel",
|
33
34
|
"WorldModelResponse",
|
34
35
|
]
|
futurehouse_client/models/app.py
CHANGED
@@ -27,10 +27,7 @@ if TYPE_CHECKING:
|
|
27
27
|
MAX_CROW_JOB_RUN_TIMEOUT = 60 * 60 * 24 # 24 hours in sec
|
28
28
|
MIN_CROW_JOB_RUN_TIMEOUT = 0 # sec
|
29
29
|
|
30
|
-
|
31
|
-
class PythonVersion(StrEnum):
|
32
|
-
V3_11 = "3.11"
|
33
|
-
V3_12 = "3.12"
|
30
|
+
DEFAULT_PYTHON_VERSION_USED_FOR_JOB_BUILDS = "3.13"
|
34
31
|
|
35
32
|
|
36
33
|
class AuthType(StrEnum):
|
@@ -420,9 +417,9 @@ class JobDeploymentConfig(BaseModel):
|
|
420
417
|
description="The configuration for the cloud run container.",
|
421
418
|
)
|
422
419
|
|
423
|
-
python_version:
|
424
|
-
default=
|
425
|
-
description="The python version your docker image should build with.",
|
420
|
+
python_version: str = Field(
|
421
|
+
default=DEFAULT_PYTHON_VERSION_USED_FOR_JOB_BUILDS,
|
422
|
+
description="The python version your docker image should build with (e.g., '3.11', '3.12', '3.13').",
|
426
423
|
)
|
427
424
|
|
428
425
|
agent: Agent | AgentConfig | str = Field(
|
@@ -34,6 +34,10 @@ class DataStorageEntry(BaseModel):
|
|
34
34
|
default=None,
|
35
35
|
description="ID of the parent entry if this is a sub-entry for hierarchical storage",
|
36
36
|
)
|
37
|
+
project_id: UUID | None = Field(
|
38
|
+
default=None,
|
39
|
+
description="ID of the project this data storage entry belongs to",
|
40
|
+
)
|
37
41
|
dataset_id: UUID | None = Field(
|
38
42
|
default=None,
|
39
43
|
description="ID of the dataset this entry belongs to",
|
@@ -79,8 +83,8 @@ class DataStorageLocationPayload(BaseModel):
|
|
79
83
|
location: str | None = None
|
80
84
|
|
81
85
|
|
82
|
-
class
|
83
|
-
"""Model representing the location
|
86
|
+
class DataStorageLocationConfig(BaseModel):
|
87
|
+
"""Model representing the location configuration within a DataStorageLocations object."""
|
84
88
|
|
85
89
|
storage_type: str = Field(description="Type of storage (e.g., 'gcs', 'pg_table')")
|
86
90
|
content_type: str = Field(description="Type of content stored")
|
@@ -89,15 +93,19 @@ class DataStorageLocationDetails(BaseModel):
|
|
89
93
|
location: str | None = Field(
|
90
94
|
default=None, description="Location path or identifier"
|
91
95
|
)
|
96
|
+
signed_url: str | None = Field(
|
97
|
+
default=None,
|
98
|
+
description="Signed URL for uploading/downloading the file to/from GCS",
|
99
|
+
)
|
92
100
|
|
93
101
|
|
94
|
-
class
|
102
|
+
class DataStorageLocation(BaseModel):
|
95
103
|
"""Model representing storage locations for a data storage entry."""
|
96
104
|
|
97
105
|
id: UUID = Field(description="Unique identifier for the storage locations")
|
98
106
|
data_storage_id: UUID = Field(description="ID of the associated data storage entry")
|
99
|
-
storage_config:
|
100
|
-
description="Storage configuration
|
107
|
+
storage_config: DataStorageLocationConfig = Field(
|
108
|
+
description="Storage location configuration"
|
101
109
|
)
|
102
110
|
created_at: datetime = Field(description="Timestamp when the location was created")
|
103
111
|
|
@@ -106,13 +114,9 @@ class DataStorageResponse(BaseModel):
|
|
106
114
|
"""Response model for data storage operations."""
|
107
115
|
|
108
116
|
data_storage: DataStorageEntry = Field(description="The created data storage entry")
|
109
|
-
|
117
|
+
storage_locations: list[DataStorageLocation] = Field(
|
110
118
|
description="Storage location for this data entry"
|
111
119
|
)
|
112
|
-
signed_url: str | None = Field(
|
113
|
-
default=None,
|
114
|
-
description="Signed URL for uploading/downloading the file to/from GCS",
|
115
|
-
)
|
116
120
|
|
117
121
|
|
118
122
|
class DataStorageRequestPayload(BaseModel):
|
@@ -131,6 +135,10 @@ class DataStorageRequestPayload(BaseModel):
|
|
131
135
|
parent_id: UUID | None = Field(
|
132
136
|
default=None, description="ID of the parent entry for hierarchical storage"
|
133
137
|
)
|
138
|
+
project_id: UUID | None = Field(
|
139
|
+
default=None,
|
140
|
+
description="ID of the project this data storage entry belongs to",
|
141
|
+
)
|
134
142
|
dataset_id: UUID | None = Field(
|
135
143
|
default=None,
|
136
144
|
description="ID of existing dataset to add entry to, or None to create new dataset",
|
@@ -144,6 +152,19 @@ class DataStorageRequestPayload(BaseModel):
|
|
144
152
|
)
|
145
153
|
|
146
154
|
|
155
|
+
class CreateDatasetPayload(BaseModel):
|
156
|
+
"""Payload for creating a dataset."""
|
157
|
+
|
158
|
+
id: UUID | None = Field(
|
159
|
+
default=None,
|
160
|
+
description="ID of the dataset to create, or None to create a new dataset",
|
161
|
+
)
|
162
|
+
name: str = Field(description="Name of the dataset")
|
163
|
+
description: str | None = Field(
|
164
|
+
default=None, description="Description of the dataset"
|
165
|
+
)
|
166
|
+
|
167
|
+
|
147
168
|
class ManifestEntry(BaseModel):
|
148
169
|
"""Model representing a single entry in a manifest file."""
|
149
170
|
|
@@ -23,6 +23,15 @@ class StoreEnvironmentFrameRequest(BaseModel):
|
|
23
23
|
trajectory_timestep: int
|
24
24
|
|
25
25
|
|
26
|
+
class TrajectoryPatchRequest(BaseModel):
|
27
|
+
public: bool | None = None
|
28
|
+
shared_with: list[int] | None = None
|
29
|
+
notification_enabled: bool | None = None
|
30
|
+
notification_type: str | None = None
|
31
|
+
min_estimated_time: float | None = None
|
32
|
+
max_estimated_time: float | None = None
|
33
|
+
|
34
|
+
|
26
35
|
class ExecutionStatus(StrEnum):
|
27
36
|
QUEUED = auto()
|
28
37
|
IN_PROGRESS = "in progress"
|
@@ -54,7 +63,37 @@ class WorldModel(BaseModel):
|
|
54
63
|
project_id: UUID | str | None = None
|
55
64
|
|
56
65
|
|
57
|
-
class
|
66
|
+
class SearchOperator(StrEnum):
|
67
|
+
"""Operators for structured search criteria."""
|
68
|
+
|
69
|
+
EQUALS = "equals"
|
70
|
+
CONTAINS = "contains"
|
71
|
+
STARTS_WITH = "starts_with"
|
72
|
+
ENDS_WITH = "ends_with"
|
73
|
+
GREATER_THAN = "greater_than"
|
74
|
+
LESS_THAN = "less_than"
|
75
|
+
BETWEEN = "between"
|
76
|
+
IN = "in"
|
77
|
+
|
78
|
+
|
79
|
+
class SearchCriterion(BaseModel):
|
80
|
+
"""A single search criterion with field, operator, and value."""
|
81
|
+
|
82
|
+
field: str
|
83
|
+
operator: SearchOperator
|
84
|
+
value: str | list[str] | bool
|
85
|
+
|
86
|
+
|
87
|
+
class WorldModelSearchPayload(BaseModel):
|
88
|
+
"""Payload for structured world model search."""
|
89
|
+
|
90
|
+
criteria: list[SearchCriterion]
|
91
|
+
size: int = 10
|
92
|
+
project_id: UUID | str | None = None
|
93
|
+
search_all_versions: bool = False
|
94
|
+
|
95
|
+
|
96
|
+
class WorldModelResponse(WorldModel):
|
58
97
|
"""
|
59
98
|
Response model for a world model snapshot.
|
60
99
|
|
@@ -62,13 +101,8 @@ class WorldModelResponse(BaseModel):
|
|
62
101
|
"""
|
63
102
|
|
64
103
|
id: UUID | str
|
65
|
-
|
66
|
-
name: str
|
67
|
-
description: str | None
|
68
|
-
content: str
|
69
|
-
trajectory_id: UUID | str | None
|
104
|
+
name: str # type: ignore[mutable-override] # The API always returns a non-optional name, overriding the base model's optional field.
|
70
105
|
email: str | None
|
71
|
-
model_metadata: JsonValue | None
|
72
106
|
enabled: bool
|
73
107
|
created_at: datetime
|
74
108
|
|
@@ -132,3 +166,10 @@ class DiscoveryResponse(BaseModel):
|
|
132
166
|
associated_trajectories: list[UUID | str]
|
133
167
|
validation_level: int
|
134
168
|
created_at: datetime
|
169
|
+
|
170
|
+
|
171
|
+
class DataStorageSearchPayload(BaseModel):
|
172
|
+
"""Payload for structured data storage search."""
|
173
|
+
|
174
|
+
criteria: list[SearchCriterion]
|
175
|
+
size: int = 10
|
@@ -1,22 +1,31 @@
|
|
1
1
|
import asyncio
|
2
|
-
from collections.abc import Awaitable, Iterable
|
2
|
+
from collections.abc import Awaitable, Callable, Iterable
|
3
3
|
from typing import TypeVar
|
4
4
|
|
5
5
|
from httpx import (
|
6
6
|
CloseError,
|
7
7
|
ConnectError,
|
8
8
|
ConnectTimeout,
|
9
|
+
HTTPStatusError,
|
9
10
|
NetworkError,
|
10
11
|
ReadError,
|
11
12
|
ReadTimeout,
|
12
13
|
RemoteProtocolError,
|
14
|
+
codes,
|
13
15
|
)
|
14
16
|
from requests.exceptions import RequestException, Timeout
|
15
|
-
from tenacity import
|
17
|
+
from tenacity import RetryCallState
|
16
18
|
from tqdm.asyncio import tqdm
|
17
19
|
|
18
20
|
T = TypeVar("T")
|
19
21
|
|
22
|
+
RETRYABLE_HTTP_STATUS_CODES = {
|
23
|
+
codes.TOO_MANY_REQUESTS,
|
24
|
+
codes.INTERNAL_SERVER_ERROR,
|
25
|
+
codes.BAD_GATEWAY,
|
26
|
+
codes.SERVICE_UNAVAILABLE,
|
27
|
+
codes.GATEWAY_TIMEOUT,
|
28
|
+
}
|
20
29
|
|
21
30
|
_BASE_CONNECTION_ERRORS = (
|
22
31
|
# From requests
|
@@ -33,12 +42,32 @@ _BASE_CONNECTION_ERRORS = (
|
|
33
42
|
CloseError,
|
34
43
|
)
|
35
44
|
|
36
|
-
retry_if_connection_error = retry_if_exception_type(_BASE_CONNECTION_ERRORS)
|
37
45
|
|
46
|
+
def create_retry_if_connection_error(
|
47
|
+
*additional_exceptions,
|
48
|
+
) -> Callable[[RetryCallState], bool]:
|
49
|
+
"""Create a retry condition with base connection errors, HTTP status errors, plus additional exceptions."""
|
38
50
|
|
39
|
-
def
|
40
|
-
|
41
|
-
|
51
|
+
def status_retries_with_exceptions(retry_state: RetryCallState) -> bool:
|
52
|
+
if retry_state.outcome is not None and hasattr(
|
53
|
+
retry_state.outcome, "exception"
|
54
|
+
):
|
55
|
+
exception = retry_state.outcome.exception()
|
56
|
+
# connection errors
|
57
|
+
if isinstance(exception, _BASE_CONNECTION_ERRORS):
|
58
|
+
return True
|
59
|
+
# custom exceptions provided
|
60
|
+
if additional_exceptions and isinstance(exception, additional_exceptions):
|
61
|
+
return True
|
62
|
+
# any http exceptions
|
63
|
+
if isinstance(exception, HTTPStatusError):
|
64
|
+
return exception.response.status_code in RETRYABLE_HTTP_STATUS_CODES
|
65
|
+
return False
|
66
|
+
|
67
|
+
return status_retries_with_exceptions
|
68
|
+
|
69
|
+
|
70
|
+
retry_if_connection_error = create_retry_if_connection_error()
|
42
71
|
|
43
72
|
|
44
73
|
async def gather_with_concurrency(
|