futurehouse-client 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ from .clients.job_client import JobClient, JobNames
2
+ from .clients.rest_client import JobResponse, JobResponseVerbose, PQAJobResponse
3
+ from .clients.rest_client import RestClient as Client
4
+
5
+ __all__ = [
6
+ "Client",
7
+ "JobClient",
8
+ "JobNames",
9
+ "JobResponse",
10
+ "JobResponseVerbose",
11
+ "PQAJobResponse",
12
+ ]
@@ -0,0 +1,12 @@
1
+ from .job_client import JobClient, JobNames
2
+ from .rest_client import JobResponse, JobResponseVerbose, PQAJobResponse
3
+ from .rest_client import RestClient as CrowClient
4
+
5
+ __all__ = [
6
+ "CrowClient",
7
+ "JobClient",
8
+ "JobNames",
9
+ "JobResponse",
10
+ "JobResponseVerbose",
11
+ "PQAJobResponse",
12
+ ]
@@ -0,0 +1,232 @@
1
+ import logging
2
+ from enum import StrEnum
3
+ from typing import ClassVar
4
+ from uuid import UUID, uuid4
5
+
6
+ import httpx
7
+ from aviary.env import Frame
8
+ from pydantic import BaseModel
9
+ from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
10
+
11
+ from futurehouse_client.models.app import Stage
12
+ from futurehouse_client.models.rest import (
13
+ FinalEnvironmentRequest,
14
+ StoreAgentStatePostRequest,
15
+ StoreEnvironmentFrameRequest,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class JobNames(StrEnum):
22
+ """Enum of available crow jobs."""
23
+
24
+ CROW = "job-futurehouse-paperqa2"
25
+ FALCON = "job-futurehouse-paperqa2-deep"
26
+ OWL = "job-futurehouse-hasanyone"
27
+ DUMMY = "job-futurehouse-dummy-env"
28
+
29
+ @classmethod
30
+ def from_stage(cls, job_name: str, stage: Stage | None = None) -> str:
31
+ if stage is None:
32
+ logger.warning(
33
+ "Stage is not provided, Stage.PROD as default stage. "
34
+ "Explicitly providing the stage is recommended."
35
+ )
36
+ stage = Stage.PROD
37
+ job_enum = cls.from_string(job_name)
38
+ return job_enum.value
39
+
40
+ @classmethod
41
+ def from_string(cls, job_name: str) -> "JobNames":
42
+ try:
43
+ return cls[job_name.upper()]
44
+ except KeyError as e:
45
+ raise ValueError(
46
+ f"Invalid job name: {job_name}. \nOptions are: {', '.join([name.name for name in cls])}"
47
+ ) from e
48
+
49
+
50
+ class JobClient:
51
+ REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
52
+ MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
53
+ RETRY_MULTIPLIER: ClassVar[int] = 1
54
+ MAX_RETRY_WAIT: ClassVar[int] = 10
55
+
56
+ def __init__(
57
+ self,
58
+ environment: str,
59
+ agent: str,
60
+ auth_token: str,
61
+ base_uri: str = Stage.LOCAL.value,
62
+ trajectory_id: str | UUID | None = None,
63
+ ):
64
+ self.base_uri = base_uri
65
+ self.agent = agent
66
+ self.environment = environment
67
+ self.oauth_jwt = auth_token
68
+ self.current_timestep = 0
69
+ self.current_step: str | None = None
70
+ try:
71
+ self.trajectory_id = self._cast_trajectory_id(trajectory_id)
72
+ logger.info(
73
+ f"Initialized JobClient for agent {agent} with trajectory_id {self.trajectory_id}",
74
+ )
75
+ except ValueError:
76
+ logger.exception("Failed to initialize JobClient")
77
+ raise
78
+
79
+ @staticmethod
80
+ def _cast_trajectory_id(provided_trajectory_id: str | UUID | None) -> str:
81
+ if provided_trajectory_id is None:
82
+ return str(uuid4())
83
+ if isinstance(provided_trajectory_id, str):
84
+ return provided_trajectory_id
85
+ if isinstance(provided_trajectory_id, UUID):
86
+ return str(provided_trajectory_id)
87
+ raise ValueError("Invalid trajectory ID provided")
88
+
89
+ async def finalize_environment(self, status: str) -> None:
90
+ data = FinalEnvironmentRequest(status=status)
91
+ try:
92
+ async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
93
+ response = await client.patch(
94
+ url=f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}/environment-frame",
95
+ json=data.model_dump(mode="json"),
96
+ headers={
97
+ "Authorization": f"Bearer {self.oauth_jwt}",
98
+ "x-trajectory-id": self.trajectory_id,
99
+ },
100
+ )
101
+ response.raise_for_status()
102
+ logger.debug(f"Environment updated with status {status}")
103
+ except httpx.HTTPStatusError:
104
+ logger.exception(
105
+ f"HTTP error while finalizing environment. "
106
+ f"Status code: {response.status_code}, "
107
+ f"Response: {response.text}",
108
+ )
109
+ except httpx.TimeoutException:
110
+ logger.exception(
111
+ f"Timeout while finalizing environment after {self.REQUEST_TIMEOUT}s",
112
+ )
113
+ raise
114
+ except httpx.NetworkError:
115
+ logger.exception("Network error while finalizing environment")
116
+ raise
117
+ except Exception:
118
+ logger.exception("Unexpected error while finalizing environment")
119
+ raise
120
+
121
+ @retry(
122
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
123
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
124
+ before_sleep=before_sleep_log(logger, logging.WARNING),
125
+ )
126
+ async def store_agent_state(self, step: str, state: BaseModel | dict) -> None:
127
+ """Store agent state with retry mechanism."""
128
+ self.current_step = step
129
+
130
+ state_data = (
131
+ state.model_dump(mode="json") if isinstance(state, BaseModel) else state
132
+ )
133
+
134
+ if state_data and state_data.get("transition"):
135
+ timestep = state_data.get("transition", {}).get("timestep")
136
+ if timestep is not None:
137
+ self.current_timestep = timestep
138
+
139
+ data = StoreAgentStatePostRequest(
140
+ agent_id=self.agent,
141
+ step=self.current_step,
142
+ state=state_data,
143
+ trajectory_timestep=self.current_timestep,
144
+ )
145
+
146
+ try:
147
+ async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
148
+ response = await client.post(
149
+ url=f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}/agent-state",
150
+ json=data.model_dump(mode="json"),
151
+ headers={
152
+ "Authorization": f"Bearer {self.oauth_jwt}",
153
+ "x-trajectory-id": self.trajectory_id,
154
+ },
155
+ )
156
+ response.raise_for_status()
157
+ logger.info(f"Successfully stored agent state for step {step}")
158
+ return response.json()
159
+ except httpx.HTTPStatusError:
160
+ logger.exception(
161
+ f"HTTP error storing agent state. "
162
+ f"Status code: {response.status_code}, "
163
+ f"Response: {response.text}",
164
+ )
165
+ except httpx.TimeoutException:
166
+ logger.exception(
167
+ f"Timeout while storing agent state after {self.REQUEST_TIMEOUT}s",
168
+ )
169
+ raise
170
+ except httpx.NetworkError:
171
+ logger.exception("Network error while storing agent state")
172
+ raise
173
+ except Exception:
174
+ logger.exception(f"Unexpected error storing agent state for step {step}")
175
+ raise
176
+
177
+ @retry(
178
+ stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
179
+ wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
180
+ before_sleep=before_sleep_log(logger, logging.WARNING),
181
+ )
182
+ async def store_environment_frame(self, state: Frame) -> None:
183
+ """Store environment frame with retry mechanism."""
184
+ state_identifier = None
185
+ if self.current_step is not None:
186
+ state_identifier = (
187
+ f"{self.agent}-{self.current_step}-{self.current_timestep}"
188
+ )
189
+
190
+ logger.debug(f"Storing environment frame for state {state_identifier}")
191
+
192
+ data = StoreEnvironmentFrameRequest(
193
+ agent_state_point_in_time=state_identifier,
194
+ current_agent_step=self.current_step,
195
+ state=state.model_dump(mode="json"),
196
+ trajectory_timestep=self.current_timestep,
197
+ )
198
+
199
+ try:
200
+ async with httpx.AsyncClient(timeout=self.REQUEST_TIMEOUT) as client:
201
+ response = await client.post(
202
+ url=f"{self.base_uri}/v0.1/trajectories/{self.trajectory_id}/environment-frame",
203
+ json=data.model_dump(mode="json"),
204
+ headers={
205
+ "Authorization": f"Bearer {self.oauth_jwt}",
206
+ "x-trajectory-id": self.trajectory_id,
207
+ },
208
+ )
209
+ response.raise_for_status()
210
+ logger.debug(
211
+ f"Successfully stored environment frame for state {state_identifier}",
212
+ )
213
+ return response.json()
214
+ except httpx.HTTPStatusError:
215
+ logger.exception(
216
+ f"HTTP error storing environment frame. "
217
+ f"Status code: {response.status_code}, "
218
+ f"Response: {response.text}",
219
+ )
220
+ except httpx.TimeoutException:
221
+ logger.exception(
222
+ f"Timeout while storing environment frame after {self.REQUEST_TIMEOUT}s",
223
+ )
224
+ raise
225
+ except httpx.NetworkError:
226
+ logger.exception("Network error while storing environment frame")
227
+ raise
228
+ except Exception:
229
+ logger.exception(
230
+ f"Unexpected error storing environment frame for state {state_identifier}",
231
+ )
232
+ raise