futurehouse-client 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from futurehouse_client.models.rest import (
19
19
  FinalEnvironmentRequest,
20
20
  StoreAgentStatePostRequest,
21
21
  StoreEnvironmentFrameRequest,
22
+ TrajectoryPatchRequest,
22
23
  )
23
24
  from futurehouse_client.utils.monitoring import (
24
25
  external_trace,
@@ -318,3 +319,52 @@ class JobClient:
318
319
  f"Unexpected error storing environment frame for state {state_identifier}",
319
320
  )
320
321
  raise
322
+
323
+ async def patch_trajectory(
324
+ self,
325
+ public: bool | None = None,
326
+ shared_with: list[int] | None = None,
327
+ notification_enabled: bool | None = None,
328
+ notification_type: str | None = None,
329
+ min_estimated_time: float | None = None,
330
+ max_estimated_time: float | None = None,
331
+ ) -> None:
332
+ data = TrajectoryPatchRequest(
333
+ public=public,
334
+ shared_with=shared_with,
335
+ notification_enabled=notification_enabled,
336
+ notification_type=notification_type,
337
+ min_estimated_time=min_estimated_time,
338
+ max_estimated_time=max_estimated_time,
339
+ )
340
+ try:
341
+ async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
342
+ url = f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}"
343
+ headers = {
344
+ "Authorization": f"Bearer {self.oauth_jwt}",
345
+ "x-trajectory-id": self.trajectory_id,
346
+ }
347
+ response = await client.patch(
348
+ url=url,
349
+ json=data.model_dump(mode="json", exclude_none=True),
350
+ headers=headers,
351
+ )
352
+ response.raise_for_status()
353
+ logger.debug("Trajectory updated successfully")
354
+ except httpx.HTTPStatusError as e:
355
+ logger.exception(
356
+ "HTTP error while patching trajectory. "
357
+ f"Status code: {e.response.status_code}, "
358
+ f"Response: {e.response.text}",
359
+ )
360
+ except httpx.TimeoutException:
361
+ logger.exception(
362
+ f"Timeout while patching trajectory after {self.REQUEST_TIMEOUT}s",
363
+ )
364
+ raise
365
+ except httpx.NetworkError:
366
+ logger.exception("Network error while patching trajectory")
367
+ raise
368
+ except Exception:
369
+ logger.exception("Unexpected error while patching trajectory")
370
+ raise
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: PLR0915
1
2
  import ast
2
3
  import asyncio
3
4
  import base64
@@ -26,21 +27,14 @@ from httpx import (
26
27
  AsyncClient,
27
28
  Client,
28
29
  CloseError,
29
- ConnectError,
30
- ConnectTimeout,
31
30
  HTTPStatusError,
32
- NetworkError,
33
- ReadError,
34
- ReadTimeout,
35
31
  RemoteProtocolError,
36
32
  codes,
37
33
  )
38
34
  from ldp.agent import AgentConfig
39
- from requests.exceptions import RequestException, Timeout
40
35
  from tenacity import (
41
36
  before_sleep_log,
42
37
  retry,
43
- retry_if_exception_type,
44
38
  stop_after_attempt,
45
39
  wait_exponential,
46
40
  )
@@ -48,6 +42,7 @@ from tqdm import tqdm as sync_tqdm
48
42
  from tqdm.asyncio import tqdm
49
43
 
50
44
  from futurehouse_client.clients import JobNames
45
+ from futurehouse_client.clients.data_storage_methods import DataStorageMethods
51
46
  from futurehouse_client.models.app import (
52
47
  AuthType,
53
48
  JobDeploymentConfig,
@@ -60,15 +55,20 @@ from futurehouse_client.models.app import (
60
55
  from futurehouse_client.models.rest import (
61
56
  DiscoveryResponse,
62
57
  ExecutionStatus,
58
+ SearchCriterion,
63
59
  UserAgentRequest,
64
60
  UserAgentRequestPostPayload,
65
61
  UserAgentRequestStatus,
66
62
  UserAgentResponsePayload,
67
63
  WorldModel,
68
64
  WorldModelResponse,
65
+ WorldModelSearchPayload,
69
66
  )
70
67
  from futurehouse_client.utils.auth import RefreshingJWT
71
- from futurehouse_client.utils.general import gather_with_concurrency
68
+ from futurehouse_client.utils.general import (
69
+ create_retry_if_connection_error,
70
+ gather_with_concurrency,
71
+ )
72
72
  from futurehouse_client.utils.module_utils import (
73
73
  OrganizationSelector,
74
74
  fetch_environment_function_docstring,
@@ -160,28 +160,15 @@ class FileUploadError(RestClientError):
160
160
  """Raised when there's an error uploading a file."""
161
161
 
162
162
 
163
- retry_if_connection_error = retry_if_exception_type((
164
- # From requests
165
- Timeout,
166
- ConnectionError,
167
- RequestException,
168
- # From httpx
169
- ConnectError,
170
- ConnectTimeout,
171
- ReadTimeout,
172
- ReadError,
173
- NetworkError,
174
- RemoteProtocolError,
175
- CloseError,
176
- FileUploadError,
177
- ))
163
+ retry_if_connection_error = create_retry_if_connection_error(FileUploadError)
178
164
 
179
165
  DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
180
166
 
181
167
 
182
168
  # pylint: disable=too-many-public-methods
183
- class RestClient:
184
- REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
169
+ class RestClient(DataStorageMethods):
170
+ REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec - for general API calls
171
+ FILE_UPLOAD_TIMEOUT: ClassVar[float] = 600.0 # 10 minutes - for file uploads
185
172
  MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
186
173
  RETRY_MULTIPLIER: ClassVar[int] = 1
187
174
  MAX_RETRY_WAIT: ClassVar[int] = 10
@@ -239,11 +226,35 @@ class RestClient:
239
226
  """Authenticated HTTP client for multipart uploads."""
240
227
  return cast(Client, self.get_client(None, authenticated=True))
241
228
 
229
+ @property
230
+ def file_upload_client(self) -> Client:
231
+ """Authenticated HTTP client with extended timeout for file uploads."""
232
+ return cast(
233
+ Client,
234
+ self.get_client(
235
+ "application/json", authenticated=True, timeout=self.FILE_UPLOAD_TIMEOUT
236
+ ),
237
+ )
238
+
239
+ @property
240
+ def async_file_upload_client(self) -> AsyncClient:
241
+ """Authenticated async HTTP client with extended timeout for file uploads."""
242
+ return cast(
243
+ AsyncClient,
244
+ self.get_client(
245
+ "application/json",
246
+ authenticated=True,
247
+ async_client=True,
248
+ timeout=self.FILE_UPLOAD_TIMEOUT,
249
+ ),
250
+ )
251
+
242
252
  def get_client(
243
253
  self,
244
254
  content_type: str | None = "application/json",
245
255
  authenticated: bool = True,
246
256
  async_client: bool = False,
257
+ timeout: float | None = None,
247
258
  ) -> Client | AsyncClient:
248
259
  """Return a cached HTTP client or create one if needed.
249
260
 
@@ -251,12 +262,13 @@ class RestClient:
251
262
  content_type: The desired content type header. Use None for multipart uploads.
252
263
  authenticated: Whether the client should include authentication.
253
264
  async_client: Whether to use an async client.
265
+ timeout: Custom timeout in seconds. Uses REQUEST_TIMEOUT if not provided.
254
266
 
255
267
  Returns:
256
268
  An HTTP client configured with the appropriate headers.
257
269
  """
258
- # Create a composite key based on content type and auth flag
259
- key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
270
+ client_timeout = timeout or self.REQUEST_TIMEOUT
271
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}_{client_timeout}"
260
272
 
261
273
  if key not in self._clients:
262
274
  headers = copy.deepcopy(self.headers)
@@ -282,14 +294,14 @@ class RestClient:
282
294
  AsyncClient(
283
295
  base_url=self.base_url,
284
296
  headers=headers,
285
- timeout=self.REQUEST_TIMEOUT,
297
+ timeout=client_timeout,
286
298
  auth=auth,
287
299
  )
288
300
  if async_client
289
301
  else Client(
290
302
  base_url=self.base_url,
291
303
  headers=headers,
292
- timeout=self.REQUEST_TIMEOUT,
304
+ timeout=client_timeout,
293
305
  auth=auth,
294
306
  )
295
307
  )
@@ -323,16 +335,23 @@ class RestClient:
323
335
  raise ValueError(f"Organization '{organization}' not found.")
324
336
  return filtered_orgs
325
337
 
338
+ @retry(
339
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
340
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
341
+ retry=retry_if_connection_error,
342
+ before_sleep=before_sleep_log(logger, logging.WARNING),
343
+ )
326
344
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
327
- try:
328
- response = self.client.get(
329
- f"/v0.1/crows/{name}/organizations/{organization}"
330
- )
331
- response.raise_for_status()
332
- return response.json()
333
- except Exception as e:
334
- raise JobFetchError(f"Error checking job: {e!r}.") from e
345
+ response = self.client.get(f"/v0.1/crows/{name}/organizations/{organization}")
346
+ response.raise_for_status()
347
+ return response.json()
335
348
 
349
+ @retry(
350
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
351
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
352
+ retry=retry_if_connection_error,
353
+ before_sleep=before_sleep_log(logger, logging.WARNING),
354
+ )
336
355
  def _fetch_my_orgs(self) -> list[str]:
337
356
  response = self.client.get(f"/v0.1/organizations?filter={True}")
338
357
  response.raise_for_status()
@@ -690,10 +709,12 @@ class RestClient:
690
709
 
691
710
  async def arun_tasks_until_done(
692
711
  self,
693
- task_data: TaskRequest
694
- | dict[str, Any]
695
- | Collection[TaskRequest]
696
- | Collection[dict[str, Any]],
712
+ task_data: (
713
+ TaskRequest
714
+ | dict[str, Any]
715
+ | Collection[TaskRequest]
716
+ | Collection[dict[str, Any]]
717
+ ),
697
718
  verbose: bool = False,
698
719
  progress_bar: bool = False,
699
720
  concurrency: int = 10,
@@ -761,10 +782,12 @@ class RestClient:
761
782
 
762
783
  def run_tasks_until_done(
763
784
  self,
764
- task_data: TaskRequest
765
- | dict[str, Any]
766
- | Collection[TaskRequest]
767
- | Collection[dict[str, Any]],
785
+ task_data: (
786
+ TaskRequest
787
+ | dict[str, Any]
788
+ | Collection[TaskRequest]
789
+ | Collection[dict[str, Any]]
790
+ ),
768
791
  verbose: bool = False,
769
792
  progress_bar: bool = False,
770
793
  timeout: int = DEFAULT_AGENT_TIMEOUT,
@@ -837,12 +860,9 @@ class RestClient:
837
860
  )
838
861
  def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
839
862
  """Get the status of a build."""
840
- try:
841
- build_id = build_id or self.build_id
842
- response = self.client.get(f"/v0.1/builds/{build_id}")
843
- response.raise_for_status()
844
- except Exception as e:
845
- raise JobFetchError(f"Error getting build status: {e!r}.") from e
863
+ build_id = build_id or self.build_id
864
+ response = self.client.get(f"/v0.1/builds/{build_id}")
865
+ response.raise_for_status()
846
866
  return response.json()
847
867
 
848
868
  # TODO: Refactor later so we don't have to ignore PLR0915
@@ -852,7 +872,7 @@ class RestClient:
852
872
  retry=retry_if_connection_error,
853
873
  before_sleep=before_sleep_log(logger, logging.WARNING),
854
874
  )
855
- def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]: # noqa: PLR0915
875
+ def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
856
876
  """Creates a futurehouse job deployment from the environment and environment files.
857
877
 
858
878
  Args:
@@ -1597,6 +1617,56 @@ class RestClient:
1597
1617
  except Exception as e:
1598
1618
  raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
1599
1619
 
1620
+ @retry(
1621
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1622
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1623
+ retry=retry_if_connection_error,
1624
+ )
1625
+ def list_world_models(
1626
+ self,
1627
+ name: str | None = None,
1628
+ project_id: UUID | str | None = None,
1629
+ limit: int = 150,
1630
+ offset: int = 0,
1631
+ sort_order: str = "asc",
1632
+ ) -> list[WorldModelResponse]:
1633
+ """List world models with different behavior based on filters.
1634
+
1635
+ When filtering by name: returns only the latest version for that name.
1636
+ When filtering by project_id (without name): returns all versions for that project.
1637
+ When no filters: returns latest version of each world model.
1638
+
1639
+ Args:
1640
+ name: Filter by world model name.
1641
+ project_id: Filter by project ID.
1642
+ limit: The maximum number of models to return.
1643
+ offset: Number of results to skip for pagination.
1644
+ sort_order: Sort order 'asc' or 'desc'.
1645
+
1646
+ Returns:
1647
+ A list of world model dictionaries.
1648
+ """
1649
+ try:
1650
+ params: dict[str, str | int] = {
1651
+ "limit": limit,
1652
+ "offset": offset,
1653
+ "sort_order": sort_order,
1654
+ }
1655
+ if name:
1656
+ params["name"] = name
1657
+ if project_id:
1658
+ params["project_id"] = str(project_id)
1659
+
1660
+ response = self.client.get("/v0.1/world-models", params=params)
1661
+ response.raise_for_status()
1662
+ return response.json()
1663
+ except HTTPStatusError as e:
1664
+ raise WorldModelFetchError(
1665
+ f"Error listing world models: {e.response.status_code} - {e.response.text}"
1666
+ ) from e
1667
+ except Exception as e:
1668
+ raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
1669
+
1600
1670
  @retry(
1601
1671
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1602
1672
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1604,34 +1674,43 @@ class RestClient:
1604
1674
  )
1605
1675
  def search_world_models(
1606
1676
  self,
1607
- query: str,
1677
+ criteria: list[SearchCriterion] | None = None,
1608
1678
  size: int = 10,
1609
- total_search_size: int = 50,
1679
+ project_id: UUID | str | None = None,
1610
1680
  search_all_versions: bool = False,
1611
- ) -> list[str]:
1612
- """Search for world models.
1681
+ ) -> list[WorldModelResponse]:
1682
+ """Search world models using structured criteria.
1613
1683
 
1614
1684
  Args:
1615
- query: The search query.
1685
+ criteria: List of SearchCriterion objects with field, operator, and value.
1616
1686
  size: The number of results to return.
1617
- total_search_size: The number of results to search for.
1618
- search_all_versions: Whether to search all versions of the world model or just the latest one.
1687
+ project_id: Optional filter by project ID.
1688
+ search_all_versions: Whether to search all versions or just latest.
1619
1689
 
1620
1690
  Returns:
1621
- A list of world model names.
1691
+ A list of world model responses.
1692
+
1693
+ Example:
1694
+ from futurehouse_client.models.rest import SearchCriterion, SearchOperator
1695
+ criteria = [
1696
+ SearchCriterion(field="name", operator=SearchOperator.CONTAINS, value="chemistry"),
1697
+ SearchCriterion(field="email", operator=SearchOperator.CONTAINS, value="tyler"),
1698
+ ]
1699
+ results = client.search_world_models(criteria=criteria, size=20)
1622
1700
  """
1623
1701
  try:
1624
- # Use the consolidated endpoint with search parameters
1625
- response = self.client.get(
1626
- "/v0.1/world-models",
1627
- params={
1628
- "q": query,
1629
- "size": size,
1630
- "search_all_versions": search_all_versions,
1631
- },
1702
+ payload = WorldModelSearchPayload(
1703
+ criteria=criteria or [],
1704
+ size=size,
1705
+ project_id=project_id,
1706
+ search_all_versions=search_all_versions,
1707
+ )
1708
+
1709
+ response = self.client.post(
1710
+ "/v0.1/world-models/search",
1711
+ json=payload.model_dump(mode="json"),
1632
1712
  )
1633
1713
  response.raise_for_status()
1634
- # The new endpoint returns a list of models directly
1635
1714
  return response.json()
1636
1715
  except HTTPStatusError as e:
1637
1716
  raise WorldModelFetchError(
@@ -1746,22 +1825,19 @@ class RestClient:
1746
1825
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1747
1826
  retry=retry_if_connection_error,
1748
1827
  )
1749
- def get_project_by_name(self, name: str) -> UUID:
1828
+ def get_project_by_name(self, name: str, limit: int = 2) -> UUID | list[UUID]:
1750
1829
  """Get a project UUID by name.
1751
1830
 
1752
1831
  Args:
1753
1832
  name: The name of the project to find
1833
+ limit: Maximum number of projects to return
1754
1834
 
1755
1835
  Returns:
1756
- UUID of the project as a string
1757
-
1758
- Raises:
1759
- ProjectError: If no project is found, multiple projects are found, or there's an error
1836
+ UUID of the project as a string or a list of UUIDs if multiple projects are found
1760
1837
  """
1761
1838
  try:
1762
- # Get projects filtered by name (backend now filters by name and owner)
1763
1839
  response = self.client.get(
1764
- "/v0.1/projects", params={"limit": 2, "name": name}
1840
+ "/v0.1/projects", params={"limit": limit, "name": name}
1765
1841
  )
1766
1842
  response.raise_for_status()
1767
1843
  projects = response.json()
@@ -1774,32 +1850,33 @@ class RestClient:
1774
1850
  if len(projects) == 0:
1775
1851
  raise ProjectError(f"No project found with name '{name}'")
1776
1852
  if len(projects) > 1:
1777
- raise ProjectError(
1853
+ logger.warning(
1778
1854
  f"Multiple projects found with name '{name}'. Found {len(projects)} projects."
1779
1855
  )
1780
1856
 
1781
- return UUID(projects[0]["id"])
1857
+ ids = [UUID(project["id"]) for project in projects]
1858
+ return ids[0] if len(ids) == 1 else ids
1782
1859
 
1783
1860
  @retry(
1784
1861
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1785
1862
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1786
1863
  retry=retry_if_connection_error,
1787
1864
  )
1788
- async def aget_project_by_name(self, name: str) -> UUID:
1865
+ async def aget_project_by_name(
1866
+ self, name: str, limit: int = 2
1867
+ ) -> UUID | list[UUID]:
1789
1868
  """Asynchronously get a project UUID by name.
1790
1869
 
1791
1870
  Args:
1792
1871
  name: The name of the project to find
1872
+ limit: Maximum number of projects to return
1793
1873
 
1794
1874
  Returns:
1795
- UUID of the project as a string
1796
-
1797
- Raises:
1798
- ProjectError: If no project is found, multiple projects are found, or there's an error
1875
+ UUID of the project as a string or a list of UUIDs if multiple projects are found
1799
1876
  """
1800
1877
  try:
1801
1878
  response = await self.async_client.get(
1802
- "/v0.1/projects", params={"limit": 2, "name": name}
1879
+ "/v0.1/projects", params={"limit": limit, "name": name}
1803
1880
  )
1804
1881
  response.raise_for_status()
1805
1882
  projects = response.json()
@@ -1808,11 +1885,12 @@ class RestClient:
1808
1885
  if len(projects) == 0:
1809
1886
  raise ProjectError(f"No project found with name '{name}'")
1810
1887
  if len(projects) > 1:
1811
- raise ProjectError(
1888
+ logger.warning(
1812
1889
  f"Multiple projects found with name '{name}'. Found {len(projects)} projects."
1813
1890
  )
1814
1891
 
1815
- return UUID(projects[0]["id"])
1892
+ ids = [UUID(project["id"]) for project in projects]
1893
+ return ids[0] if len(ids) == 1 else ids
1816
1894
 
1817
1895
  @retry(
1818
1896
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
@@ -13,7 +13,7 @@ from .app import (
13
13
  TaskResponse,
14
14
  TaskResponseVerbose,
15
15
  )
16
- from .rest import WorldModel, WorldModelResponse
16
+ from .rest import TrajectoryPatchRequest, WorldModel, WorldModelResponse
17
17
 
18
18
  __all__ = [
19
19
  "AuthType",
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "TaskRequest",
30
30
  "TaskResponse",
31
31
  "TaskResponseVerbose",
32
+ "TrajectoryPatchRequest",
32
33
  "WorldModel",
33
34
  "WorldModelResponse",
34
35
  ]