futurehouse-client 0.3.18.dev186__py3-none-any.whl → 0.3.19.dev111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/__init__.py +1 -2
- futurehouse_client/clients/__init__.py +1 -2
- futurehouse_client/clients/job_client.py +27 -1
- futurehouse_client/clients/rest_client.py +399 -180
- futurehouse_client/models/app.py +70 -2
- futurehouse_client/utils/auth.py +86 -101
- {futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dev111.dist-info}/METADATA +1 -1
- futurehouse_client-0.3.19.dev111.dist-info/RECORD +17 -0
- {futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dev111.dist-info}/WHEEL +1 -1
- futurehouse_client-0.3.18.dev186.dist-info/RECORD +0 -17
- {futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dev111.dist-info}/top_level.txt +0 -0
futurehouse_client/__init__.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1
1
|
from .clients.job_client import JobClient, JobNames
|
2
|
-
from .clients.rest_client import PQATaskResponse, TaskResponse, TaskResponseVerbose
|
3
2
|
from .clients.rest_client import RestClient as FutureHouseClient
|
3
|
+
from .clients.rest_client import TaskResponse, TaskResponseVerbose
|
4
4
|
|
5
5
|
__all__ = [
|
6
6
|
"FutureHouseClient",
|
7
7
|
"JobClient",
|
8
8
|
"JobNames",
|
9
|
-
"PQATaskResponse",
|
10
9
|
"TaskResponse",
|
11
10
|
"TaskResponseVerbose",
|
12
11
|
]
|
@@ -1,12 +1,11 @@
|
|
1
1
|
from .job_client import JobClient, JobNames
|
2
|
-
from .rest_client import PQATaskResponse, TaskResponse, TaskResponseVerbose
|
3
2
|
from .rest_client import RestClient as FutureHouseClient
|
3
|
+
from .rest_client import TaskResponse, TaskResponseVerbose
|
4
4
|
|
5
5
|
__all__ = [
|
6
6
|
"FutureHouseClient",
|
7
7
|
"JobClient",
|
8
8
|
"JobNames",
|
9
|
-
"PQATaskResponse",
|
10
9
|
"TaskResponse",
|
11
10
|
"TaskResponseVerbose",
|
12
11
|
]
|
@@ -8,7 +8,13 @@ from aviary.env import Frame
|
|
8
8
|
from pydantic import BaseModel
|
9
9
|
from tenacity import before_sleep_log, retry, stop_after_attempt, wait_exponential
|
10
10
|
|
11
|
-
from futurehouse_client.models.app import
|
11
|
+
from futurehouse_client.models.app import (
|
12
|
+
FinchTaskResponse,
|
13
|
+
PhoenixTaskResponse,
|
14
|
+
PQATaskResponse,
|
15
|
+
Stage,
|
16
|
+
TaskResponse,
|
17
|
+
)
|
12
18
|
from futurehouse_client.models.rest import (
|
13
19
|
FinalEnvironmentRequest,
|
14
20
|
StoreAgentStatePostRequest,
|
@@ -31,6 +37,19 @@ class JobNames(StrEnum):
|
|
31
37
|
DUMMY = "job-futurehouse-dummy-env"
|
32
38
|
PHOENIX = "job-futurehouse-phoenix"
|
33
39
|
FINCH = "job-futurehouse-data-analysis-crow-high"
|
40
|
+
CHIMP = "job-futurehouse-chimp"
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def _get_response_mapping(cls) -> dict[str, type[TaskResponse]]:
|
44
|
+
return {
|
45
|
+
cls.CROW: PQATaskResponse,
|
46
|
+
cls.FALCON: PQATaskResponse,
|
47
|
+
cls.OWL: PQATaskResponse,
|
48
|
+
cls.CHIMP: PQATaskResponse,
|
49
|
+
cls.PHOENIX: PhoenixTaskResponse,
|
50
|
+
cls.FINCH: FinchTaskResponse,
|
51
|
+
cls.DUMMY: TaskResponse,
|
52
|
+
}
|
34
53
|
|
35
54
|
@classmethod
|
36
55
|
def from_stage(cls, job_name: str, stage: Stage | None = None) -> str:
|
@@ -52,6 +71,13 @@ class JobNames(StrEnum):
|
|
52
71
|
f"Invalid job name: {job_name}. \nOptions are: {', '.join([name.name for name in cls])}"
|
53
72
|
) from e
|
54
73
|
|
74
|
+
@staticmethod
|
75
|
+
def get_response_object_from_job(job_name: str) -> type[TaskResponse]:
|
76
|
+
return JobNames._get_response_mapping()[job_name]
|
77
|
+
|
78
|
+
def get_response_object(self) -> type[TaskResponse]:
|
79
|
+
return self._get_response_mapping()[self.name]
|
80
|
+
|
55
81
|
|
56
82
|
class JobClient:
|
57
83
|
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
|
@@ -13,9 +13,10 @@ import tempfile
|
|
13
13
|
import time
|
14
14
|
import uuid
|
15
15
|
from collections.abc import Collection
|
16
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
16
17
|
from pathlib import Path
|
17
18
|
from types import ModuleType
|
18
|
-
from typing import Any, ClassVar,
|
19
|
+
from typing import Any, ClassVar, cast
|
19
20
|
from uuid import UUID
|
20
21
|
|
21
22
|
import cloudpickle
|
@@ -31,6 +32,7 @@ from httpx import (
|
|
31
32
|
ReadError,
|
32
33
|
ReadTimeout,
|
33
34
|
RemoteProtocolError,
|
35
|
+
codes,
|
34
36
|
)
|
35
37
|
from ldp.agent import AgentConfig
|
36
38
|
from requests.exceptions import RequestException, Timeout
|
@@ -45,21 +47,15 @@ from tqdm.asyncio import tqdm
|
|
45
47
|
|
46
48
|
from futurehouse_client.clients import JobNames
|
47
49
|
from futurehouse_client.models.app import (
|
48
|
-
APIKeyPayload,
|
49
50
|
AuthType,
|
50
51
|
JobDeploymentConfig,
|
51
|
-
PQATaskResponse,
|
52
52
|
Stage,
|
53
53
|
TaskRequest,
|
54
54
|
TaskResponse,
|
55
55
|
TaskResponseVerbose,
|
56
56
|
)
|
57
57
|
from futurehouse_client.models.rest import ExecutionStatus
|
58
|
-
from futurehouse_client.utils.auth import
|
59
|
-
AUTH_ERRORS_TO_RETRY_ON,
|
60
|
-
AuthError,
|
61
|
-
refresh_token_on_auth_error,
|
62
|
-
)
|
58
|
+
from futurehouse_client.utils.auth import RefreshingJWT
|
63
59
|
from futurehouse_client.utils.general import gather_with_concurrency
|
64
60
|
from futurehouse_client.utils.module_utils import (
|
65
61
|
OrganizationSelector,
|
@@ -128,8 +124,6 @@ retry_if_connection_error = retry_if_exception_type((
|
|
128
124
|
FileUploadError,
|
129
125
|
))
|
130
126
|
|
131
|
-
# 5 minute default for JWTs
|
132
|
-
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
133
127
|
DEFAULT_AGENT_TIMEOUT: int = 2400 # seconds
|
134
128
|
|
135
129
|
|
@@ -140,6 +134,9 @@ class RestClient:
|
|
140
134
|
MAX_RETRY_WAIT: ClassVar[int] = 10
|
141
135
|
DEFAULT_POLLING_TIME: ClassVar[int] = 5 # seconds
|
142
136
|
CHUNK_SIZE: ClassVar[int] = 16 * 1024 * 1024 # 16MB chunks
|
137
|
+
ASSEMBLY_POLLING_INTERVAL: ClassVar[int] = 10 # seconds
|
138
|
+
MAX_ASSEMBLY_WAIT_TIME: ClassVar[int] = 1800 # 30 minutes
|
139
|
+
MAX_CONCURRENT_CHUNKS: ClassVar[int] = 12 # Maximum concurrent chunk uploads
|
143
140
|
|
144
141
|
def __init__(
|
145
142
|
self,
|
@@ -163,69 +160,87 @@ class RestClient:
|
|
163
160
|
self.api_key = api_key
|
164
161
|
self._clients: dict[str, Client | AsyncClient] = {}
|
165
162
|
self.headers = headers or {}
|
166
|
-
self.
|
163
|
+
self.jwt = jwt
|
167
164
|
self.organizations: list[str] = self._filter_orgs(organization)
|
168
165
|
|
169
166
|
@property
|
170
167
|
def client(self) -> Client:
|
171
|
-
"""
|
172
|
-
return cast(Client, self.get_client("application/json",
|
168
|
+
"""Authenticated HTTP client for regular API calls."""
|
169
|
+
return cast(Client, self.get_client("application/json", authenticated=True))
|
173
170
|
|
174
171
|
@property
|
175
172
|
def async_client(self) -> AsyncClient:
|
176
|
-
"""
|
173
|
+
"""Authenticated async HTTP client for regular API calls."""
|
177
174
|
return cast(
|
178
175
|
AsyncClient,
|
179
|
-
self.get_client("application/json",
|
176
|
+
self.get_client("application/json", authenticated=True, async_client=True),
|
180
177
|
)
|
181
178
|
|
182
179
|
@property
|
183
|
-
def
|
184
|
-
"""
|
185
|
-
return cast(Client, self.get_client("application/json",
|
180
|
+
def unauthenticated_client(self) -> Client:
|
181
|
+
"""Unauthenticated HTTP client for auth operations."""
|
182
|
+
return cast(Client, self.get_client("application/json", authenticated=False))
|
186
183
|
|
187
184
|
@property
|
188
185
|
def multipart_client(self) -> Client:
|
189
|
-
"""
|
190
|
-
return cast(Client, self.get_client(None,
|
186
|
+
"""Authenticated HTTP client for multipart uploads."""
|
187
|
+
return cast(Client, self.get_client(None, authenticated=True))
|
191
188
|
|
192
189
|
def get_client(
|
193
190
|
self,
|
194
191
|
content_type: str | None = "application/json",
|
195
|
-
|
196
|
-
|
192
|
+
authenticated: bool = True,
|
193
|
+
async_client: bool = False,
|
197
194
|
) -> Client | AsyncClient:
|
198
195
|
"""Return a cached HTTP client or create one if needed.
|
199
196
|
|
200
197
|
Args:
|
201
198
|
content_type: The desired content type header. Use None for multipart uploads.
|
202
|
-
|
203
|
-
|
199
|
+
authenticated: Whether the client should include authentication.
|
200
|
+
async_client: Whether to use an async client.
|
204
201
|
|
205
202
|
Returns:
|
206
203
|
An HTTP client configured with the appropriate headers.
|
207
204
|
"""
|
208
|
-
# Create a composite key based on content type and auth flag
|
209
|
-
key = f"{content_type or 'multipart'}_{
|
205
|
+
# Create a composite key based on content type and auth flag
|
206
|
+
key = f"{content_type or 'multipart'}_{authenticated}_{async_client}"
|
207
|
+
|
210
208
|
if key not in self._clients:
|
211
209
|
headers = copy.deepcopy(self.headers)
|
212
|
-
|
213
|
-
|
210
|
+
auth = None
|
211
|
+
|
212
|
+
if authenticated:
|
213
|
+
auth = RefreshingJWT(
|
214
|
+
# authenticated=False will always return a synchronous client
|
215
|
+
auth_client=cast(
|
216
|
+
Client, self.get_client("application/json", authenticated=False)
|
217
|
+
),
|
218
|
+
auth_type=self.auth_type,
|
219
|
+
api_key=self.api_key,
|
220
|
+
jwt=self.jwt,
|
221
|
+
)
|
222
|
+
|
214
223
|
if content_type:
|
215
224
|
headers["Content-Type"] = content_type
|
225
|
+
|
226
|
+
headers["x-client"] = "sdk"
|
227
|
+
|
216
228
|
self._clients[key] = (
|
217
229
|
AsyncClient(
|
218
230
|
base_url=self.base_url,
|
219
231
|
headers=headers,
|
220
232
|
timeout=self.REQUEST_TIMEOUT,
|
233
|
+
auth=auth,
|
221
234
|
)
|
222
|
-
if
|
235
|
+
if async_client
|
223
236
|
else Client(
|
224
237
|
base_url=self.base_url,
|
225
238
|
headers=headers,
|
226
239
|
timeout=self.REQUEST_TIMEOUT,
|
240
|
+
auth=auth,
|
227
241
|
)
|
228
242
|
)
|
243
|
+
|
229
244
|
return self._clients[key]
|
230
245
|
|
231
246
|
def close(self):
|
@@ -255,32 +270,6 @@ class RestClient:
|
|
255
270
|
raise ValueError(f"Organization '{organization}' not found.")
|
256
271
|
return filtered_orgs
|
257
272
|
|
258
|
-
def _run_auth(self, jwt: str | None = None) -> str:
|
259
|
-
auth_payload: APIKeyPayload | None
|
260
|
-
if self.auth_type == AuthType.API_KEY:
|
261
|
-
auth_payload = APIKeyPayload(api_key=self.api_key)
|
262
|
-
elif self.auth_type == AuthType.JWT:
|
263
|
-
auth_payload = None
|
264
|
-
else:
|
265
|
-
assert_never(self.auth_type)
|
266
|
-
try:
|
267
|
-
# Use the unauthenticated client for login
|
268
|
-
if auth_payload:
|
269
|
-
response = self.auth_client.post(
|
270
|
-
"/auth/login", json=auth_payload.model_dump()
|
271
|
-
)
|
272
|
-
response.raise_for_status()
|
273
|
-
token_data = response.json()
|
274
|
-
elif jwt:
|
275
|
-
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
276
|
-
else:
|
277
|
-
raise ValueError("JWT token required for JWT authentication.")
|
278
|
-
|
279
|
-
return token_data["access_token"]
|
280
|
-
except Exception as e:
|
281
|
-
raise RestClientError(f"Error authenticating: {e!s}") from e
|
282
|
-
|
283
|
-
@refresh_token_on_auth_error()
|
284
273
|
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
285
274
|
try:
|
286
275
|
response = self.client.get(
|
@@ -288,25 +277,113 @@ class RestClient:
|
|
288
277
|
)
|
289
278
|
response.raise_for_status()
|
290
279
|
return response.json()
|
291
|
-
except HTTPStatusError as e:
|
292
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
293
|
-
raise AuthError(
|
294
|
-
e.response.status_code,
|
295
|
-
f"Authentication failed: {e}",
|
296
|
-
request=e.request,
|
297
|
-
response=e.response,
|
298
|
-
) from e
|
299
|
-
raise
|
300
280
|
except Exception as e:
|
301
281
|
raise JobFetchError(f"Error checking job: {e!s}") from e
|
302
282
|
|
303
|
-
@refresh_token_on_auth_error()
|
304
283
|
def _fetch_my_orgs(self) -> list[str]:
|
305
284
|
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
306
285
|
response.raise_for_status()
|
307
286
|
orgs = response.json()
|
308
287
|
return [org["name"] for org in orgs]
|
309
288
|
|
289
|
+
def _check_assembly_status(
|
290
|
+
self, job_name: str, upload_id: str, file_name: str
|
291
|
+
) -> dict[str, Any]:
|
292
|
+
"""Check the assembly status of an uploaded file.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
job_name: The name of the futurehouse job
|
296
|
+
upload_id: The upload ID
|
297
|
+
file_name: The name of the file
|
298
|
+
|
299
|
+
Returns:
|
300
|
+
Dict containing status information
|
301
|
+
|
302
|
+
Raises:
|
303
|
+
RestClientError: If there's an error checking status
|
304
|
+
"""
|
305
|
+
try:
|
306
|
+
url = f"/v0.1/crows/{job_name}/assembly-status/{upload_id}/{file_name}"
|
307
|
+
response = self.client.get(url)
|
308
|
+
response.raise_for_status()
|
309
|
+
return response.json()
|
310
|
+
except Exception as e:
|
311
|
+
raise RestClientError(f"Error checking assembly status: {e}") from e
|
312
|
+
|
313
|
+
def _wait_for_all_assemblies_completion(
|
314
|
+
self,
|
315
|
+
job_name: str,
|
316
|
+
upload_id: str,
|
317
|
+
file_names: list[str],
|
318
|
+
timeout: int = MAX_ASSEMBLY_WAIT_TIME,
|
319
|
+
) -> bool:
|
320
|
+
"""Wait for all file assemblies to complete.
|
321
|
+
|
322
|
+
Args:
|
323
|
+
job_name: The name of the futurehouse job
|
324
|
+
upload_id: The upload ID
|
325
|
+
file_names: List of file names to wait for
|
326
|
+
timeout: Maximum time to wait in seconds
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
True if all assemblies succeeded, False if any failed or timed out
|
330
|
+
|
331
|
+
Raises:
|
332
|
+
RestClientError: If any assembly fails
|
333
|
+
"""
|
334
|
+
if not file_names:
|
335
|
+
return True
|
336
|
+
|
337
|
+
start_time = time.time()
|
338
|
+
logger.info(f"Waiting for assembly of {len(file_names)} file(s) to complete...")
|
339
|
+
|
340
|
+
completed_files: set[str] = set()
|
341
|
+
|
342
|
+
while (time.time() - start_time) < timeout and len(completed_files) < len(
|
343
|
+
file_names
|
344
|
+
):
|
345
|
+
for file_name in file_names:
|
346
|
+
if file_name in completed_files:
|
347
|
+
continue
|
348
|
+
|
349
|
+
try:
|
350
|
+
status_data = self._check_assembly_status(
|
351
|
+
job_name, upload_id, file_name
|
352
|
+
)
|
353
|
+
status = status_data.get("status")
|
354
|
+
|
355
|
+
if status == ExecutionStatus.SUCCESS.value:
|
356
|
+
logger.info(f"Assembly completed for {file_name}")
|
357
|
+
completed_files.add(file_name)
|
358
|
+
elif status == ExecutionStatus.FAIL.value:
|
359
|
+
error_msg = status_data.get("error", "Unknown assembly error")
|
360
|
+
raise RestClientError(
|
361
|
+
f"Assembly failed for {file_name}: {error_msg}"
|
362
|
+
)
|
363
|
+
elif status == ExecutionStatus.IN_PROGRESS.value:
|
364
|
+
logger.debug(f"Assembly in progress for {file_name}...")
|
365
|
+
|
366
|
+
except RestClientError:
|
367
|
+
raise # Re-raise assembly errors
|
368
|
+
except Exception as e:
|
369
|
+
logger.warning(
|
370
|
+
f"Error checking assembly status for {file_name}: {e}"
|
371
|
+
)
|
372
|
+
|
373
|
+
# Don't sleep if all files are complete
|
374
|
+
if len(completed_files) < len(file_names):
|
375
|
+
time.sleep(self.ASSEMBLY_POLLING_INTERVAL)
|
376
|
+
|
377
|
+
if len(completed_files) < len(file_names):
|
378
|
+
remaining_files = set(file_names) - completed_files
|
379
|
+
logger.warning(
|
380
|
+
f"Assembly timeout for files: {remaining_files} after {timeout} seconds"
|
381
|
+
)
|
382
|
+
return False
|
383
|
+
|
384
|
+
logger.info(f"All {len(file_names)} file assemblies completed successfully")
|
385
|
+
return True
|
386
|
+
|
310
387
|
@staticmethod
|
311
388
|
def _validate_module_path(path: Path) -> None:
|
312
389
|
"""Validates that the given path exists and is a directory.
|
@@ -358,7 +435,6 @@ class RestClient:
|
|
358
435
|
if not files:
|
359
436
|
raise TaskFetchError(f"No files found in {path}")
|
360
437
|
|
361
|
-
@refresh_token_on_auth_error()
|
362
438
|
@retry(
|
363
439
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
364
440
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -394,25 +470,12 @@ class RestClient:
|
|
394
470
|
|
395
471
|
if verbose:
|
396
472
|
return verbose_response
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
):
|
401
|
-
return PQATaskResponse(**data)
|
402
|
-
return TaskResponse(**data)
|
403
|
-
except HTTPStatusError as e:
|
404
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
405
|
-
raise AuthError(
|
406
|
-
e.response.status_code,
|
407
|
-
f"Authentication failed: {e}",
|
408
|
-
request=e.request,
|
409
|
-
response=e.response,
|
410
|
-
) from e
|
411
|
-
raise
|
473
|
+
return JobNames.get_response_object_from_job(verbose_response.job_name)(
|
474
|
+
**data
|
475
|
+
)
|
412
476
|
except Exception as e:
|
413
477
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
414
478
|
|
415
|
-
@refresh_token_on_auth_error()
|
416
479
|
@retry(
|
417
480
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
418
481
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -450,25 +513,12 @@ class RestClient:
|
|
450
513
|
|
451
514
|
if verbose:
|
452
515
|
return verbose_response
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
):
|
457
|
-
return PQATaskResponse(**data)
|
458
|
-
return TaskResponse(**data)
|
459
|
-
except HTTPStatusError as e:
|
460
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
461
|
-
raise AuthError(
|
462
|
-
e.response.status_code,
|
463
|
-
f"Authentication failed: {e}",
|
464
|
-
request=e.request,
|
465
|
-
response=e.response,
|
466
|
-
) from e
|
467
|
-
raise
|
516
|
+
return JobNames.get_response_object_from_job(verbose_response.job_name)(
|
517
|
+
**data
|
518
|
+
)
|
468
519
|
except Exception as e:
|
469
520
|
raise TaskFetchError(f"Error getting task: {e!s}") from e
|
470
521
|
|
471
|
-
@refresh_token_on_auth_error()
|
472
522
|
@retry(
|
473
523
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
474
524
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -492,20 +542,10 @@ class RestClient:
|
|
492
542
|
response.raise_for_status()
|
493
543
|
trajectory_id = response.json()["trajectory_id"]
|
494
544
|
self.trajectory_id = trajectory_id
|
495
|
-
except HTTPStatusError as e:
|
496
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
497
|
-
raise AuthError(
|
498
|
-
e.response.status_code,
|
499
|
-
f"Authentication failed: {e}",
|
500
|
-
request=e.request,
|
501
|
-
response=e.response,
|
502
|
-
) from e
|
503
|
-
raise
|
504
545
|
except Exception as e:
|
505
546
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
506
547
|
return trajectory_id
|
507
548
|
|
508
|
-
@refresh_token_on_auth_error()
|
509
549
|
@retry(
|
510
550
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
511
551
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -529,15 +569,6 @@ class RestClient:
|
|
529
569
|
response.raise_for_status()
|
530
570
|
trajectory_id = response.json()["trajectory_id"]
|
531
571
|
self.trajectory_id = trajectory_id
|
532
|
-
except HTTPStatusError as e:
|
533
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
534
|
-
raise AuthError(
|
535
|
-
e.response.status_code,
|
536
|
-
f"Authentication failed: {e}",
|
537
|
-
request=e.request,
|
538
|
-
response=e.response,
|
539
|
-
) from e
|
540
|
-
raise
|
541
572
|
except Exception as e:
|
542
573
|
raise TaskFetchError(f"Error creating task: {e!s}") from e
|
543
574
|
return trajectory_id
|
@@ -683,7 +714,6 @@ class RestClient:
|
|
683
714
|
for task_id in trajectory_ids
|
684
715
|
]
|
685
716
|
|
686
|
-
@refresh_token_on_auth_error()
|
687
717
|
@retry(
|
688
718
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
689
719
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -695,19 +725,11 @@ class RestClient:
|
|
695
725
|
build_id = build_id or self.build_id
|
696
726
|
response = self.client.get(f"/v0.1/builds/{build_id}")
|
697
727
|
response.raise_for_status()
|
698
|
-
except
|
699
|
-
|
700
|
-
raise AuthError(
|
701
|
-
e.response.status_code,
|
702
|
-
f"Authentication failed: {e}",
|
703
|
-
request=e.request,
|
704
|
-
response=e.response,
|
705
|
-
) from e
|
706
|
-
raise
|
728
|
+
except Exception as e:
|
729
|
+
raise JobFetchError(f"Error getting build status: {e!s}") from e
|
707
730
|
return response.json()
|
708
731
|
|
709
732
|
# TODO: Refactor later so we don't have to ignore PLR0915
|
710
|
-
@refresh_token_on_auth_error()
|
711
733
|
@retry(
|
712
734
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
713
735
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -887,13 +909,6 @@ class RestClient:
|
|
887
909
|
build_context = response.json()
|
888
910
|
self.build_id = build_context["build_id"]
|
889
911
|
except HTTPStatusError as e:
|
890
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
891
|
-
raise AuthError(
|
892
|
-
e.response.status_code,
|
893
|
-
f"Authentication failed: {e}",
|
894
|
-
request=e.request,
|
895
|
-
response=e.response,
|
896
|
-
) from e
|
897
912
|
error_detail = response.json()
|
898
913
|
error_message = error_detail.get("detail", str(e))
|
899
914
|
raise JobCreationError(
|
@@ -903,6 +918,8 @@ class RestClient:
|
|
903
918
|
raise JobCreationError(f"Error generating docker image: {e!s}") from e
|
904
919
|
return build_context
|
905
920
|
|
921
|
+
# TODO: we should have have an async upload_file, check_assembly_status,
|
922
|
+
# wait_for_assembly_completion, upload_directory, upload_single_file
|
906
923
|
@retry(
|
907
924
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
908
925
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -913,6 +930,8 @@ class RestClient:
|
|
913
930
|
job_name: str,
|
914
931
|
file_path: str | os.PathLike,
|
915
932
|
upload_id: str | None = None,
|
933
|
+
wait_for_assembly: bool = True,
|
934
|
+
assembly_timeout: int = MAX_ASSEMBLY_WAIT_TIME,
|
916
935
|
) -> str:
|
917
936
|
"""Upload a file or directory to a futurehouse job bucket.
|
918
937
|
|
@@ -920,29 +939,47 @@ class RestClient:
|
|
920
939
|
job_name: The name of the futurehouse job to upload to.
|
921
940
|
file_path: The local path to the file or directory to upload.
|
922
941
|
upload_id: Optional folder name to use for the upload. If not provided, a random UUID will be used.
|
942
|
+
wait_for_assembly: After file chunking, wait for the assembly to be processed.
|
943
|
+
assembly_timeout: Maximum time to wait for assembly in seconds.
|
923
944
|
|
924
945
|
Returns:
|
925
946
|
The upload ID used for the upload.
|
926
947
|
|
927
948
|
Raises:
|
928
949
|
FileUploadError: If there's an error uploading the file.
|
950
|
+
RestClientError: If assembly fails or times out.
|
929
951
|
"""
|
930
952
|
file_path = Path(file_path)
|
931
953
|
if not file_path.exists():
|
932
954
|
raise FileNotFoundError(f"File or directory not found: {file_path}")
|
933
955
|
|
934
956
|
upload_id = upload_id or str(uuid.uuid4())
|
957
|
+
uploaded_files: list[str] = []
|
935
958
|
|
936
959
|
if file_path.is_dir():
|
937
960
|
# Process directory recursively
|
938
|
-
self._upload_directory(job_name, file_path, upload_id)
|
961
|
+
uploaded_files = self._upload_directory(job_name, file_path, upload_id)
|
939
962
|
else:
|
940
963
|
# Process single file
|
941
964
|
self._upload_single_file(job_name, file_path, upload_id)
|
965
|
+
uploaded_files = [file_path.name]
|
966
|
+
|
967
|
+
# Wait for all assemblies if requested and we have files
|
968
|
+
if wait_for_assembly and uploaded_files:
|
969
|
+
success = self._wait_for_all_assemblies_completion(
|
970
|
+
job_name, upload_id, uploaded_files, assembly_timeout
|
971
|
+
)
|
972
|
+
if not success:
|
973
|
+
raise RestClientError(
|
974
|
+
f"Assembly failed or timed out for one or more files: {uploaded_files}"
|
975
|
+
)
|
976
|
+
|
942
977
|
logger.info(f"Successfully uploaded {file_path} to {upload_id}")
|
943
978
|
return upload_id
|
944
979
|
|
945
|
-
def _upload_directory(
|
980
|
+
def _upload_directory(
|
981
|
+
self, job_name: str, dir_path: Path, upload_id: str
|
982
|
+
) -> list[str]:
|
946
983
|
"""Upload all files in a directory recursively.
|
947
984
|
|
948
985
|
Args:
|
@@ -950,12 +987,17 @@ class RestClient:
|
|
950
987
|
dir_path: The path to the directory to upload.
|
951
988
|
upload_id: The upload ID to use.
|
952
989
|
|
990
|
+
Returns:
|
991
|
+
List of uploaded file names.
|
992
|
+
|
953
993
|
Raises:
|
954
994
|
FileUploadError: If there's an error uploading any file.
|
955
995
|
"""
|
956
996
|
# Skip common directories that shouldn't be uploaded
|
957
997
|
if any(ignore in dir_path.parts for ignore in FILE_UPLOAD_IGNORE_PARTS):
|
958
|
-
return
|
998
|
+
return []
|
999
|
+
|
1000
|
+
uploaded_files: list[str] = []
|
959
1001
|
|
960
1002
|
try:
|
961
1003
|
# Upload all files in the directory recursively
|
@@ -965,24 +1007,27 @@ class RestClient:
|
|
965
1007
|
):
|
966
1008
|
# Use path relative to the original directory as file name
|
967
1009
|
rel_path = path.relative_to(dir_path)
|
1010
|
+
file_name = str(rel_path)
|
968
1011
|
self._upload_single_file(
|
969
1012
|
job_name,
|
970
1013
|
path,
|
971
1014
|
upload_id,
|
972
|
-
file_name=
|
1015
|
+
file_name=file_name,
|
973
1016
|
)
|
1017
|
+
uploaded_files.append(file_name)
|
974
1018
|
except Exception as e:
|
975
1019
|
raise FileUploadError(f"Error uploading directory {dir_path}: {e}") from e
|
976
1020
|
|
977
|
-
|
1021
|
+
return uploaded_files
|
1022
|
+
|
978
1023
|
def _upload_single_file(
|
979
1024
|
self,
|
980
1025
|
job_name: str,
|
981
1026
|
file_path: Path,
|
982
1027
|
upload_id: str,
|
983
1028
|
file_name: str | None = None,
|
984
|
-
) -> None:
|
985
|
-
"""Upload a single file in chunks.
|
1029
|
+
) -> str | None:
|
1030
|
+
"""Upload a single file in chunks using parallel uploads.
|
986
1031
|
|
987
1032
|
Args:
|
988
1033
|
job_name: The key of the crow to upload to.
|
@@ -990,6 +1035,9 @@ class RestClient:
|
|
990
1035
|
upload_id: The upload ID to use.
|
991
1036
|
file_name: Optional name to use for the file. If not provided, the file's name will be used.
|
992
1037
|
|
1038
|
+
Returns:
|
1039
|
+
The status URL if this was the last chunk, None otherwise.
|
1040
|
+
|
993
1041
|
Raises:
|
994
1042
|
FileUploadError: If there's an error uploading the file.
|
995
1043
|
"""
|
@@ -999,16 +1047,190 @@ class RestClient:
|
|
999
1047
|
# Skip empty files
|
1000
1048
|
if file_size == 0:
|
1001
1049
|
logger.warning(f"Skipping upload of empty file: {file_path}")
|
1002
|
-
return
|
1050
|
+
return None
|
1003
1051
|
|
1004
1052
|
total_chunks = (file_size + self.CHUNK_SIZE - 1) // self.CHUNK_SIZE
|
1005
1053
|
|
1006
1054
|
logger.info(f"Uploading {file_path} as {file_name} ({total_chunks} chunks)")
|
1007
1055
|
|
1056
|
+
status_url = None
|
1057
|
+
|
1008
1058
|
try:
|
1009
|
-
|
1010
|
-
|
1011
|
-
|
1059
|
+
# Upload all chunks except the last one in parallel
|
1060
|
+
if total_chunks > 1:
|
1061
|
+
self._upload_chunks_parallel(
|
1062
|
+
job_name,
|
1063
|
+
file_path,
|
1064
|
+
file_name,
|
1065
|
+
upload_id,
|
1066
|
+
total_chunks - 1,
|
1067
|
+
total_chunks,
|
1068
|
+
)
|
1069
|
+
|
1070
|
+
# Upload the last chunk separately (handles assembly)
|
1071
|
+
status_url = self._upload_final_chunk(
|
1072
|
+
job_name,
|
1073
|
+
file_path,
|
1074
|
+
file_name,
|
1075
|
+
upload_id,
|
1076
|
+
total_chunks - 1,
|
1077
|
+
total_chunks,
|
1078
|
+
)
|
1079
|
+
|
1080
|
+
logger.info(f"Successfully uploaded {file_name}")
|
1081
|
+
except Exception as e:
|
1082
|
+
logger.exception(f"Error uploading file {file_path}")
|
1083
|
+
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
1084
|
+
return status_url
|
1085
|
+
|
1086
|
+
def _upload_chunks_parallel(
|
1087
|
+
self,
|
1088
|
+
job_name: str,
|
1089
|
+
file_path: Path,
|
1090
|
+
file_name: str,
|
1091
|
+
upload_id: str,
|
1092
|
+
num_regular_chunks: int,
|
1093
|
+
total_chunks: int,
|
1094
|
+
) -> None:
|
1095
|
+
"""Upload chunks in parallel batches.
|
1096
|
+
|
1097
|
+
Args:
|
1098
|
+
job_name: The key of the crow to upload to.
|
1099
|
+
file_path: The path to the file to upload.
|
1100
|
+
file_name: The name to use for the file.
|
1101
|
+
upload_id: The upload ID to use.
|
1102
|
+
num_regular_chunks: Number of regular chunks (excluding final chunk).
|
1103
|
+
total_chunks: Total number of chunks.
|
1104
|
+
|
1105
|
+
Raises:
|
1106
|
+
FileUploadError: If there's an error uploading any chunk.
|
1107
|
+
"""
|
1108
|
+
if num_regular_chunks <= 0:
|
1109
|
+
return
|
1110
|
+
|
1111
|
+
# Process chunks in batches
|
1112
|
+
for batch_start in range(0, num_regular_chunks, self.MAX_CONCURRENT_CHUNKS):
|
1113
|
+
batch_end = min(
|
1114
|
+
batch_start + self.MAX_CONCURRENT_CHUNKS, num_regular_chunks
|
1115
|
+
)
|
1116
|
+
|
1117
|
+
# Upload chunks in this batch concurrently
|
1118
|
+
with ThreadPoolExecutor(max_workers=self.MAX_CONCURRENT_CHUNKS) as executor:
|
1119
|
+
futures = {
|
1120
|
+
executor.submit(
|
1121
|
+
self._upload_single_chunk,
|
1122
|
+
job_name,
|
1123
|
+
file_path,
|
1124
|
+
file_name,
|
1125
|
+
upload_id,
|
1126
|
+
chunk_index,
|
1127
|
+
total_chunks,
|
1128
|
+
): chunk_index
|
1129
|
+
for chunk_index in range(batch_start, batch_end)
|
1130
|
+
}
|
1131
|
+
|
1132
|
+
for future in as_completed(futures):
|
1133
|
+
chunk_index = futures[future]
|
1134
|
+
try:
|
1135
|
+
future.result()
|
1136
|
+
logger.debug(
|
1137
|
+
f"Uploaded chunk {chunk_index + 1}/{total_chunks} of {file_name}"
|
1138
|
+
)
|
1139
|
+
except Exception as e:
|
1140
|
+
logger.error(f"Error uploading chunk {chunk_index}: {e}")
|
1141
|
+
raise FileUploadError(
|
1142
|
+
f"Error uploading chunk {chunk_index} of {file_name}: {e}"
|
1143
|
+
) from e
|
1144
|
+
|
1145
|
+
def _upload_single_chunk(
|
1146
|
+
self,
|
1147
|
+
job_name: str,
|
1148
|
+
file_path: Path,
|
1149
|
+
file_name: str,
|
1150
|
+
upload_id: str,
|
1151
|
+
chunk_index: int,
|
1152
|
+
total_chunks: int,
|
1153
|
+
) -> None:
|
1154
|
+
"""Upload a single chunk.
|
1155
|
+
|
1156
|
+
Args:
|
1157
|
+
job_name: The key of the crow to upload to.
|
1158
|
+
file_path: The path to the file to upload.
|
1159
|
+
file_name: The name to use for the file.
|
1160
|
+
upload_id: The upload ID to use.
|
1161
|
+
chunk_index: The index of this chunk.
|
1162
|
+
total_chunks: Total number of chunks.
|
1163
|
+
|
1164
|
+
Raises:
|
1165
|
+
Exception: If there's an error uploading the chunk.
|
1166
|
+
"""
|
1167
|
+
with open(file_path, "rb") as f:
|
1168
|
+
# Read the chunk from the file
|
1169
|
+
f.seek(chunk_index * self.CHUNK_SIZE)
|
1170
|
+
chunk_data = f.read(self.CHUNK_SIZE)
|
1171
|
+
|
1172
|
+
# Prepare and send the chunk
|
1173
|
+
with tempfile.NamedTemporaryFile() as temp_file:
|
1174
|
+
temp_file.write(chunk_data)
|
1175
|
+
temp_file.flush()
|
1176
|
+
|
1177
|
+
# Create form data
|
1178
|
+
with open(temp_file.name, "rb") as chunk_file_obj:
|
1179
|
+
files = {
|
1180
|
+
"chunk": (
|
1181
|
+
file_name,
|
1182
|
+
chunk_file_obj,
|
1183
|
+
"application/octet-stream",
|
1184
|
+
)
|
1185
|
+
}
|
1186
|
+
data = {
|
1187
|
+
"file_name": file_name,
|
1188
|
+
"chunk_index": chunk_index,
|
1189
|
+
"total_chunks": total_chunks,
|
1190
|
+
"upload_id": upload_id,
|
1191
|
+
}
|
1192
|
+
|
1193
|
+
# Send the chunk
|
1194
|
+
response = self.multipart_client.post(
|
1195
|
+
f"/v0.1/crows/{job_name}/upload-chunk",
|
1196
|
+
files=files,
|
1197
|
+
data=data,
|
1198
|
+
)
|
1199
|
+
response.raise_for_status()
|
1200
|
+
|
1201
|
+
def _upload_final_chunk(
|
1202
|
+
self,
|
1203
|
+
job_name: str,
|
1204
|
+
file_path: Path,
|
1205
|
+
file_name: str,
|
1206
|
+
upload_id: str,
|
1207
|
+
chunk_index: int,
|
1208
|
+
total_chunks: int,
|
1209
|
+
) -> str | None:
|
1210
|
+
"""Upload the final chunk with retry logic for missing chunks.
|
1211
|
+
|
1212
|
+
Args:
|
1213
|
+
job_name: The key of the crow to upload to.
|
1214
|
+
file_path: The path to the file to upload.
|
1215
|
+
file_name: The name to use for the file.
|
1216
|
+
upload_id: The upload ID to use.
|
1217
|
+
chunk_index: The index of the final chunk.
|
1218
|
+
total_chunks: Total number of chunks.
|
1219
|
+
|
1220
|
+
Returns:
|
1221
|
+
The status URL from the response.
|
1222
|
+
|
1223
|
+
Raises:
|
1224
|
+
FileUploadError: If there's an error uploading the final chunk.
|
1225
|
+
"""
|
1226
|
+
retries = 0
|
1227
|
+
max_retries = 3
|
1228
|
+
retry_delay = 2.0 # seconds
|
1229
|
+
|
1230
|
+
while retries < max_retries:
|
1231
|
+
try:
|
1232
|
+
with open(file_path, "rb") as f:
|
1233
|
+
# Read the final chunk from the file
|
1012
1234
|
f.seek(chunk_index * self.CHUNK_SIZE)
|
1013
1235
|
chunk_data = f.read(self.CHUNK_SIZE)
|
1014
1236
|
|
@@ -1033,35 +1255,47 @@ class RestClient:
|
|
1033
1255
|
"upload_id": upload_id,
|
1034
1256
|
}
|
1035
1257
|
|
1036
|
-
# Send the chunk
|
1258
|
+
# Send the final chunk
|
1037
1259
|
response = self.multipart_client.post(
|
1038
1260
|
f"/v0.1/crows/{job_name}/upload-chunk",
|
1039
1261
|
files=files,
|
1040
1262
|
data=data,
|
1041
1263
|
)
|
1264
|
+
|
1265
|
+
# Handle missing chunks (status 409)
|
1266
|
+
if response.status_code == codes.CONFLICT:
|
1267
|
+
retries += 1
|
1268
|
+
if retries < max_retries:
|
1269
|
+
logger.warning(
|
1270
|
+
f"Missing chunks detected for {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries})"
|
1271
|
+
)
|
1272
|
+
time.sleep(retry_delay)
|
1273
|
+
continue
|
1274
|
+
|
1042
1275
|
response.raise_for_status()
|
1276
|
+
response_data = response.json()
|
1277
|
+
status_url = response_data.get("status_url")
|
1043
1278
|
|
1044
|
-
|
1279
|
+
logger.debug(
|
1280
|
+
f"Uploaded final chunk {chunk_index + 1}/{total_chunks} of {file_name}"
|
1281
|
+
)
|
1282
|
+
return status_url
|
1045
1283
|
|
1046
|
-
|
1047
|
-
|
1048
|
-
|
1284
|
+
except Exception as e:
|
1285
|
+
if retries >= max_retries - 1:
|
1286
|
+
raise FileUploadError(
|
1287
|
+
f"Error uploading final chunk of {file_name}: {e}"
|
1288
|
+
) from e
|
1289
|
+
retries += 1
|
1290
|
+
logger.warning(
|
1291
|
+
f"Error uploading final chunk of {file_name}, retrying in {retry_delay}s... (attempt {retries}/{max_retries}): {e}"
|
1292
|
+
)
|
1293
|
+
time.sleep(retry_delay)
|
1049
1294
|
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
raise AuthError(
|
1054
|
-
e.response.status_code,
|
1055
|
-
f"Authentication failed: {e}",
|
1056
|
-
request=e.request,
|
1057
|
-
response=e.response,
|
1058
|
-
) from e
|
1059
|
-
raise
|
1060
|
-
except Exception as e:
|
1061
|
-
logger.exception(f"Error uploading file {file_path}")
|
1062
|
-
raise FileUploadError(f"Error uploading file {file_path}: {e}") from e
|
1295
|
+
raise FileUploadError(
|
1296
|
+
f"Failed to upload final chunk of {file_name} after {max_retries} retries"
|
1297
|
+
)
|
1063
1298
|
|
1064
|
-
@refresh_token_on_auth_error()
|
1065
1299
|
@retry(
|
1066
1300
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1067
1301
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1098,13 +1332,6 @@ class RestClient:
|
|
1098
1332
|
response.raise_for_status()
|
1099
1333
|
return response.json()
|
1100
1334
|
except HTTPStatusError as e:
|
1101
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1102
|
-
raise AuthError(
|
1103
|
-
e.response.status_code,
|
1104
|
-
f"Authentication failed: {e}",
|
1105
|
-
request=e.request,
|
1106
|
-
response=e.response,
|
1107
|
-
) from e
|
1108
1335
|
logger.exception(
|
1109
1336
|
f"Error listing files for job {job_name}, trajectory {trajectory_id}, upload_id {upload_id}: {e.response.text}"
|
1110
1337
|
)
|
@@ -1117,7 +1344,6 @@ class RestClient:
|
|
1117
1344
|
)
|
1118
1345
|
raise RestClientError(f"Error listing files: {e!s}") from e
|
1119
1346
|
|
1120
|
-
@refresh_token_on_auth_error()
|
1121
1347
|
@retry(
|
1122
1348
|
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
1123
1349
|
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
@@ -1165,13 +1391,6 @@ class RestClient:
|
|
1165
1391
|
|
1166
1392
|
logger.info(f"File {file_path} downloaded to {destination_path}")
|
1167
1393
|
except HTTPStatusError as e:
|
1168
|
-
if e.response.status_code in AUTH_ERRORS_TO_RETRY_ON:
|
1169
|
-
raise AuthError(
|
1170
|
-
e.response.status_code,
|
1171
|
-
f"Authentication failed: {e}",
|
1172
|
-
request=e.request,
|
1173
|
-
response=e.response,
|
1174
|
-
) from e
|
1175
1394
|
logger.exception(
|
1176
1395
|
f"Error downloading file {file_path} for job {job_name}, trajectory_id {trajectory_id}: {e.response.text}"
|
1177
1396
|
)
|
futurehouse_client/models/app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import copy
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
import re
|
@@ -675,7 +676,8 @@ class TaskResponse(BaseModel):
|
|
675
676
|
|
676
677
|
@model_validator(mode="before")
|
677
678
|
@classmethod
|
678
|
-
def validate_fields(cls,
|
679
|
+
def validate_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
|
680
|
+
data = copy.deepcopy(original_data) # Avoid mutating the original data
|
679
681
|
# Extract fields from environment frame state
|
680
682
|
if not isinstance(data, dict):
|
681
683
|
return data
|
@@ -690,7 +692,72 @@ class TaskResponse(BaseModel):
|
|
690
692
|
return data
|
691
693
|
|
692
694
|
|
695
|
+
class PhoenixTaskResponse(TaskResponse):
|
696
|
+
"""
|
697
|
+
Response scheme for tasks executed with Phoenix.
|
698
|
+
|
699
|
+
Additional fields:
|
700
|
+
answer: Final answer from Phoenix
|
701
|
+
"""
|
702
|
+
|
703
|
+
model_config = ConfigDict(extra="ignore")
|
704
|
+
answer: str | None = None
|
705
|
+
|
706
|
+
@model_validator(mode="before")
|
707
|
+
@classmethod
|
708
|
+
def validate_phoenix_fields(
|
709
|
+
cls, original_data: Mapping[str, Any]
|
710
|
+
) -> Mapping[str, Any]:
|
711
|
+
data = copy.deepcopy(original_data)
|
712
|
+
if not isinstance(data, dict):
|
713
|
+
return data
|
714
|
+
if not (env_frame := data.get("environment_frame", {})):
|
715
|
+
return data
|
716
|
+
state = env_frame.get("state", {}).get("state", {})
|
717
|
+
data["answer"] = state.get("answer")
|
718
|
+
return data
|
719
|
+
|
720
|
+
|
721
|
+
class FinchTaskResponse(TaskResponse):
|
722
|
+
"""
|
723
|
+
Response scheme for tasks executed with Finch.
|
724
|
+
|
725
|
+
Additional fields:
|
726
|
+
answer: Final answer from Finch
|
727
|
+
notebook: a dictionary with `cells` and `metadata` regarding the notebook content
|
728
|
+
"""
|
729
|
+
|
730
|
+
model_config = ConfigDict(extra="ignore")
|
731
|
+
answer: str | None = None
|
732
|
+
notebook: dict[str, Any] | None = None
|
733
|
+
|
734
|
+
@model_validator(mode="before")
|
735
|
+
@classmethod
|
736
|
+
def validate_finch_fields(
|
737
|
+
cls, original_data: Mapping[str, Any]
|
738
|
+
) -> Mapping[str, Any]:
|
739
|
+
data = copy.deepcopy(original_data)
|
740
|
+
if not isinstance(data, dict):
|
741
|
+
return data
|
742
|
+
if not (env_frame := data.get("environment_frame", {})):
|
743
|
+
return data
|
744
|
+
state = env_frame.get("state", {}).get("state", {})
|
745
|
+
data["answer"] = state.get("answer")
|
746
|
+
data["notebook"] = state.get("nb_state")
|
747
|
+
return data
|
748
|
+
|
749
|
+
|
693
750
|
class PQATaskResponse(TaskResponse):
|
751
|
+
"""
|
752
|
+
Response scheme for tasks executed with PQA.
|
753
|
+
|
754
|
+
Additional fields:
|
755
|
+
answer: Final answer from PQA
|
756
|
+
formatted_answer: Formatted answer from PQA
|
757
|
+
answer_reasoning: Reasoning used to generate the final answer, if available
|
758
|
+
has_successful_answer: Whether the answer is successful
|
759
|
+
"""
|
760
|
+
|
694
761
|
model_config = ConfigDict(extra="ignore")
|
695
762
|
|
696
763
|
answer: str | None = None
|
@@ -702,7 +769,8 @@ class PQATaskResponse(TaskResponse):
|
|
702
769
|
|
703
770
|
@model_validator(mode="before")
|
704
771
|
@classmethod
|
705
|
-
def validate_pqa_fields(cls,
|
772
|
+
def validate_pqa_fields(cls, original_data: Mapping[str, Any]) -> Mapping[str, Any]:
|
773
|
+
data = copy.deepcopy(original_data) # Avoid mutating the original data
|
706
774
|
if not isinstance(data, dict):
|
707
775
|
return data
|
708
776
|
if not (env_frame := data.get("environment_frame", {})):
|
futurehouse_client/utils/auth.py
CHANGED
@@ -1,107 +1,92 @@
|
|
1
|
-
import asyncio
|
2
1
|
import logging
|
3
|
-
from collections.abc import
|
4
|
-
from
|
5
|
-
from typing import Any, Final, Optional, ParamSpec, TypeVar, overload
|
2
|
+
from collections.abc import Collection, Generator
|
3
|
+
from typing import ClassVar, Final
|
6
4
|
|
7
5
|
import httpx
|
8
|
-
from httpx import HTTPStatusError
|
9
6
|
|
10
|
-
|
11
|
-
|
12
|
-
T = TypeVar("T")
|
13
|
-
P = ParamSpec("P")
|
14
|
-
|
15
|
-
AUTH_ERRORS_TO_RETRY_ON: Final[set[int]] = {
|
16
|
-
httpx.codes.UNAUTHORIZED,
|
17
|
-
httpx.codes.FORBIDDEN,
|
18
|
-
}
|
19
|
-
|
20
|
-
|
21
|
-
class AuthError(Exception):
|
22
|
-
"""Raised when authentication fails with 401/403 status."""
|
23
|
-
|
24
|
-
def __init__(self, status_code: int, message: str, request=None, response=None):
|
25
|
-
self.status_code = status_code
|
26
|
-
self.request = request
|
27
|
-
self.response = response
|
28
|
-
super().__init__(message)
|
29
|
-
|
30
|
-
|
31
|
-
def is_auth_error(e: Exception) -> bool:
|
32
|
-
if isinstance(e, AuthError):
|
33
|
-
return True
|
34
|
-
if isinstance(e, HTTPStatusError):
|
35
|
-
return e.response.status_code in AUTH_ERRORS_TO_RETRY_ON
|
36
|
-
return False
|
37
|
-
|
38
|
-
|
39
|
-
def get_status_code(e: Exception) -> Optional[int]:
|
40
|
-
if isinstance(e, AuthError):
|
41
|
-
return e.status_code
|
42
|
-
if isinstance(e, HTTPStatusError):
|
43
|
-
return e.response.status_code
|
44
|
-
return None
|
7
|
+
from futurehouse_client.models.app import APIKeyPayload, AuthType
|
45
8
|
|
9
|
+
logger = logging.getLogger(__name__)
|
46
10
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
11
|
+
INVALID_REFRESH_TYPE_MSG: Final[str] = (
|
12
|
+
"API key auth is required to refresh auth tokens."
|
13
|
+
)
|
14
|
+
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
15
|
+
|
16
|
+
|
17
|
+
def _run_auth(
|
18
|
+
client: httpx.Client,
|
19
|
+
auth_type: AuthType = AuthType.API_KEY,
|
20
|
+
api_key: str | None = None,
|
21
|
+
jwt: str | None = None,
|
22
|
+
) -> str:
|
23
|
+
auth_payload: APIKeyPayload | None
|
24
|
+
if auth_type == AuthType.API_KEY:
|
25
|
+
auth_payload = APIKeyPayload(api_key=api_key)
|
26
|
+
elif auth_type == AuthType.JWT:
|
27
|
+
auth_payload = None
|
28
|
+
try:
|
29
|
+
if auth_payload:
|
30
|
+
response = client.post("/auth/login", json=auth_payload.model_dump())
|
31
|
+
response.raise_for_status()
|
32
|
+
token_data = response.json()
|
33
|
+
elif jwt:
|
34
|
+
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
35
|
+
else:
|
36
|
+
raise ValueError("JWT token required for JWT authentication.")
|
37
|
+
|
38
|
+
return token_data["access_token"]
|
39
|
+
except Exception as e:
|
40
|
+
raise Exception("Failed to authenticate") from e # noqa: TRY002
|
41
|
+
|
42
|
+
|
43
|
+
class RefreshingJWT(httpx.Auth):
|
44
|
+
"""Automatically (re-)inject a JWT and transparently retry exactly once when we hit a 401/403."""
|
45
|
+
|
46
|
+
RETRY_STATUSES: ClassVar[Collection[httpx.codes]] = {
|
47
|
+
httpx.codes.UNAUTHORIZED,
|
48
|
+
httpx.codes.FORBIDDEN,
|
49
|
+
}
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
auth_client: httpx.Client,
|
54
|
+
auth_type: AuthType = AuthType.API_KEY,
|
55
|
+
api_key: str | None = None,
|
56
|
+
jwt: str | None = None,
|
57
|
+
):
|
58
|
+
self.auth_type = auth_type
|
59
|
+
self.auth_client = auth_client
|
60
|
+
self.api_key = api_key
|
61
|
+
self._jwt = _run_auth(
|
62
|
+
client=auth_client,
|
63
|
+
jwt=jwt,
|
64
|
+
auth_type=auth_type,
|
65
|
+
api_key=api_key,
|
66
|
+
)
|
67
|
+
|
68
|
+
def refresh_token(self) -> None:
|
69
|
+
if self.auth_type == AuthType.JWT:
|
70
|
+
logger.error(INVALID_REFRESH_TYPE_MSG)
|
71
|
+
raise ValueError(INVALID_REFRESH_TYPE_MSG)
|
72
|
+
self._jwt = _run_auth(
|
73
|
+
client=self.auth_client,
|
74
|
+
auth_type=self.auth_type,
|
75
|
+
api_key=self.api_key,
|
76
|
+
)
|
77
|
+
|
78
|
+
def auth_flow(
|
79
|
+
self, request: httpx.Request
|
80
|
+
) -> Generator[httpx.Request, httpx.Response, None]:
|
81
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
82
|
+
response = yield request
|
83
|
+
|
84
|
+
# If it failed, refresh once and replay the request
|
85
|
+
if response.status_code in self.RETRY_STATUSES:
|
86
|
+
logger.info(
|
87
|
+
"Received %s, refreshing token and retrying …",
|
88
|
+
response.status_code,
|
89
|
+
)
|
90
|
+
self.refresh_token()
|
91
|
+
request.headers["Authorization"] = f"Bearer {self._jwt}"
|
92
|
+
yield request # second (and final) attempt, again or use a while loop
|
{futurehouse_client-0.3.18.dev186.dist-info → futurehouse_client-0.3.19.dev111.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: futurehouse-client
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.19.dev111
|
4
4
|
Summary: A client for interacting with endpoints of the FutureHouse service.
|
5
5
|
Author-email: FutureHouse technical staff <hello@futurehouse.org>
|
6
6
|
Classifier: Operating System :: OS Independent
|
@@ -0,0 +1,17 @@
|
|
1
|
+
futurehouse_client/__init__.py,sha256=OzGDkVm5UTUzd4n8yOmRjMF73YrK0FaIQX5gS3Dk8Zo,304
|
2
|
+
futurehouse_client/clients/__init__.py,sha256=-HXNj-XJ3LRO5XM6MZ709iPs29YpApss0Q2YYg1qMZw,280
|
3
|
+
futurehouse_client/clients/job_client.py,sha256=JgB5IUAyCmnhGRsYc3bgKldA-lkM1JLwHRwwUeOCdus,11944
|
4
|
+
futurehouse_client/clients/rest_client.py,sha256=_XgkzA9OhUKjL9vpkU6ixh2lUW9StgqfGgLk2qHjGgI,55518
|
5
|
+
futurehouse_client/models/__init__.py,sha256=5x-f9AoM1hGzJBEHcHAXSt7tPeImST5oZLuMdwp0mXc,554
|
6
|
+
futurehouse_client/models/app.py,sha256=VCtg0ygd-TSrR6DtfljTBt9jnl1eBNal8UXHFdkDg88,28587
|
7
|
+
futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
|
8
|
+
futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
|
9
|
+
futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
futurehouse_client/utils/auth.py,sha256=tgWELjKfg8eWme_qdcRmc8TjQN9DVZuHHaVXZNHLchk,2960
|
11
|
+
futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
|
12
|
+
futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
|
13
|
+
futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
|
14
|
+
futurehouse_client-0.3.19.dev111.dist-info/METADATA,sha256=N4Msi8W4mMBXFs_-Pl8Ii12RcLRm2eBl9NiIFCy5--E,12767
|
15
|
+
futurehouse_client-0.3.19.dev111.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
16
|
+
futurehouse_client-0.3.19.dev111.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
|
17
|
+
futurehouse_client-0.3.19.dev111.dist-info/RECORD,,
|
@@ -1,17 +0,0 @@
|
|
1
|
-
futurehouse_client/__init__.py,sha256=ddxO7JE97c6bt7LjNglZZ2Ql8bYCGI9laSFeh9MP6VU,344
|
2
|
-
futurehouse_client/clients/__init__.py,sha256=tFWqwIAY5PvwfOVsCje4imjTpf6xXNRMh_UHIKVI1_0,320
|
3
|
-
futurehouse_client/clients/job_client.py,sha256=uNkqQbeZw7wbA0qDWcIOwOykrosza-jev58paJZ_mbA,11150
|
4
|
-
futurehouse_client/clients/rest_client.py,sha256=W9ASP1ZKYS7UL5J9b-Km77YXEiDQ9hCf4X_9PqaZZZc,47914
|
5
|
-
futurehouse_client/models/__init__.py,sha256=5x-f9AoM1hGzJBEHcHAXSt7tPeImST5oZLuMdwp0mXc,554
|
6
|
-
futurehouse_client/models/app.py,sha256=w_1e4F0IiC-BKeOLqYkABYo4U-Nka1S-F64S_eHB2KM,26421
|
7
|
-
futurehouse_client/models/client.py,sha256=n4HD0KStKLm6Ek9nL9ylP-bkK10yzAaD1uIDF83Qp_A,1828
|
8
|
-
futurehouse_client/models/rest.py,sha256=lgwkMIXz0af-49BYSkKeS7SRqvN3motqnAikDN4YGTc,789
|
9
|
-
futurehouse_client/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
futurehouse_client/utils/auth.py,sha256=Lq9mjSGc7iuRP6fmLICCS6KjzLHN6-tJUuhYp0XXrkE,3342
|
11
|
-
futurehouse_client/utils/general.py,sha256=A_rtTiYW30ELGEZlWCIArO7q1nEmqi8hUlmBRYkMQ_c,767
|
12
|
-
futurehouse_client/utils/module_utils.py,sha256=aFyd-X-pDARXz9GWpn8SSViUVYdSbuy9vSkrzcVIaGI,4955
|
13
|
-
futurehouse_client/utils/monitoring.py,sha256=UjRlufe67kI3VxRHOd5fLtJmlCbVA2Wqwpd4uZhXkQM,8728
|
14
|
-
futurehouse_client-0.3.18.dev186.dist-info/METADATA,sha256=PvjehEQZu2ihl7kG1uDvWJVUxyYbV7J-VmAe42Ml3zo,12767
|
15
|
-
futurehouse_client-0.3.18.dev186.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
16
|
-
futurehouse_client-0.3.18.dev186.dist-info/top_level.txt,sha256=TRuLUCt_qBnggdFHCX4O_BoCu1j2X43lKfIZC-ElwWY,19
|
17
|
-
futurehouse_client-0.3.18.dev186.dist-info/RECORD,,
|
File without changes
|