futurehouse-client 0.4.2.dev11__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from futurehouse_client.models.rest import (
19
19
  FinalEnvironmentRequest,
20
20
  StoreAgentStatePostRequest,
21
21
  StoreEnvironmentFrameRequest,
22
+ TrajectoryPatchRequest,
22
23
  )
23
24
  from futurehouse_client.utils.monitoring import (
24
25
  external_trace,
@@ -318,3 +319,52 @@ class JobClient:
318
319
  f"Unexpected error storing environment frame for state {state_identifier}",
319
320
  )
320
321
  raise
322
+
323
+ async def patch_trajectory(
324
+ self,
325
+ public: bool | None = None,
326
+ shared_with: list[int] | None = None,
327
+ notification_enabled: bool | None = None,
328
+ notification_type: str | None = None,
329
+ min_estimated_time: float | None = None,
330
+ max_estimated_time: float | None = None,
331
+ ) -> None:
332
+ data = TrajectoryPatchRequest(
333
+ public=public,
334
+ shared_with=shared_with,
335
+ notification_enabled=notification_enabled,
336
+ notification_type=notification_type,
337
+ min_estimated_time=min_estimated_time,
338
+ max_estimated_time=max_estimated_time,
339
+ )
340
+ try:
341
+ async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
342
+ url = f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}"
343
+ headers = {
344
+ "Authorization": f"Bearer {self.oauth_jwt}",
345
+ "x-trajectory-id": self.trajectory_id,
346
+ }
347
+ response = await client.patch(
348
+ url=url,
349
+ json=data.model_dump(mode="json", exclude_none=True),
350
+ headers=headers,
351
+ )
352
+ response.raise_for_status()
353
+ logger.debug("Trajectory updated successfully")
354
+ except httpx.HTTPStatusError as e:
355
+ logger.exception(
356
+ "HTTP error while patching trajectory. "
357
+ f"Status code: {e.response.status_code}, "
358
+ f"Response: {e.response.text}",
359
+ )
360
+ except httpx.TimeoutException:
361
+ logger.exception(
362
+ f"Timeout while patching trajectory after {self.REQUEST_TIMEOUT}s",
363
+ )
364
+ raise
365
+ except httpx.NetworkError:
366
+ logger.exception("Network error while patching trajectory")
367
+ raise
368
+ except Exception:
369
+ logger.exception("Unexpected error while patching trajectory")
370
+ raise
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: PLR0915
1
2
  import ast
2
3
  import asyncio
3
4
  import base64
@@ -26,21 +27,14 @@ from httpx import (
26
27
  AsyncClient,
27
28
  Client,
28
29
  CloseError,
29
- ConnectError,
30
- ConnectTimeout,
31
30
  HTTPStatusError,
32
- NetworkError,
33
- ReadError,
34
- ReadTimeout,
35
31
  RemoteProtocolError,
36
32
  codes,
37
33
  )
38
34
  from ldp.agent import AgentConfig
39
- from requests.exceptions import RequestException, Timeout
40
35
  from tenacity import (
41
36
  before_sleep_log,
42
37
  retry,
43
- retry_if_exception_type,
44
38
  stop_after_attempt,
45
39
  wait_exponential,
46
40
  )
@@ -48,6 +42,7 @@ from tqdm import tqdm as sync_tqdm
48
42
  from tqdm.asyncio import tqdm
49
43
 
50
44
  from futurehouse_client.clients import JobNames
45
+ from futurehouse_client.clients.data_storage_methods import DataStorageMethods
51
46
  from futurehouse_client.models.app import (
52
47
  AuthType,
53
48
  JobDeploymentConfig,
@@ -60,15 +55,20 @@ from futurehouse_client.models.app import (
60
55
  from futurehouse_client.models.rest import (
61
56
  DiscoveryResponse,
62
57
  ExecutionStatus,
58
+ SearchCriterion,
63
59
  UserAgentRequest,
64
60
  UserAgentRequestPostPayload,
65
61
  UserAgentRequestStatus,
66
62
  UserAgentResponsePayload,
67
63
  WorldModel,
68
64
  WorldModelResponse,
65
+ WorldModelSearchPayload,
69
66
  )
70
67
  from futurehouse_client.utils.auth import RefreshingJWT
71
- from futurehouse_client.utils.general import gather_with_concurrency
68
+ from futurehouse_client.utils.general import (
69
+ create_retry_if_connection_error,
70
+ gather_with_concurrency,
71
+ )
72
72
  from futurehouse_client.utils.module_utils import (
73
73
  OrganizationSelector,
74
74
  fetch_environment_function_docstring,
@@ -160,28 +160,14 @@ class FileUploadError(RestClientError):
160
160
  """Raised when there's an error uploading a file."""
161
161
 
162
162
 
163
- retry_if_connection_error = retry_if_exception_type((
164
- # From requests
165
- Timeout,
166
- ConnectionError,
167
- RequestException,
168
- # From httpx
169
- ConnectError,
170
- ConnectTimeout,
171
- ReadTimeout,
172
- ReadError,
173
- NetworkError,
174
- RemoteProtocolError,
175
- CloseError,
176
- FileUploadError,
177
- ))
163
+ retry_if_connection_error = create_retry_if_connection_error(FileUploadError)
178
164
 
179
165
  DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
180
166
 
181
167
 
182
- # pylint: disable=too-many-public-methods
183
- class RestClient:
184
- REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
168
+ class RestClient(DataStorageMethods):
169
+ REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec - for general API calls
170
+ FILE_UPLOAD_TIMEOUT: ClassVar[float] = 600.0 # 10 minutes - for file uploads
185
171
  MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
186
172
  RETRY_MULTIPLIER: ClassVar[int] = 1
187
173
  MAX_RETRY_WAIT: ClassVar[int] = 10
@@ -239,11 +225,35 @@ class RestClient:
239
225
  """Authenticated HTTP client for multipart uploads."""
240
226
  return cast(Client, self.get_client(None, authenticated=True))
241
227
 
228
+ @property
229
+ def file_upload_client(self) -> Client:
230
+ """Authenticated HTTP client with extended timeout for file uploads."""
231
+ return cast(
232
+ Client,
233
+ self.get_client(
234
+ "application/json", authenticated=True, timeout=self.FILE_UPLOAD_TIMEOUT
235
+ ),
236
+ )
237
+
238
+ @property
239
+ def async_file_upload_client(self) -> AsyncClient:
240
+ """Authenticated async HTTP client with extended timeout for file uploads."""
241
+ return cast(
242
+ AsyncClient,
243
+ self.get_client(
244
+ "application/json",
245
+ authenticated=True,
246
+ async_client=True,
247
+ timeout=self.FILE_UPLOAD_TIMEOUT,
248
+ ),
249
+ )
250
+
242
251
  def get_client(
243
252
  self,
244
253
  content_type: str | None = "application/json",
245
254
  authenticated: bool = True,
246
255
  async_client: bool = False,
256
+ timeout: float | None = None,
247
257
  ) -> Client | AsyncClient:
248
258
  """Return a cached HTTP client or create one if needed.
249
259
 
@@ -251,12 +261,13 @@ class RestClient:
251
261
  content_type: The desired content type header. Use None for multipart uploads.
252
262
  authenticated: Whether the client should include authentication.
253
263
  async_client: Whether to use an async client.
264
+ timeout: Custom timeout in seconds. Uses REQUEST_TIMEOUT if not provided.
254
265
 
255
266
  Returns:
256
267
  An HTTP client configured with the appropriate headers.
257
268
  """
258
- # Create a composite key based on content type and auth flag
259
- key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
269
+ client_timeout = timeout or self.REQUEST_TIMEOUT
270
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}_{client_timeout}"
260
271
 
261
272
  if key not in self._clients:
262
273
  headers = copy.deepcopy(self.headers)
@@ -282,14 +293,14 @@ class RestClient:
282
293
  AsyncClient(
283
294
  base_url=self.base_url,
284
295
  headers=headers,
285
- timeout=self.REQUEST_TIMEOUT,
296
+ timeout=client_timeout,
286
297
  auth=auth,
287
298
  )
288
299
  if async_client
289
300
  else Client(
290
301
  base_url=self.base_url,
291
302
  headers=headers,
292
- timeout=self.REQUEST_TIMEOUT,
303
+ timeout=client_timeout,
293
304
  auth=auth,
294
305
  )
295
306
  )
@@ -323,16 +334,23 @@ class RestClient:
323
334
  raise ValueError(f"Organization '{organization}' not found.")
324
335
  return filtered_orgs
325
336
 
337
+ @retry(
338
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
339
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
340
+ retry=retry_if_connection_error,
341
+ before_sleep=before_sleep_log(logger, logging.WARNING),
342
+ )
326
343
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
327
- try:
328
- response = self.client.get(
329
- f"/v0.1/crows/{name}/organizations/{organization}"
330
- )
331
- response.raise_for_status()
332
- return response.json()
333
- except Exception as e:
334
- raise JobFetchError(f"Error checking job: {e!r}.") from e
344
+ response = self.client.get(f"/v0.1/crows/{name}/organizations/{organization}")
345
+ response.raise_for_status()
346
+ return response.json()
335
347
 
348
+ @retry(
349
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
350
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
351
+ retry=retry_if_connection_error,
352
+ before_sleep=before_sleep_log(logger, logging.WARNING),
353
+ )
336
354
  def _fetch_my_orgs(self) -> list[str]:
337
355
  response = self.client.get(f"/v0.1/organizations?filter={True}")
338
356
  response.raise_for_status()
@@ -690,10 +708,12 @@ class RestClient:
690
708
 
691
709
  async def arun_tasks_until_done(
692
710
  self,
693
- task_data: TaskRequest
694
- | dict[str, Any]
695
- | Collection[TaskRequest]
696
- | Collection[dict[str, Any]],
711
+ task_data: (
712
+ TaskRequest
713
+ | dict[str, Any]
714
+ | Collection[TaskRequest]
715
+ | Collection[dict[str, Any]]
716
+ ),
697
717
  verbose: bool = False,
698
718
  progress_bar: bool = False,
699
719
  concurrency: int = 10,
@@ -761,10 +781,12 @@ class RestClient:
761
781
 
762
782
  def run_tasks_until_done(
763
783
  self,
764
- task_data: TaskRequest
765
- | dict[str, Any]
766
- | Collection[TaskRequest]
767
- | Collection[dict[str, Any]],
784
+ task_data: (
785
+ TaskRequest
786
+ | dict[str, Any]
787
+ | Collection[TaskRequest]
788
+ | Collection[dict[str, Any]]
789
+ ),
768
790
  verbose: bool = False,
769
791
  progress_bar: bool = False,
770
792
  timeout: int = DEFAULT_AGENT_TIMEOUT,
@@ -837,12 +859,9 @@ class RestClient:
837
859
  )
838
860
  def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
839
861
  """Get the status of a build."""
840
- try:
841
- build_id = build_id or self.build_id
842
- response = self.client.get(f"/v0.1/builds/{build_id}")
843
- response.raise_for_status()
844
- except Exception as e:
845
- raise JobFetchError(f"Error getting build status: {e!r}.") from e
862
+ build_id = build_id or self.build_id
863
+ response = self.client.get(f"/v0.1/builds/{build_id}")
864
+ response.raise_for_status()
846
865
  return response.json()
847
866
 
848
867
  # TODO: Refactor later so we don't have to ignore PLR0915
@@ -852,7 +871,7 @@ class RestClient:
852
871
  retry=retry_if_connection_error,
853
872
  before_sleep=before_sleep_log(logger, logging.WARNING),
854
873
  )
855
- def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]: # noqa: PLR0915
874
+ def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
856
875
  """Creates a futurehouse job deployment from the environment and environment files.
857
876
 
858
877
  Args:
@@ -1597,6 +1616,56 @@ class RestClient:
1597
1616
  except Exception as e:
1598
1617
  raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
1599
1618
 
1619
+ @retry(
1620
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1621
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1622
+ retry=retry_if_connection_error,
1623
+ )
1624
+ def list_world_models(
1625
+ self,
1626
+ name: str | None = None,
1627
+ project_id: UUID | str | None = None,
1628
+ limit: int = 150,
1629
+ offset: int = 0,
1630
+ sort_order: str = "asc",
1631
+ ) -> list[WorldModelResponse]:
1632
+ """List world models with different behavior based on filters.
1633
+
1634
+ When filtering by name: returns only the latest version for that name.
1635
+ When filtering by project_id (without name): returns all versions for that project.
1636
+ When no filters: returns latest version of each world model.
1637
+
1638
+ Args:
1639
+ name: Filter by world model name.
1640
+ project_id: Filter by project ID.
1641
+ limit: The maximum number of models to return.
1642
+ offset: Number of results to skip for pagination.
1643
+ sort_order: Sort order 'asc' or 'desc'.
1644
+
1645
+ Returns:
1646
+ A list of world model dictionaries.
1647
+ """
1648
+ try:
1649
+ params: dict[str, str | int] = {
1650
+ "limit": limit,
1651
+ "offset": offset,
1652
+ "sort_order": sort_order,
1653
+ }
1654
+ if name:
1655
+ params["name"] = name
1656
+ if project_id:
1657
+ params["project_id"] = str(project_id)
1658
+
1659
+ response = self.client.get("/v0.1/world-models", params=params)
1660
+ response.raise_for_status()
1661
+ return response.json()
1662
+ except HTTPStatusError as e:
1663
+ raise WorldModelFetchError(
1664
+ f"Error listing world models: {e.response.status_code} - {e.response.text}"
1665
+ ) from e
1666
+ except Exception as e:
1667
+ raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
1668
+
1600
1669
  @retry(
1601
1670
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1602
1671
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1604,34 +1673,43 @@ class RestClient:
1604
1673
  )
1605
1674
  def search_world_models(
1606
1675
  self,
1607
- query: str,
1676
+ criteria: list[SearchCriterion] | None = None,
1608
1677
  size: int = 10,
1609
- total_search_size: int = 50,
1678
+ project_id: UUID | str | None = None,
1610
1679
  search_all_versions: bool = False,
1611
- ) -> list[str]:
1612
- """Search for world models.
1680
+ ) -> list[WorldModelResponse]:
1681
+ """Search world models using structured criteria.
1613
1682
 
1614
1683
  Args:
1615
- query: The search query.
1684
+ criteria: List of SearchCriterion objects with field, operator, and value.
1616
1685
  size: The number of results to return.
1617
- total_search_size: The number of results to search for.
1618
- search_all_versions: Whether to search all versions of the world model or just the latest one.
1686
+ project_id: Optional filter by project ID.
1687
+ search_all_versions: Whether to search all versions or just latest.
1619
1688
 
1620
1689
  Returns:
1621
- A list of world model names.
1690
+ A list of world model responses.
1691
+
1692
+ Example:
1693
+ from futurehouse_client.models.rest import SearchCriterion, SearchOperator
1694
+ criteria = [
1695
+ SearchCriterion(field="name", operator=SearchOperator.CONTAINS, value="chemistry"),
1696
+ SearchCriterion(field="email", operator=SearchOperator.CONTAINS, value="tyler"),
1697
+ ]
1698
+ results = client.search_world_models(criteria=criteria, size=20)
1622
1699
  """
1623
1700
  try:
1624
- # Use the consolidated endpoint with search parameters
1625
- response = self.client.get(
1626
- "/v0.1/world-models",
1627
- params={
1628
- "q": query,
1629
- "size": size,
1630
- "search_all_versions": search_all_versions,
1631
- },
1701
+ payload = WorldModelSearchPayload(
1702
+ criteria=criteria or [],
1703
+ size=size,
1704
+ project_id=project_id,
1705
+ search_all_versions=search_all_versions,
1706
+ )
1707
+
1708
+ response = self.client.post(
1709
+ "/v0.1/world-models/search",
1710
+ json=payload.model_dump(mode="json"),
1632
1711
  )
1633
1712
  response.raise_for_status()
1634
- # The new endpoint returns a list of models directly
1635
1713
  return response.json()
1636
1714
  except HTTPStatusError as e:
1637
1715
  raise WorldModelFetchError(
@@ -13,7 +13,7 @@ from .app import (
13
13
  TaskResponse,
14
14
  TaskResponseVerbose,
15
15
  )
16
- from .rest import WorldModel, WorldModelResponse
16
+ from .rest import TrajectoryPatchRequest, WorldModel, WorldModelResponse
17
17
 
18
18
  __all__ = [
19
19
  "AuthType",
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "TaskRequest",
30
30
  "TaskResponse",
31
31
  "TaskResponseVerbose",
32
+ "TrajectoryPatchRequest",
32
33
  "WorldModel",
33
34
  "WorldModelResponse",
34
35
  ]
@@ -27,13 +27,17 @@ class InitialState(BaseState):
27
27
 
28
28
  class ASVState(BaseState, Generic[T]):
29
29
  action: OpResult[T] = Field()
30
- next_agent_state: Any = Field()
30
+ next_state: Any = Field()
31
31
  value: float = Field()
32
32
 
33
33
  @field_serializer("action")
34
34
  def serialize_action(self, action: OpResult[T]) -> dict:
35
35
  return action.to_dict()
36
36
 
37
+ @field_serializer("next_state")
38
+ def serialize_next_state(self, state: Any) -> str:
39
+ return str(state)
40
+
37
41
 
38
42
  class EnvResetState(BaseState):
39
43
  observations: list[Message] = Field()