futurehouse-client 0.3.20.dev225__tar.gz → 0.3.20.dev295__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {futurehouse_client-0.3.20.dev225/futurehouse_client.egg-info → futurehouse_client-0.3.20.dev295}/PKG-INFO +1 -1
  2. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/clients/rest_client.py +539 -0
  3. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/models/app.py +79 -0
  4. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/models/client.py +5 -1
  5. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/models/rest.py +49 -1
  6. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/version.py +2 -2
  7. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295/futurehouse_client.egg-info}/PKG-INFO +1 -1
  8. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/tests/test_rest.py +283 -13
  9. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/LICENSE +0 -0
  10. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/README.md +0 -0
  11. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/docs/__init__.py +0 -0
  12. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/docs/client_notebook.ipynb +0 -0
  13. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/__init__.py +0 -0
  14. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/clients/__init__.py +0 -0
  15. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/clients/job_client.py +0 -0
  16. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/models/__init__.py +0 -0
  17. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/py.typed +0 -0
  18. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/utils/__init__.py +0 -0
  19. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/utils/auth.py +0 -0
  20. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/utils/general.py +0 -0
  21. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/utils/module_utils.py +0 -0
  22. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client/utils/monitoring.py +0 -0
  23. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client.egg-info/SOURCES.txt +0 -0
  24. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client.egg-info/dependency_links.txt +0 -0
  25. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client.egg-info/requires.txt +0 -0
  26. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/futurehouse_client.egg-info/top_level.txt +0 -0
  27. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/pyproject.toml +0 -0
  28. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/setup.cfg +0 -0
  29. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev295}/tests/test_client.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.20.dev225
3
+ Version: 0.3.20.dev295
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  License: Apache License
@@ -14,6 +14,7 @@ import time
14
14
  import uuid
15
15
  from collections.abc import Collection
16
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from http import HTTPStatus
17
18
  from pathlib import Path
18
19
  from types import ModuleType
19
20
  from typing import Any, ClassVar, cast
@@ -54,9 +55,14 @@ from futurehouse_client.models.app import (
54
55
  TaskRequest,
55
56
  TaskResponse,
56
57
  TaskResponseVerbose,
58
+ TrajectoryQueryParams,
57
59
  )
58
60
  from futurehouse_client.models.rest import (
59
61
  ExecutionStatus,
62
+ UserAgentRequest,
63
+ UserAgentRequestPostPayload,
64
+ UserAgentRequestStatus,
65
+ UserAgentResponsePayload,
60
66
  WorldModel,
61
67
  WorldModelResponse,
62
68
  )
@@ -105,6 +111,22 @@ class JobCreationError(RestClientError):
105
111
  """Raised when there's an error creating a job."""
106
112
 
107
113
 
114
+ class UserAgentRequestError(RestClientError):
115
+ """Base exception for User Agent Request operations."""
116
+
117
+
118
+ class UserAgentRequestFetchError(UserAgentRequestError):
119
+ """Raised when there's an error fetching a user agent request."""
120
+
121
+
122
+ class UserAgentRequestCreationError(UserAgentRequestError):
123
+ """Raised when there's an error creating a user agent request."""
124
+
125
+
126
+ class UserAgentRequestResponseError(UserAgentRequestError):
127
+ """Raised when there's an error responding to a user agent request."""
128
+
129
+
108
130
  class WorldModelFetchError(RestClientError):
109
131
  """Raised when there's an error fetching a world model."""
110
132
 
@@ -535,6 +557,52 @@ class RestClient:
535
557
  return verbose_response
536
558
  return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
537
559
 
560
+ @retry(
561
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
562
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
563
+ retry=retry_if_connection_error,
564
+ before_sleep=before_sleep_log(logger, logging.WARNING),
565
+ )
566
+ def cancel_task(self, task_id: str | None = None) -> bool:
567
+ """Cancel a specific task/trajectory."""
568
+ task_id = task_id or self.trajectory_id
569
+ url = f"/v0.1/trajectories/{task_id}/cancel"
570
+ full_url = f"{self.base_url}{url}"
571
+
572
+ with external_trace(
573
+ url=full_url,
574
+ method="POST",
575
+ library="httpx",
576
+ custom_params={
577
+ "operation": "cancel_job",
578
+ "job_id": task_id,
579
+ },
580
+ ):
581
+ get_task_response = self.get_task(task_id)
582
+ # cancel if task is in progress
583
+ if get_task_response.status == ExecutionStatus.IN_PROGRESS.value:
584
+ response = self.client.post(url)
585
+ try:
586
+ response.raise_for_status()
587
+
588
+ except HTTPStatusError as e:
589
+ if e.response.status_code in {
590
+ HTTPStatus.UNAUTHORIZED,
591
+ HTTPStatus.FORBIDDEN,
592
+ }:
593
+ raise PermissionError(
594
+ f"Error canceling task: Permission denied for task {task_id}"
595
+ ) from e
596
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
597
+ raise TaskFetchError(
598
+ f"Error canceling task: Trajectory not found for task {task_id}"
599
+ ) from e
600
+ raise
601
+
602
+ get_task_response = self.get_task(task_id)
603
+ return get_task_response.status == ExecutionStatus.CANCELLED.value
604
+ return False
605
+
538
606
  @retry(
539
607
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
540
608
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -927,6 +995,13 @@ class RestClient:
927
995
  if config.task_queues_config
928
996
  else None
929
997
  ),
998
+ "user_input_config": (
999
+ json.dumps([
1000
+ entity.model_dump() for entity in config.user_input_config
1001
+ ])
1002
+ if config.user_input_config
1003
+ else None
1004
+ ),
930
1005
  }
931
1006
  response = self.multipart_client.post(
932
1007
  "/v0.1/builds",
@@ -1787,6 +1862,470 @@ class RestClient:
1787
1862
  except Exception as e:
1788
1863
  raise ProjectError(f"Error adding trajectory to project: {e}") from e
1789
1864
 
1865
+ @retry(
1866
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1867
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1868
+ retry=retry_if_connection_error,
1869
+ before_sleep=before_sleep_log(logger, logging.WARNING),
1870
+ )
1871
+ def get_tasks(
1872
+ self,
1873
+ query_params: TrajectoryQueryParams | None = None,
1874
+ *,
1875
+ project_id: UUID | None = None,
1876
+ name: str | None = None,
1877
+ user: str | None = None,
1878
+ limit: int = 50,
1879
+ offset: int = 0,
1880
+ sort_by: str = "created_at",
1881
+ sort_order: str = "desc",
1882
+ ) -> list[dict[str, Any]]:
1883
+ """Fetches trajectories with applied filtering.
1884
+
1885
+ Args:
1886
+ query_params: Optional TrajectoryQueryParams model with all parameters
1887
+ project_id: Optional project ID to filter trajectories by
1888
+ name: Optional name filter for trajectories
1889
+ user: Optional user email filter for trajectories
1890
+ limit: Maximum number of trajectories to return (default: 50, max: 200)
1891
+ offset: Number of trajectories to skip for pagination (default: 0)
1892
+ sort_by: Field to sort by, either "created_at" or "name" (default: "created_at")
1893
+ sort_order: Sort order, either "asc" or "desc" (default: "desc")
1894
+
1895
+ Returns:
1896
+ List of trajectory dictionaries
1897
+
1898
+ Raises:
1899
+ TaskFetchError: If there's an error fetching trajectories
1900
+ """
1901
+ try:
1902
+ if query_params is not None:
1903
+ params = query_params.to_query_params()
1904
+ else:
1905
+ params_model = TrajectoryQueryParams(
1906
+ project_id=project_id,
1907
+ name=name,
1908
+ user=user,
1909
+ limit=limit,
1910
+ offset=offset,
1911
+ sort_by=sort_by,
1912
+ sort_order=sort_order,
1913
+ )
1914
+ params = params_model.to_query_params()
1915
+
1916
+ response = self.client.get("/v0.1/trajectories", params=params)
1917
+ response.raise_for_status()
1918
+ return response.json()
1919
+ except HTTPStatusError as e:
1920
+ if e.response.status_code in {401, 403}:
1921
+ raise PermissionError(
1922
+ "Error getting trajectories: Permission denied"
1923
+ ) from e
1924
+ raise TaskFetchError(
1925
+ f"Error getting trajectories: {e.response.status_code} - {e.response.text}"
1926
+ ) from e
1927
+ except Exception as e:
1928
+ raise TaskFetchError(f"Error getting trajectories: {e!r}") from e
1929
+
1930
+ @retry(
1931
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1932
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1933
+ retry=retry_if_connection_error,
1934
+ before_sleep=before_sleep_log(logger, logging.WARNING),
1935
+ )
1936
+ async def aget_tasks(
1937
+ self,
1938
+ query_params: TrajectoryQueryParams | None = None,
1939
+ *,
1940
+ project_id: UUID | None = None,
1941
+ name: str | None = None,
1942
+ user: str | None = None,
1943
+ limit: int = 50,
1944
+ offset: int = 0,
1945
+ sort_by: str = "created_at",
1946
+ sort_order: str = "desc",
1947
+ ) -> list[dict[str, Any]]:
1948
+ """Asynchronously fetch trajectories with applied filtering.
1949
+
1950
+ Args:
1951
+ query_params: Optional TrajectoryQueryParams model with all parameters
1952
+ project_id: Optional project ID to filter trajectories by
1953
+ name: Optional name filter for trajectories
1954
+ user: Optional user email filter for trajectories
1955
+ limit: Maximum number of trajectories to return (default: 50, max: 200)
1956
+ offset: Number of trajectories to skip for pagination (default: 0)
1957
+ sort_by: Field to sort by, either "created_at" or "name" (default: "created_at")
1958
+ sort_order: Sort order, either "asc" or "desc" (default: "desc")
1959
+
1960
+ Returns:
1961
+ List of trajectory dictionaries
1962
+
1963
+ Raises:
1964
+ TaskFetchError: If there's an error fetching trajectories
1965
+ """
1966
+ try:
1967
+ if query_params is not None:
1968
+ params = query_params.to_query_params()
1969
+ else:
1970
+ params_model = TrajectoryQueryParams(
1971
+ project_id=project_id,
1972
+ name=name,
1973
+ user=user,
1974
+ limit=limit,
1975
+ offset=offset,
1976
+ sort_by=sort_by,
1977
+ sort_order=sort_order,
1978
+ )
1979
+ params = params_model.to_query_params()
1980
+
1981
+ response = await self.async_client.get("/v0.1/trajectories", params=params)
1982
+ response.raise_for_status()
1983
+ return response.json()
1984
+ except HTTPStatusError as e:
1985
+ if e.response.status_code in {401, 403}:
1986
+ raise PermissionError(
1987
+ "Error getting trajectories: Permission denied"
1988
+ ) from e
1989
+ raise TaskFetchError(
1990
+ f"Error getting trajectories: {e.response.status_code} - {e.response.text}"
1991
+ ) from e
1992
+ except Exception as e:
1993
+ raise TaskFetchError(f"Error getting trajectories: {e!r}") from e
1994
+
1995
+ @retry(
1996
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1997
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1998
+ retry=retry_if_connection_error,
1999
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2000
+ )
2001
+ def list_user_agent_requests(
2002
+ self,
2003
+ user_id: str | None = None,
2004
+ trajectory_id: UUID | None = None,
2005
+ request_status: UserAgentRequestStatus | None = None,
2006
+ limit: int = 50,
2007
+ offset: int = 0,
2008
+ ) -> list[UserAgentRequest]:
2009
+ """List user agent requests with optional filters.
2010
+
2011
+ Args:
2012
+ user_id: Filter requests by user ID. Defaults to the authenticated user's ID if not provided.
2013
+ trajectory_id: Filter requests by trajectory ID.
2014
+ request_status: Filter requests by status (e.g., PENDING).
2015
+ limit: Maximum number of requests to return.
2016
+ offset: Offset for pagination.
2017
+
2018
+ Returns:
2019
+ A list of user agent requests.
2020
+
2021
+ Raises:
2022
+ UserAgentRequestFetchError: If the API call fails.
2023
+ """
2024
+ params = {
2025
+ "user_id": user_id,
2026
+ "trajectory_id": str(trajectory_id) if trajectory_id else None,
2027
+ "request_status": request_status.value if request_status else None,
2028
+ "limit": limit,
2029
+ "offset": offset,
2030
+ }
2031
+ # Filter out None values
2032
+ params = {k: v for k, v in params.items() if v is not None}
2033
+ try:
2034
+ response = self.client.get("/v0.1/user-agent-requests", params=params)
2035
+ response.raise_for_status()
2036
+ return [UserAgentRequest.model_validate(item) for item in response.json()]
2037
+ except HTTPStatusError as e:
2038
+ raise UserAgentRequestFetchError(
2039
+ f"Error listing user agent requests: {e.response.status_code} - {e.response.text}"
2040
+ ) from e
2041
+ except Exception as e:
2042
+ raise UserAgentRequestFetchError(
2043
+ f"An unexpected error occurred: {e!r}"
2044
+ ) from e
2045
+
2046
+ @retry(
2047
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2048
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2049
+ retry=retry_if_connection_error,
2050
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2051
+ )
2052
+ async def alist_user_agent_requests(
2053
+ self,
2054
+ user_id: str | None = None,
2055
+ trajectory_id: UUID | None = None,
2056
+ request_status: UserAgentRequestStatus | None = None,
2057
+ limit: int = 50,
2058
+ offset: int = 0,
2059
+ ) -> list[UserAgentRequest]:
2060
+ """Asynchronously list user agent requests with optional filters.
2061
+
2062
+ Args:
2063
+ user_id: Filter requests by user ID. Defaults to the authenticated user's ID if not provided.
2064
+ trajectory_id: Filter requests by trajectory ID.
2065
+ request_status: Filter requests by status (e.g., PENDING).
2066
+ limit: Maximum number of requests to return.
2067
+ offset: Offset for pagination.
2068
+
2069
+ Returns:
2070
+ A list of user agent requests.
2071
+
2072
+ Raises:
2073
+ UserAgentRequestFetchError: If the API call fails.
2074
+ """
2075
+ params = {
2076
+ "user_id": user_id,
2077
+ "trajectory_id": str(trajectory_id) if trajectory_id else None,
2078
+ "request_status": request_status.value if request_status else None,
2079
+ "limit": limit,
2080
+ "offset": offset,
2081
+ }
2082
+ # Filter out None values
2083
+ params = {k: v for k, v in params.items() if v is not None}
2084
+ try:
2085
+ response = await self.async_client.get(
2086
+ "/v0.1/user-agent-requests", params=params
2087
+ )
2088
+ response.raise_for_status()
2089
+ return [UserAgentRequest.model_validate(item) for item in response.json()]
2090
+ except HTTPStatusError as e:
2091
+ raise UserAgentRequestFetchError(
2092
+ f"Error listing user agent requests: {e.response.status_code} - {e.response.text}"
2093
+ ) from e
2094
+ except Exception as e:
2095
+ raise UserAgentRequestFetchError(
2096
+ f"An unexpected error occurred: {e!r}"
2097
+ ) from e
2098
+
2099
+ @retry(
2100
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2101
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2102
+ retry=retry_if_connection_error,
2103
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2104
+ )
2105
+ def get_user_agent_request(self, request_id: UUID) -> UserAgentRequest:
2106
+ """Retrieve a single user agent request by its unique ID.
2107
+
2108
+ Args:
2109
+ request_id: The unique ID of the request.
2110
+
2111
+ Returns:
2112
+ The user agent request.
2113
+
2114
+ Raises:
2115
+ UserAgentRequestFetchError: If the API call fails or the request is not found.
2116
+ """
2117
+ try:
2118
+ response = self.client.get(f"/v0.1/user-agent-requests/{request_id}")
2119
+ response.raise_for_status()
2120
+ return UserAgentRequest.model_validate(response.json())
2121
+ except HTTPStatusError as e:
2122
+ if e.response.status_code == codes.NOT_FOUND:
2123
+ raise UserAgentRequestFetchError(
2124
+ f"User agent request with ID {request_id} not found."
2125
+ ) from e
2126
+ raise UserAgentRequestFetchError(
2127
+ f"Error fetching user agent request: {e.response.status_code} - {e.response.text}"
2128
+ ) from e
2129
+ except Exception as e:
2130
+ raise UserAgentRequestFetchError(
2131
+ f"An unexpected error occurred: {e!r}"
2132
+ ) from e
2133
+
2134
+ @retry(
2135
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2136
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2137
+ retry=retry_if_connection_error,
2138
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2139
+ )
2140
+ async def aget_user_agent_request(self, request_id: UUID) -> UserAgentRequest:
2141
+ """Asynchronously retrieve a single user agent request by its unique ID.
2142
+
2143
+ Args:
2144
+ request_id: The unique ID of the request.
2145
+
2146
+ Returns:
2147
+ The user agent request.
2148
+
2149
+ Raises:
2150
+ UserAgentRequestFetchError: If the API call fails or the request is not found.
2151
+ """
2152
+ try:
2153
+ response = await self.async_client.get(
2154
+ f"/v0.1/user-agent-requests/{request_id}"
2155
+ )
2156
+ response.raise_for_status()
2157
+ return UserAgentRequest.model_validate(response.json())
2158
+ except HTTPStatusError as e:
2159
+ if e.response.status_code == codes.NOT_FOUND:
2160
+ raise UserAgentRequestFetchError(
2161
+ f"User agent request with ID {request_id} not found."
2162
+ ) from e
2163
+ raise UserAgentRequestFetchError(
2164
+ f"Error fetching user agent request: {e.response.status_code} - {e.response.text}"
2165
+ ) from e
2166
+ except Exception as e:
2167
+ raise UserAgentRequestFetchError(
2168
+ f"An unexpected error occurred: {e!r}"
2169
+ ) from e
2170
+
2171
+ @retry(
2172
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2173
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2174
+ retry=retry_if_connection_error,
2175
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2176
+ )
2177
+ def create_user_agent_request(self, payload: UserAgentRequestPostPayload) -> UUID:
2178
+ """Creates a new request from an agent to a user.
2179
+
2180
+ Args:
2181
+ payload: An instance of UserAgentRequestPostPayload with the request data.
2182
+
2183
+ Returns:
2184
+ The UUID of the newly created user agent request.
2185
+
2186
+ Raises:
2187
+ UserAgentRequestCreationError: If the API call fails.
2188
+ """
2189
+ try:
2190
+ response = self.client.post(
2191
+ "/v0.1/user-agent-requests", json=payload.model_dump(mode="json")
2192
+ )
2193
+ response.raise_for_status()
2194
+ return UUID(response.json())
2195
+ except HTTPStatusError as e:
2196
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2197
+ raise UserAgentRequestCreationError(
2198
+ f"Invalid payload for user agent request creation: {e.response.text}."
2199
+ ) from e
2200
+ raise UserAgentRequestCreationError(
2201
+ f"Error creating user agent request: {e.response.status_code} - {e.response.text}"
2202
+ ) from e
2203
+ except Exception as e:
2204
+ raise UserAgentRequestCreationError(
2205
+ f"An unexpected error occurred during user agent request creation: {e!r}."
2206
+ ) from e
2207
+
2208
+ @retry(
2209
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2210
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2211
+ retry=retry_if_connection_error,
2212
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2213
+ )
2214
+ async def acreate_user_agent_request(
2215
+ self, payload: UserAgentRequestPostPayload
2216
+ ) -> UUID:
2217
+ """Asynchronously creates a new request from an agent to a user.
2218
+
2219
+ Args:
2220
+ payload: An instance of UserAgentRequestPostPayload with the request data.
2221
+
2222
+ Returns:
2223
+ The UUID of the newly created user agent request.
2224
+
2225
+ Raises:
2226
+ UserAgentRequestCreationError: If the API call fails.
2227
+ """
2228
+ try:
2229
+ response = await self.async_client.post(
2230
+ "/v0.1/user-agent-requests", json=payload.model_dump(mode="json")
2231
+ )
2232
+ response.raise_for_status()
2233
+ return UUID(response.json())
2234
+ except HTTPStatusError as e:
2235
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2236
+ raise UserAgentRequestCreationError(
2237
+ f"Invalid payload for user agent request creation: {e.response.text}."
2238
+ ) from e
2239
+ raise UserAgentRequestCreationError(
2240
+ f"Error creating user agent request: {e.response.status_code} - {e.response.text}"
2241
+ ) from e
2242
+ except Exception as e:
2243
+ raise UserAgentRequestCreationError(
2244
+ f"An unexpected error occurred during user agent request creation: {e!r}."
2245
+ ) from e
2246
+
2247
+ @retry(
2248
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2249
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2250
+ retry=retry_if_connection_error,
2251
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2252
+ )
2253
+ def respond_to_user_agent_request(
2254
+ self, request_id: UUID, payload: UserAgentResponsePayload
2255
+ ) -> None:
2256
+ """Submit a user's response to a pending agent request.
2257
+
2258
+ Args:
2259
+ request_id: The unique ID of the request to respond to.
2260
+ payload: An instance of UserAgentResponsePayload with the response data.
2261
+
2262
+ Raises:
2263
+ UserAgentRequestResponseError: If the API call fails.
2264
+ """
2265
+ try:
2266
+ response = self.client.post(
2267
+ f"/v0.1/user-agent-requests/{request_id}/response",
2268
+ json=payload.model_dump(mode="json"),
2269
+ )
2270
+ response.raise_for_status()
2271
+ except HTTPStatusError as e:
2272
+ if e.response.status_code == codes.NOT_FOUND:
2273
+ raise UserAgentRequestResponseError(
2274
+ f"User agent request with ID {request_id} not found."
2275
+ ) from e
2276
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2277
+ raise UserAgentRequestResponseError(
2278
+ f"Invalid response payload: {e.response.text}."
2279
+ ) from e
2280
+ raise UserAgentRequestResponseError(
2281
+ f"Error responding to user agent request: {e.response.status_code} - {e.response.text}"
2282
+ ) from e
2283
+ except Exception as e:
2284
+ raise UserAgentRequestResponseError(
2285
+ f"An unexpected error occurred while responding to the request: {e!r}."
2286
+ ) from e
2287
+
2288
+ @retry(
2289
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2290
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2291
+ retry=retry_if_connection_error,
2292
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2293
+ )
2294
+ async def arespond_to_user_agent_request(
2295
+ self, request_id: UUID, payload: UserAgentResponsePayload
2296
+ ) -> None:
2297
+ """Asynchronously submit a user's response to a pending agent request.
2298
+
2299
+ Args:
2300
+ request_id: The unique ID of the request to respond to.
2301
+ payload: An instance of UserAgentResponsePayload with the response data.
2302
+
2303
+ Raises:
2304
+ UserAgentRequestResponseError: If the API call fails.
2305
+ """
2306
+ try:
2307
+ response = await self.async_client.post(
2308
+ f"/v0.1/user-agent-requests/{request_id}/response",
2309
+ json=payload.model_dump(mode="json"),
2310
+ )
2311
+ response.raise_for_status()
2312
+ except HTTPStatusError as e:
2313
+ if e.response.status_code == codes.NOT_FOUND:
2314
+ raise UserAgentRequestResponseError(
2315
+ f"User agent request with ID {request_id} not found."
2316
+ ) from e
2317
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2318
+ raise UserAgentRequestResponseError(
2319
+ f"Invalid response payload: {e.response.text}."
2320
+ ) from e
2321
+ raise UserAgentRequestResponseError(
2322
+ f"Error responding to user agent request: {e.response.status_code} - {e.response.text}"
2323
+ ) from e
2324
+ except Exception as e:
2325
+ raise UserAgentRequestResponseError(
2326
+ f"An unexpected error occurred while responding to the request: {e!r}."
2327
+ ) from e
2328
+
1790
2329
 
1791
2330
  def get_installed_packages() -> dict[str, str]:
1792
2331
  """Returns a dictionary of installed packages and their versions."""
@@ -278,6 +278,20 @@ class FramePath(BaseModel):
278
278
  )
279
279
 
280
280
 
281
+ class NamedEntity(BaseModel):
282
+ name: str = Field(
283
+ description=(
284
+ "Name of an entity for a user to provide a value to during query submission. "
285
+ "This will be used as a key to prompt users for a value. "
286
+ "Example: 'pdb' would result in <pdb>user input here</pdb> in the task string."
287
+ )
288
+ )
289
+ description: str | None = Field(
290
+ default=None,
291
+ description="Helper text to provide the user context to what the name or format needs to be.",
292
+ )
293
+
294
+
281
295
  class DockerContainerConfiguration(BaseModel):
282
296
  cpu: str = Field(description="CPU allotment for the container")
283
297
  memory: str = Field(description="Memory allotment for the container")
@@ -458,6 +472,15 @@ class JobDeploymentConfig(BaseModel):
458
472
  description="The configuration for the task queue(s) that will be created for this deployment.",
459
473
  )
460
474
 
475
+ user_input_config: list[NamedEntity] | None = Field(
476
+ default=None,
477
+ description=(
478
+ "List of NamedEntity objects that represent user input fields "
479
+ "to be included in the task string. "
480
+ "These will be used to prompt users for values during query submission."
481
+ ),
482
+ )
483
+
461
484
  @field_validator("markdown_template_path")
462
485
  @classmethod
463
486
  def validate_markdown_path(
@@ -646,6 +669,62 @@ class RuntimeConfig(BaseModel):
646
669
  return value
647
670
 
648
671
 
672
+ class TrajectoryQueryParams(BaseModel):
673
+ """Params for trajectories with filtering."""
674
+
675
+ model_config = ConfigDict(extra="forbid")
676
+
677
+ project_id: UUID | None = Field(
678
+ default=None, description="Optional project ID to filter trajectories by"
679
+ )
680
+ name: str | None = Field(
681
+ default=None, description="Optional name filter for trajectories"
682
+ )
683
+ user: str | None = Field(
684
+ default=None, description="Optional user email filter for trajectories"
685
+ )
686
+ limit: int = Field(
687
+ default=50,
688
+ ge=1,
689
+ le=200,
690
+ description="Maximum number of trajectories to return (max: 200)",
691
+ )
692
+ offset: int = Field(
693
+ default=0, ge=0, description="Number of trajectories to skip for pagination"
694
+ )
695
+ sort_by: str = Field(default="created_at", description="Field to sort by")
696
+ sort_order: str = Field(default="desc", description="Sort order")
697
+
698
+ @field_validator("sort_by")
699
+ @classmethod
700
+ def validate_sort_by(cls, v: str) -> str:
701
+ if v not in {"created_at", "name"}:
702
+ raise ValueError("sort_by must be either 'created_at' or 'name'")
703
+ return v
704
+
705
+ @field_validator("sort_order")
706
+ @classmethod
707
+ def validate_sort_order(cls, v: str) -> str:
708
+ if v not in {"asc", "desc"}:
709
+ raise ValueError("sort_order must be either 'asc' or 'desc'")
710
+ return v
711
+
712
+ def to_query_params(self) -> dict[str, str | int]:
713
+ params: dict[str, str | int] = {
714
+ "limit": self.limit,
715
+ "offset": self.offset,
716
+ "sort_by": self.sort_by,
717
+ "sort_order": self.sort_order,
718
+ }
719
+ if self.project_id is not None:
720
+ params["project_id"] = str(self.project_id)
721
+ if self.name is not None:
722
+ params["name"] = self.name
723
+ if self.user is not None:
724
+ params["user"] = self.user
725
+ return params
726
+
727
+
649
728
  class TaskRequest(BaseModel):
650
729
  model_config = ConfigDict(extra="forbid")
651
730
 
@@ -27,13 +27,17 @@ class InitialState(BaseState):
27
27
 
28
28
  class ASVState(BaseState, Generic[T]):
29
29
  action: OpResult[T] = Field()
30
- next_agent_state: Any = Field()
30
+ next_state: Any = Field()
31
31
  value: float = Field()
32
32
 
33
33
  @field_serializer("action")
34
34
  def serialize_action(self, action: OpResult[T]) -> dict:
35
35
  return action.to_dict()
36
36
 
37
+ @field_serializer("next_state")
38
+ def serialize_next_state(self, state: Any) -> str:
39
+ return str(state)
40
+
37
41
 
38
42
  class EnvResetState(BaseState):
39
43
  observations: list[Message] = Field()
@@ -2,7 +2,7 @@ from datetime import datetime
2
2
  from enum import StrEnum, auto
3
3
  from uuid import UUID
4
4
 
5
- from pydantic import BaseModel, JsonValue
5
+ from pydantic import BaseModel, ConfigDict, JsonValue
6
6
 
7
7
 
8
8
  class FinalEnvironmentRequest(BaseModel):
@@ -51,6 +51,7 @@ class WorldModel(BaseModel):
51
51
  description: str | None = None
52
52
  trajectory_id: UUID | str | None = None
53
53
  model_metadata: JsonValue | None = None
54
+ project_id: UUID | str | None = None
54
55
 
55
56
 
56
57
  class WorldModelResponse(BaseModel):
@@ -70,3 +71,50 @@ class WorldModelResponse(BaseModel):
70
71
  model_metadata: JsonValue | None
71
72
  enabled: bool
72
73
  created_at: datetime
74
+
75
+
76
+ class UserAgentRequestStatus(StrEnum):
77
+ """Enum for the status of a user agent request."""
78
+
79
+ PENDING = auto()
80
+ RESPONDED = auto()
81
+ EXPIRED = auto()
82
+ CANCELLED = auto()
83
+
84
+
85
+ class UserAgentRequest(BaseModel):
86
+ """Sister model for UserAgentRequestsDB."""
87
+
88
+ model_config = ConfigDict(from_attributes=True)
89
+
90
+ id: UUID
91
+ user_id: str
92
+ trajectory_id: UUID
93
+ response_trajectory_id: UUID | None = None
94
+ request: JsonValue
95
+ response: JsonValue | None = None
96
+ request_world_model_edit_id: UUID | None = None
97
+ response_world_model_edit_id: UUID | None = None
98
+ expires_at: datetime | None = None
99
+ user_response_task: JsonValue | None = None
100
+ status: UserAgentRequestStatus
101
+ created_at: datetime | None = None
102
+ modified_at: datetime | None = None
103
+
104
+
105
+ class UserAgentRequestPostPayload(BaseModel):
106
+ """Payload to create a new user agent request."""
107
+
108
+ trajectory_id: UUID
109
+ request: JsonValue
110
+ request_world_model_edit_id: UUID | None = None
111
+ status: UserAgentRequestStatus = UserAgentRequestStatus.PENDING
112
+ expires_in_seconds: int | None = None
113
+ user_response_task: JsonValue | None = None
114
+
115
+
116
+ class UserAgentResponsePayload(BaseModel):
117
+ """Payload for a user to submit a response to a request."""
118
+
119
+ response: JsonValue
120
+ response_world_model_edit_id: UUID | None = None
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.3.20.dev225'
21
- __version_tuple__ = version_tuple = (0, 3, 20, 'dev225')
20
+ __version__ = version = '0.3.20.dev295'
21
+ __version_tuple__ = version_tuple = (0, 3, 20, 'dev295')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.20.dev225
3
+ Version: 0.3.20.dev295
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  License: Apache License
@@ -7,6 +7,7 @@ import tempfile
7
7
  import time
8
8
  import types
9
9
  from pathlib import Path
10
+ from typing import cast
10
11
  from unittest.mock import MagicMock, mock_open, patch
11
12
  from uuid import UUID, uuid4
12
13
 
@@ -19,6 +20,8 @@ from futurehouse_client.clients.rest_client import (
19
20
  ProjectError,
20
21
  RestClient,
21
22
  RestClientError,
23
+ UserAgentRequestCreationError,
24
+ UserAgentRequestFetchError,
22
25
  )
23
26
  from futurehouse_client.models.app import (
24
27
  PhoenixTaskResponse,
@@ -27,7 +30,13 @@ from futurehouse_client.models.app import (
27
30
  TaskRequest,
28
31
  TaskResponseVerbose,
29
32
  )
30
- from futurehouse_client.models.rest import ExecutionStatus, WorldModel
33
+ from futurehouse_client.models.rest import (
34
+ ExecutionStatus,
35
+ UserAgentRequestPostPayload,
36
+ UserAgentRequestStatus,
37
+ UserAgentResponsePayload,
38
+ WorldModel,
39
+ )
31
40
  from pytest_subtests import SubTests
32
41
 
33
42
  ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
@@ -78,6 +87,11 @@ def phoenix_task_req():
78
87
  )
79
88
 
80
89
 
90
+ @pytest.fixture
91
+ def running_trajectory_id(admin_client: RestClient, task_req: TaskRequest) -> str:
92
+ return admin_client.create_task(task_req)
93
+
94
+
81
95
  @pytest.mark.timeout(300)
82
96
  @pytest.mark.flaky(reruns=3)
83
97
  def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_req: TaskRequest):
@@ -247,6 +261,41 @@ async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
247
261
  )
248
262
 
249
263
 
264
+ @pytest.mark.timeout(300)
265
+ @pytest.mark.flaky(reruns=3)
266
+ def test_cancel_task(admin_client: RestClient):
267
+ """Test successful task cancellation using MagicMock."""
268
+ task_id = "1c28bb94-efbf-442f-954c-f0d8ddb0cff5"
269
+
270
+ with (
271
+ patch.object(admin_client.client, "post") as mock_post,
272
+ patch.object(admin_client, "get_task") as mock_get_task,
273
+ ):
274
+ mock_response = MagicMock()
275
+ mock_response.raise_for_status.return_value = None
276
+ mock_post.return_value = mock_response
277
+
278
+ mock_task_response_running = MagicMock()
279
+ mock_task_response_running.status = "in progress"
280
+
281
+ mock_task_response_cancelled = MagicMock()
282
+ mock_task_response_cancelled.status = ExecutionStatus.CANCELLED.value
283
+
284
+ mock_get_task.side_effect = [
285
+ mock_task_response_running,
286
+ mock_task_response_cancelled,
287
+ ]
288
+
289
+ result = admin_client.cancel_task(task_id)
290
+
291
+ assert result is True
292
+
293
+ expected_url = f"/v0.1/trajectories/{task_id}/cancel"
294
+ mock_post.assert_called_once_with(expected_url)
295
+
296
+ assert mock_get_task.call_count == 2
297
+
298
+
250
299
  class TestParallelChunking:
251
300
  """Test suite for parallel chunk upload functionality."""
252
301
 
@@ -697,10 +746,9 @@ async def test_world_model_acreate_and_aget(admin_client: RestClient):
697
746
 
698
747
  # try getting the newly created model by id
699
748
  model_by_id = await admin_client.aget_world_model(model_id)
700
- model_by_name = await admin_client.aget_world_model(name=model.name)
701
749
 
702
- assert str(model_by_id.id) == str(model_id) == str(model_by_name.id)
703
- assert model_by_id.content == model.content == model_by_name.content
750
+ assert str(model_by_id.id) == str(model_id)
751
+ assert model_by_id.content == model.content
704
752
 
705
753
  updated_model = WorldModel(
706
754
  content="updated test content",
@@ -711,11 +759,9 @@ async def test_world_model_acreate_and_aget(admin_client: RestClient):
711
759
 
712
760
  # try getting the newly created model by id
713
761
  updated_model_by_id = await admin_client.aget_world_model(updated_model_id)
714
- updated_model_by_name = await admin_client.aget_world_model(name=model.name)
715
762
 
716
763
  assert updated_model_by_id.name == model.name
717
764
  assert updated_model_by_id.content != model.content
718
- assert updated_model_by_name.content != model.content
719
765
 
720
766
 
721
767
  def test_world_model_create_and_get(admin_client: RestClient):
@@ -728,10 +774,9 @@ def test_world_model_create_and_get(admin_client: RestClient):
728
774
 
729
775
  # try getting the newly created model by id
730
776
  model_by_id = admin_client.get_world_model(model_id)
731
- model_by_name = admin_client.get_world_model(name=model.name)
732
777
 
733
- assert str(model_by_id.id) == str(model_id) == str(model_by_name.id)
734
- assert model_by_id.content == model.content == model_by_name.content
778
+ assert str(model_by_id.id) == str(model_id)
779
+ assert model_by_id.content == model.content
735
780
 
736
781
  updated_model = WorldModel(
737
782
  content="updated test content",
@@ -742,17 +787,15 @@ def test_world_model_create_and_get(admin_client: RestClient):
742
787
 
743
788
  # try getting the newly created model by id
744
789
  updated_model_by_id = admin_client.get_world_model(updated_model_id)
745
- updated_model_by_name = admin_client.get_world_model(name=model.name)
746
790
 
747
791
  assert updated_model_by_id.name == model.name
748
792
  assert updated_model_by_id.content != model.content
749
- assert updated_model_by_name.content != model.content
750
793
 
751
794
 
752
795
  class TestProjectOperations:
753
796
  @pytest.fixture
754
797
  def test_project_name(self):
755
- return f"test-project-{int(time.time())}"
798
+ return f"test-project-{uuid4()}"
756
799
 
757
800
  def test_create_project_success(
758
801
  self, admin_client: RestClient, test_project_name: str
@@ -821,7 +864,7 @@ class TestProjectOperations:
821
864
  class TestAsyncProjectOperations:
822
865
  @pytest.fixture
823
866
  def test_project_name(self):
824
- return f"test-async-project-{int(time.time())}"
867
+ return f"test-async-project-{uuid4()}"
825
868
 
826
869
  @pytest.mark.asyncio
827
870
  async def test_acreate_project_success(
@@ -873,3 +916,230 @@ class TestAsyncProjectOperations:
873
916
  await admin_client.aadd_task_to_project(
874
917
  fake_project_id, fake_trajectory_id
875
918
  )
919
+
920
+
921
+ @pytest.mark.timeout(300)
922
+ @pytest.mark.flaky(reruns=3)
923
+ def test_get_tasks_with_project_filter(admin_client: RestClient, task_req: TaskRequest):
924
+ """Test retrieving trajectories filtered by project_id using real API calls."""
925
+ project_name = f"e2e-trajectories-fetch-{uuid4()}"
926
+ project_id = admin_client.create_project(project_name)
927
+
928
+ trajectory_id = admin_client.create_task(task_req)
929
+ admin_client.add_task_to_project(UUID(project_id), trajectory_id)
930
+
931
+ while (task_status := admin_client.get_task(trajectory_id).status) in {
932
+ "queued",
933
+ "in progress",
934
+ }:
935
+ time.sleep(5)
936
+
937
+ trajectories = admin_client.get_tasks(project_id=UUID(project_id))
938
+
939
+ trajectory_ids = [t["id"] for t in trajectories]
940
+ assert trajectory_id in trajectory_ids
941
+
942
+
943
+ @pytest.mark.timeout(300)
944
+ @pytest.mark.flaky(reruns=3)
945
+ @pytest.mark.asyncio
946
+ async def test_aget_tasks_with_project_filter(
947
+ admin_client: RestClient, task_req: TaskRequest
948
+ ):
949
+ """Test async retrieving trajectories filtered by project_id using real API calls."""
950
+ project_name = f"e2e-trajectories-async-fetch-{uuid4()}"
951
+ project_id = await admin_client.acreate_project(project_name)
952
+
953
+ trajectory_id = await admin_client.acreate_task(task_req)
954
+ await admin_client.aadd_task_to_project(UUID(project_id), trajectory_id)
955
+
956
+ while True:
957
+ task = await admin_client.aget_task(trajectory_id)
958
+ if task.status not in {"queued", "in progress"}:
959
+ break
960
+ await asyncio.sleep(5)
961
+
962
+ trajectories = await admin_client.aget_tasks(project_id=UUID(project_id))
963
+
964
+ trajectory_ids = [t["id"] for t in trajectories]
965
+ assert trajectory_id in trajectory_ids
966
+
967
+
968
+ class TestUserAgentRequestOperations:
969
+ """Test suite for synchronous User Agent Request operations."""
970
+
971
+ @pytest.mark.flaky(reruns=3)
972
+ def test_e2e_user_agent_request_flow(
973
+ self,
974
+ admin_client: RestClient,
975
+ running_trajectory_id: str,
976
+ ):
977
+ """Tests the full lifecycle: create, get, list, and respond."""
978
+ payload = UserAgentRequestPostPayload(
979
+ trajectory_id=running_trajectory_id,
980
+ request={"question": "Do you approve?"},
981
+ )
982
+ request_id = admin_client.create_user_agent_request(payload)
983
+ assert isinstance(request_id, UUID)
984
+
985
+ # 2. GET the created request
986
+ retrieved_req = admin_client.get_user_agent_request(request_id)
987
+ assert retrieved_req.id == request_id
988
+ assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
989
+ assert retrieved_req.status == UserAgentRequestStatus.PENDING
990
+ assert retrieved_req.request == payload.request
991
+
992
+ # 3. LIST requests and find the created one
993
+ request_list = admin_client.list_user_agent_requests(
994
+ trajectory_id=UUID(running_trajectory_id),
995
+ request_status=UserAgentRequestStatus.PENDING,
996
+ )
997
+ assert isinstance(request_list, list)
998
+ assert any(req.id == request_id for req in request_list)
999
+
1000
+ # 4. RESPOND to the request
1001
+ response_payload = UserAgentResponsePayload(response={"answer": "Yes"})
1002
+ admin_client.respond_to_user_agent_request(request_id, response_payload)
1003
+
1004
+ # 5. GET the request again to verify the response
1005
+ responded_req = admin_client.get_user_agent_request(request_id)
1006
+ assert responded_req.status == UserAgentRequestStatus.RESPONDED
1007
+ assert responded_req.response == response_payload.response
1008
+
1009
+ def test_get_nonexistent_request_fails(self, admin_client: RestClient):
1010
+ """Verifies that fetching a non-existent request raises an error."""
1011
+ non_existent_id = uuid4()
1012
+ with pytest.raises(UserAgentRequestFetchError):
1013
+ admin_client.get_user_agent_request(non_existent_id)
1014
+
1015
+ def test_unauthorized_access_fails(self, pub_client: RestClient):
1016
+ """Ensures a client with insufficient permissions cannot perform actions."""
1017
+ # Using a public client that shouldn't have access
1018
+ with pytest.raises((UserAgentRequestCreationError, PermissionError)): # noqa: PT012
1019
+ payload = UserAgentRequestPostPayload(
1020
+ trajectory_id=uuid4(),
1021
+ request={"data": "test"},
1022
+ )
1023
+ pub_client.create_user_agent_request(payload)
1024
+
1025
+ with pytest.raises((UserAgentRequestFetchError, PermissionError)):
1026
+ # Attempt to fetch a request that the user doesn't own
1027
+ pub_client.get_user_agent_request(uuid4())
1028
+
1029
+
1030
+ class TestAsyncUserAgentRequestOperations:
1031
+ """Test suite for asynchronous User Agent Request operations."""
1032
+
1033
+ @pytest.mark.asyncio
1034
+ async def test_async_expiring_e2e_user_agent_request_flow(
1035
+ self,
1036
+ admin_client: RestClient,
1037
+ running_trajectory_id: str,
1038
+ ):
1039
+ """Tests the full async lifecycle: acreate, aget, alist, and arespond."""
1040
+ payload = UserAgentRequestPostPayload(
1041
+ trajectory_id=running_trajectory_id,
1042
+ request={"question": "Async: Do you approve?"},
1043
+ user_response_task=TaskRequest(
1044
+ name=JobNames.from_string("dummy"),
1045
+ query="Why would I follow up on this query?",
1046
+ ).model_dump(mode="json"),
1047
+ expires_in_seconds=10,
1048
+ )
1049
+
1050
+ request_id = await admin_client.acreate_user_agent_request(payload)
1051
+ assert isinstance(request_id, UUID)
1052
+
1053
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1054
+ assert retrieved_req.id == request_id
1055
+ assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
1056
+ assert retrieved_req.status == UserAgentRequestStatus.PENDING
1057
+
1058
+ request_list = await admin_client.alist_user_agent_requests(
1059
+ trajectory_id=UUID(running_trajectory_id)
1060
+ )
1061
+ assert isinstance(request_list, list)
1062
+ assert any(req.id == request_id for req in request_list)
1063
+
1064
+ # ensure we allow it to expire so auto response can happen
1065
+ await asyncio.sleep(10)
1066
+
1067
+ # now this should be expired
1068
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1069
+ assert retrieved_req.status == UserAgentRequestStatus.EXPIRED
1070
+
1071
+ # we should also see the job having started -- along with the registration of the job in the
1072
+ job_data = await admin_client.aget_task(
1073
+ cast(str, retrieved_req.response_trajectory_id)
1074
+ )
1075
+ assert job_data.status in {"queued", "in progress"}
1076
+
1077
+ # 4. RESPOND to the request -- ensure nothing changes
1078
+ ignored_response = {"answer": "Async Yes"}
1079
+ response_payload = UserAgentResponsePayload(response=ignored_response)
1080
+ await admin_client.arespond_to_user_agent_request(request_id, response_payload)
1081
+
1082
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1083
+ assert retrieved_req.response != ignored_response
1084
+
1085
+ @pytest.mark.asyncio
1086
+ async def test_async_e2e_user_agent_request_flow(
1087
+ self,
1088
+ admin_client: RestClient,
1089
+ running_trajectory_id: str,
1090
+ ):
1091
+ """Tests the full async lifecycle: acreate, aget, alist, and arespond."""
1092
+ # 1. CREATE a request
1093
+ payload = UserAgentRequestPostPayload(
1094
+ trajectory_id=running_trajectory_id,
1095
+ request={"question": "Async: Do you approve?"},
1096
+ user_response_task=TaskRequest(
1097
+ name=JobNames.from_string("dummy"),
1098
+ query="Why would I follow up on this query?",
1099
+ ).model_dump(mode="json"),
1100
+ )
1101
+
1102
+ request_id = await admin_client.acreate_user_agent_request(payload)
1103
+ assert isinstance(request_id, UUID)
1104
+
1105
+ # 2. GET the created request
1106
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1107
+ assert retrieved_req.id == request_id
1108
+ assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
1109
+ assert retrieved_req.status == UserAgentRequestStatus.PENDING
1110
+
1111
+ # 3. LIST requests and find the created one
1112
+ request_list = await admin_client.alist_user_agent_requests(
1113
+ trajectory_id=UUID(running_trajectory_id)
1114
+ )
1115
+ assert isinstance(request_list, list)
1116
+ assert any(req.id == request_id for req in request_list)
1117
+
1118
+ # 4. RESPOND to the request
1119
+ response_payload = UserAgentResponsePayload(response={"answer": "Async Yes"})
1120
+ await admin_client.arespond_to_user_agent_request(request_id, response_payload)
1121
+
1122
+ # 5. GET the request again to verify the response
1123
+ responded_req = await admin_client.aget_user_agent_request(request_id)
1124
+ assert responded_req.status == UserAgentRequestStatus.RESPONDED
1125
+ assert responded_req.response == response_payload.response
1126
+
1127
+ @pytest.mark.asyncio
1128
+ async def test_aget_nonexistent_request_fails(self, admin_client: RestClient):
1129
+ """Verifies fetching a non-existent request asynchronously raises an error."""
1130
+ non_existent_id = uuid4()
1131
+ with pytest.raises(UserAgentRequestFetchError):
1132
+ await admin_client.aget_user_agent_request(non_existent_id)
1133
+
1134
+ @pytest.mark.asyncio
1135
+ async def test_async_unauthorized_access_fails(self, pub_client: RestClient):
1136
+ """Ensures an unauthorized client fails on async methods."""
1137
+ with pytest.raises((UserAgentRequestCreationError, PermissionError)): # noqa: PT012
1138
+ payload = UserAgentRequestPostPayload(
1139
+ trajectory_id=uuid4(),
1140
+ request={"data": "test"},
1141
+ )
1142
+ await pub_client.acreate_user_agent_request(payload)
1143
+
1144
+ with pytest.raises((UserAgentRequestFetchError, PermissionError)):
1145
+ await pub_client.aget_user_agent_request(uuid4())