futurehouse-client 0.3.20.dev225__tar.gz → 0.3.20.dev266__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {futurehouse_client-0.3.20.dev225/futurehouse_client.egg-info → futurehouse_client-0.3.20.dev266}/PKG-INFO +1 -1
  2. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/clients/rest_client.py +539 -0
  3. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/app.py +79 -0
  4. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/client.py +5 -1
  5. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/rest.py +48 -1
  6. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/version.py +2 -2
  7. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266/futurehouse_client.egg-info}/PKG-INFO +1 -1
  8. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/tests/test_rest.py +279 -3
  9. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/LICENSE +0 -0
  10. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/README.md +0 -0
  11. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/docs/__init__.py +0 -0
  12. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/docs/client_notebook.ipynb +0 -0
  13. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/__init__.py +0 -0
  14. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/clients/__init__.py +0 -0
  15. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/clients/job_client.py +0 -0
  16. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/__init__.py +0 -0
  17. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/py.typed +0 -0
  18. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/__init__.py +0 -0
  19. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/auth.py +0 -0
  20. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/general.py +0 -0
  21. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/module_utils.py +0 -0
  22. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/monitoring.py +0 -0
  23. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/SOURCES.txt +0 -0
  24. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/dependency_links.txt +0 -0
  25. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/requires.txt +0 -0
  26. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/top_level.txt +0 -0
  27. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/pyproject.toml +0 -0
  28. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/setup.cfg +0 -0
  29. {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/tests/test_client.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.20.dev225
3
+ Version: 0.3.20.dev266
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  License: Apache License
@@ -14,6 +14,7 @@ import time
14
14
  import uuid
15
15
  from collections.abc import Collection
16
16
  from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from http import HTTPStatus
17
18
  from pathlib import Path
18
19
  from types import ModuleType
19
20
  from typing import Any, ClassVar, cast
@@ -54,9 +55,14 @@ from futurehouse_client.models.app import (
54
55
  TaskRequest,
55
56
  TaskResponse,
56
57
  TaskResponseVerbose,
58
+ TrajectoryQueryParams,
57
59
  )
58
60
  from futurehouse_client.models.rest import (
59
61
  ExecutionStatus,
62
+ UserAgentRequest,
63
+ UserAgentRequestPostPayload,
64
+ UserAgentRequestStatus,
65
+ UserAgentResponsePayload,
60
66
  WorldModel,
61
67
  WorldModelResponse,
62
68
  )
@@ -105,6 +111,22 @@ class JobCreationError(RestClientError):
105
111
  """Raised when there's an error creating a job."""
106
112
 
107
113
 
114
+ class UserAgentRequestError(RestClientError):
115
+ """Base exception for User Agent Request operations."""
116
+
117
+
118
+ class UserAgentRequestFetchError(UserAgentRequestError):
119
+ """Raised when there's an error fetching a user agent request."""
120
+
121
+
122
+ class UserAgentRequestCreationError(UserAgentRequestError):
123
+ """Raised when there's an error creating a user agent request."""
124
+
125
+
126
+ class UserAgentRequestResponseError(UserAgentRequestError):
127
+ """Raised when there's an error responding to a user agent request."""
128
+
129
+
108
130
  class WorldModelFetchError(RestClientError):
109
131
  """Raised when there's an error fetching a world model."""
110
132
 
@@ -535,6 +557,52 @@ class RestClient:
535
557
  return verbose_response
536
558
  return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
537
559
 
560
+ @retry(
561
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
562
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
563
+ retry=retry_if_connection_error,
564
+ before_sleep=before_sleep_log(logger, logging.WARNING),
565
+ )
566
+ def cancel_task(self, task_id: str | None = None) -> bool:
567
+ """Cancel a specific task/trajectory."""
568
+ task_id = task_id or self.trajectory_id
569
+ url = f"/v0.1/trajectories/{task_id}/cancel"
570
+ full_url = f"{self.base_url}{url}"
571
+
572
+ with external_trace(
573
+ url=full_url,
574
+ method="POST",
575
+ library="httpx",
576
+ custom_params={
577
+ "operation": "cancel_job",
578
+ "job_id": task_id,
579
+ },
580
+ ):
581
+ get_task_response = self.get_task(task_id)
582
+ # cancel if task is in progress
583
+ if get_task_response.status == ExecutionStatus.IN_PROGRESS.value:
584
+ response = self.client.post(url)
585
+ try:
586
+ response.raise_for_status()
587
+
588
+ except HTTPStatusError as e:
589
+ if e.response.status_code in {
590
+ HTTPStatus.UNAUTHORIZED,
591
+ HTTPStatus.FORBIDDEN,
592
+ }:
593
+ raise PermissionError(
594
+ f"Error canceling task: Permission denied for task {task_id}"
595
+ ) from e
596
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
597
+ raise TaskFetchError(
598
+ f"Error canceling task: Trajectory not found for task {task_id}"
599
+ ) from e
600
+ raise
601
+
602
+ get_task_response = self.get_task(task_id)
603
+ return get_task_response.status == ExecutionStatus.CANCELLED.value
604
+ return False
605
+
538
606
  @retry(
539
607
  stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
540
608
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
@@ -927,6 +995,13 @@ class RestClient:
927
995
  if config.task_queues_config
928
996
  else None
929
997
  ),
998
+ "user_input_config": (
999
+ json.dumps([
1000
+ entity.model_dump() for entity in config.user_input_config
1001
+ ])
1002
+ if config.user_input_config
1003
+ else None
1004
+ ),
930
1005
  }
931
1006
  response = self.multipart_client.post(
932
1007
  "/v0.1/builds",
@@ -1787,6 +1862,470 @@ class RestClient:
1787
1862
  except Exception as e:
1788
1863
  raise ProjectError(f"Error adding trajectory to project: {e}") from e
1789
1864
 
1865
+ @retry(
1866
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1867
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1868
+ retry=retry_if_connection_error,
1869
+ before_sleep=before_sleep_log(logger, logging.WARNING),
1870
+ )
1871
+ def get_tasks(
1872
+ self,
1873
+ query_params: TrajectoryQueryParams | None = None,
1874
+ *,
1875
+ project_id: UUID | None = None,
1876
+ name: str | None = None,
1877
+ user: str | None = None,
1878
+ limit: int = 50,
1879
+ offset: int = 0,
1880
+ sort_by: str = "created_at",
1881
+ sort_order: str = "desc",
1882
+ ) -> list[dict[str, Any]]:
1883
+ """Fetches trajectories with applied filtering.
1884
+
1885
+ Args:
1886
+ query_params: Optional TrajectoryQueryParams model with all parameters
1887
+ project_id: Optional project ID to filter trajectories by
1888
+ name: Optional name filter for trajectories
1889
+ user: Optional user email filter for trajectories
1890
+ limit: Maximum number of trajectories to return (default: 50, max: 200)
1891
+ offset: Number of trajectories to skip for pagination (default: 0)
1892
+ sort_by: Field to sort by, either "created_at" or "name" (default: "created_at")
1893
+ sort_order: Sort order, either "asc" or "desc" (default: "desc")
1894
+
1895
+ Returns:
1896
+ List of trajectory dictionaries
1897
+
1898
+ Raises:
1899
+ TaskFetchError: If there's an error fetching trajectories
1900
+ """
1901
+ try:
1902
+ if query_params is not None:
1903
+ params = query_params.to_query_params()
1904
+ else:
1905
+ params_model = TrajectoryQueryParams(
1906
+ project_id=project_id,
1907
+ name=name,
1908
+ user=user,
1909
+ limit=limit,
1910
+ offset=offset,
1911
+ sort_by=sort_by,
1912
+ sort_order=sort_order,
1913
+ )
1914
+ params = params_model.to_query_params()
1915
+
1916
+ response = self.client.get("/v0.1/trajectories", params=params)
1917
+ response.raise_for_status()
1918
+ return response.json()
1919
+ except HTTPStatusError as e:
1920
+ if e.response.status_code in {401, 403}:
1921
+ raise PermissionError(
1922
+ "Error getting trajectories: Permission denied"
1923
+ ) from e
1924
+ raise TaskFetchError(
1925
+ f"Error getting trajectories: {e.response.status_code} - {e.response.text}"
1926
+ ) from e
1927
+ except Exception as e:
1928
+ raise TaskFetchError(f"Error getting trajectories: {e!r}") from e
1929
+
1930
+ @retry(
1931
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1932
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1933
+ retry=retry_if_connection_error,
1934
+ before_sleep=before_sleep_log(logger, logging.WARNING),
1935
+ )
1936
+ async def aget_tasks(
1937
+ self,
1938
+ query_params: TrajectoryQueryParams | None = None,
1939
+ *,
1940
+ project_id: UUID | None = None,
1941
+ name: str | None = None,
1942
+ user: str | None = None,
1943
+ limit: int = 50,
1944
+ offset: int = 0,
1945
+ sort_by: str = "created_at",
1946
+ sort_order: str = "desc",
1947
+ ) -> list[dict[str, Any]]:
1948
+ """Asynchronously fetch trajectories with applied filtering.
1949
+
1950
+ Args:
1951
+ query_params: Optional TrajectoryQueryParams model with all parameters
1952
+ project_id: Optional project ID to filter trajectories by
1953
+ name: Optional name filter for trajectories
1954
+ user: Optional user email filter for trajectories
1955
+ limit: Maximum number of trajectories to return (default: 50, max: 200)
1956
+ offset: Number of trajectories to skip for pagination (default: 0)
1957
+ sort_by: Field to sort by, either "created_at" or "name" (default: "created_at")
1958
+ sort_order: Sort order, either "asc" or "desc" (default: "desc")
1959
+
1960
+ Returns:
1961
+ List of trajectory dictionaries
1962
+
1963
+ Raises:
1964
+ TaskFetchError: If there's an error fetching trajectories
1965
+ """
1966
+ try:
1967
+ if query_params is not None:
1968
+ params = query_params.to_query_params()
1969
+ else:
1970
+ params_model = TrajectoryQueryParams(
1971
+ project_id=project_id,
1972
+ name=name,
1973
+ user=user,
1974
+ limit=limit,
1975
+ offset=offset,
1976
+ sort_by=sort_by,
1977
+ sort_order=sort_order,
1978
+ )
1979
+ params = params_model.to_query_params()
1980
+
1981
+ response = await self.async_client.get("/v0.1/trajectories", params=params)
1982
+ response.raise_for_status()
1983
+ return response.json()
1984
+ except HTTPStatusError as e:
1985
+ if e.response.status_code in {401, 403}:
1986
+ raise PermissionError(
1987
+ "Error getting trajectories: Permission denied"
1988
+ ) from e
1989
+ raise TaskFetchError(
1990
+ f"Error getting trajectories: {e.response.status_code} - {e.response.text}"
1991
+ ) from e
1992
+ except Exception as e:
1993
+ raise TaskFetchError(f"Error getting trajectories: {e!r}") from e
1994
+
1995
+ @retry(
1996
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
1997
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1998
+ retry=retry_if_connection_error,
1999
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2000
+ )
2001
+ def list_user_agent_requests(
2002
+ self,
2003
+ user_id: str | None = None,
2004
+ trajectory_id: UUID | None = None,
2005
+ request_status: UserAgentRequestStatus | None = None,
2006
+ limit: int = 50,
2007
+ offset: int = 0,
2008
+ ) -> list[UserAgentRequest]:
2009
+ """List user agent requests with optional filters.
2010
+
2011
+ Args:
2012
+ user_id: Filter requests by user ID. Defaults to the authenticated user's ID if not provided.
2013
+ trajectory_id: Filter requests by trajectory ID.
2014
+ request_status: Filter requests by status (e.g., PENDING).
2015
+ limit: Maximum number of requests to return.
2016
+ offset: Offset for pagination.
2017
+
2018
+ Returns:
2019
+ A list of user agent requests.
2020
+
2021
+ Raises:
2022
+ UserAgentRequestFetchError: If the API call fails.
2023
+ """
2024
+ params = {
2025
+ "user_id": user_id,
2026
+ "trajectory_id": str(trajectory_id) if trajectory_id else None,
2027
+ "request_status": request_status.value if request_status else None,
2028
+ "limit": limit,
2029
+ "offset": offset,
2030
+ }
2031
+ # Filter out None values
2032
+ params = {k: v for k, v in params.items() if v is not None}
2033
+ try:
2034
+ response = self.client.get("/v0.1/user-agent-requests", params=params)
2035
+ response.raise_for_status()
2036
+ return [UserAgentRequest.model_validate(item) for item in response.json()]
2037
+ except HTTPStatusError as e:
2038
+ raise UserAgentRequestFetchError(
2039
+ f"Error listing user agent requests: {e.response.status_code} - {e.response.text}"
2040
+ ) from e
2041
+ except Exception as e:
2042
+ raise UserAgentRequestFetchError(
2043
+ f"An unexpected error occurred: {e!r}"
2044
+ ) from e
2045
+
2046
+ @retry(
2047
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2048
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2049
+ retry=retry_if_connection_error,
2050
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2051
+ )
2052
+ async def alist_user_agent_requests(
2053
+ self,
2054
+ user_id: str | None = None,
2055
+ trajectory_id: UUID | None = None,
2056
+ request_status: UserAgentRequestStatus | None = None,
2057
+ limit: int = 50,
2058
+ offset: int = 0,
2059
+ ) -> list[UserAgentRequest]:
2060
+ """Asynchronously list user agent requests with optional filters.
2061
+
2062
+ Args:
2063
+ user_id: Filter requests by user ID. Defaults to the authenticated user's ID if not provided.
2064
+ trajectory_id: Filter requests by trajectory ID.
2065
+ request_status: Filter requests by status (e.g., PENDING).
2066
+ limit: Maximum number of requests to return.
2067
+ offset: Offset for pagination.
2068
+
2069
+ Returns:
2070
+ A list of user agent requests.
2071
+
2072
+ Raises:
2073
+ UserAgentRequestFetchError: If the API call fails.
2074
+ """
2075
+ params = {
2076
+ "user_id": user_id,
2077
+ "trajectory_id": str(trajectory_id) if trajectory_id else None,
2078
+ "request_status": request_status.value if request_status else None,
2079
+ "limit": limit,
2080
+ "offset": offset,
2081
+ }
2082
+ # Filter out None values
2083
+ params = {k: v for k, v in params.items() if v is not None}
2084
+ try:
2085
+ response = await self.async_client.get(
2086
+ "/v0.1/user-agent-requests", params=params
2087
+ )
2088
+ response.raise_for_status()
2089
+ return [UserAgentRequest.model_validate(item) for item in response.json()]
2090
+ except HTTPStatusError as e:
2091
+ raise UserAgentRequestFetchError(
2092
+ f"Error listing user agent requests: {e.response.status_code} - {e.response.text}"
2093
+ ) from e
2094
+ except Exception as e:
2095
+ raise UserAgentRequestFetchError(
2096
+ f"An unexpected error occurred: {e!r}"
2097
+ ) from e
2098
+
2099
+ @retry(
2100
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2101
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2102
+ retry=retry_if_connection_error,
2103
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2104
+ )
2105
+ def get_user_agent_request(self, request_id: UUID) -> UserAgentRequest:
2106
+ """Retrieve a single user agent request by its unique ID.
2107
+
2108
+ Args:
2109
+ request_id: The unique ID of the request.
2110
+
2111
+ Returns:
2112
+ The user agent request.
2113
+
2114
+ Raises:
2115
+ UserAgentRequestFetchError: If the API call fails or the request is not found.
2116
+ """
2117
+ try:
2118
+ response = self.client.get(f"/v0.1/user-agent-requests/{request_id}")
2119
+ response.raise_for_status()
2120
+ return UserAgentRequest.model_validate(response.json())
2121
+ except HTTPStatusError as e:
2122
+ if e.response.status_code == codes.NOT_FOUND:
2123
+ raise UserAgentRequestFetchError(
2124
+ f"User agent request with ID {request_id} not found."
2125
+ ) from e
2126
+ raise UserAgentRequestFetchError(
2127
+ f"Error fetching user agent request: {e.response.status_code} - {e.response.text}"
2128
+ ) from e
2129
+ except Exception as e:
2130
+ raise UserAgentRequestFetchError(
2131
+ f"An unexpected error occurred: {e!r}"
2132
+ ) from e
2133
+
2134
+ @retry(
2135
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2136
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2137
+ retry=retry_if_connection_error,
2138
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2139
+ )
2140
+ async def aget_user_agent_request(self, request_id: UUID) -> UserAgentRequest:
2141
+ """Asynchronously retrieve a single user agent request by its unique ID.
2142
+
2143
+ Args:
2144
+ request_id: The unique ID of the request.
2145
+
2146
+ Returns:
2147
+ The user agent request.
2148
+
2149
+ Raises:
2150
+ UserAgentRequestFetchError: If the API call fails or the request is not found.
2151
+ """
2152
+ try:
2153
+ response = await self.async_client.get(
2154
+ f"/v0.1/user-agent-requests/{request_id}"
2155
+ )
2156
+ response.raise_for_status()
2157
+ return UserAgentRequest.model_validate(response.json())
2158
+ except HTTPStatusError as e:
2159
+ if e.response.status_code == codes.NOT_FOUND:
2160
+ raise UserAgentRequestFetchError(
2161
+ f"User agent request with ID {request_id} not found."
2162
+ ) from e
2163
+ raise UserAgentRequestFetchError(
2164
+ f"Error fetching user agent request: {e.response.status_code} - {e.response.text}"
2165
+ ) from e
2166
+ except Exception as e:
2167
+ raise UserAgentRequestFetchError(
2168
+ f"An unexpected error occurred: {e!r}"
2169
+ ) from e
2170
+
2171
+ @retry(
2172
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2173
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2174
+ retry=retry_if_connection_error,
2175
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2176
+ )
2177
+ def create_user_agent_request(self, payload: UserAgentRequestPostPayload) -> UUID:
2178
+ """Creates a new request from an agent to a user.
2179
+
2180
+ Args:
2181
+ payload: An instance of UserAgentRequestPostPayload with the request data.
2182
+
2183
+ Returns:
2184
+ The UUID of the newly created user agent request.
2185
+
2186
+ Raises:
2187
+ UserAgentRequestCreationError: If the API call fails.
2188
+ """
2189
+ try:
2190
+ response = self.client.post(
2191
+ "/v0.1/user-agent-requests", json=payload.model_dump(mode="json")
2192
+ )
2193
+ response.raise_for_status()
2194
+ return UUID(response.json())
2195
+ except HTTPStatusError as e:
2196
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2197
+ raise UserAgentRequestCreationError(
2198
+ f"Invalid payload for user agent request creation: {e.response.text}."
2199
+ ) from e
2200
+ raise UserAgentRequestCreationError(
2201
+ f"Error creating user agent request: {e.response.status_code} - {e.response.text}"
2202
+ ) from e
2203
+ except Exception as e:
2204
+ raise UserAgentRequestCreationError(
2205
+ f"An unexpected error occurred during user agent request creation: {e!r}."
2206
+ ) from e
2207
+
2208
+ @retry(
2209
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2210
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2211
+ retry=retry_if_connection_error,
2212
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2213
+ )
2214
+ async def acreate_user_agent_request(
2215
+ self, payload: UserAgentRequestPostPayload
2216
+ ) -> UUID:
2217
+ """Asynchronously creates a new request from an agent to a user.
2218
+
2219
+ Args:
2220
+ payload: An instance of UserAgentRequestPostPayload with the request data.
2221
+
2222
+ Returns:
2223
+ The UUID of the newly created user agent request.
2224
+
2225
+ Raises:
2226
+ UserAgentRequestCreationError: If the API call fails.
2227
+ """
2228
+ try:
2229
+ response = await self.async_client.post(
2230
+ "/v0.1/user-agent-requests", json=payload.model_dump(mode="json")
2231
+ )
2232
+ response.raise_for_status()
2233
+ return UUID(response.json())
2234
+ except HTTPStatusError as e:
2235
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2236
+ raise UserAgentRequestCreationError(
2237
+ f"Invalid payload for user agent request creation: {e.response.text}."
2238
+ ) from e
2239
+ raise UserAgentRequestCreationError(
2240
+ f"Error creating user agent request: {e.response.status_code} - {e.response.text}"
2241
+ ) from e
2242
+ except Exception as e:
2243
+ raise UserAgentRequestCreationError(
2244
+ f"An unexpected error occurred during user agent request creation: {e!r}."
2245
+ ) from e
2246
+
2247
+ @retry(
2248
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2249
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2250
+ retry=retry_if_connection_error,
2251
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2252
+ )
2253
+ def respond_to_user_agent_request(
2254
+ self, request_id: UUID, payload: UserAgentResponsePayload
2255
+ ) -> None:
2256
+ """Submit a user's response to a pending agent request.
2257
+
2258
+ Args:
2259
+ request_id: The unique ID of the request to respond to.
2260
+ payload: An instance of UserAgentResponsePayload with the response data.
2261
+
2262
+ Raises:
2263
+ UserAgentRequestResponseError: If the API call fails.
2264
+ """
2265
+ try:
2266
+ response = self.client.post(
2267
+ f"/v0.1/user-agent-requests/{request_id}/response",
2268
+ json=payload.model_dump(mode="json"),
2269
+ )
2270
+ response.raise_for_status()
2271
+ except HTTPStatusError as e:
2272
+ if e.response.status_code == codes.NOT_FOUND:
2273
+ raise UserAgentRequestResponseError(
2274
+ f"User agent request with ID {request_id} not found."
2275
+ ) from e
2276
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2277
+ raise UserAgentRequestResponseError(
2278
+ f"Invalid response payload: {e.response.text}."
2279
+ ) from e
2280
+ raise UserAgentRequestResponseError(
2281
+ f"Error responding to user agent request: {e.response.status_code} - {e.response.text}"
2282
+ ) from e
2283
+ except Exception as e:
2284
+ raise UserAgentRequestResponseError(
2285
+ f"An unexpected error occurred while responding to the request: {e!r}."
2286
+ ) from e
2287
+
2288
+ @retry(
2289
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
2290
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
2291
+ retry=retry_if_connection_error,
2292
+ before_sleep=before_sleep_log(logger, logging.WARNING),
2293
+ )
2294
+ async def arespond_to_user_agent_request(
2295
+ self, request_id: UUID, payload: UserAgentResponsePayload
2296
+ ) -> None:
2297
+ """Asynchronously submit a user's response to a pending agent request.
2298
+
2299
+ Args:
2300
+ request_id: The unique ID of the request to respond to.
2301
+ payload: An instance of UserAgentResponsePayload with the response data.
2302
+
2303
+ Raises:
2304
+ UserAgentRequestResponseError: If the API call fails.
2305
+ """
2306
+ try:
2307
+ response = await self.async_client.post(
2308
+ f"/v0.1/user-agent-requests/{request_id}/response",
2309
+ json=payload.model_dump(mode="json"),
2310
+ )
2311
+ response.raise_for_status()
2312
+ except HTTPStatusError as e:
2313
+ if e.response.status_code == codes.NOT_FOUND:
2314
+ raise UserAgentRequestResponseError(
2315
+ f"User agent request with ID {request_id} not found."
2316
+ ) from e
2317
+ if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
2318
+ raise UserAgentRequestResponseError(
2319
+ f"Invalid response payload: {e.response.text}."
2320
+ ) from e
2321
+ raise UserAgentRequestResponseError(
2322
+ f"Error responding to user agent request: {e.response.status_code} - {e.response.text}"
2323
+ ) from e
2324
+ except Exception as e:
2325
+ raise UserAgentRequestResponseError(
2326
+ f"An unexpected error occurred while responding to the request: {e!r}."
2327
+ ) from e
2328
+
1790
2329
 
1791
2330
  def get_installed_packages() -> dict[str, str]:
1792
2331
  """Returns a dictionary of installed packages and their versions."""
@@ -278,6 +278,20 @@ class FramePath(BaseModel):
278
278
  )
279
279
 
280
280
 
281
+ class NamedEntity(BaseModel):
282
+ name: str = Field(
283
+ description=(
284
+ "Name of an entity for a user to provide a value to during query submission. "
285
+ "This will be used as a key to prompt users for a value. "
286
+ "Example: 'pdb' would result in <pdb>user input here</pdb> in the task string."
287
+ )
288
+ )
289
+ description: str | None = Field(
290
+ default=None,
291
+ description="Helper text to provide the user context to what the name or format needs to be.",
292
+ )
293
+
294
+
281
295
  class DockerContainerConfiguration(BaseModel):
282
296
  cpu: str = Field(description="CPU allotment for the container")
283
297
  memory: str = Field(description="Memory allotment for the container")
@@ -458,6 +472,15 @@ class JobDeploymentConfig(BaseModel):
458
472
  description="The configuration for the task queue(s) that will be created for this deployment.",
459
473
  )
460
474
 
475
+ user_input_config: list[NamedEntity] | None = Field(
476
+ default=None,
477
+ description=(
478
+ "List of NamedEntity objects that represent user input fields "
479
+ "to be included in the task string. "
480
+ "These will be used to prompt users for values during query submission."
481
+ ),
482
+ )
483
+
461
484
  @field_validator("markdown_template_path")
462
485
  @classmethod
463
486
  def validate_markdown_path(
@@ -646,6 +669,62 @@ class RuntimeConfig(BaseModel):
646
669
  return value
647
670
 
648
671
 
672
+ class TrajectoryQueryParams(BaseModel):
673
+ """Params for trajectories with filtering."""
674
+
675
+ model_config = ConfigDict(extra="forbid")
676
+
677
+ project_id: UUID | None = Field(
678
+ default=None, description="Optional project ID to filter trajectories by"
679
+ )
680
+ name: str | None = Field(
681
+ default=None, description="Optional name filter for trajectories"
682
+ )
683
+ user: str | None = Field(
684
+ default=None, description="Optional user email filter for trajectories"
685
+ )
686
+ limit: int = Field(
687
+ default=50,
688
+ ge=1,
689
+ le=200,
690
+ description="Maximum number of trajectories to return (max: 200)",
691
+ )
692
+ offset: int = Field(
693
+ default=0, ge=0, description="Number of trajectories to skip for pagination"
694
+ )
695
+ sort_by: str = Field(default="created_at", description="Field to sort by")
696
+ sort_order: str = Field(default="desc", description="Sort order")
697
+
698
+ @field_validator("sort_by")
699
+ @classmethod
700
+ def validate_sort_by(cls, v: str) -> str:
701
+ if v not in {"created_at", "name"}:
702
+ raise ValueError("sort_by must be either 'created_at' or 'name'")
703
+ return v
704
+
705
+ @field_validator("sort_order")
706
+ @classmethod
707
+ def validate_sort_order(cls, v: str) -> str:
708
+ if v not in {"asc", "desc"}:
709
+ raise ValueError("sort_order must be either 'asc' or 'desc'")
710
+ return v
711
+
712
+ def to_query_params(self) -> dict[str, str | int]:
713
+ params: dict[str, str | int] = {
714
+ "limit": self.limit,
715
+ "offset": self.offset,
716
+ "sort_by": self.sort_by,
717
+ "sort_order": self.sort_order,
718
+ }
719
+ if self.project_id is not None:
720
+ params["project_id"] = str(self.project_id)
721
+ if self.name is not None:
722
+ params["name"] = self.name
723
+ if self.user is not None:
724
+ params["user"] = self.user
725
+ return params
726
+
727
+
649
728
  class TaskRequest(BaseModel):
650
729
  model_config = ConfigDict(extra="forbid")
651
730
 
@@ -27,13 +27,17 @@ class InitialState(BaseState):
27
27
 
28
28
  class ASVState(BaseState, Generic[T]):
29
29
  action: OpResult[T] = Field()
30
- next_agent_state: Any = Field()
30
+ next_state: Any = Field()
31
31
  value: float = Field()
32
32
 
33
33
  @field_serializer("action")
34
34
  def serialize_action(self, action: OpResult[T]) -> dict:
35
35
  return action.to_dict()
36
36
 
37
+ @field_serializer("next_state")
38
+ def serialize_next_state(self, state: Any) -> str:
39
+ return str(state)
40
+
37
41
 
38
42
  class EnvResetState(BaseState):
39
43
  observations: list[Message] = Field()
@@ -2,7 +2,7 @@ from datetime import datetime
2
2
  from enum import StrEnum, auto
3
3
  from uuid import UUID
4
4
 
5
- from pydantic import BaseModel, JsonValue
5
+ from pydantic import BaseModel, ConfigDict, JsonValue
6
6
 
7
7
 
8
8
  class FinalEnvironmentRequest(BaseModel):
@@ -70,3 +70,50 @@ class WorldModelResponse(BaseModel):
70
70
  model_metadata: JsonValue | None
71
71
  enabled: bool
72
72
  created_at: datetime
73
+
74
+
75
+ class UserAgentRequestStatus(StrEnum):
76
+ """Enum for the status of a user agent request."""
77
+
78
+ PENDING = auto()
79
+ RESPONDED = auto()
80
+ EXPIRED = auto()
81
+ CANCELLED = auto()
82
+
83
+
84
+ class UserAgentRequest(BaseModel):
85
+ """Sister model for UserAgentRequestsDB."""
86
+
87
+ model_config = ConfigDict(from_attributes=True)
88
+
89
+ id: UUID
90
+ user_id: str
91
+ trajectory_id: UUID
92
+ response_trajectory_id: UUID | None = None
93
+ request: JsonValue
94
+ response: JsonValue | None = None
95
+ request_world_model_edit_id: UUID | None = None
96
+ response_world_model_edit_id: UUID | None = None
97
+ expires_at: datetime | None = None
98
+ user_response_task: JsonValue | None = None
99
+ status: UserAgentRequestStatus
100
+ created_at: datetime | None = None
101
+ modified_at: datetime | None = None
102
+
103
+
104
+ class UserAgentRequestPostPayload(BaseModel):
105
+ """Payload to create a new user agent request."""
106
+
107
+ trajectory_id: UUID
108
+ request: JsonValue
109
+ request_world_model_edit_id: UUID | None = None
110
+ status: UserAgentRequestStatus = UserAgentRequestStatus.PENDING
111
+ expires_in_seconds: int | None = None
112
+ user_response_task: JsonValue | None = None
113
+
114
+
115
+ class UserAgentResponsePayload(BaseModel):
116
+ """Payload for a user to submit a response to a request."""
117
+
118
+ response: JsonValue
119
+ response_world_model_edit_id: UUID | None = None
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.3.20.dev225'
21
- __version_tuple__ = version_tuple = (0, 3, 20, 'dev225')
20
+ __version__ = version = '0.3.20.dev266'
21
+ __version_tuple__ = version_tuple = (0, 3, 20, 'dev266')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.20.dev225
3
+ Version: 0.3.20.dev266
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  License: Apache License
@@ -7,6 +7,7 @@ import tempfile
7
7
  import time
8
8
  import types
9
9
  from pathlib import Path
10
+ from typing import cast
10
11
  from unittest.mock import MagicMock, mock_open, patch
11
12
  from uuid import UUID, uuid4
12
13
 
@@ -19,6 +20,8 @@ from futurehouse_client.clients.rest_client import (
19
20
  ProjectError,
20
21
  RestClient,
21
22
  RestClientError,
23
+ UserAgentRequestCreationError,
24
+ UserAgentRequestFetchError,
22
25
  )
23
26
  from futurehouse_client.models.app import (
24
27
  PhoenixTaskResponse,
@@ -27,7 +30,13 @@ from futurehouse_client.models.app import (
27
30
  TaskRequest,
28
31
  TaskResponseVerbose,
29
32
  )
30
- from futurehouse_client.models.rest import ExecutionStatus, WorldModel
33
+ from futurehouse_client.models.rest import (
34
+ ExecutionStatus,
35
+ UserAgentRequestPostPayload,
36
+ UserAgentRequestStatus,
37
+ UserAgentResponsePayload,
38
+ WorldModel,
39
+ )
31
40
  from pytest_subtests import SubTests
32
41
 
33
42
  ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
@@ -78,6 +87,11 @@ def phoenix_task_req():
78
87
  )
79
88
 
80
89
 
90
+ @pytest.fixture
91
+ def running_trajectory_id(admin_client: RestClient, task_req: TaskRequest) -> str:
92
+ return admin_client.create_task(task_req)
93
+
94
+
81
95
  @pytest.mark.timeout(300)
82
96
  @pytest.mark.flaky(reruns=3)
83
97
  def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_req: TaskRequest):
@@ -247,6 +261,41 @@ async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
247
261
  )
248
262
 
249
263
 
264
+ @pytest.mark.timeout(300)
265
+ @pytest.mark.flaky(reruns=3)
266
+ def test_cancel_task(admin_client: RestClient):
267
+ """Test successful task cancellation using MagicMock."""
268
+ task_id = "1c28bb94-efbf-442f-954c-f0d8ddb0cff5"
269
+
270
+ with (
271
+ patch.object(admin_client.client, "post") as mock_post,
272
+ patch.object(admin_client, "get_task") as mock_get_task,
273
+ ):
274
+ mock_response = MagicMock()
275
+ mock_response.raise_for_status.return_value = None
276
+ mock_post.return_value = mock_response
277
+
278
+ mock_task_response_running = MagicMock()
279
+ mock_task_response_running.status = "in progress"
280
+
281
+ mock_task_response_cancelled = MagicMock()
282
+ mock_task_response_cancelled.status = ExecutionStatus.CANCELLED.value
283
+
284
+ mock_get_task.side_effect = [
285
+ mock_task_response_running,
286
+ mock_task_response_cancelled,
287
+ ]
288
+
289
+ result = admin_client.cancel_task(task_id)
290
+
291
+ assert result is True
292
+
293
+ expected_url = f"/v0.1/trajectories/{task_id}/cancel"
294
+ mock_post.assert_called_once_with(expected_url)
295
+
296
+ assert mock_get_task.call_count == 2
297
+
298
+
250
299
  class TestParallelChunking:
251
300
  """Test suite for parallel chunk upload functionality."""
252
301
 
@@ -752,7 +801,7 @@ def test_world_model_create_and_get(admin_client: RestClient):
752
801
  class TestProjectOperations:
753
802
  @pytest.fixture
754
803
  def test_project_name(self):
755
- return f"test-project-{int(time.time())}"
804
+ return f"test-project-{uuid4()}"
756
805
 
757
806
  def test_create_project_success(
758
807
  self, admin_client: RestClient, test_project_name: str
@@ -821,7 +870,7 @@ class TestProjectOperations:
821
870
  class TestAsyncProjectOperations:
822
871
  @pytest.fixture
823
872
  def test_project_name(self):
824
- return f"test-async-project-{int(time.time())}"
873
+ return f"test-async-project-{uuid4()}"
825
874
 
826
875
  @pytest.mark.asyncio
827
876
  async def test_acreate_project_success(
@@ -873,3 +922,230 @@ class TestAsyncProjectOperations:
873
922
  await admin_client.aadd_task_to_project(
874
923
  fake_project_id, fake_trajectory_id
875
924
  )
925
+
926
+
927
+ @pytest.mark.timeout(300)
928
+ @pytest.mark.flaky(reruns=3)
929
+ def test_get_tasks_with_project_filter(admin_client: RestClient, task_req: TaskRequest):
930
+ """Test retrieving trajectories filtered by project_id using real API calls."""
931
+ project_name = f"e2e-trajectories-fetch-{uuid4()}"
932
+ project_id = admin_client.create_project(project_name)
933
+
934
+ trajectory_id = admin_client.create_task(task_req)
935
+ admin_client.add_task_to_project(UUID(project_id), trajectory_id)
936
+
937
+ while (task_status := admin_client.get_task(trajectory_id).status) in {
938
+ "queued",
939
+ "in progress",
940
+ }:
941
+ time.sleep(5)
942
+
943
+ trajectories = admin_client.get_tasks(project_id=UUID(project_id))
944
+
945
+ trajectory_ids = [t["id"] for t in trajectories]
946
+ assert trajectory_id in trajectory_ids
947
+
948
+
949
+ @pytest.mark.timeout(300)
950
+ @pytest.mark.flaky(reruns=3)
951
+ @pytest.mark.asyncio
952
+ async def test_aget_tasks_with_project_filter(
953
+ admin_client: RestClient, task_req: TaskRequest
954
+ ):
955
+ """Test async retrieving trajectories filtered by project_id using real API calls."""
956
+ project_name = f"e2e-trajectories-async-fetch-{uuid4()}"
957
+ project_id = await admin_client.acreate_project(project_name)
958
+
959
+ trajectory_id = await admin_client.acreate_task(task_req)
960
+ await admin_client.aadd_task_to_project(UUID(project_id), trajectory_id)
961
+
962
+ while True:
963
+ task = await admin_client.aget_task(trajectory_id)
964
+ if task.status not in {"queued", "in progress"}:
965
+ break
966
+ await asyncio.sleep(5)
967
+
968
+ trajectories = await admin_client.aget_tasks(project_id=UUID(project_id))
969
+
970
+ trajectory_ids = [t["id"] for t in trajectories]
971
+ assert trajectory_id in trajectory_ids
972
+
973
+
974
+ class TestUserAgentRequestOperations:
975
+ """Test suite for synchronous User Agent Request operations."""
976
+
977
+ @pytest.mark.flaky(reruns=3)
978
+ def test_e2e_user_agent_request_flow(
979
+ self,
980
+ admin_client: RestClient,
981
+ running_trajectory_id: str,
982
+ ):
983
+ """Tests the full lifecycle: create, get, list, and respond."""
984
+ payload = UserAgentRequestPostPayload(
985
+ trajectory_id=running_trajectory_id,
986
+ request={"question": "Do you approve?"},
987
+ )
988
+ request_id = admin_client.create_user_agent_request(payload)
989
+ assert isinstance(request_id, UUID)
990
+
991
+ # 2. GET the created request
992
+ retrieved_req = admin_client.get_user_agent_request(request_id)
993
+ assert retrieved_req.id == request_id
994
+ assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
995
+ assert retrieved_req.status == UserAgentRequestStatus.PENDING
996
+ assert retrieved_req.request == payload.request
997
+
998
+ # 3. LIST requests and find the created one
999
+ request_list = admin_client.list_user_agent_requests(
1000
+ trajectory_id=UUID(running_trajectory_id),
1001
+ request_status=UserAgentRequestStatus.PENDING,
1002
+ )
1003
+ assert isinstance(request_list, list)
1004
+ assert any(req.id == request_id for req in request_list)
1005
+
1006
+ # 4. RESPOND to the request
1007
+ response_payload = UserAgentResponsePayload(response={"answer": "Yes"})
1008
+ admin_client.respond_to_user_agent_request(request_id, response_payload)
1009
+
1010
+ # 5. GET the request again to verify the response
1011
+ responded_req = admin_client.get_user_agent_request(request_id)
1012
+ assert responded_req.status == UserAgentRequestStatus.RESPONDED
1013
+ assert responded_req.response == response_payload.response
1014
+
1015
+ def test_get_nonexistent_request_fails(self, admin_client: RestClient):
1016
+ """Verifies that fetching a non-existent request raises an error."""
1017
+ non_existent_id = uuid4()
1018
+ with pytest.raises(UserAgentRequestFetchError):
1019
+ admin_client.get_user_agent_request(non_existent_id)
1020
+
1021
+ def test_unauthorized_access_fails(self, pub_client: RestClient):
1022
+ """Ensures a client with insufficient permissions cannot perform actions."""
1023
+ # Using a public client that shouldn't have access
1024
+ with pytest.raises((UserAgentRequestCreationError, PermissionError)): # noqa: PT012
1025
+ payload = UserAgentRequestPostPayload(
1026
+ trajectory_id=uuid4(),
1027
+ request={"data": "test"},
1028
+ )
1029
+ pub_client.create_user_agent_request(payload)
1030
+
1031
+ with pytest.raises((UserAgentRequestFetchError, PermissionError)):
1032
+ # Attempt to fetch a request that the user doesn't own
1033
+ pub_client.get_user_agent_request(uuid4())
1034
+
1035
+
1036
+ class TestAsyncUserAgentRequestOperations:
1037
+ """Test suite for asynchronous User Agent Request operations."""
1038
+
1039
+ @pytest.mark.asyncio
1040
+ async def test_async_expiring_e2e_user_agent_request_flow(
1041
+ self,
1042
+ admin_client: RestClient,
1043
+ running_trajectory_id: str,
1044
+ ):
1045
+ """Tests the full async lifecycle: acreate, aget, alist, and arespond."""
1046
+ payload = UserAgentRequestPostPayload(
1047
+ trajectory_id=running_trajectory_id,
1048
+ request={"question": "Async: Do you approve?"},
1049
+ user_response_task=TaskRequest(
1050
+ name=JobNames.from_string("dummy"),
1051
+ query="Why would I follow up on this query?",
1052
+ ).model_dump(mode="json"),
1053
+ expires_in_seconds=10,
1054
+ )
1055
+
1056
+ request_id = await admin_client.acreate_user_agent_request(payload)
1057
+ assert isinstance(request_id, UUID)
1058
+
1059
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1060
+ assert retrieved_req.id == request_id
1061
+ assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
1062
+ assert retrieved_req.status == UserAgentRequestStatus.PENDING
1063
+
1064
+ request_list = await admin_client.alist_user_agent_requests(
1065
+ trajectory_id=UUID(running_trajectory_id)
1066
+ )
1067
+ assert isinstance(request_list, list)
1068
+ assert any(req.id == request_id for req in request_list)
1069
+
1070
+ # ensure we allow it to expire so auto response can happen
1071
+ await asyncio.sleep(10)
1072
+
1073
+ # now this should be expired
1074
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1075
+ assert retrieved_req.status == UserAgentRequestStatus.EXPIRED
1076
+
1077
+ # we should also see the job having started -- along with the registration of the job in the
1078
+ job_data = await admin_client.aget_task(
1079
+ cast(str, retrieved_req.response_trajectory_id)
1080
+ )
1081
+ assert job_data.status == "in progress"
1082
+
1083
+ # 4. RESPOND to the request -- ensure nothing changes
1084
+ ignored_response = {"answer": "Async Yes"}
1085
+ response_payload = UserAgentResponsePayload(response=ignored_response)
1086
+ await admin_client.arespond_to_user_agent_request(request_id, response_payload)
1087
+
1088
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1089
+ assert retrieved_req.response != ignored_response
1090
+
1091
+ @pytest.mark.asyncio
1092
+ async def test_async_e2e_user_agent_request_flow(
1093
+ self,
1094
+ admin_client: RestClient,
1095
+ running_trajectory_id: str,
1096
+ ):
1097
+ """Tests the full async lifecycle: acreate, aget, alist, and arespond."""
1098
+ # 1. CREATE a request
1099
+ payload = UserAgentRequestPostPayload(
1100
+ trajectory_id=running_trajectory_id,
1101
+ request={"question": "Async: Do you approve?"},
1102
+ user_response_task=TaskRequest(
1103
+ name=JobNames.from_string("dummy"),
1104
+ query="Why would I follow up on this query?",
1105
+ ).model_dump(mode="json"),
1106
+ )
1107
+
1108
+ request_id = await admin_client.acreate_user_agent_request(payload)
1109
+ assert isinstance(request_id, UUID)
1110
+
1111
+ # 2. GET the created request
1112
+ retrieved_req = await admin_client.aget_user_agent_request(request_id)
1113
+ assert retrieved_req.id == request_id
1114
+ assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
1115
+ assert retrieved_req.status == UserAgentRequestStatus.PENDING
1116
+
1117
+ # 3. LIST requests and find the created one
1118
+ request_list = await admin_client.alist_user_agent_requests(
1119
+ trajectory_id=UUID(running_trajectory_id)
1120
+ )
1121
+ assert isinstance(request_list, list)
1122
+ assert any(req.id == request_id for req in request_list)
1123
+
1124
+ # 4. RESPOND to the request
1125
+ response_payload = UserAgentResponsePayload(response={"answer": "Async Yes"})
1126
+ await admin_client.arespond_to_user_agent_request(request_id, response_payload)
1127
+
1128
+ # 5. GET the request again to verify the response
1129
+ responded_req = await admin_client.aget_user_agent_request(request_id)
1130
+ assert responded_req.status == UserAgentRequestStatus.RESPONDED
1131
+ assert responded_req.response == response_payload.response
1132
+
1133
+ @pytest.mark.asyncio
1134
+ async def test_aget_nonexistent_request_fails(self, admin_client: RestClient):
1135
+ """Verifies fetching a non-existent request asynchronously raises an error."""
1136
+ non_existent_id = uuid4()
1137
+ with pytest.raises(UserAgentRequestFetchError):
1138
+ await admin_client.aget_user_agent_request(non_existent_id)
1139
+
1140
+ @pytest.mark.asyncio
1141
+ async def test_async_unauthorized_access_fails(self, pub_client: RestClient):
1142
+ """Ensures an unauthorized client fails on async methods."""
1143
+ with pytest.raises((UserAgentRequestCreationError, PermissionError)): # noqa: PT012
1144
+ payload = UserAgentRequestPostPayload(
1145
+ trajectory_id=uuid4(),
1146
+ request={"data": "test"},
1147
+ )
1148
+ await pub_client.acreate_user_agent_request(payload)
1149
+
1150
+ with pytest.raises((UserAgentRequestFetchError, PermissionError)):
1151
+ await pub_client.aget_user_agent_request(uuid4())