futurehouse-client 0.3.17.dev94__py3-none-any.whl → 0.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,6 +30,7 @@ class JobNames(StrEnum):
30
30
  OWL = "job-futurehouse-hasanyone"
31
31
  DUMMY = "job-futurehouse-dummy-env"
32
32
  PHOENIX = "job-futurehouse-phoenix"
33
+ FINCH = "job-futurehouse-data-analysis-crow-high"
33
34
 
34
35
  @classmethod
35
36
  def from_stage(cls, job_name: str, stage: Stage | None = None) -> str:
@@ -8,14 +8,14 @@ import inspect
8
8
  import json
9
9
  import logging
10
10
  import os
11
+ import sys
11
12
  import tempfile
12
13
  import time
13
14
  import uuid
14
- from collections.abc import Collection, Mapping
15
- from datetime import datetime
15
+ from collections.abc import Collection
16
16
  from pathlib import Path
17
17
  from types import ModuleType
18
- from typing import Any, ClassVar, assert_never, cast
18
+ from typing import Any, ClassVar, cast
19
19
  from uuid import UUID
20
20
 
21
21
  import cloudpickle
@@ -33,7 +33,6 @@ from httpx import (
33
33
  RemoteProtocolError,
34
34
  )
35
35
  from ldp.agent import AgentConfig
36
- from pydantic import BaseModel, ConfigDict, model_validator
37
36
  from requests.exceptions import RequestException, Timeout
38
37
  from tenacity import (
39
38
  retry,
@@ -46,13 +45,16 @@ from tqdm.asyncio import tqdm
46
45
 
47
46
  from futurehouse_client.clients import JobNames
48
47
  from futurehouse_client.models.app import (
49
- APIKeyPayload,
50
48
  AuthType,
51
49
  JobDeploymentConfig,
50
+ PQATaskResponse,
52
51
  Stage,
53
52
  TaskRequest,
53
+ TaskResponse,
54
+ TaskResponseVerbose,
54
55
  )
55
56
  from futurehouse_client.models.rest import ExecutionStatus
57
+ from futurehouse_client.utils.auth import RefreshingJWT
56
58
  from futurehouse_client.utils.general import gather_with_concurrency
57
59
  from futurehouse_client.utils.module_utils import (
58
60
  OrganizationSelector,
@@ -63,24 +65,14 @@ from futurehouse_client.utils.monitoring import (
63
65
  )
64
66
 
65
67
  logger = logging.getLogger(__name__)
66
-
68
+ logging.basicConfig(
69
+ level=logging.WARNING,
70
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
71
+ stream=sys.stdout,
72
+ )
73
+ logging.getLogger("httpx").setLevel(logging.WARNING)
67
74
  TaskRequest.model_rebuild()
68
75
 
69
- retry_if_connection_error = retry_if_exception_type((
70
- # From requests
71
- Timeout,
72
- ConnectionError,
73
- RequestException,
74
- # From httpx
75
- ConnectError,
76
- ConnectTimeout,
77
- ReadTimeout,
78
- ReadError,
79
- NetworkError,
80
- RemoteProtocolError,
81
- CloseError,
82
- ))
83
-
84
76
  FILE_UPLOAD_IGNORE_PARTS = {
85
77
  ".ruff_cache",
86
78
  "__pycache__",
@@ -111,104 +103,27 @@ class InvalidTaskDescriptionError(Exception):
111
103
  """Raised when the task description is invalid or empty."""
112
104
 
113
105
 
114
- class SimpleOrganization(BaseModel):
115
- id: int
116
- name: str
117
- display_name: str
118
-
119
-
120
- # 5 minute default for JWTs
121
- JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
122
- DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
123
-
106
+ class FileUploadError(RestClientError):
107
+ """Raised when there's an error uploading a file."""
124
108
 
125
- class TaskResponse(BaseModel):
126
- """Base class for task responses. This holds attributes shared over all futurehouse jobs."""
127
-
128
- model_config = ConfigDict(extra="ignore")
129
-
130
- status: str
131
- query: str
132
- user: str | None = None
133
- created_at: datetime
134
- job_name: str
135
- public: bool
136
- shared_with: list[SimpleOrganization] | None = None
137
- build_owner: str | None = None
138
- environment_name: str | None = None
139
- agent_name: str | None = None
140
- task_id: UUID | None = None
141
-
142
- @model_validator(mode="before")
143
- @classmethod
144
- def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
145
- # Extract fields from environment frame state
146
- if not isinstance(data, dict):
147
- return data
148
- # TODO: We probably want to remove these two once we define the final names.
149
- data["job_name"] = data.get("crow")
150
- data["query"] = data.get("task")
151
- data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
152
- if not (metadata := data.get("metadata", {})):
153
- return data
154
- data["environment_name"] = metadata.get("environment_name")
155
- data["agent_name"] = metadata.get("agent_name")
156
- return data
157
-
158
-
159
- class PQATaskResponse(TaskResponse):
160
- model_config = ConfigDict(extra="ignore")
161
-
162
- answer: str | None = None
163
- formatted_answer: str | None = None
164
- answer_reasoning: str | None = None
165
- has_successful_answer: bool | None = None
166
- total_cost: float | None = None
167
- total_queries: int | None = None
168
-
169
- @model_validator(mode="before")
170
- @classmethod
171
- def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
172
- if not isinstance(data, dict):
173
- return data
174
- if not (env_frame := data.get("environment_frame", {})):
175
- return data
176
- state = env_frame.get("state", {}).get("state", {})
177
- response = state.get("response", {})
178
- answer = response.get("answer", {})
179
- usage = state.get("info", {}).get("usage", {})
180
-
181
- # Add additional PQA specific fields to data so that pydantic can validate the model
182
- data["answer"] = answer.get("answer")
183
- data["formatted_answer"] = answer.get("formatted_answer")
184
- data["answer_reasoning"] = answer.get("answer_reasoning")
185
- data["has_successful_answer"] = answer.get("has_successful_answer")
186
- data["total_cost"] = cast(float, usage.get("total_cost"))
187
- data["total_queries"] = cast(int, usage.get("total_queries"))
188
-
189
- return data
190
-
191
- def clean_verbose(self) -> "TaskResponse":
192
- """Clean the verbose response from the server."""
193
- self.request = None
194
- self.response = None
195
- return self
196
-
197
-
198
- class TaskResponseVerbose(TaskResponse):
199
- """Class for responses to include all the fields of a task response."""
200
-
201
- model_config = ConfigDict(extra="allow")
202
-
203
- public: bool
204
- agent_state: list[dict[str, Any]] | None = None
205
- environment_frame: dict[str, Any] | None = None
206
- metadata: dict[str, Any] | None = None
207
- shared_with: list[SimpleOrganization] | None = None
208
109
 
110
+ retry_if_connection_error = retry_if_exception_type((
111
+ # From requests
112
+ Timeout,
113
+ ConnectionError,
114
+ RequestException,
115
+ # From httpx
116
+ ConnectError,
117
+ ConnectTimeout,
118
+ ReadTimeout,
119
+ ReadError,
120
+ NetworkError,
121
+ RemoteProtocolError,
122
+ CloseError,
123
+ FileUploadError,
124
+ ))
209
125
 
210
- class FileUploadError(RestClientError):
211
- """Raised when there's an error uploading a file."""
126
+ DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
212
127
 
213
128
 
214
129
  class RestClient:
@@ -228,76 +143,98 @@ class RestClient:
228
143
  api_key: str | None = None,
229
144
  jwt: str | None = None,
230
145
  headers: dict[str, str] | None = None,
146
+ verbose_logging: bool = False,
231
147
  ):
148
+ if verbose_logging:
149
+ logger.setLevel(logging.INFO)
150
+ else:
151
+ logger.setLevel(logging.WARNING)
152
+
232
153
  self.base_url = service_uri or stage.value
233
154
  self.stage = stage
234
155
  self.auth_type = auth_type
235
156
  self.api_key = api_key
236
157
  self._clients: dict[str, Client | AsyncClient] = {}
237
158
  self.headers = headers or {}
238
- self.auth_jwt = self._run_auth(jwt=jwt)
159
+ self.jwt = jwt
239
160
  self.organizations: list[str] = self._filter_orgs(organization)
240
161
 
241
162
  @property
242
163
  def client(self) -> Client:
243
- """Lazily initialized and cached HTTP client with authentication."""
244
- return cast(Client, self.get_client("application/json", with_auth=True))
164
+ """Authenticated HTTP client for regular API calls."""
165
+ return cast(Client, self.get_client("application/json", authenticated=True))
245
166
 
246
167
  @property
247
168
  def async_client(self) -> AsyncClient:
248
- """Lazily initialized and cached HTTP client with authentication."""
169
+ """Authenticated async HTTP client for regular API calls."""
249
170
  return cast(
250
171
  AsyncClient,
251
- self.get_client("application/json", with_auth=True, with_async=True),
172
+ self.get_client("application/json", authenticated=True, async_client=True),
252
173
  )
253
174
 
254
175
  @property
255
- def auth_client(self) -> Client:
256
- """Lazily initialized and cached HTTP client without authentication."""
257
- return cast(Client, self.get_client("application/json", with_auth=False))
176
+ def unauthenticated_client(self) -> Client:
177
+ """Unauthenticated HTTP client for auth operations."""
178
+ return cast(Client, self.get_client("application/json", authenticated=False))
258
179
 
259
180
  @property
260
181
  def multipart_client(self) -> Client:
261
- """Lazily initialized and cached HTTP client for multipart uploads."""
262
- return cast(Client, self.get_client(None, with_auth=True))
182
+ """Authenticated HTTP client for multipart uploads."""
183
+ return cast(Client, self.get_client(None, authenticated=True))
263
184
 
264
185
  def get_client(
265
186
  self,
266
187
  content_type: str | None = "application/json",
267
- with_auth: bool = True,
268
- with_async: bool = False,
188
+ authenticated: bool = True,
189
+ async_client: bool = False,
269
190
  ) -> Client | AsyncClient:
270
191
  """Return a cached HTTP client or create one if needed.
271
192
 
272
193
  Args:
273
194
  content_type: The desired content type header. Use None for multipart uploads.
274
- with_auth: Whether the client should include an Authorization header.
275
- with_async: Whether to use an async client.
195
+ authenticated: Whether the client should include authentication.
196
+ async_client: Whether to use an async client.
276
197
 
277
198
  Returns:
278
199
  An HTTP client configured with the appropriate headers.
279
200
  """
280
- # Create a composite key based on content type and auth flag.
281
- key = f"{content_type or 'multipart'}_{with_auth}_{with_async}"
201
+ # Create a composite key based on content type and auth flag
202
+ key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
203
+
282
204
  if key not in self._clients:
283
205
  headers = copy.deepcopy(self.headers)
284
- if with_auth:
285
- headers["Authorization"] = f"Bearer {self.auth_jwt}"
206
+ auth = None
207
+
208
+ if authenticated:
209
+ auth = RefreshingJWT(
210
+ # authenticated=False will always return a synchronous client
211
+ auth_client=cast(
212
+ Client, self.get_client("application/json", authenticated=False)
213
+ ),
214
+ auth_type=self.auth_type,
215
+ api_key=self.api_key,
216
+ jwt=self.jwt,
217
+ )
218
+
286
219
  if content_type:
287
220
  headers["Content-Type"] = content_type
221
+
288
222
  self._clients[key] = (
289
223
  AsyncClient(
290
224
  base_url=self.base_url,
291
225
  headers=headers,
292
226
  timeout=self.REQUEST_TIMEOUT,
227
+ auth=auth,
293
228
  )
294
- if with_async
229
+ if async_client
295
230
  else Client(
296
231
  base_url=self.base_url,
297
232
  headers=headers,
298
233
  timeout=self.REQUEST_TIMEOUT,
234
+ auth=auth,
299
235
  )
300
236
  )
237
+
301
238
  return self._clients[key]
302
239
 
303
240
  def close(self):
@@ -327,31 +264,6 @@ class RestClient:
327
264
  raise ValueError(f"Organization '{organization}' not found.")
328
265
  return filtered_orgs
329
266
 
330
- def _run_auth(self, jwt: str | None = None) -> str:
331
- auth_payload: APIKeyPayload | None
332
- if self.auth_type == AuthType.API_KEY:
333
- auth_payload = APIKeyPayload(api_key=self.api_key)
334
- elif self.auth_type == AuthType.JWT:
335
- auth_payload = None
336
- else:
337
- assert_never(self.auth_type)
338
- try:
339
- # Use the unauthenticated client for login
340
- if auth_payload:
341
- response = self.auth_client.post(
342
- "/auth/login", json=auth_payload.model_dump()
343
- )
344
- response.raise_for_status()
345
- token_data = response.json()
346
- elif jwt:
347
- token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
348
- else:
349
- raise ValueError("JWT token required for JWT authentication.")
350
-
351
- return token_data["access_token"]
352
- except Exception as e:
353
- raise RestClientError(f"Error authenticating: {e!s}") from e
354
-
355
267
  def _check_job(self, name: str, organization: str) -> dict[str, Any]:
356
268
  try:
357
269
  response = self.client.get(
@@ -445,6 +357,7 @@ class RestClient:
445
357
  ),
446
358
  self.client.stream("GET", url, params={"history": history}) as response,
447
359
  ):
360
+ response.raise_for_status()
448
361
  json_data = "".join(response.iter_text(chunk_size=1024))
449
362
  data = json.loads(json_data)
450
363
  if "id" not in data:
@@ -459,8 +372,6 @@ class RestClient:
459
372
  ):
460
373
  return PQATaskResponse(**data)
461
374
  return TaskResponse(**data)
462
- except ValueError as e:
463
- raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
464
375
  except Exception as e:
465
376
  raise TaskFetchError(f"Error getting task: {e!s}") from e
466
377
 
@@ -507,8 +418,6 @@ class RestClient:
507
418
  ):
508
419
  return PQATaskResponse(**data)
509
420
  return TaskResponse(**data)
510
- except ValueError as e:
511
- raise ValueError("Invalid task ID format. Must be a valid UUID.") from e
512
421
  except Exception as e:
513
422
  raise TaskFetchError(f"Error getting task: {e!s}") from e
514
423
 
@@ -714,9 +623,12 @@ class RestClient:
714
623
  )
715
624
  def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
716
625
  """Get the status of a build."""
717
- build_id = build_id or self.build_id
718
- response = self.client.get(f"/v0.1/builds/{build_id}")
719
- response.raise_for_status()
626
+ try:
627
+ build_id = build_id or self.build_id
628
+ response = self.client.get(f"/v0.1/builds/{build_id}")
629
+ response.raise_for_status()
630
+ except Exception as e:
631
+ raise JobFetchError(f"Error getting build status: {e!s}") from e
720
632
  return response.json()
721
633
 
722
634
  # TODO: Refactor later so we don't have to ignore PLR0915
@@ -917,14 +829,14 @@ class RestClient:
917
829
  self,
918
830
  job_name: str,
919
831
  file_path: str | os.PathLike,
920
- folder_name: str | None = None,
832
+ upload_id: str | None = None,
921
833
  ) -> str:
922
834
  """Upload a file or directory to a futurehouse job bucket.
923
835
 
924
836
  Args:
925
837
  job_name: The name of the futurehouse job to upload to.
926
838
  file_path: The local path to the file or directory to upload.
927
- folder_name: Optional folder name to use for the upload. If not provided, a random UUID will be used.
839
+ upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
928
840
 
929
841
  Returns:
930
842
  The upload ID used for the upload.
@@ -936,7 +848,7 @@ class RestClient:
936
848
  if not file_path.exists():
937
849
  raise FileNotFoundError(f"File or directory not found: {file_path}")
938
850
 
939
- upload_id = folder_name or str(uuid.uuid4())
851
+ upload_id = upload_id or str(uuid.uuid4())
940
852
 
941
853
  if file_path.is_dir():
942
854
  # Process directory recursively
@@ -999,6 +911,12 @@ class RestClient:
999
911
  """
1000
912
  file_name = file_name or file_path.name
1001
913
  file_size = file_path.stat().st_size
914
+
915
+ # Skip empty files
916
+ if file_size == 0:
917
+ logger.warning(f"Skipping upload of empty file: {file_path}")
918
+ return
919
+
1002
920
  total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
1003
921
 
1004
922
  logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
@@ -1046,7 +964,6 @@ class RestClient:
1046
964
  )
1047
965
 
1048
966
  logger.info(f"Successfully uploaded {file_name}")
1049
-
1050
967
  except Exception as e:
1051
968
  logger.exception(f"Error uploading file {file_path}")
1052
969
  raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
@@ -1056,12 +973,18 @@ class RestClient:
1056
973
  wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
1057
974
  retry=retry_if_connection_error,
1058
975
  )
1059
- def list_files(self, job_name: str, folder_name: str) -> dict[str, list[str]]:
976
+ def list_files(
977
+ self,
978
+ job_name: str,
979
+ trajectory_id: str | None = None,
980
+ upload_id: str | None = None,
981
+ ) -> dict[str, list[str]]:
1060
982
  """List files and directories in a GCS location for a given job_name and upload_id.
1061
983
 
1062
984
  Args:
1063
985
  job_name: The name of the futurehouse job.
1064
- folder_name: The specific folder name (upload_id) to list files from.
986
+ trajectory_id: The specific trajectory id to list files from.
987
+ upload_id: The specific upload id to list files from.
1065
988
 
1066
989
  Returns:
1067
990
  A list of files in the GCS folder.
@@ -1069,22 +992,27 @@ class RestClient:
1069
992
  Raises:
1070
993
  RestClientError: If there is an error listing the files.
1071
994
  """
995
+ if not bool(trajectory_id) ^ bool(upload_id):
996
+ raise RestClientError(
997
+ "Must at least specify one of trajectory_id or upload_id, but not both"
998
+ )
1072
999
  try:
1073
1000
  url = f"/v0.1/crows/{job_name}/list-files"
1074
- params = {"upload_id": folder_name}
1001
+ params = {"trajectory_id": trajectory_id, "upload_id": upload_id}
1002
+ params = {k: v for k, v in params.items() if v is not None}
1075
1003
  response = self.client.get(url, params=params)
1076
1004
  response.raise_for_status()
1077
1005
  return response.json()
1078
1006
  except HTTPStatusError as e:
1079
1007
  logger.exception(
1080
- f"Error listing files for job {job_name}, folder {folder_name}: {e.response.text}"
1008
+ f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
1081
1009
  )
1082
1010
  raise RestClientError(
1083
1011
  f"Error listing files: {e.response.status_code} - {e.response.text}"
1084
1012
  ) from e
1085
1013
  except Exception as e:
1086
1014
  logger.exception(
1087
- f"Error listing files for job {job_name}, folder {folder_name}"
1015
+ f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}"
1088
1016
  )
1089
1017
  raise RestClientError(f"Error listing files: {e!s}") from e
1090
1018
 
@@ -1096,7 +1024,7 @@ class RestClient:
1096
1024
  def download_file(
1097
1025
  self,
1098
1026
  job_name: str,
1099
- folder_name: str,
1027
+ trajectory_id: str,
1100
1028
  file_path: str,
1101
1029
  destination_path: str | os.PathLike,
1102
1030
  ) -> None:
@@ -1104,14 +1032,14 @@ class RestClient:
1104
1032
 
1105
1033
  Args:
1106
1034
  job_name: The name of the futurehouse job.
1107
- folder_name: The specific folder name (upload_id) the file belongs to.
1035
+ trajectory_id: The specific trajectory id the file belongs to.
1108
1036
  file_path: The relative path of the file to download
1109
1037
  (e.g., 'data/my_file.csv' or 'my_image.png').
1110
1038
  destination_path: The local path where the file should be saved.
1111
1039
 
1112
1040
  Raises:
1113
1041
  RestClientError: If there is an error downloading the file.
1114
- FileNotFoundError: If the destination directory does not exist.
1042
+ FileNotFoundError: If the destination directory does not exist or if the file is not found.
1115
1043
  """
1116
1044
  destination_path = Path(destination_path)
1117
1045
  # Ensure the destination directory exists
@@ -1119,17 +1047,24 @@ class RestClient:
1119
1047
 
1120
1048
  try:
1121
1049
  url = f"/v0.1/crows/{job_name}/download-file"
1122
- params = {"upload_id": folder_name, "file_path": file_path}
1050
+ params = {"trajectory_id": trajectory_id, "file_path": file_path}
1123
1051
 
1124
1052
  with self.client.stream("GET", url, params=params) as response:
1125
1053
  response.raise_for_status() # Check for HTTP errors before streaming
1126
1054
  with open(destination_path, "wb") as f:
1127
1055
  for chunk in response.iter_bytes(chunk_size=8192):
1128
1056
  f.write(chunk)
1057
+
1058
+ # Check if the downloaded file is empty
1059
+ if destination_path.stat().st_size == 0:
1060
+ # Remove the empty file
1061
+ destination_path.unlink()
1062
+ raise FileNotFoundError(f"File not found or is empty: {file_path}")
1063
+
1129
1064
  logger.info(f"File {file_path} downloaded to {destination_path}")
1130
1065
  except HTTPStatusError as e:
1131
1066
  logger.exception(
1132
- f"Error downloading file {file_path} for job {job_name}, folder {folder_name}: {e.response.text}"
1067
+ f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
1133
1068
  )
1134
1069
  # Clean up partially downloaded file if an error occurs
1135
1070
  if destination_path.exists():
@@ -1137,9 +1072,20 @@ class RestClient:
1137
1072
  raise RestClientError(
1138
1073
  f"Error downloading file: {e.response.status_code} - {e.response.text}"
1139
1074
  ) from e
1075
+ except RemoteProtocolError as e:
1076
+ logger.error(
1077
+ f"Connection error while downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
1078
+ )
1079
+ # Clean up partially downloaded file
1080
+ if destination_path.exists():
1081
+ destination_path.unlink()
1082
+
1083
+ # Often RemoteProtocolError during download means the file wasn't found
1084
+ # or was empty/corrupted on the server side
1085
+ raise FileNotFoundError(f"File not found or corrupted: {file_path}") from e
1140
1086
  except Exception as e:
1141
1087
  logger.exception(
1142
- f"Error downloading file {file_path} for job {job_name}, folder {folder_name}"
1088
+ f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}"
1143
1089
  )
1144
1090
  if destination_path.exists():
1145
1091
  destination_path.unlink() # Clean up partial file
@@ -3,10 +3,15 @@ from .app import (
3
3
  DockerContainerConfiguration,
4
4
  FramePath,
5
5
  JobDeploymentConfig,
6
+ PQATaskResponse,
6
7
  RuntimeConfig,
7
8
  Stage,
8
9
  Step,
10
+ TaskQueue,
11
+ TaskQueuesConfig,
9
12
  TaskRequest,
13
+ TaskResponse,
14
+ TaskResponseVerbose,
10
15
  )
11
16
 
12
17
  __all__ = [
@@ -14,8 +19,13 @@ __all__ = [
14
19
  "DockerContainerConfiguration",
15
20
  "FramePath",
16
21
  "JobDeploymentConfig",
22
+ "PQATaskResponse",
17
23
  "RuntimeConfig",
18
24
  "Stage",
19
25
  "Step",
26
+ "TaskQueue",
27
+ "TaskQueuesConfig",
20
28
  "TaskRequest",
29
+ "TaskResponse",
30
+ "TaskResponseVerbose",
21
31
  ]
@@ -1,6 +1,9 @@
1
+ import copy
1
2
  import json
2
3
  import os
3
4
  import re
5
+ from collections.abc import Mapping
6
+ from datetime import datetime
4
7
  from enum import StrEnum, auto
5
8
  from pathlib import Path
6
9
  from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
@@ -646,3 +649,96 @@ class TaskRequest(BaseModel):
646
649
  runtime_config: RuntimeConfig | None = Field(
647
650
  default=None, description="All optional runtime parameters for the job"
648
651
  )
652
+
653
+
654
+ class SimpleOrganization(BaseModel):
655
+ id: int
656
+ name: str
657
+ display_name: str
658
+
659
+
660
+ class TaskResponse(BaseModel):
661
+ """Base class for task responses. This holds attributes shared over all futurehouse jobs."""
662
+
663
+ model_config = ConfigDict(extra="ignore")
664
+
665
+ status: str
666
+ query: str
667
+ user: str | None = None
668
+ created_at: datetime
669
+ job_name: str
670
+ public: bool
671
+ shared_with: list[SimpleOrganization] | None = None
672
+ build_owner: str | None = None
673
+ environment_name: str | None = None
674
+ agent_name: str | None = None
675
+ task_id: UUID | None = None
676
+
677
+ @model_validator(mode="before")
678
+ @classmethod
679
+ def validate_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
680
+ data = copy.deepcopy(original_data) # Avoid mutating the original data
681
+ # Extract fields from environment frame state
682
+ if not isinstance(data, dict):
683
+ return data
684
+ # TODO: We probably want to remove these two once we define the final names.
685
+ data["job_name"] = data.get("crow")
686
+ data["query"] = data.get("task")
687
+ data["task_id"] = cast(UUID, data.get("id")) if data.get("id") else None
688
+ if not (metadata := data.get("metadata", {})):
689
+ return data
690
+ data["environment_name"] = metadata.get("environment_name")
691
+ data["agent_name"] = metadata.get("agent_name")
692
+ return data
693
+
694
+
695
+ class PQATaskResponse(TaskResponse):
696
+ model_config = ConfigDict(extra="ignore")
697
+
698
+ answer: str | None = None
699
+ formatted_answer: str | None = None
700
+ answer_reasoning: str | None = None
701
+ has_successful_answer: bool | None = None
702
+ total_cost: float | None = None
703
+ total_queries: int | None = None
704
+
705
+ @model_validator(mode="before")
706
+ @classmethod
707
+ def validate_pqa_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
708
+ data = copy.deepcopy(original_data) # Avoid mutating the original data
709
+ if not isinstance(data, dict):
710
+ return data
711
+ if not (env_frame := data.get("environment_frame", {})):
712
+ return data
713
+ state = env_frame.get("state", {}).get("state", {})
714
+ response = state.get("response", {})
715
+ answer = response.get("answer", {})
716
+ usage = state.get("info", {}).get("usage", {})
717
+
718
+ # Add additional PQA specific fields to data so that pydantic can validate the model
719
+ data["answer"] = answer.get("answer")
720
+ data["formatted_answer"] = answer.get("formatted_answer")
721
+ data["answer_reasoning"] = answer.get("answer_reasoning")
722
+ data["has_successful_answer"] = answer.get("has_successful_answer")
723
+ data["total_cost"] = cast(float, usage.get("total_cost"))
724
+ data["total_queries"] = cast(int, usage.get("total_queries"))
725
+
726
+ return data
727
+
728
+ def clean_verbose(self) -> "TaskResponse":
729
+ """Clean the verbose response from the server."""
730
+ self.request = None
731
+ self.response = None
732
+ return self
733
+
734
+
735
+ class TaskResponseVerbose(TaskResponse):
736
+ """Class for responses to include all the fields of a task response."""
737
+
738
+ model_config = ConfigDict(extra="allow")
739
+
740
+ public: bool
741
+ agent_state: list[dict[str, Any]] | None = None
742
+ environment_frame: dict[str, Any] | None = None
743
+ metadata: dict[str, Any] | None = None
744
+ shared_with: list[SimpleOrganization] | None = None
@@ -0,0 +1,92 @@
1
+ import logging
2
+ from collections.abc import Collection, Generator
3
+ from typing import ClassVar, Final
4
+
5
+ import httpx
6
+
7
+ from futurehouse_client.models.app import APIKeyPayload, AuthType
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ INVALID_REFRESH_TYPE_MSG: Final[str] = (
12
+ "API key auth is required to refresh auth tokens."
13
+ )
14
+ JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
15
+
16
+
17
+ def _run_auth(
18
+ client: httpx.Client,
19
+ auth_type: AuthType = AuthType.API_KEY,
20
+ api_key: str | None = None,
21
+ jwt: str | None = None,
22
+ ) -> str:
23
+ auth_payload: APIKeyPayload | None
24
+ if auth_type == AuthType.API_KEY:
25
+ auth_payload = APIKeyPayload(api_key=api_key)
26
+ elif auth_type == AuthType.JWT:
27
+ auth_payload = None
28
+ try:
29
+ if auth_payload:
30
+ response = client.post("/auth/login", json=auth_payload.model_dump())
31
+ response.raise_for_status()
32
+ token_data = response.json()
33
+ elif jwt:
34
+ token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
35
+ else:
36
+ raise ValueError("JWT token required for JWT authentication.")
37
+
38
+ return token_data["access_token"]
39
+ except Exception as e:
40
+ raise Exception("Failed to authenticate") from e # noqa: TRY002
41
+
42
+
43
+ class RefreshingJWT(httpx.Auth):
44
+ """Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
45
+
46
+ RETRY_STATUSES: ClassVar[Collection[httpx.codes]] = {
47
+ httpx.codes.UNAUTHORIZED,
48
+ httpx.codes.FORBIDDEN,
49
+ }
50
+
51
+ def __init__(
52
+ self,
53
+ auth_client: httpx.Client,
54
+ auth_type: AuthType = AuthType.API_KEY,
55
+ api_key: str | None = None,
56
+ jwt: str | None = None,
57
+ ):
58
+ self.auth_type = auth_type
59
+ self.auth_client = auth_client
60
+ self.api_key = api_key
61
+ self._jwt = _run_auth(
62
+ client=auth_client,
63
+ jwt=jwt,
64
+ auth_type=auth_type,
65
+ api_key=api_key,
66
+ )
67
+
68
+ def refresh_token(self) -> None:
69
+ if self.auth_type == AuthType.JWT:
70
+ logger.error(INVALID_REFRESH_TYPE_MSG)
71
+ raise ValueError(INVALID_REFRESH_TYPE_MSG)
72
+ self._jwt = _run_auth(
73
+ client=self.auth_client,
74
+ auth_type=self.auth_type,
75
+ api_key=self.api_key,
76
+ )
77
+
78
+ def auth_flow(
79
+ self, request: httpx.Request
80
+ ) -> Generator[httpx.Request, httpx.Response, None]:
81
+ request.headers["Authorization"] = f"Bearer {self._jwt}"
82
+ response = yield request
83
+
84
+ # If it failed, refresh once and replay the request
85
+ if response.status_code in self.RETRY_STATUSES:
86
+ logger.info(
87
+ "Received %s, refreshing token and retrying …",
88
+ response.status_code,
89
+ )
90
+ self.refresh_token()
91
+ request.headers["Authorization"] = f"Bearer {self._jwt}"
92
+ yield request # second (and final) attempt, again or use a while loop
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: futurehouse-client
3
- Version: 0.3.17.dev94
3
+ Version: 0.3.18
4
4
  Summary: A client for interacting with endpoints of the FutureHouse service.
5
5
  Author-email: FutureHouse technical staff <hello@futurehouse.org>
6
6
  Classifier: Operating System :: OS Independent
@@ -0,0 +1,17 @@
1
+ futurehouse_client/__init__.py,sha256=ddxO7JE97c6bt7LjNglZZ2Ql8bYCGI9laSFeh9MP6VU,344
2
+ futurehouse_client/clients/__init__.py,sha256=tFWqwIAY5PvwfOVsCje4imjTpf6xXNRMh_UHIKVI1_0,320
3
+ futurehouse_client/clients/job_client.py,sha256=uNkqQbeZw7wbA0qDWcIOwOykrosza-jev58paJZ_mbA,11150
4
+ futurehouse_client/clients/rest_client.py,sha256=CwgyjYj-i6U0aVXb1GkxP6KSxR5tVrlXIDE9WIHYtds,43435
5
+ futurehouse_client/models/__init__.py,sha256=5x-f9AoM1hGzJBEHcHAXSt7tPeImST5oZLuMdwp0mXc,554
6
+ futurehouse_client/models/app.py,sha256=24lCnpOtxQNzvcc1HJyP62wfertvTuZ30sX_vrBCUnk,26611
7
+ futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
8
+ futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
9
+ futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ futurehouse_client/utils/auth.py,sha256=tgWELjKfg8eWme_qdcRmc8TjQN9DVZuHHaVXZNHLchk,2960
11
+ futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
12
+ futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
13
+ futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
14
+ futurehouse_client-0.3.18.dist-info/METADATA,sha256=AJVzzXq77PTe-Hg31YXy-KWFAb5IXiZhyC08YlbRec4,12760
15
+ futurehouse_client-0.3.18.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
16
+ futurehouse_client-0.3.18.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
17
+ futurehouse_client-0.3.18.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,16 +0,0 @@
1
- futurehouse_client/__init__.py,sha256=ddxO7JE97c6bt7LjNglZZ2Ql8bYCGI9laSFeh9MP6VU,344
2
- futurehouse_client/clients/__init__.py,sha256=tFWqwIAY5PvwfOVsCje4imjTpf6xXNRMh_UHIKVI1_0,320
3
- futurehouse_client/clients/job_client.py,sha256=Fi3YvN4k82AuXCe8vlwxhkK8CXS164NQrs7paj9qIek,11096
4
- futurehouse_client/clients/rest_client.py,sha256=dsUmpgV5sfyb4GDv6whWVwRN1z2LOfZsPF8vjoioNfY,45472
5
- futurehouse_client/models/__init__.py,sha256=ta3jFLM_LsDz1rKDmx8rja8sT7WtSKoFvMgLF0yFpvA,342
6
- futurehouse_client/models/app.py,sha256=yfZ9tyw4VATVAfYrU7aTdCNPSljLEho09_nIbh8oZDY,23174
7
- futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
8
- futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
9
- futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
11
- futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
12
- futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
13
- futurehouse_client-0.3.17.dev94.dist-info/METADATA,sha256=acLPon9oE1ecVZzz8JrpumcSLmhRkqGGG62gjGEW1IQ,12766
14
- futurehouse_client-0.3.17.dev94.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
15
- futurehouse_client-0.3.17.dev94.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
16
- futurehouse_client-0.3.17.dev94.dist-info/RECORD,,