futurehouse-client 0.4.1.dev95__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from futurehouse_client.models.rest import (
19
19
  FinalEnvironmentRequest,
20
20
  StoreAgentStatePostRequest,
21
21
  StoreEnvironmentFrameRequest,
22
+ TrajectoryPatchRequest,
22
23
  )
23
24
  from futurehouse_client.utils.monitoring import (
24
25
  external_trace,
@@ -318,3 +319,52 @@ class JobClient:
318
319
  f"Unexpected error storing environment frame for state {state_identifier}",
319
320
  )
320
321
  raise
322
+
323
+ async def patch_trajectory(
324
+ self,
325
+ public: bool | None = None,
326
+ shared_with: list[int] | None = None,
327
+ notification_enabled: bool | None = None,
328
+ notification_type: str | None = None,
329
+ min_estimated_time: float | None = None,
330
+ max_estimated_time: float | None = None,
331
+ ) -> None:
332
+ data = TrajectoryPatchRequest(
333
+ public=public,
334
+ shared_with=shared_with,
335
+ notification_enabled=notification_enabled,
336
+ notification_type=notification_type,
337
+ min_estimated_time=min_estimated_time,
338
+ max_estimated_time=max_estimated_time,
339
+ )
340
+ try:
341
+ async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
342
+ url = f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}"
343
+ headers = {
344
+ "Authorization": f"Bearer {self.oauth_jwt}",
345
+ "x-trajectory-id": self.trajectory_id,
346
+ }
347
+ response = await client.patch(
348
+ url=url,
349
+ json=data.model_dump(mode="json", exclude_none=True),
350
+ headers=headers,
351
+ )
352
+ response.raise_for_status()
353
+ logger.debug("Trajectory updated successfully")
354
+ except httpx.HTTPStatusError as e:
355
+ logger.exception(
356
+ "HTTP error while patching trajectory. "
357
+ f"Status code: {e.response.status_code}, "
358
+ f"Response: {e.response.text}",
359
+ )
360
+ except httpx.TimeoutException:
361
+ logger.exception(
362
+ f"Timeout while patching trajectory after {self.REQUEST_TIMEOUT}s",
363
+ )
364
+ raise
365
+ except httpx.NetworkError:
366
+ logger.exception("Network error while patching trajectory")
367
+ raise
368
+ except Exception:
369
+ logger.exception("Unexpected error while patching trajectory")
370
+ raise
@@ -1,3 +1,4 @@
1
+ # ruff: noqa: PLR0915
1
2
  import ast
2
3
  import asyncio
3
4
  import base64
@@ -54,12 +55,14 @@ from futurehouse_client.models.app import (
54
55
  from futurehouse_client.models.rest import (
55
56
  DiscoveryResponse,
56
57
  ExecutionStatus,
58
+ SearchCriterion,
57
59
  UserAgentRequest,
58
60
  UserAgentRequestPostPayload,
59
61
  UserAgentRequestStatus,
60
62
  UserAgentResponsePayload,
61
63
  WorldModel,
62
64
  WorldModelResponse,
65
+ WorldModelSearchPayload,
63
66
  )
64
67
  from futurehouse_client.utils.auth import RefreshingJWT
65
68
  from futurehouse_client.utils.general import (
@@ -332,16 +335,23 @@ class RestClient(DataStorageMethods):
332
335
  raise ValueError(f"Organization '{organization}' not found.")
333
336
  return filtered_orgs
334
337
 
338
+ @retry(
339
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
340
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
341
+ retry=retry_if_connection_error,
342
+ before_sleep=before_sleep_log(logger, logging.WARNING),
343
+ )
335
344
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
336
- try:
337
- response = self.client.get(
338
- f"/v0.1/crows/{name}/organizations/{organization}"
339
- )
340
- response.raise_for_status()
341
- return response.json()
342
- except Exception as e:
343
- raise JobFetchError(f"Error checking job: {e!r}.") from e
345
+ response = self.client.get(f"/v0.1/crows/{name}/organizations/{organization}")
346
+ response.raise_for_status()
347
+ return response.json()
344
348
 
349
+ @retry(
350
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
351
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
352
+ retry=retry_if_connection_error,
353
+ before_sleep=before_sleep_log(logger, logging.WARNING),
354
+ )
345
355
  def _fetch_my_orgs(self) -> list[str]:
346
356
  response = self.client.get(f"/v0.1/organizations?filter={True}")
347
357
  response.raise_for_status()
@@ -699,10 +709,12 @@ class RestClient(DataStorageMethods):
699
709
 
700
710
  async def arun_tasks_until_done(
701
711
  self,
702
- task_data: TaskRequest
703
- | dict[str, Any]
704
- | Collection[TaskRequest]
705
- | Collection[dict[str, Any]],
712
+ task_data: (
713
+ TaskRequest
714
+ | dict[str, Any]
715
+ | Collection[TaskRequest]
716
+ | Collection[dict[str, Any]]
717
+ ),
706
718
  verbose: bool = False,
707
719
  progress_bar: bool = False,
708
720
  concurrency: int = 10,
@@ -770,10 +782,12 @@ class RestClient(DataStorageMethods):
770
782
 
771
783
  def run_tasks_until_done(
772
784
  self,
773
- task_data: TaskRequest
774
- | dict[str, Any]
775
- | Collection[TaskRequest]
776
- | Collection[dict[str, Any]],
785
+ task_data: (
786
+ TaskRequest
787
+ | dict[str, Any]
788
+ | Collection[TaskRequest]
789
+ | Collection[dict[str, Any]]
790
+ ),
777
791
  verbose: bool = False,
778
792
  progress_bar: bool = False,
779
793
  timeout: int = DEFAULT_AGENT_TIMEOUT,
@@ -846,12 +860,9 @@ class RestClient(DataStorageMethods):
846
860
  )
847
861
  def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
848
862
  """Get the status of a build."""
849
- try:
850
- build_id = build_id or self.build_id
851
- response = self.client.get(f"/v0.1/builds/{build_id}")
852
- response.raise_for_status()
853
- except Exception as e:
854
- raise JobFetchError(f"Error getting build status: {e!r}.") from e
863
+ build_id = build_id or self.build_id
864
+ response = self.client.get(f"/v0.1/builds/{build_id}")
865
+ response.raise_for_status()
855
866
  return response.json()
856
867
 
857
868
  # TODO: Refactor later so we don't have to ignore PLR0915
@@ -861,7 +872,7 @@ class RestClient(DataStorageMethods):
861
872
  retry=retry_if_connection_error,
862
873
  before_sleep=before_sleep_log(logger, logging.WARNING),
863
874
  )
864
- def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]: # noqa: PLR0915
875
+ def create_job(self, config: JobDeploymentConfig) -> dict[str, Any]:
865
876
  """Creates a futurehouse job deployment from the environment and environment files.
866
877
 
867
878
  Args:
@@ -1606,6 +1617,56 @@ class RestClient(DataStorageMethods):
1606
1617
  except Exception as e:
1607
1618
  raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
1608
1619
 
1620
+ @retry(
1621
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1622
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1623
+ retry=retry_if_connection_error,
1624
+ )
1625
+ def list_world_models(
1626
+ self,
1627
+ name: str | None = None,
1628
+ project_id: UUID | str | None = None,
1629
+ limit: int = 150,
1630
+ offset: int = 0,
1631
+ sort_order: str = "asc",
1632
+ ) -> list[WorldModelResponse]:
1633
+ """List world models with different behavior based on filters.
1634
+
1635
+ When filtering by name: returns only the latest version for that name.
1636
+ When filtering by project_id (without name): returns all versions for that project.
1637
+ When no filters: returns latest version of each world model.
1638
+
1639
+ Args:
1640
+ name: Filter by world model name.
1641
+ project_id: Filter by project ID.
1642
+ limit: The maximum number of models to return.
1643
+ offset: Number of results to skip for pagination.
1644
+ sort_order: Sort order 'asc' or 'desc'.
1645
+
1646
+ Returns:
1647
+ A list of world model dictionaries.
1648
+ """
1649
+ try:
1650
+ params: dict[str, str | int] = {
1651
+ "limit": limit,
1652
+ "offset": offset,
1653
+ "sort_order": sort_order,
1654
+ }
1655
+ if name:
1656
+ params["name"] = name
1657
+ if project_id:
1658
+ params["project_id"] = str(project_id)
1659
+
1660
+ response = self.client.get("/v0.1/world-models", params=params)
1661
+ response.raise_for_status()
1662
+ return response.json()
1663
+ except HTTPStatusError as e:
1664
+ raise WorldModelFetchError(
1665
+ f"Error listing world models: {e.response.status_code} - {e.response.text}"
1666
+ ) from e
1667
+ except Exception as e:
1668
+ raise WorldModelFetchError(f"An unexpected error occurred: {e!r}.") from e
1669
+
1609
1670
  @retry(
1610
1671
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1611
1672
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -1613,31 +1674,41 @@ class RestClient(DataStorageMethods):
1613
1674
  )
1614
1675
  def search_world_models(
1615
1676
  self,
1616
- query: str,
1677
+ criteria: list[SearchCriterion] | None = None,
1617
1678
  size: int = 10,
1618
- total_search_size: int = 50,
1679
+ project_id: UUID | str | None = None,
1619
1680
  search_all_versions: bool = False,
1620
- ) -> list[str]:
1621
- """Search for world models.
1681
+ ) -> list[WorldModelResponse]:
1682
+ """Search world models using structured criteria.
1622
1683
 
1623
1684
  Args:
1624
- query: The search query.
1685
+ criteria: List of SearchCriterion objects with field, operator, and value.
1625
1686
  size: The number of results to return.
1626
- total_search_size: The number of results to search for.
1627
- search_all_versions: Whether to search all versions of the world model or just the latest one.
1687
+ project_id: Optional filter by project ID.
1688
+ search_all_versions: Whether to search all versions or just latest.
1628
1689
 
1629
1690
  Returns:
1630
- A list of world model names.
1691
+ A list of world model responses.
1692
+
1693
+ Example:
1694
+ from futurehouse_client.models.rest import SearchCriterion, SearchOperator
1695
+ criteria = [
1696
+ SearchCriterion(field="name", operator=SearchOperator.CONTAINS, value="chemistry"),
1697
+ SearchCriterion(field="email", operator=SearchOperator.CONTAINS, value="tyler"),
1698
+ ]
1699
+ results = client.search_world_models(criteria=criteria, size=20)
1631
1700
  """
1632
1701
  try:
1633
- response = self.client.get(
1634
- "/v0.1/world-models/search/",
1635
- params={
1636
- "query": query,
1637
- "size": size,
1638
- "total_search_size": total_search_size,
1639
- "search_all_versions": search_all_versions,
1640
- },
1702
+ payload = WorldModelSearchPayload(
1703
+ criteria=criteria or [],
1704
+ size=size,
1705
+ project_id=project_id,
1706
+ search_all_versions=search_all_versions,
1707
+ )
1708
+
1709
+ response = self.client.post(
1710
+ "/v0.1/world-models/search",
1711
+ json=payload.model_dump(mode="json"),
1641
1712
  )
1642
1713
  response.raise_for_status()
1643
1714
  return response.json()
@@ -1754,22 +1825,19 @@ class RestClient(DataStorageMethods):
1754
1825
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1755
1826
  retry=retry_if_connection_error,
1756
1827
  )
1757
- def get_project_by_name(self, name: str) -> UUID:
1828
+ def get_project_by_name(self, name: str, limit: int = 2) -> UUID | list[UUID]:
1758
1829
  """Get a project UUID by name.
1759
1830
 
1760
1831
  Args:
1761
1832
  name: The name of the project to find
1833
+ limit: Maximum number of projects to return
1762
1834
 
1763
1835
  Returns:
1764
- UUID of the project as a string
1765
-
1766
- Raises:
1767
- ProjectError: If no project is found, multiple projects are found, or there's an error
1836
+ UUID of the project as a string or a list of UUIDs if multiple projects are found
1768
1837
  """
1769
1838
  try:
1770
- # Get projects filtered by name (backend now filters by name and owner)
1771
1839
  response = self.client.get(
1772
- "/v0.1/projects", params={"limit": 2, "name": name}
1840
+ "/v0.1/projects", params={"limit": limit, "name": name}
1773
1841
  )
1774
1842
  response.raise_for_status()
1775
1843
  projects = response.json()
@@ -1782,32 +1850,33 @@ class RestClient(DataStorageMethods):
1782
1850
  if len(projects) == 0:
1783
1851
  raise ProjectError(f"No project found with name '{name}'")
1784
1852
  if len(projects) > 1:
1785
- raise ProjectError(
1853
+ logger.warning(
1786
1854
  f"Multiple projects found with name '{name}'. Found {len(projects)} projects."
1787
1855
  )
1788
1856
 
1789
- return UUID(projects[0]["id"])
1857
+ ids = [UUID(project["id"]) for project in projects]
1858
+ return ids[0] if len(ids) == 1 else ids
1790
1859
 
1791
1860
  @retry(
1792
1861
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1793
1862
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1794
1863
  retry=retry_if_connection_error,
1795
1864
  )
1796
- async def aget_project_by_name(self, name: str) -> UUID:
1865
+ async def aget_project_by_name(
1866
+ self, name: str, limit: int = 2
1867
+ ) -> UUID | list[UUID]:
1797
1868
  """Asynchronously get a project UUID by name.
1798
1869
 
1799
1870
  Args:
1800
1871
  name: The name of the project to find
1872
+ limit: Maximum number of projects to return
1801
1873
 
1802
1874
  Returns:
1803
- UUID of the project as a string
1804
-
1805
- Raises:
1806
- ProjectError: If no project is found, multiple projects are found, or there's an error
1875
+ UUID of the project as a string or a list of UUIDs if multiple projects are found
1807
1876
  """
1808
1877
  try:
1809
1878
  response = await self.async_client.get(
1810
- "/v0.1/projects", params={"limit": 2, "name": name}
1879
+ "/v0.1/projects", params={"limit": limit, "name": name}
1811
1880
  )
1812
1881
  response.raise_for_status()
1813
1882
  projects = response.json()
@@ -1816,11 +1885,12 @@ class RestClient(DataStorageMethods):
1816
1885
  if len(projects) == 0:
1817
1886
  raise ProjectError(f"No project found with name '{name}'")
1818
1887
  if len(projects) > 1:
1819
- raise ProjectError(
1888
+ logger.warning(
1820
1889
  f"Multiple projects found with name '{name}'. Found {len(projects)} projects."
1821
1890
  )
1822
1891
 
1823
- return UUID(projects[0]["id"])
1892
+ ids = [UUID(project["id"]) for project in projects]
1893
+ return ids[0] if len(ids) == 1 else ids
1824
1894
 
1825
1895
  @retry(
1826
1896
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
@@ -13,7 +13,7 @@ from .app import (
13
13
  TaskResponse,
14
14
  TaskResponseVerbose,
15
15
  )
16
- from .rest import WorldModel, WorldModelResponse
16
+ from .rest import TrajectoryPatchRequest, WorldModel, WorldModelResponse
17
17
 
18
18
  __all__ = [
19
19
  "AuthType",
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "TaskRequest",
30
30
  "TaskResponse",
31
31
  "TaskResponseVerbose",
32
+ "TrajectoryPatchRequest",
32
33
  "WorldModel",
33
34
  "WorldModelResponse",
34
35
  ]
@@ -27,10 +27,7 @@ if TYPE_CHECKING:
27
27
  MAX_CROW_JOB_RUN_TIMEOUT = 60 * 60 * 24 # 24 hours in sec
28
28
  MIN_CROW_JOB_RUN_TIMEOUT = 0 # sec
29
29
 
30
-
31
- class PythonVersion(StrEnum):
32
- V3_11 = "3.11"
33
- V3_12 = "3.12"
30
+ DEFAULT_PYTHON_VERSION_USED_FOR_JOB_BUILDS = "3.13"
34
31
 
35
32
 
36
33
  class AuthType(StrEnum):
@@ -420,9 +417,9 @@ class JobDeploymentConfig(BaseModel):
420
417
  description="The configuration for the cloud run container.",
421
418
  )
422
419
 
423
- python_version: PythonVersion = Field(
424
- default=PythonVersion.V3_12,
425
- description="The python version your docker image should build with.",
420
+ python_version: str = Field(
421
+ default=DEFAULT_PYTHON_VERSION_USED_FOR_JOB_BUILDS,
422
+ description="The python version your docker image should build with (e.g., '3.11', '3.12', '3.13').",
426
423
  )
427
424
 
428
425
  agent: Agent | AgentConfig | str = Field(
@@ -34,6 +34,10 @@ class DataStorageEntry(BaseModel):
34
34
  default=None,
35
35
  description="ID of the parent entry if this is a sub-entry for hierarchical storage",
36
36
  )
37
+ project_id: UUID | None = Field(
38
+ default=None,
39
+ description="ID of the project this data storage entry belongs to",
40
+ )
37
41
  dataset_id: UUID | None = Field(
38
42
  default=None,
39
43
  description="ID of the dataset this entry belongs to",
@@ -79,8 +83,8 @@ class DataStorageLocationPayload(BaseModel):
79
83
  location: str | None = None
80
84
 
81
85
 
82
- class DataStorageLocationDetails(BaseModel):
83
- """Model representing the location details within a DataStorageLocations object."""
86
+ class DataStorageLocationConfig(BaseModel):
87
+ """Model representing the location configuration within a DataStorageLocations object."""
84
88
 
85
89
  storage_type: str = Field(description="Type of storage (e.g., 'gcs', 'pg_table')")
86
90
  content_type: str = Field(description="Type of content stored")
@@ -89,15 +93,19 @@ class DataStorageLocationDetails(BaseModel):
89
93
  location: str | None = Field(
90
94
  default=None, description="Location path or identifier"
91
95
  )
96
+ signed_url: str | None = Field(
97
+ default=None,
98
+ description="Signed URL for uploading/downloading the file to/from GCS",
99
+ )
92
100
 
93
101
 
94
- class DataStorageLocations(BaseModel):
102
+ class DataStorageLocation(BaseModel):
95
103
  """Model representing storage locations for a data storage entry."""
96
104
 
97
105
  id: UUID = Field(description="Unique identifier for the storage locations")
98
106
  data_storage_id: UUID = Field(description="ID of the associated data storage entry")
99
- storage_config: DataStorageLocationDetails = Field(
100
- description="Storage configuration details"
107
+ storage_config: DataStorageLocationConfig = Field(
108
+ description="Storage location configuration"
101
109
  )
102
110
  created_at: datetime = Field(description="Timestamp when the location was created")
103
111
 
@@ -106,13 +114,9 @@ class DataStorageResponse(BaseModel):
106
114
  """Response model for data storage operations."""
107
115
 
108
116
  data_storage: DataStorageEntry = Field(description="The created data storage entry")
109
- storage_location: DataStorageLocations = Field(
117
+ storage_locations: list[DataStorageLocation] = Field(
110
118
  description="Storage location for this data entry"
111
119
  )
112
- signed_url: str | None = Field(
113
- default=None,
114
- description="Signed URL for uploading/downloading the file to/from GCS",
115
- )
116
120
 
117
121
 
118
122
  class DataStorageRequestPayload(BaseModel):
@@ -131,6 +135,10 @@ class DataStorageRequestPayload(BaseModel):
131
135
  parent_id: UUID | None = Field(
132
136
  default=None, description="ID of the parent entry for hierarchical storage"
133
137
  )
138
+ project_id: UUID | None = Field(
139
+ default=None,
140
+ description="ID of the project this data storage entry belongs to",
141
+ )
134
142
  dataset_id: UUID | None = Field(
135
143
  default=None,
136
144
  description="ID of existing dataset to add entry to, or None to create new dataset",
@@ -144,6 +152,19 @@ class DataStorageRequestPayload(BaseModel):
144
152
  )
145
153
 
146
154
 
155
+ class CreateDatasetPayload(BaseModel):
156
+ """Payload for creating a dataset."""
157
+
158
+ id: UUID | None = Field(
159
+ default=None,
160
+ description="ID of the dataset to create, or None to create a new dataset",
161
+ )
162
+ name: str = Field(description="Name of the dataset")
163
+ description: str | None = Field(
164
+ default=None, description="Description of the dataset"
165
+ )
166
+
167
+
147
168
  class ManifestEntry(BaseModel):
148
169
  """Model representing a single entry in a manifest file."""
149
170
 
@@ -23,6 +23,15 @@ class StoreEnvironmentFrameRequest(BaseModel):
23
23
  trajectory_timestep: int
24
24
 
25
25
 
26
+ class TrajectoryPatchRequest(BaseModel):
27
+ public: bool | None = None
28
+ shared_with: list[int] | None = None
29
+ notification_enabled: bool | None = None
30
+ notification_type: str | None = None
31
+ min_estimated_time: float | None = None
32
+ max_estimated_time: float | None = None
33
+
34
+
26
35
  class ExecutionStatus(StrEnum):
27
36
  QUEUED = auto()
28
37
  IN_PROGRESS = "in progress"
@@ -54,7 +63,37 @@ class WorldModel(BaseModel):
54
63
  project_id: UUID | str | None = None
55
64
 
56
65
 
57
- class WorldModelResponse(BaseModel):
66
+ class SearchOperator(StrEnum):
67
+ """Operators for structured search criteria."""
68
+
69
+ EQUALS = "equals"
70
+ CONTAINS = "contains"
71
+ STARTS_WITH = "starts_with"
72
+ ENDS_WITH = "ends_with"
73
+ GREATER_THAN = "greater_than"
74
+ LESS_THAN = "less_than"
75
+ BETWEEN = "between"
76
+ IN = "in"
77
+
78
+
79
+ class SearchCriterion(BaseModel):
80
+ """A single search criterion with field, operator, and value."""
81
+
82
+ field: str
83
+ operator: SearchOperator
84
+ value: str | list[str] | bool
85
+
86
+
87
+ class WorldModelSearchPayload(BaseModel):
88
+ """Payload for structured world model search."""
89
+
90
+ criteria: list[SearchCriterion]
91
+ size: int = 10
92
+ project_id: UUID | str | None = None
93
+ search_all_versions: bool = False
94
+
95
+
96
+ class WorldModelResponse(WorldModel):
58
97
  """
59
98
  Response model for a world model snapshot.
60
99
 
@@ -62,13 +101,8 @@ class WorldModelResponse(BaseModel):
62
101
  """
63
102
 
64
103
  id: UUID | str
65
- prior: UUID | str | None
66
- name: str
67
- description: str | None
68
- content: str
69
- trajectory_id: UUID | str | None
104
+ name: str # type: ignore[mutable-override] # The API always returns a non-optional name, overriding the base model's optional field.
70
105
  email: str | None
71
- model_metadata: JsonValue | None
72
106
  enabled: bool
73
107
  created_at: datetime
74
108
 
@@ -132,3 +166,10 @@ class DiscoveryResponse(BaseModel):
132
166
  associated_trajectories: list[UUID | str]
133
167
  validation_level: int
134
168
  created_at: datetime
169
+
170
+
171
+ class DataStorageSearchPayload(BaseModel):
172
+ """Payload for structured data storage search."""
173
+
174
+ criteria: list[SearchCriterion]
175
+ size: int = 10
@@ -1,22 +1,31 @@
1
1
  import asyncio
2
- from collections.abc import Awaitable, Iterable
2
+ from collections.abc import Awaitable, Callable, Iterable
3
3
  from typing import TypeVar
4
4
 
5
5
  from httpx import (
6
6
  CloseError,
7
7
  ConnectError,
8
8
  ConnectTimeout,
9
+ HTTPStatusError,
9
10
  NetworkError,
10
11
  ReadError,
11
12
  ReadTimeout,
12
13
  RemoteProtocolError,
14
+ codes,
13
15
  )
14
16
  from requests.exceptions import RequestException, Timeout
15
- from tenacity import retry_if_exception_type
17
+ from tenacity import RetryCallState
16
18
  from tqdm.asyncio import tqdm
17
19
 
18
20
  T = TypeVar("T")
19
21
 
22
+ RETRYABLE_HTTP_STATUS_CODES = {
23
+ codes.TOO_MANY_REQUESTS,
24
+ codes.INTERNAL_SERVER_ERROR,
25
+ codes.BAD_GATEWAY,
26
+ codes.SERVICE_UNAVAILABLE,
27
+ codes.GATEWAY_TIMEOUT,
28
+ }
20
29
 
21
30
  _BASE_CONNECTION_ERRORS = (
22
31
  # From requests
@@ -33,12 +42,32 @@ _BASE_CONNECTION_ERRORS = (
33
42
  CloseError,
34
43
  )
35
44
 
36
- retry_if_connection_error = retry_if_exception_type(_BASE_CONNECTION_ERRORS)
37
45
 
46
+ def create_retry_if_connection_error(
47
+ *additional_exceptions,
48
+ ) -> Callable[[RetryCallState], bool]:
49
+ """Create a retry condition with base connection errors, HTTP status errors, plus additional exceptions."""
38
50
 
39
- def create_retry_if_connection_error(*additional_exceptions):
40
- """Create a retry condition with base connection errors plus additional exceptions."""
41
- return retry_if_exception_type(_BASE_CONNECTION_ERRORS + additional_exceptions)
51
+ def status_retries_with_exceptions(retry_state: RetryCallState) -> bool:
52
+ if retry_state.outcome is not None and hasattr(
53
+ retry_state.outcome, "exception"
54
+ ):
55
+ exception = retry_state.outcome.exception()
56
+ # connection errors
57
+ if isinstance(exception, _BASE_CONNECTION_ERRORS):
58
+ return True
59
+ # custom exceptions provided
60
+ if additional_exceptions and isinstance(exception, additional_exceptions):
61
+ return True
62
+ # any http exceptions
63
+ if isinstance(exception, HTTPStatusError):
64
+ return exception.response.status_code in RETRYABLE_HTTP_STATUS_CODES
65
+ return False
66
+
67
+ return status_retries_with_exceptions
68
+
69
+
70
+ retry_if_connection_error = create_retry_if_connection_error()
42
71
 
43
72
 
44
73
  async def gather_with_concurrency(