futurehouse-client 0.3.20.dev225__tar.gz → 0.3.20.dev266__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {futurehouse_client-0.3.20.dev225/futurehouse_client.egg-info → futurehouse_client-0.3.20.dev266}/PKG-INFO +1 -1
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/clients/rest_client.py +539 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/app.py +79 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/client.py +5 -1
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/rest.py +48 -1
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/version.py +2 -2
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266/futurehouse_client.egg-info}/PKG-INFO +1 -1
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/tests/test_rest.py +279 -3
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/LICENSE +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/README.md +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/docs/__init__.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/docs/client_notebook.ipynb +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/__init__.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/clients/__init__.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/clients/job_client.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/models/__init__.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/py.typed +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/__init__.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/auth.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/general.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/module_utils.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/utils/monitoring.py +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/SOURCES.txt +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/dependency_links.txt +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/requires.txt +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client.egg-info/top_level.txt +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/pyproject.toml +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/setup.cfg +0 -0
- {futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/tests/test_client.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.20.
|
3
|
+
Version: 0.3.20.dev266
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
License: Apache License
|
@@ -14,6 +14,7 @@ import time
|
|
14
14
|
import uuid
|
15
15
|
from collections.abc import Collection
|
16
16
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
17
|
+
from http import HTTPStatus
|
17
18
|
from pathlib import Path
|
18
19
|
from types import ModuleType
|
19
20
|
from typing import Any, ClassVar, cast
|
@@ -54,9 +55,14 @@ from futurehouse_client.models.app import (
|
|
54
55
|
TaskRequest,
|
55
56
|
TaskResponse,
|
56
57
|
TaskResponseVerbose,
|
58
|
+
TrajectoryQueryParams,
|
57
59
|
)
|
58
60
|
from futurehouse_client.models.rest import (
|
59
61
|
ExecutionStatus,
|
62
|
+
UserAgentRequest,
|
63
|
+
UserAgentRequestPostPayload,
|
64
|
+
UserAgentRequestStatus,
|
65
|
+
UserAgentResponsePayload,
|
60
66
|
WorldModel,
|
61
67
|
WorldModelResponse,
|
62
68
|
)
|
@@ -105,6 +111,22 @@ class JobCreationError(RestClientError):
|
|
105
111
|
"""Raised when there's an error creating a job."""
|
106
112
|
|
107
113
|
|
114
|
+
class UserAgentRequestError(RestClientError):
|
115
|
+
"""Base exception for User Agent Request operations."""
|
116
|
+
|
117
|
+
|
118
|
+
class UserAgentRequestFetchError(UserAgentRequestError):
|
119
|
+
"""Raised when there's an error fetching a user agent request."""
|
120
|
+
|
121
|
+
|
122
|
+
class UserAgentRequestCreationError(UserAgentRequestError):
|
123
|
+
"""Raised when there's an error creating a user agent request."""
|
124
|
+
|
125
|
+
|
126
|
+
class UserAgentRequestResponseError(UserAgentRequestError):
|
127
|
+
"""Raised when there's an error responding to a user agent request."""
|
128
|
+
|
129
|
+
|
108
130
|
class WorldModelFetchError(RestClientError):
|
109
131
|
"""Raised when there's an error fetching a world model."""
|
110
132
|
|
@@ -535,6 +557,52 @@ class RestClient:
|
|
535
557
|
return verbose_response
|
536
558
|
return JobNames.get_response_object_from_job(verbose_response.job_name)(**data)
|
537
559
|
|
560
|
+
@retry(
|
561
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
562
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
563
|
+
retry=retry_if_connection_error,
|
564
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
565
|
+
)
|
566
|
+
def cancel_task(self, task_id: str | None = None) -> bool:
|
567
|
+
"""Cancel a specific task/trajectory."""
|
568
|
+
task_id = task_id or self.trajectory_id
|
569
|
+
url = f"/v0.1/trajectories/{task_id}/cancel"
|
570
|
+
full_url = f"{self.base_url}{url}"
|
571
|
+
|
572
|
+
with external_trace(
|
573
|
+
url=full_url,
|
574
|
+
method="POST",
|
575
|
+
library="httpx",
|
576
|
+
custom_params={
|
577
|
+
"operation": "cancel_job",
|
578
|
+
"job_id": task_id,
|
579
|
+
},
|
580
|
+
):
|
581
|
+
get_task_response = self.get_task(task_id)
|
582
|
+
# cancel if task is in progress
|
583
|
+
if get_task_response.status == ExecutionStatus.IN_PROGRESS.value:
|
584
|
+
response = self.client.post(url)
|
585
|
+
try:
|
586
|
+
response.raise_for_status()
|
587
|
+
|
588
|
+
except HTTPStatusError as e:
|
589
|
+
if e.response.status_code in {
|
590
|
+
HTTPStatus.UNAUTHORIZED,
|
591
|
+
HTTPStatus.FORBIDDEN,
|
592
|
+
}:
|
593
|
+
raise PermissionError(
|
594
|
+
f"Error canceling task: Permission denied for task {task_id}"
|
595
|
+
) from e
|
596
|
+
if e.response.status_code == HTTPStatus.NOT_FOUND:
|
597
|
+
raise TaskFetchError(
|
598
|
+
f"Error canceling task: Trajectory not found for task {task_id}"
|
599
|
+
) from e
|
600
|
+
raise
|
601
|
+
|
602
|
+
get_task_response = self.get_task(task_id)
|
603
|
+
return get_task_response.status == ExecutionStatus.CANCELLED.value
|
604
|
+
return False
|
605
|
+
|
538
606
|
@retry(
|
539
607
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
540
608
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -927,6 +995,13 @@ class RestClient:
|
|
927
995
|
if config.task_queues_config
|
928
996
|
else None
|
929
997
|
),
|
998
|
+
"user_input_config": (
|
999
|
+
json.dumps([
|
1000
|
+
entity.model_dump() for entity in config.user_input_config
|
1001
|
+
])
|
1002
|
+
if config.user_input_config
|
1003
|
+
else None
|
1004
|
+
),
|
930
1005
|
}
|
931
1006
|
response = self.multipart_client.post(
|
932
1007
|
"/v0.1/builds",
|
@@ -1787,6 +1862,470 @@ class RestClient:
|
|
1787
1862
|
except Exception as e:
|
1788
1863
|
raise ProjectError(f"Error adding trajectory to project: {e}") from e
|
1789
1864
|
|
1865
|
+
@retry(
|
1866
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1867
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1868
|
+
retry=retry_if_connection_error,
|
1869
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
1870
|
+
)
|
1871
|
+
def get_tasks(
|
1872
|
+
self,
|
1873
|
+
query_params: TrajectoryQueryParams | None = None,
|
1874
|
+
*,
|
1875
|
+
project_id: UUID | None = None,
|
1876
|
+
name: str | None = None,
|
1877
|
+
user: str | None = None,
|
1878
|
+
limit: int = 50,
|
1879
|
+
offset: int = 0,
|
1880
|
+
sort_by: str = "created_at",
|
1881
|
+
sort_order: str = "desc",
|
1882
|
+
) -> list[dict[str, Any]]:
|
1883
|
+
"""Fetches trajectories with applied filtering.
|
1884
|
+
|
1885
|
+
Args:
|
1886
|
+
query_params: Optional TrajectoryQueryParams model with all parameters
|
1887
|
+
project_id: Optional project ID to filter trajectories by
|
1888
|
+
name: Optional name filter for trajectories
|
1889
|
+
user: Optional user email filter for trajectories
|
1890
|
+
limit: Maximum number of trajectories to return (default: 50, max: 200)
|
1891
|
+
offset: Number of trajectories to skip for pagination (default: 0)
|
1892
|
+
sort_by: Field to sort by, either "created_at" or "name" (default: "created_at")
|
1893
|
+
sort_order: Sort order, either "asc" or "desc" (default: "desc")
|
1894
|
+
|
1895
|
+
Returns:
|
1896
|
+
List of trajectory dictionaries
|
1897
|
+
|
1898
|
+
Raises:
|
1899
|
+
TaskFetchError: If there's an error fetching trajectories
|
1900
|
+
"""
|
1901
|
+
try:
|
1902
|
+
if query_params is not None:
|
1903
|
+
params = query_params.to_query_params()
|
1904
|
+
else:
|
1905
|
+
params_model = TrajectoryQueryParams(
|
1906
|
+
project_id=project_id,
|
1907
|
+
name=name,
|
1908
|
+
user=user,
|
1909
|
+
limit=limit,
|
1910
|
+
offset=offset,
|
1911
|
+
sort_by=sort_by,
|
1912
|
+
sort_order=sort_order,
|
1913
|
+
)
|
1914
|
+
params = params_model.to_query_params()
|
1915
|
+
|
1916
|
+
response = self.client.get("/v0.1/trajectories", params=params)
|
1917
|
+
response.raise_for_status()
|
1918
|
+
return response.json()
|
1919
|
+
except HTTPStatusError as e:
|
1920
|
+
if e.response.status_code in {401, 403}:
|
1921
|
+
raise PermissionError(
|
1922
|
+
"Error getting trajectories: Permission denied"
|
1923
|
+
) from e
|
1924
|
+
raise TaskFetchError(
|
1925
|
+
f"Error getting trajectories: {e.response.status_code} - {e.response.text}"
|
1926
|
+
) from e
|
1927
|
+
except Exception as e:
|
1928
|
+
raise TaskFetchError(f"Error getting trajectories: {e!r}") from e
|
1929
|
+
|
1930
|
+
@retry(
|
1931
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1932
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1933
|
+
retry=retry_if_connection_error,
|
1934
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
1935
|
+
)
|
1936
|
+
async def aget_tasks(
|
1937
|
+
self,
|
1938
|
+
query_params: TrajectoryQueryParams | None = None,
|
1939
|
+
*,
|
1940
|
+
project_id: UUID | None = None,
|
1941
|
+
name: str | None = None,
|
1942
|
+
user: str | None = None,
|
1943
|
+
limit: int = 50,
|
1944
|
+
offset: int = 0,
|
1945
|
+
sort_by: str = "created_at",
|
1946
|
+
sort_order: str = "desc",
|
1947
|
+
) -> list[dict[str, Any]]:
|
1948
|
+
"""Asynchronously fetch trajectories with applied filtering.
|
1949
|
+
|
1950
|
+
Args:
|
1951
|
+
query_params: Optional TrajectoryQueryParams model with all parameters
|
1952
|
+
project_id: Optional project ID to filter trajectories by
|
1953
|
+
name: Optional name filter for trajectories
|
1954
|
+
user: Optional user email filter for trajectories
|
1955
|
+
limit: Maximum number of trajectories to return (default: 50, max: 200)
|
1956
|
+
offset: Number of trajectories to skip for pagination (default: 0)
|
1957
|
+
sort_by: Field to sort by, either "created_at" or "name" (default: "created_at")
|
1958
|
+
sort_order: Sort order, either "asc" or "desc" (default: "desc")
|
1959
|
+
|
1960
|
+
Returns:
|
1961
|
+
List of trajectory dictionaries
|
1962
|
+
|
1963
|
+
Raises:
|
1964
|
+
TaskFetchError: If there's an error fetching trajectories
|
1965
|
+
"""
|
1966
|
+
try:
|
1967
|
+
if query_params is not None:
|
1968
|
+
params = query_params.to_query_params()
|
1969
|
+
else:
|
1970
|
+
params_model = TrajectoryQueryParams(
|
1971
|
+
project_id=project_id,
|
1972
|
+
name=name,
|
1973
|
+
user=user,
|
1974
|
+
limit=limit,
|
1975
|
+
offset=offset,
|
1976
|
+
sort_by=sort_by,
|
1977
|
+
sort_order=sort_order,
|
1978
|
+
)
|
1979
|
+
params = params_model.to_query_params()
|
1980
|
+
|
1981
|
+
response = await self.async_client.get("/v0.1/trajectories", params=params)
|
1982
|
+
response.raise_for_status()
|
1983
|
+
return response.json()
|
1984
|
+
except HTTPStatusError as e:
|
1985
|
+
if e.response.status_code in {401, 403}:
|
1986
|
+
raise PermissionError(
|
1987
|
+
"Error getting trajectories: Permission denied"
|
1988
|
+
) from e
|
1989
|
+
raise TaskFetchError(
|
1990
|
+
f"Error getting trajectories: {e.response.status_code} - {e.response.text}"
|
1991
|
+
) from e
|
1992
|
+
except Exception as e:
|
1993
|
+
raise TaskFetchError(f"Error getting trajectories: {e!r}") from e
|
1994
|
+
|
1995
|
+
@retry(
|
1996
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1997
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
1998
|
+
retry=retry_if_connection_error,
|
1999
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2000
|
+
)
|
2001
|
+
def list_user_agent_requests(
|
2002
|
+
self,
|
2003
|
+
user_id: str | None = None,
|
2004
|
+
trajectory_id: UUID | None = None,
|
2005
|
+
request_status: UserAgentRequestStatus | None = None,
|
2006
|
+
limit: int = 50,
|
2007
|
+
offset: int = 0,
|
2008
|
+
) -> list[UserAgentRequest]:
|
2009
|
+
"""List user agent requests with optional filters.
|
2010
|
+
|
2011
|
+
Args:
|
2012
|
+
user_id: Filter requests by user ID. Defaults to the authenticated user's ID if not provided.
|
2013
|
+
trajectory_id: Filter requests by trajectory ID.
|
2014
|
+
request_status: Filter requests by status (e.g., PENDING).
|
2015
|
+
limit: Maximum number of requests to return.
|
2016
|
+
offset: Offset for pagination.
|
2017
|
+
|
2018
|
+
Returns:
|
2019
|
+
A list of user agent requests.
|
2020
|
+
|
2021
|
+
Raises:
|
2022
|
+
UserAgentRequestFetchError: If the API call fails.
|
2023
|
+
"""
|
2024
|
+
params = {
|
2025
|
+
"user_id": user_id,
|
2026
|
+
"trajectory_id": str(trajectory_id) if trajectory_id else None,
|
2027
|
+
"request_status": request_status.value if request_status else None,
|
2028
|
+
"limit": limit,
|
2029
|
+
"offset": offset,
|
2030
|
+
}
|
2031
|
+
# Filter out None values
|
2032
|
+
params = {k: v for k, v in params.items() if v is not None}
|
2033
|
+
try:
|
2034
|
+
response = self.client.get("/v0.1/user-agent-requests", params=params)
|
2035
|
+
response.raise_for_status()
|
2036
|
+
return [UserAgentRequest.model_validate(item) for item in response.json()]
|
2037
|
+
except HTTPStatusError as e:
|
2038
|
+
raise UserAgentRequestFetchError(
|
2039
|
+
f"Error listing user agent requests: {e.response.status_code} - {e.response.text}"
|
2040
|
+
) from e
|
2041
|
+
except Exception as e:
|
2042
|
+
raise UserAgentRequestFetchError(
|
2043
|
+
f"An unexpected error occurred: {e!r}"
|
2044
|
+
) from e
|
2045
|
+
|
2046
|
+
@retry(
|
2047
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2048
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2049
|
+
retry=retry_if_connection_error,
|
2050
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2051
|
+
)
|
2052
|
+
async def alist_user_agent_requests(
|
2053
|
+
self,
|
2054
|
+
user_id: str | None = None,
|
2055
|
+
trajectory_id: UUID | None = None,
|
2056
|
+
request_status: UserAgentRequestStatus | None = None,
|
2057
|
+
limit: int = 50,
|
2058
|
+
offset: int = 0,
|
2059
|
+
) -> list[UserAgentRequest]:
|
2060
|
+
"""Asynchronously list user agent requests with optional filters.
|
2061
|
+
|
2062
|
+
Args:
|
2063
|
+
user_id: Filter requests by user ID. Defaults to the authenticated user's ID if not provided.
|
2064
|
+
trajectory_id: Filter requests by trajectory ID.
|
2065
|
+
request_status: Filter requests by status (e.g., PENDING).
|
2066
|
+
limit: Maximum number of requests to return.
|
2067
|
+
offset: Offset for pagination.
|
2068
|
+
|
2069
|
+
Returns:
|
2070
|
+
A list of user agent requests.
|
2071
|
+
|
2072
|
+
Raises:
|
2073
|
+
UserAgentRequestFetchError: If the API call fails.
|
2074
|
+
"""
|
2075
|
+
params = {
|
2076
|
+
"user_id": user_id,
|
2077
|
+
"trajectory_id": str(trajectory_id) if trajectory_id else None,
|
2078
|
+
"request_status": request_status.value if request_status else None,
|
2079
|
+
"limit": limit,
|
2080
|
+
"offset": offset,
|
2081
|
+
}
|
2082
|
+
# Filter out None values
|
2083
|
+
params = {k: v for k, v in params.items() if v is not None}
|
2084
|
+
try:
|
2085
|
+
response = await self.async_client.get(
|
2086
|
+
"/v0.1/user-agent-requests", params=params
|
2087
|
+
)
|
2088
|
+
response.raise_for_status()
|
2089
|
+
return [UserAgentRequest.model_validate(item) for item in response.json()]
|
2090
|
+
except HTTPStatusError as e:
|
2091
|
+
raise UserAgentRequestFetchError(
|
2092
|
+
f"Error listing user agent requests: {e.response.status_code} - {e.response.text}"
|
2093
|
+
) from e
|
2094
|
+
except Exception as e:
|
2095
|
+
raise UserAgentRequestFetchError(
|
2096
|
+
f"An unexpected error occurred: {e!r}"
|
2097
|
+
) from e
|
2098
|
+
|
2099
|
+
@retry(
|
2100
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2101
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2102
|
+
retry=retry_if_connection_error,
|
2103
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2104
|
+
)
|
2105
|
+
def get_user_agent_request(self, request_id: UUID) -> UserAgentRequest:
|
2106
|
+
"""Retrieve a single user agent request by its unique ID.
|
2107
|
+
|
2108
|
+
Args:
|
2109
|
+
request_id: The unique ID of the request.
|
2110
|
+
|
2111
|
+
Returns:
|
2112
|
+
The user agent request.
|
2113
|
+
|
2114
|
+
Raises:
|
2115
|
+
UserAgentRequestFetchError: If the API call fails or the request is not found.
|
2116
|
+
"""
|
2117
|
+
try:
|
2118
|
+
response = self.client.get(f"/v0.1/user-agent-requests/{request_id}")
|
2119
|
+
response.raise_for_status()
|
2120
|
+
return UserAgentRequest.model_validate(response.json())
|
2121
|
+
except HTTPStatusError as e:
|
2122
|
+
if e.response.status_code == codes.NOT_FOUND:
|
2123
|
+
raise UserAgentRequestFetchError(
|
2124
|
+
f"User agent request with ID {request_id} not found."
|
2125
|
+
) from e
|
2126
|
+
raise UserAgentRequestFetchError(
|
2127
|
+
f"Error fetching user agent request: {e.response.status_code} - {e.response.text}"
|
2128
|
+
) from e
|
2129
|
+
except Exception as e:
|
2130
|
+
raise UserAgentRequestFetchError(
|
2131
|
+
f"An unexpected error occurred: {e!r}"
|
2132
|
+
) from e
|
2133
|
+
|
2134
|
+
@retry(
|
2135
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2136
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2137
|
+
retry=retry_if_connection_error,
|
2138
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2139
|
+
)
|
2140
|
+
async def aget_user_agent_request(self, request_id: UUID) -> UserAgentRequest:
|
2141
|
+
"""Asynchronously retrieve a single user agent request by its unique ID.
|
2142
|
+
|
2143
|
+
Args:
|
2144
|
+
request_id: The unique ID of the request.
|
2145
|
+
|
2146
|
+
Returns:
|
2147
|
+
The user agent request.
|
2148
|
+
|
2149
|
+
Raises:
|
2150
|
+
UserAgentRequestFetchError: If the API call fails or the request is not found.
|
2151
|
+
"""
|
2152
|
+
try:
|
2153
|
+
response = await self.async_client.get(
|
2154
|
+
f"/v0.1/user-agent-requests/{request_id}"
|
2155
|
+
)
|
2156
|
+
response.raise_for_status()
|
2157
|
+
return UserAgentRequest.model_validate(response.json())
|
2158
|
+
except HTTPStatusError as e:
|
2159
|
+
if e.response.status_code == codes.NOT_FOUND:
|
2160
|
+
raise UserAgentRequestFetchError(
|
2161
|
+
f"User agent request with ID {request_id} not found."
|
2162
|
+
) from e
|
2163
|
+
raise UserAgentRequestFetchError(
|
2164
|
+
f"Error fetching user agent request: {e.response.status_code} - {e.response.text}"
|
2165
|
+
) from e
|
2166
|
+
except Exception as e:
|
2167
|
+
raise UserAgentRequestFetchError(
|
2168
|
+
f"An unexpected error occurred: {e!r}"
|
2169
|
+
) from e
|
2170
|
+
|
2171
|
+
@retry(
|
2172
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2173
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2174
|
+
retry=retry_if_connection_error,
|
2175
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2176
|
+
)
|
2177
|
+
def create_user_agent_request(self, payload: UserAgentRequestPostPayload) -> UUID:
|
2178
|
+
"""Creates a new request from an agent to a user.
|
2179
|
+
|
2180
|
+
Args:
|
2181
|
+
payload: An instance of UserAgentRequestPostPayload with the request data.
|
2182
|
+
|
2183
|
+
Returns:
|
2184
|
+
The UUID of the newly created user agent request.
|
2185
|
+
|
2186
|
+
Raises:
|
2187
|
+
UserAgentRequestCreationError: If the API call fails.
|
2188
|
+
"""
|
2189
|
+
try:
|
2190
|
+
response = self.client.post(
|
2191
|
+
"/v0.1/user-agent-requests", json=payload.model_dump(mode="json")
|
2192
|
+
)
|
2193
|
+
response.raise_for_status()
|
2194
|
+
return UUID(response.json())
|
2195
|
+
except HTTPStatusError as e:
|
2196
|
+
if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
|
2197
|
+
raise UserAgentRequestCreationError(
|
2198
|
+
f"Invalid payload for user agent request creation: {e.response.text}."
|
2199
|
+
) from e
|
2200
|
+
raise UserAgentRequestCreationError(
|
2201
|
+
f"Error creating user agent request: {e.response.status_code} - {e.response.text}"
|
2202
|
+
) from e
|
2203
|
+
except Exception as e:
|
2204
|
+
raise UserAgentRequestCreationError(
|
2205
|
+
f"An unexpected error occurred during user agent request creation: {e!r}."
|
2206
|
+
) from e
|
2207
|
+
|
2208
|
+
@retry(
|
2209
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2210
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2211
|
+
retry=retry_if_connection_error,
|
2212
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2213
|
+
)
|
2214
|
+
async def acreate_user_agent_request(
|
2215
|
+
self, payload: UserAgentRequestPostPayload
|
2216
|
+
) -> UUID:
|
2217
|
+
"""Asynchronously creates a new request from an agent to a user.
|
2218
|
+
|
2219
|
+
Args:
|
2220
|
+
payload: An instance of UserAgentRequestPostPayload with the request data.
|
2221
|
+
|
2222
|
+
Returns:
|
2223
|
+
The UUID of the newly created user agent request.
|
2224
|
+
|
2225
|
+
Raises:
|
2226
|
+
UserAgentRequestCreationError: If the API call fails.
|
2227
|
+
"""
|
2228
|
+
try:
|
2229
|
+
response = await self.async_client.post(
|
2230
|
+
"/v0.1/user-agent-requests", json=payload.model_dump(mode="json")
|
2231
|
+
)
|
2232
|
+
response.raise_for_status()
|
2233
|
+
return UUID(response.json())
|
2234
|
+
except HTTPStatusError as e:
|
2235
|
+
if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
|
2236
|
+
raise UserAgentRequestCreationError(
|
2237
|
+
f"Invalid payload for user agent request creation: {e.response.text}."
|
2238
|
+
) from e
|
2239
|
+
raise UserAgentRequestCreationError(
|
2240
|
+
f"Error creating user agent request: {e.response.status_code} - {e.response.text}"
|
2241
|
+
) from e
|
2242
|
+
except Exception as e:
|
2243
|
+
raise UserAgentRequestCreationError(
|
2244
|
+
f"An unexpected error occurred during user agent request creation: {e!r}."
|
2245
|
+
) from e
|
2246
|
+
|
2247
|
+
@retry(
|
2248
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2249
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2250
|
+
retry=retry_if_connection_error,
|
2251
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2252
|
+
)
|
2253
|
+
def respond_to_user_agent_request(
|
2254
|
+
self, request_id: UUID, payload: UserAgentResponsePayload
|
2255
|
+
) -> None:
|
2256
|
+
"""Submit a user's response to a pending agent request.
|
2257
|
+
|
2258
|
+
Args:
|
2259
|
+
request_id: The unique ID of the request to respond to.
|
2260
|
+
payload: An instance of UserAgentResponsePayload with the response data.
|
2261
|
+
|
2262
|
+
Raises:
|
2263
|
+
UserAgentRequestResponseError: If the API call fails.
|
2264
|
+
"""
|
2265
|
+
try:
|
2266
|
+
response = self.client.post(
|
2267
|
+
f"/v0.1/user-agent-requests/{request_id}/response",
|
2268
|
+
json=payload.model_dump(mode="json"),
|
2269
|
+
)
|
2270
|
+
response.raise_for_status()
|
2271
|
+
except HTTPStatusError as e:
|
2272
|
+
if e.response.status_code == codes.NOT_FOUND:
|
2273
|
+
raise UserAgentRequestResponseError(
|
2274
|
+
f"User agent request with ID {request_id} not found."
|
2275
|
+
) from e
|
2276
|
+
if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
|
2277
|
+
raise UserAgentRequestResponseError(
|
2278
|
+
f"Invalid response payload: {e.response.text}."
|
2279
|
+
) from e
|
2280
|
+
raise UserAgentRequestResponseError(
|
2281
|
+
f"Error responding to user agent request: {e.response.status_code} - {e.response.text}"
|
2282
|
+
) from e
|
2283
|
+
except Exception as e:
|
2284
|
+
raise UserAgentRequestResponseError(
|
2285
|
+
f"An unexpected error occurred while responding to the request: {e!r}."
|
2286
|
+
) from e
|
2287
|
+
|
2288
|
+
@retry(
|
2289
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
2290
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
2291
|
+
retry=retry_if_connection_error,
|
2292
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
2293
|
+
)
|
2294
|
+
async def arespond_to_user_agent_request(
|
2295
|
+
self, request_id: UUID, payload: UserAgentResponsePayload
|
2296
|
+
) -> None:
|
2297
|
+
"""Asynchronously submit a user's response to a pending agent request.
|
2298
|
+
|
2299
|
+
Args:
|
2300
|
+
request_id: The unique ID of the request to respond to.
|
2301
|
+
payload: An instance of UserAgentResponsePayload with the response data.
|
2302
|
+
|
2303
|
+
Raises:
|
2304
|
+
UserAgentRequestResponseError: If the API call fails.
|
2305
|
+
"""
|
2306
|
+
try:
|
2307
|
+
response = await self.async_client.post(
|
2308
|
+
f"/v0.1/user-agent-requests/{request_id}/response",
|
2309
|
+
json=payload.model_dump(mode="json"),
|
2310
|
+
)
|
2311
|
+
response.raise_for_status()
|
2312
|
+
except HTTPStatusError as e:
|
2313
|
+
if e.response.status_code == codes.NOT_FOUND:
|
2314
|
+
raise UserAgentRequestResponseError(
|
2315
|
+
f"User agent request with ID {request_id} not found."
|
2316
|
+
) from e
|
2317
|
+
if e.response.status_code == codes.UNPROCESSABLE_ENTITY:
|
2318
|
+
raise UserAgentRequestResponseError(
|
2319
|
+
f"Invalid response payload: {e.response.text}."
|
2320
|
+
) from e
|
2321
|
+
raise UserAgentRequestResponseError(
|
2322
|
+
f"Error responding to user agent request: {e.response.status_code} - {e.response.text}"
|
2323
|
+
) from e
|
2324
|
+
except Exception as e:
|
2325
|
+
raise UserAgentRequestResponseError(
|
2326
|
+
f"An unexpected error occurred while responding to the request: {e!r}."
|
2327
|
+
) from e
|
2328
|
+
|
1790
2329
|
|
1791
2330
|
def get_installed_packages() -> dict[str, str]:
|
1792
2331
|
"""Returns a dictionary of installed packages and their versions."""
|
@@ -278,6 +278,20 @@ class FramePath(BaseModel):
|
|
278
278
|
)
|
279
279
|
|
280
280
|
|
281
|
+
class NamedEntity(BaseModel):
|
282
|
+
name: str = Field(
|
283
|
+
description=(
|
284
|
+
"Name of an entity for a user to provide a value to during query submission. "
|
285
|
+
"This will be used as a key to prompt users for a value. "
|
286
|
+
"Example: 'pdb' would result in <pdb>user input here</pdb> in the task string."
|
287
|
+
)
|
288
|
+
)
|
289
|
+
description: str | None = Field(
|
290
|
+
default=None,
|
291
|
+
description="Helper text to provide the user context to what the name or format needs to be.",
|
292
|
+
)
|
293
|
+
|
294
|
+
|
281
295
|
class DockerContainerConfiguration(BaseModel):
|
282
296
|
cpu: str = Field(description="CPU allotment for the container")
|
283
297
|
memory: str = Field(description="Memory allotment for the container")
|
@@ -458,6 +472,15 @@ class JobDeploymentConfig(BaseModel):
|
|
458
472
|
description="The configuration for the task queue(s) that will be created for this deployment.",
|
459
473
|
)
|
460
474
|
|
475
|
+
user_input_config: list[NamedEntity] | None = Field(
|
476
|
+
default=None,
|
477
|
+
description=(
|
478
|
+
"List of NamedEntity objects that represent user input fields "
|
479
|
+
"to be included in the task string. "
|
480
|
+
"These will be used to prompt users for values during query submission."
|
481
|
+
),
|
482
|
+
)
|
483
|
+
|
461
484
|
@field_validator("markdown_template_path")
|
462
485
|
@classmethod
|
463
486
|
def validate_markdown_path(
|
@@ -646,6 +669,62 @@ class RuntimeConfig(BaseModel):
|
|
646
669
|
return value
|
647
670
|
|
648
671
|
|
672
|
+
class TrajectoryQueryParams(BaseModel):
|
673
|
+
"""Params for trajectories with filtering."""
|
674
|
+
|
675
|
+
model_config = ConfigDict(extra="forbid")
|
676
|
+
|
677
|
+
project_id: UUID | None = Field(
|
678
|
+
default=None, description="Optional project ID to filter trajectories by"
|
679
|
+
)
|
680
|
+
name: str | None = Field(
|
681
|
+
default=None, description="Optional name filter for trajectories"
|
682
|
+
)
|
683
|
+
user: str | None = Field(
|
684
|
+
default=None, description="Optional user email filter for trajectories"
|
685
|
+
)
|
686
|
+
limit: int = Field(
|
687
|
+
default=50,
|
688
|
+
ge=1,
|
689
|
+
le=200,
|
690
|
+
description="Maximum number of trajectories to return (max: 200)",
|
691
|
+
)
|
692
|
+
offset: int = Field(
|
693
|
+
default=0, ge=0, description="Number of trajectories to skip for pagination"
|
694
|
+
)
|
695
|
+
sort_by: str = Field(default="created_at", description="Field to sort by")
|
696
|
+
sort_order: str = Field(default="desc", description="Sort order")
|
697
|
+
|
698
|
+
@field_validator("sort_by")
|
699
|
+
@classmethod
|
700
|
+
def validate_sort_by(cls, v: str) -> str:
|
701
|
+
if v not in {"created_at", "name"}:
|
702
|
+
raise ValueError("sort_by must be either 'created_at' or 'name'")
|
703
|
+
return v
|
704
|
+
|
705
|
+
@field_validator("sort_order")
|
706
|
+
@classmethod
|
707
|
+
def validate_sort_order(cls, v: str) -> str:
|
708
|
+
if v not in {"asc", "desc"}:
|
709
|
+
raise ValueError("sort_order must be either 'asc' or 'desc'")
|
710
|
+
return v
|
711
|
+
|
712
|
+
def to_query_params(self) -> dict[str, str | int]:
|
713
|
+
params: dict[str, str | int] = {
|
714
|
+
"limit": self.limit,
|
715
|
+
"offset": self.offset,
|
716
|
+
"sort_by": self.sort_by,
|
717
|
+
"sort_order": self.sort_order,
|
718
|
+
}
|
719
|
+
if self.project_id is not None:
|
720
|
+
params["project_id"] = str(self.project_id)
|
721
|
+
if self.name is not None:
|
722
|
+
params["name"] = self.name
|
723
|
+
if self.user is not None:
|
724
|
+
params["user"] = self.user
|
725
|
+
return params
|
726
|
+
|
727
|
+
|
649
728
|
class TaskRequest(BaseModel):
|
650
729
|
model_config = ConfigDict(extra="forbid")
|
651
730
|
|
@@ -27,13 +27,17 @@ class InitialState(BaseState):
|
|
27
27
|
|
28
28
|
class ASVState(BaseState, Generic[T]):
|
29
29
|
action: OpResult[T] = Field()
|
30
|
-
|
30
|
+
next_state: Any = Field()
|
31
31
|
value: float = Field()
|
32
32
|
|
33
33
|
@field_serializer("action")
|
34
34
|
def serialize_action(self, action: OpResult[T]) -> dict:
|
35
35
|
return action.to_dict()
|
36
36
|
|
37
|
+
@field_serializer("next_state")
|
38
|
+
def serialize_next_state(self, state: Any) -> str:
|
39
|
+
return str(state)
|
40
|
+
|
37
41
|
|
38
42
|
class EnvResetState(BaseState):
|
39
43
|
observations: list[Message] = Field()
|
@@ -2,7 +2,7 @@ from datetime import datetime
|
|
2
2
|
from enum import StrEnum, auto
|
3
3
|
from uuid import UUID
|
4
4
|
|
5
|
-
from pydantic import BaseModel, JsonValue
|
5
|
+
from pydantic import BaseModel, ConfigDict, JsonValue
|
6
6
|
|
7
7
|
|
8
8
|
class FinalEnvironmentRequest(BaseModel):
|
@@ -70,3 +70,50 @@ class WorldModelResponse(BaseModel):
|
|
70
70
|
model_metadata: JsonValue | None
|
71
71
|
enabled: bool
|
72
72
|
created_at: datetime
|
73
|
+
|
74
|
+
|
75
|
+
class UserAgentRequestStatus(StrEnum):
|
76
|
+
"""Enum for the status of a user agent request."""
|
77
|
+
|
78
|
+
PENDING = auto()
|
79
|
+
RESPONDED = auto()
|
80
|
+
EXPIRED = auto()
|
81
|
+
CANCELLED = auto()
|
82
|
+
|
83
|
+
|
84
|
+
class UserAgentRequest(BaseModel):
|
85
|
+
"""Sister model for UserAgentRequestsDB."""
|
86
|
+
|
87
|
+
model_config = ConfigDict(from_attributes=True)
|
88
|
+
|
89
|
+
id: UUID
|
90
|
+
user_id: str
|
91
|
+
trajectory_id: UUID
|
92
|
+
response_trajectory_id: UUID | None = None
|
93
|
+
request: JsonValue
|
94
|
+
response: JsonValue | None = None
|
95
|
+
request_world_model_edit_id: UUID | None = None
|
96
|
+
response_world_model_edit_id: UUID | None = None
|
97
|
+
expires_at: datetime | None = None
|
98
|
+
user_response_task: JsonValue | None = None
|
99
|
+
status: UserAgentRequestStatus
|
100
|
+
created_at: datetime | None = None
|
101
|
+
modified_at: datetime | None = None
|
102
|
+
|
103
|
+
|
104
|
+
class UserAgentRequestPostPayload(BaseModel):
|
105
|
+
"""Payload to create a new user agent request."""
|
106
|
+
|
107
|
+
trajectory_id: UUID
|
108
|
+
request: JsonValue
|
109
|
+
request_world_model_edit_id: UUID | None = None
|
110
|
+
status: UserAgentRequestStatus = UserAgentRequestStatus.PENDING
|
111
|
+
expires_in_seconds: int | None = None
|
112
|
+
user_response_task: JsonValue | None = None
|
113
|
+
|
114
|
+
|
115
|
+
class UserAgentResponsePayload(BaseModel):
|
116
|
+
"""Payload for a user to submit a response to a request."""
|
117
|
+
|
118
|
+
response: JsonValue
|
119
|
+
response_world_model_edit_id: UUID | None = None
|
{futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/version.py
RENAMED
@@ -17,5 +17,5 @@ __version__: str
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
18
18
|
version_tuple: VERSION_TUPLE
|
19
19
|
|
20
|
-
__version__ = version = '0.3.20.
|
21
|
-
__version_tuple__ = version_tuple = (0, 3, 20, '
|
20
|
+
__version__ = version = '0.3.20.dev266'
|
21
|
+
__version_tuple__ = version_tuple = (0, 3, 20, 'dev266')
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.20.
|
3
|
+
Version: 0.3.20.dev266
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
License: Apache License
|
@@ -7,6 +7,7 @@ import tempfile
|
|
7
7
|
import time
|
8
8
|
import types
|
9
9
|
from pathlib import Path
|
10
|
+
from typing import cast
|
10
11
|
from unittest.mock import MagicMock, mock_open, patch
|
11
12
|
from uuid import UUID, uuid4
|
12
13
|
|
@@ -19,6 +20,8 @@ from futurehouse_client.clients.rest_client import (
|
|
19
20
|
ProjectError,
|
20
21
|
RestClient,
|
21
22
|
RestClientError,
|
23
|
+
UserAgentRequestCreationError,
|
24
|
+
UserAgentRequestFetchError,
|
22
25
|
)
|
23
26
|
from futurehouse_client.models.app import (
|
24
27
|
PhoenixTaskResponse,
|
@@ -27,7 +30,13 @@ from futurehouse_client.models.app import (
|
|
27
30
|
TaskRequest,
|
28
31
|
TaskResponseVerbose,
|
29
32
|
)
|
30
|
-
from futurehouse_client.models.rest import
|
33
|
+
from futurehouse_client.models.rest import (
|
34
|
+
ExecutionStatus,
|
35
|
+
UserAgentRequestPostPayload,
|
36
|
+
UserAgentRequestStatus,
|
37
|
+
UserAgentResponsePayload,
|
38
|
+
WorldModel,
|
39
|
+
)
|
31
40
|
from pytest_subtests import SubTests
|
32
41
|
|
33
42
|
ADMIN_API_KEY = os.environ["PLAYWRIGHT_ADMIN_API_KEY"]
|
@@ -78,6 +87,11 @@ def phoenix_task_req():
|
|
78
87
|
)
|
79
88
|
|
80
89
|
|
90
|
+
@pytest.fixture
|
91
|
+
def running_trajectory_id(admin_client: RestClient, task_req: TaskRequest) -> str:
|
92
|
+
return admin_client.create_task(task_req)
|
93
|
+
|
94
|
+
|
81
95
|
@pytest.mark.timeout(300)
|
82
96
|
@pytest.mark.flaky(reruns=3)
|
83
97
|
def test_futurehouse_dummy_env_crow(admin_client: RestClient, task_req: TaskRequest):
|
@@ -247,6 +261,41 @@ async def test_timeout_run_until_done_futurehouse_dummy_env_crow(
|
|
247
261
|
)
|
248
262
|
|
249
263
|
|
264
|
+
@pytest.mark.timeout(300)
|
265
|
+
@pytest.mark.flaky(reruns=3)
|
266
|
+
def test_cancel_task(admin_client: RestClient):
|
267
|
+
"""Test successful task cancellation using MagicMock."""
|
268
|
+
task_id = "1c28bb94-efbf-442f-954c-f0d8ddb0cff5"
|
269
|
+
|
270
|
+
with (
|
271
|
+
patch.object(admin_client.client, "post") as mock_post,
|
272
|
+
patch.object(admin_client, "get_task") as mock_get_task,
|
273
|
+
):
|
274
|
+
mock_response = MagicMock()
|
275
|
+
mock_response.raise_for_status.return_value = None
|
276
|
+
mock_post.return_value = mock_response
|
277
|
+
|
278
|
+
mock_task_response_running = MagicMock()
|
279
|
+
mock_task_response_running.status = "in progress"
|
280
|
+
|
281
|
+
mock_task_response_cancelled = MagicMock()
|
282
|
+
mock_task_response_cancelled.status = ExecutionStatus.CANCELLED.value
|
283
|
+
|
284
|
+
mock_get_task.side_effect = [
|
285
|
+
mock_task_response_running,
|
286
|
+
mock_task_response_cancelled,
|
287
|
+
]
|
288
|
+
|
289
|
+
result = admin_client.cancel_task(task_id)
|
290
|
+
|
291
|
+
assert result is True
|
292
|
+
|
293
|
+
expected_url = f"/v0.1/trajectories/{task_id}/cancel"
|
294
|
+
mock_post.assert_called_once_with(expected_url)
|
295
|
+
|
296
|
+
assert mock_get_task.call_count == 2
|
297
|
+
|
298
|
+
|
250
299
|
class TestParallelChunking:
|
251
300
|
"""Test suite for parallel chunk upload functionality."""
|
252
301
|
|
@@ -752,7 +801,7 @@ def test_world_model_create_and_get(admin_client: RestClient):
|
|
752
801
|
class TestProjectOperations:
|
753
802
|
@pytest.fixture
|
754
803
|
def test_project_name(self):
|
755
|
-
return f"test-project-{
|
804
|
+
return f"test-project-{uuid4()}"
|
756
805
|
|
757
806
|
def test_create_project_success(
|
758
807
|
self, admin_client: RestClient, test_project_name: str
|
@@ -821,7 +870,7 @@ class TestProjectOperations:
|
|
821
870
|
class TestAsyncProjectOperations:
|
822
871
|
@pytest.fixture
|
823
872
|
def test_project_name(self):
|
824
|
-
return f"test-async-project-{
|
873
|
+
return f"test-async-project-{uuid4()}"
|
825
874
|
|
826
875
|
@pytest.mark.asyncio
|
827
876
|
async def test_acreate_project_success(
|
@@ -873,3 +922,230 @@ class TestAsyncProjectOperations:
|
|
873
922
|
await admin_client.aadd_task_to_project(
|
874
923
|
fake_project_id, fake_trajectory_id
|
875
924
|
)
|
925
|
+
|
926
|
+
|
927
|
+
@pytest.mark.timeout(300)
|
928
|
+
@pytest.mark.flaky(reruns=3)
|
929
|
+
def test_get_tasks_with_project_filter(admin_client: RestClient, task_req: TaskRequest):
|
930
|
+
"""Test retrieving trajectories filtered by project_id using real API calls."""
|
931
|
+
project_name = f"e2e-trajectories-fetch-{uuid4()}"
|
932
|
+
project_id = admin_client.create_project(project_name)
|
933
|
+
|
934
|
+
trajectory_id = admin_client.create_task(task_req)
|
935
|
+
admin_client.add_task_to_project(UUID(project_id), trajectory_id)
|
936
|
+
|
937
|
+
while (task_status := admin_client.get_task(trajectory_id).status) in {
|
938
|
+
"queued",
|
939
|
+
"in progress",
|
940
|
+
}:
|
941
|
+
time.sleep(5)
|
942
|
+
|
943
|
+
trajectories = admin_client.get_tasks(project_id=UUID(project_id))
|
944
|
+
|
945
|
+
trajectory_ids = [t["id"] for t in trajectories]
|
946
|
+
assert trajectory_id in trajectory_ids
|
947
|
+
|
948
|
+
|
949
|
+
@pytest.mark.timeout(300)
|
950
|
+
@pytest.mark.flaky(reruns=3)
|
951
|
+
@pytest.mark.asyncio
|
952
|
+
async def test_aget_tasks_with_project_filter(
|
953
|
+
admin_client: RestClient, task_req: TaskRequest
|
954
|
+
):
|
955
|
+
"""Test async retrieving trajectories filtered by project_id using real API calls."""
|
956
|
+
project_name = f"e2e-trajectories-async-fetch-{uuid4()}"
|
957
|
+
project_id = await admin_client.acreate_project(project_name)
|
958
|
+
|
959
|
+
trajectory_id = await admin_client.acreate_task(task_req)
|
960
|
+
await admin_client.aadd_task_to_project(UUID(project_id), trajectory_id)
|
961
|
+
|
962
|
+
while True:
|
963
|
+
task = await admin_client.aget_task(trajectory_id)
|
964
|
+
if task.status not in {"queued", "in progress"}:
|
965
|
+
break
|
966
|
+
await asyncio.sleep(5)
|
967
|
+
|
968
|
+
trajectories = await admin_client.aget_tasks(project_id=UUID(project_id))
|
969
|
+
|
970
|
+
trajectory_ids = [t["id"] for t in trajectories]
|
971
|
+
assert trajectory_id in trajectory_ids
|
972
|
+
|
973
|
+
|
974
|
+
class TestUserAgentRequestOperations:
|
975
|
+
"""Test suite for synchronous User Agent Request operations."""
|
976
|
+
|
977
|
+
@pytest.mark.flaky(reruns=3)
|
978
|
+
def test_e2e_user_agent_request_flow(
|
979
|
+
self,
|
980
|
+
admin_client: RestClient,
|
981
|
+
running_trajectory_id: str,
|
982
|
+
):
|
983
|
+
"""Tests the full lifecycle: create, get, list, and respond."""
|
984
|
+
payload = UserAgentRequestPostPayload(
|
985
|
+
trajectory_id=running_trajectory_id,
|
986
|
+
request={"question": "Do you approve?"},
|
987
|
+
)
|
988
|
+
request_id = admin_client.create_user_agent_request(payload)
|
989
|
+
assert isinstance(request_id, UUID)
|
990
|
+
|
991
|
+
# 2. GET the created request
|
992
|
+
retrieved_req = admin_client.get_user_agent_request(request_id)
|
993
|
+
assert retrieved_req.id == request_id
|
994
|
+
assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
|
995
|
+
assert retrieved_req.status == UserAgentRequestStatus.PENDING
|
996
|
+
assert retrieved_req.request == payload.request
|
997
|
+
|
998
|
+
# 3. LIST requests and find the created one
|
999
|
+
request_list = admin_client.list_user_agent_requests(
|
1000
|
+
trajectory_id=UUID(running_trajectory_id),
|
1001
|
+
request_status=UserAgentRequestStatus.PENDING,
|
1002
|
+
)
|
1003
|
+
assert isinstance(request_list, list)
|
1004
|
+
assert any(req.id == request_id for req in request_list)
|
1005
|
+
|
1006
|
+
# 4. RESPOND to the request
|
1007
|
+
response_payload = UserAgentResponsePayload(response={"answer": "Yes"})
|
1008
|
+
admin_client.respond_to_user_agent_request(request_id, response_payload)
|
1009
|
+
|
1010
|
+
# 5. GET the request again to verify the response
|
1011
|
+
responded_req = admin_client.get_user_agent_request(request_id)
|
1012
|
+
assert responded_req.status == UserAgentRequestStatus.RESPONDED
|
1013
|
+
assert responded_req.response == response_payload.response
|
1014
|
+
|
1015
|
+
def test_get_nonexistent_request_fails(self, admin_client: RestClient):
|
1016
|
+
"""Verifies that fetching a non-existent request raises an error."""
|
1017
|
+
non_existent_id = uuid4()
|
1018
|
+
with pytest.raises(UserAgentRequestFetchError):
|
1019
|
+
admin_client.get_user_agent_request(non_existent_id)
|
1020
|
+
|
1021
|
+
def test_unauthorized_access_fails(self, pub_client: RestClient):
|
1022
|
+
"""Ensures a client with insufficient permissions cannot perform actions."""
|
1023
|
+
# Using a public client that shouldn't have access
|
1024
|
+
with pytest.raises((UserAgentRequestCreationError, PermissionError)): # noqa: PT012
|
1025
|
+
payload = UserAgentRequestPostPayload(
|
1026
|
+
trajectory_id=uuid4(),
|
1027
|
+
request={"data": "test"},
|
1028
|
+
)
|
1029
|
+
pub_client.create_user_agent_request(payload)
|
1030
|
+
|
1031
|
+
with pytest.raises((UserAgentRequestFetchError, PermissionError)):
|
1032
|
+
# Attempt to fetch a request that the user doesn't own
|
1033
|
+
pub_client.get_user_agent_request(uuid4())
|
1034
|
+
|
1035
|
+
|
1036
|
+
class TestAsyncUserAgentRequestOperations:
|
1037
|
+
"""Test suite for asynchronous User Agent Request operations."""
|
1038
|
+
|
1039
|
+
@pytest.mark.asyncio
|
1040
|
+
async def test_async_expiring_e2e_user_agent_request_flow(
|
1041
|
+
self,
|
1042
|
+
admin_client: RestClient,
|
1043
|
+
running_trajectory_id: str,
|
1044
|
+
):
|
1045
|
+
"""Tests the full async lifecycle: acreate, aget, alist, and arespond."""
|
1046
|
+
payload = UserAgentRequestPostPayload(
|
1047
|
+
trajectory_id=running_trajectory_id,
|
1048
|
+
request={"question": "Async: Do you approve?"},
|
1049
|
+
user_response_task=TaskRequest(
|
1050
|
+
name=JobNames.from_string("dummy"),
|
1051
|
+
query="Why would I follow up on this query?",
|
1052
|
+
).model_dump(mode="json"),
|
1053
|
+
expires_in_seconds=10,
|
1054
|
+
)
|
1055
|
+
|
1056
|
+
request_id = await admin_client.acreate_user_agent_request(payload)
|
1057
|
+
assert isinstance(request_id, UUID)
|
1058
|
+
|
1059
|
+
retrieved_req = await admin_client.aget_user_agent_request(request_id)
|
1060
|
+
assert retrieved_req.id == request_id
|
1061
|
+
assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
|
1062
|
+
assert retrieved_req.status == UserAgentRequestStatus.PENDING
|
1063
|
+
|
1064
|
+
request_list = await admin_client.alist_user_agent_requests(
|
1065
|
+
trajectory_id=UUID(running_trajectory_id)
|
1066
|
+
)
|
1067
|
+
assert isinstance(request_list, list)
|
1068
|
+
assert any(req.id == request_id for req in request_list)
|
1069
|
+
|
1070
|
+
# ensure we allow it to expire so auto response can happen
|
1071
|
+
await asyncio.sleep(10)
|
1072
|
+
|
1073
|
+
# now this should be expired
|
1074
|
+
retrieved_req = await admin_client.aget_user_agent_request(request_id)
|
1075
|
+
assert retrieved_req.status == UserAgentRequestStatus.EXPIRED
|
1076
|
+
|
1077
|
+
# we should also see the job having started -- along with the registration of the job in the
|
1078
|
+
job_data = await admin_client.aget_task(
|
1079
|
+
cast(str, retrieved_req.response_trajectory_id)
|
1080
|
+
)
|
1081
|
+
assert job_data.status == "in progress"
|
1082
|
+
|
1083
|
+
# 4. RESPOND to the request -- ensure nothing changes
|
1084
|
+
ignored_response = {"answer": "Async Yes"}
|
1085
|
+
response_payload = UserAgentResponsePayload(response=ignored_response)
|
1086
|
+
await admin_client.arespond_to_user_agent_request(request_id, response_payload)
|
1087
|
+
|
1088
|
+
retrieved_req = await admin_client.aget_user_agent_request(request_id)
|
1089
|
+
assert retrieved_req.response != ignored_response
|
1090
|
+
|
1091
|
+
@pytest.mark.asyncio
|
1092
|
+
async def test_async_e2e_user_agent_request_flow(
|
1093
|
+
self,
|
1094
|
+
admin_client: RestClient,
|
1095
|
+
running_trajectory_id: str,
|
1096
|
+
):
|
1097
|
+
"""Tests the full async lifecycle: acreate, aget, alist, and arespond."""
|
1098
|
+
# 1. CREATE a request
|
1099
|
+
payload = UserAgentRequestPostPayload(
|
1100
|
+
trajectory_id=running_trajectory_id,
|
1101
|
+
request={"question": "Async: Do you approve?"},
|
1102
|
+
user_response_task=TaskRequest(
|
1103
|
+
name=JobNames.from_string("dummy"),
|
1104
|
+
query="Why would I follow up on this query?",
|
1105
|
+
).model_dump(mode="json"),
|
1106
|
+
)
|
1107
|
+
|
1108
|
+
request_id = await admin_client.acreate_user_agent_request(payload)
|
1109
|
+
assert isinstance(request_id, UUID)
|
1110
|
+
|
1111
|
+
# 2. GET the created request
|
1112
|
+
retrieved_req = await admin_client.aget_user_agent_request(request_id)
|
1113
|
+
assert retrieved_req.id == request_id
|
1114
|
+
assert str(retrieved_req.trajectory_id) == str(payload.trajectory_id)
|
1115
|
+
assert retrieved_req.status == UserAgentRequestStatus.PENDING
|
1116
|
+
|
1117
|
+
# 3. LIST requests and find the created one
|
1118
|
+
request_list = await admin_client.alist_user_agent_requests(
|
1119
|
+
trajectory_id=UUID(running_trajectory_id)
|
1120
|
+
)
|
1121
|
+
assert isinstance(request_list, list)
|
1122
|
+
assert any(req.id == request_id for req in request_list)
|
1123
|
+
|
1124
|
+
# 4. RESPOND to the request
|
1125
|
+
response_payload = UserAgentResponsePayload(response={"answer": "Async Yes"})
|
1126
|
+
await admin_client.arespond_to_user_agent_request(request_id, response_payload)
|
1127
|
+
|
1128
|
+
# 5. GET the request again to verify the response
|
1129
|
+
responded_req = await admin_client.aget_user_agent_request(request_id)
|
1130
|
+
assert responded_req.status == UserAgentRequestStatus.RESPONDED
|
1131
|
+
assert responded_req.response == response_payload.response
|
1132
|
+
|
1133
|
+
@pytest.mark.asyncio
|
1134
|
+
async def test_aget_nonexistent_request_fails(self, admin_client: RestClient):
|
1135
|
+
"""Verifies fetching a non-existent request asynchronously raises an error."""
|
1136
|
+
non_existent_id = uuid4()
|
1137
|
+
with pytest.raises(UserAgentRequestFetchError):
|
1138
|
+
await admin_client.aget_user_agent_request(non_existent_id)
|
1139
|
+
|
1140
|
+
@pytest.mark.asyncio
|
1141
|
+
async def test_async_unauthorized_access_fails(self, pub_client: RestClient):
|
1142
|
+
"""Ensures an unauthorized client fails on async methods."""
|
1143
|
+
with pytest.raises((UserAgentRequestCreationError, PermissionError)): # noqa: PT012
|
1144
|
+
payload = UserAgentRequestPostPayload(
|
1145
|
+
trajectory_id=uuid4(),
|
1146
|
+
request={"data": "test"},
|
1147
|
+
)
|
1148
|
+
await pub_client.acreate_user_agent_request(payload)
|
1149
|
+
|
1150
|
+
with pytest.raises((UserAgentRequestFetchError, PermissionError)):
|
1151
|
+
await pub_client.aget_user_agent_request(uuid4())
|
File without changes
|
File without changes
|
File without changes
|
{futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/docs/client_notebook.ipynb
RENAMED
File without changes
|
{futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{futurehouse_client-0.3.20.dev225 → futurehouse_client-0.3.20.dev266}/futurehouse_client/py.typed
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|