futurehouse-client 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/__init__.py +12 -0
- futurehouse_client/clients/__init__.py +12 -0
- futurehouse_client/clients/job_client.py +232 -0
- futurehouse_client/clients/rest_client.py +674 -0
- futurehouse_client/models/__init__.py +21 -0
- futurehouse_client/models/app.py +622 -0
- futurehouse_client/models/client.py +72 -0
- futurehouse_client/models/rest.py +19 -0
- futurehouse_client/utils/__init__.py +0 -0
- futurehouse_client/utils/module_utils.py +149 -0
- futurehouse_client-0.0.1.dist-info/METADATA +151 -0
- futurehouse_client-0.0.1.dist-info/RECORD +14 -0
- futurehouse_client-0.0.1.dist-info/WHEEL +5 -0
- futurehouse_client-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,674 @@
|
|
1
|
+
import ast
|
2
|
+
import base64
|
3
|
+
import copy
|
4
|
+
import importlib.metadata
|
5
|
+
import inspect
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
from collections.abc import Mapping
|
10
|
+
from datetime import datetime
|
11
|
+
from pathlib import Path
|
12
|
+
from types import ModuleType
|
13
|
+
from typing import Any, ClassVar, assert_never, cast
|
14
|
+
from uuid import UUID
|
15
|
+
|
16
|
+
import cloudpickle
|
17
|
+
from aviary.functional import EnvironmentBuilder
|
18
|
+
from httpx import Client, HTTPStatusError
|
19
|
+
from pydantic import BaseModel, ConfigDict, model_validator
|
20
|
+
from requests.exceptions import Timeout
|
21
|
+
from tenacity import (
|
22
|
+
retry,
|
23
|
+
retry_if_exception_type,
|
24
|
+
stop_after_attempt,
|
25
|
+
wait_exponential,
|
26
|
+
)
|
27
|
+
|
28
|
+
from futurehouse_client.clients import JobNames
|
29
|
+
from futurehouse_client.models.app import (
|
30
|
+
APIKeyPayload,
|
31
|
+
AuthType,
|
32
|
+
CrowDeploymentConfig,
|
33
|
+
JobRequest,
|
34
|
+
Stage,
|
35
|
+
)
|
36
|
+
from futurehouse_client.utils.module_utils import (
|
37
|
+
OrganizationSelector,
|
38
|
+
fetch_environment_function_docstring,
|
39
|
+
)
|
40
|
+
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
JobRequest.model_rebuild()
|
44
|
+
|
45
|
+
FILE_UPLOAD_IGNORE_PARTS = {
|
46
|
+
".ruff_cache",
|
47
|
+
"__pycache__",
|
48
|
+
".git",
|
49
|
+
".pytest_cache",
|
50
|
+
".mypy_cache",
|
51
|
+
".venv",
|
52
|
+
}
|
53
|
+
|
54
|
+
|
55
|
+
class RestClientError(Exception):
|
56
|
+
"""Base exception for REST client errors."""
|
57
|
+
|
58
|
+
|
59
|
+
class JobFetchError(RestClientError):
|
60
|
+
"""Raised when there's an error fetching a job."""
|
61
|
+
|
62
|
+
|
63
|
+
class JobCreationError(RestClientError):
|
64
|
+
"""Raised when there's an error creating a job."""
|
65
|
+
|
66
|
+
|
67
|
+
class InvalidTaskDescriptionError(Exception):
|
68
|
+
"""Raised when the task description is invalid or empty."""
|
69
|
+
|
70
|
+
|
71
|
+
# 5 minute default for JWTs
|
72
|
+
JWT_TOKEN_CACHE_EXPIRY: int = 300 # seconds
|
73
|
+
|
74
|
+
|
75
|
+
class JobResponse(BaseModel):
|
76
|
+
"""Base class for job responses. This holds attributes shared over all crows."""
|
77
|
+
|
78
|
+
model_config = ConfigDict(extra="ignore")
|
79
|
+
|
80
|
+
status: str
|
81
|
+
task: str
|
82
|
+
user: str
|
83
|
+
created_at: datetime
|
84
|
+
crow: str
|
85
|
+
public: bool
|
86
|
+
shared_with: list[str]
|
87
|
+
build_owner: str | None = None
|
88
|
+
environment_name: str | None = None
|
89
|
+
agent_name: str | None = None
|
90
|
+
job_id: UUID | None = None
|
91
|
+
|
92
|
+
@model_validator(mode="before")
|
93
|
+
@classmethod
|
94
|
+
def validate_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
95
|
+
# Extract fields from environment frame state
|
96
|
+
if not isinstance(data, dict):
|
97
|
+
return data
|
98
|
+
if not (env_frame := data.get("environment_frame", {})):
|
99
|
+
return data
|
100
|
+
state = env_frame.get("state", {}).get("state", {})
|
101
|
+
data["job_id"] = cast(UUID, state.get("id")) if state.get("id") else None
|
102
|
+
if not (metadata := data.get("metadata", {})):
|
103
|
+
return data
|
104
|
+
data["environment_name"] = metadata.get("environment_name")
|
105
|
+
data["agent_name"] = metadata.get("agent_name")
|
106
|
+
return data
|
107
|
+
|
108
|
+
|
109
|
+
class PQAJobResponse(JobResponse):
|
110
|
+
model_config = ConfigDict(extra="ignore")
|
111
|
+
|
112
|
+
answer: str | None = None
|
113
|
+
formatted_answer: str | None = None
|
114
|
+
answer_reasoning: str | None = None
|
115
|
+
has_successful_answer: bool | None = None
|
116
|
+
total_cost: float | None = None
|
117
|
+
total_queries: int | None = None
|
118
|
+
|
119
|
+
@model_validator(mode="before")
|
120
|
+
@classmethod
|
121
|
+
def validate_pqa_fields(cls, data: Mapping[str, Any]) -> Mapping[str, Any]:
|
122
|
+
# Extract fields from environment frame state
|
123
|
+
if not isinstance(data, dict):
|
124
|
+
return data
|
125
|
+
if not (env_frame := data.get("environment_frame", {})):
|
126
|
+
return data
|
127
|
+
state = env_frame.get("state", {}).get("state", {})
|
128
|
+
response = state.get("response", {})
|
129
|
+
answer = response.get("answer", {})
|
130
|
+
usage = state.get("info", {}).get("usage", {})
|
131
|
+
|
132
|
+
# Add additional PQA specific fields to data so that pydantic can validate the model
|
133
|
+
data["answer"] = answer.get("answer")
|
134
|
+
data["formatted_answer"] = answer.get("formatted_answer")
|
135
|
+
data["answer_reasoning"] = answer.get("answer_reasoning")
|
136
|
+
data["has_successful_answer"] = answer.get("has_successful_answer")
|
137
|
+
data["total_cost"] = cast(float, usage.get("total_cost"))
|
138
|
+
data["total_queries"] = cast(int, usage.get("total_queries"))
|
139
|
+
|
140
|
+
return data
|
141
|
+
|
142
|
+
def clean_verbose(self) -> "JobResponse":
|
143
|
+
"""Clean the verbose response from the server."""
|
144
|
+
self.request = None
|
145
|
+
self.response = None
|
146
|
+
return self
|
147
|
+
|
148
|
+
|
149
|
+
class JobResponseVerbose(JobResponse):
|
150
|
+
"""Class for responses to include all the fields of a job response."""
|
151
|
+
|
152
|
+
model_config = ConfigDict(extra="allow")
|
153
|
+
|
154
|
+
public: bool
|
155
|
+
agent_state: list[dict[str, Any]] | None = None
|
156
|
+
environment_frame: dict[str, Any] | None = None
|
157
|
+
metadata: dict[str, Any] | None = None
|
158
|
+
shared_with: list[str]
|
159
|
+
|
160
|
+
|
161
|
+
class RestClient:
|
162
|
+
REQUEST_TIMEOUT: ClassVar[float] = 30.0 # sec
|
163
|
+
MAX_RETRY_ATTEMPTS: ClassVar[int] = 3
|
164
|
+
RETRY_MULTIPLIER: ClassVar[int] = 1
|
165
|
+
MAX_RETRY_WAIT: ClassVar[int] = 10
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
stage: Stage = Stage.DEV,
|
170
|
+
service_uri: str | None = None,
|
171
|
+
organization: str | None = None,
|
172
|
+
auth_type: AuthType = AuthType.API_KEY,
|
173
|
+
api_key: str | None = None,
|
174
|
+
jwt: str | None = None,
|
175
|
+
headers: dict[str, str] | None = None,
|
176
|
+
):
|
177
|
+
self.base_url = service_uri or stage.value
|
178
|
+
self.stage = stage
|
179
|
+
self.auth_type = auth_type
|
180
|
+
self.api_key = api_key
|
181
|
+
self._clients: dict[str, Client] = {}
|
182
|
+
self.headers = headers or {}
|
183
|
+
self.auth_jwt = self._run_auth(jwt=jwt)
|
184
|
+
self.organizations: list[str] = self._filter_orgs(organization)
|
185
|
+
|
186
|
+
@property
|
187
|
+
def client(self) -> Client:
|
188
|
+
"""Lazily initialized and cached HTTP client with authentication."""
|
189
|
+
return self.get_client("application/json", with_auth=True)
|
190
|
+
|
191
|
+
@property
|
192
|
+
def auth_client(self) -> Client:
|
193
|
+
"""Lazily initialized and cached HTTP client without authentication."""
|
194
|
+
return self.get_client("application/json", with_auth=False)
|
195
|
+
|
196
|
+
@property
|
197
|
+
def multipart_client(self) -> Client:
|
198
|
+
"""Lazily initialized and cached HTTP client for multipart uploads."""
|
199
|
+
return self.get_client(None, with_auth=True)
|
200
|
+
|
201
|
+
def get_client(
|
202
|
+
self, content_type: str | None = "application/json", with_auth: bool = True
|
203
|
+
) -> Client:
|
204
|
+
"""Return a cached HTTP client or create one if needed.
|
205
|
+
|
206
|
+
Args:
|
207
|
+
content_type: The desired content type header. Use None for multipart uploads.
|
208
|
+
with_auth: Whether the client should include an Authorization header.
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
An HTTP client configured with the appropriate headers.
|
212
|
+
"""
|
213
|
+
# Create a composite key based on content type and auth flag.
|
214
|
+
key = f"{content_type or 'multipart'}_{with_auth}"
|
215
|
+
if key not in self._clients:
|
216
|
+
headers = copy.deepcopy(self.headers)
|
217
|
+
if with_auth:
|
218
|
+
headers["Authorization"] = f"Bearer {self.auth_jwt}"
|
219
|
+
if content_type:
|
220
|
+
headers["Content-Type"] = content_type
|
221
|
+
self._clients[key] = Client(
|
222
|
+
base_url=self.base_url,
|
223
|
+
headers=headers,
|
224
|
+
timeout=self.REQUEST_TIMEOUT,
|
225
|
+
)
|
226
|
+
return self._clients[key]
|
227
|
+
|
228
|
+
def __del__(self):
|
229
|
+
"""Ensure all cached clients are properly closed when the instance is destroyed."""
|
230
|
+
for client in self._clients.values():
|
231
|
+
client.close()
|
232
|
+
|
233
|
+
def _filter_orgs(self, organization: str | None = None) -> list[str]:
|
234
|
+
filtered_orgs = [
|
235
|
+
org
|
236
|
+
for org in self._fetch_my_orgs()
|
237
|
+
if (org == organization or organization is None)
|
238
|
+
]
|
239
|
+
if not filtered_orgs:
|
240
|
+
raise ValueError(f"Organization '{organization}' not found.")
|
241
|
+
return filtered_orgs
|
242
|
+
|
243
|
+
def _run_auth(self, jwt: str | None = None) -> str:
|
244
|
+
auth_payload: APIKeyPayload | None
|
245
|
+
if self.auth_type == AuthType.API_KEY:
|
246
|
+
auth_payload = APIKeyPayload(api_key=self.api_key)
|
247
|
+
elif self.auth_type == AuthType.JWT:
|
248
|
+
auth_payload = None
|
249
|
+
else:
|
250
|
+
assert_never(self.auth_type)
|
251
|
+
try:
|
252
|
+
# Use the unauthenticated client for login
|
253
|
+
if auth_payload:
|
254
|
+
response = self.auth_client.post(
|
255
|
+
"/auth/login", json=auth_payload.model_dump()
|
256
|
+
)
|
257
|
+
response.raise_for_status()
|
258
|
+
token_data = response.json()
|
259
|
+
elif jwt:
|
260
|
+
token_data = {"access_token": jwt, "expires_in": JWT_TOKEN_CACHE_EXPIRY}
|
261
|
+
else:
|
262
|
+
raise ValueError("JWT token required for JWT authentication.")
|
263
|
+
|
264
|
+
return token_data["access_token"]
|
265
|
+
except Exception as e:
|
266
|
+
raise RestClientError(f"Error authenticating: {e!s}") from e
|
267
|
+
|
268
|
+
def _check_job(self, name: str, organization: str) -> dict[str, Any]:
|
269
|
+
try:
|
270
|
+
response = self.client.get(
|
271
|
+
f"/v0.1/crows/{name}/organizations/{organization}"
|
272
|
+
)
|
273
|
+
response.raise_for_status()
|
274
|
+
return response.json()
|
275
|
+
except Exception as e:
|
276
|
+
raise JobFetchError(f"Error checking job: {e!s}") from e
|
277
|
+
|
278
|
+
def _fetch_my_orgs(self) -> list[str]:
|
279
|
+
response = self.client.get(f"/v0.1/organizations?filter={True}")
|
280
|
+
response.raise_for_status()
|
281
|
+
orgs = response.json()
|
282
|
+
return [org["name"] for org in orgs]
|
283
|
+
|
284
|
+
@staticmethod
|
285
|
+
def _validate_module_path(path: Path) -> None:
|
286
|
+
"""Validates that the given path exists and is a directory.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
path: Path to validate
|
290
|
+
|
291
|
+
Raises:
|
292
|
+
JobFetchError: If the path is not a directory
|
293
|
+
|
294
|
+
"""
|
295
|
+
if not path.is_dir():
|
296
|
+
raise JobFetchError(f"Path {path} is not a directory")
|
297
|
+
|
298
|
+
@staticmethod
|
299
|
+
def _validate_template_path(template_path: str | os.PathLike) -> None:
|
300
|
+
"""
|
301
|
+
Validates that a template path exists and is a file.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
template_path: Path to validate
|
305
|
+
|
306
|
+
Raises:
|
307
|
+
FileNotFoundError: If the template path doesn't exist
|
308
|
+
ValueError: If the path exists but isn't a file
|
309
|
+
"""
|
310
|
+
template_path = Path(template_path)
|
311
|
+
if not template_path.exists():
|
312
|
+
raise FileNotFoundError(
|
313
|
+
f"Markdown template file not found: {template_path}"
|
314
|
+
)
|
315
|
+
if not template_path.is_file():
|
316
|
+
raise ValueError(
|
317
|
+
f"Markdown template path exists but is not a file: {template_path}"
|
318
|
+
)
|
319
|
+
|
320
|
+
@staticmethod
|
321
|
+
def _validate_files(files: list, path: str | os.PathLike) -> None:
|
322
|
+
"""Validates that files were found in the given path.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
files: List of collected files
|
326
|
+
path: Path that was searched for files
|
327
|
+
|
328
|
+
Raises:
|
329
|
+
JobFetchError: If no files were found
|
330
|
+
|
331
|
+
"""
|
332
|
+
if not files:
|
333
|
+
raise JobFetchError(f"No files found in {path}")
|
334
|
+
|
335
|
+
@retry(
|
336
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
337
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
338
|
+
retry=retry_if_exception_type(Timeout),
|
339
|
+
)
|
340
|
+
def get_job(
|
341
|
+
self, job_id: str | None = None, history: bool = False, verbose: bool = False
|
342
|
+
) -> "JobResponse":
|
343
|
+
"""Get details for a specific crow job."""
|
344
|
+
try:
|
345
|
+
job_id = job_id or self.trajectory_id
|
346
|
+
response = self.client.get(
|
347
|
+
f"/v0.1/trajectories/{job_id}",
|
348
|
+
params={"history": history},
|
349
|
+
)
|
350
|
+
response.raise_for_status()
|
351
|
+
verbose_response = JobResponseVerbose(**response.json())
|
352
|
+
if verbose:
|
353
|
+
return verbose_response
|
354
|
+
if any(
|
355
|
+
JobNames.from_string(job_name) in verbose_response.crow
|
356
|
+
for job_name in ["crow", "falcon", "owl", "dummy"]
|
357
|
+
):
|
358
|
+
return PQAJobResponse(**response.json())
|
359
|
+
return JobResponse(**response.json())
|
360
|
+
except ValueError as e:
|
361
|
+
raise ValueError("Invalid job ID format. Must be a valid UUID.") from e
|
362
|
+
except Exception as e:
|
363
|
+
raise JobFetchError(f"Error getting job: {e!s}") from e
|
364
|
+
|
365
|
+
@retry(
|
366
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
367
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
368
|
+
retry=retry_if_exception_type(Timeout),
|
369
|
+
)
|
370
|
+
def create_job(self, job_data: JobRequest | dict[str, Any]):
|
371
|
+
"""Create a new crow job."""
|
372
|
+
if isinstance(job_data, dict):
|
373
|
+
job_data = JobRequest.model_validate(job_data)
|
374
|
+
|
375
|
+
if isinstance(job_data.name, JobNames):
|
376
|
+
job_data.name = job_data.name.from_stage(
|
377
|
+
job_data.name.name,
|
378
|
+
self.stage,
|
379
|
+
)
|
380
|
+
|
381
|
+
try:
|
382
|
+
response = self.client.post(
|
383
|
+
"/v0.1/crows", json=job_data.model_dump(mode="json")
|
384
|
+
)
|
385
|
+
response.raise_for_status()
|
386
|
+
self.trajectory_id = response.json()["trajectory_id"]
|
387
|
+
except Exception as e:
|
388
|
+
raise JobFetchError(f"Error creating job: {e!s}") from e
|
389
|
+
return self.trajectory_id
|
390
|
+
|
391
|
+
@retry(
|
392
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
393
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
394
|
+
retry=retry_if_exception_type(Timeout),
|
395
|
+
)
|
396
|
+
def get_build_status(self, build_id: UUID | None = None) -> dict[str, Any]:
|
397
|
+
"""Get the status of a build."""
|
398
|
+
build_id = build_id or self.build_id
|
399
|
+
response = self.client.get(f"/v0.1/builds/{build_id}")
|
400
|
+
response.raise_for_status()
|
401
|
+
return response.json()
|
402
|
+
|
403
|
+
# TODO: Refactor later so we don't have to ignore PLR0915
|
404
|
+
@retry(
|
405
|
+
stop=stop_after_attempt(MAX_RETRY_ATTEMPTS),
|
406
|
+
wait=wait_exponential(multiplier=RETRY_MULTIPLIER, max=MAX_RETRY_WAIT),
|
407
|
+
retry=retry_if_exception_type(Timeout),
|
408
|
+
)
|
409
|
+
def create_crow(self, config: CrowDeploymentConfig) -> dict[str, Any]: # noqa: PLR0915
|
410
|
+
"""Creates a crow deployment from the environment and environment files.
|
411
|
+
|
412
|
+
Args:
|
413
|
+
config: Configuration object containing all necessary parameters for crow deployment.
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
A response object containing metadata of the build.
|
417
|
+
|
418
|
+
"""
|
419
|
+
task_description: str = config.task_description or str(
|
420
|
+
fetch_environment_function_docstring(
|
421
|
+
config.environment,
|
422
|
+
config.path, # type: ignore[arg-type]
|
423
|
+
"from_task",
|
424
|
+
)
|
425
|
+
if config.functional_environment is None
|
426
|
+
else config.functional_environment.start_fn.__doc__
|
427
|
+
)
|
428
|
+
if not task_description or not task_description.strip():
|
429
|
+
raise InvalidTaskDescriptionError(
|
430
|
+
"Task description cannot be None or empty. Ensure your from_task environment function has a valid docstring."
|
431
|
+
" If you are deploying with your Environment as a dependency, "
|
432
|
+
"you must add a `task_description` to your `CrowDeploymentConfig`.",
|
433
|
+
)
|
434
|
+
selected_org = OrganizationSelector.select_organization(self.organizations)
|
435
|
+
if selected_org is None:
|
436
|
+
return {
|
437
|
+
"status": "cancelled",
|
438
|
+
"message": "Organization selection cancelled",
|
439
|
+
}
|
440
|
+
try:
|
441
|
+
try:
|
442
|
+
job_status = self._check_job(config.job_name, selected_org)
|
443
|
+
if job_status["exists"]:
|
444
|
+
if config.force:
|
445
|
+
logger.warning(
|
446
|
+
f"Overwriting existing deployment '{job_status['name']}'"
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
user_response = input(
|
450
|
+
f"A deployment named '{config.job_name}' already exists. Do you want to proceed? [y/N]: "
|
451
|
+
)
|
452
|
+
if user_response.lower() != "y":
|
453
|
+
logger.info("Deployment cancelled.")
|
454
|
+
return {
|
455
|
+
"status": "cancelled",
|
456
|
+
"message": "User cancelled deployment",
|
457
|
+
}
|
458
|
+
except Exception:
|
459
|
+
logger.warning("Unable to check for existing deployment, proceeding.")
|
460
|
+
encoded_pickle = None
|
461
|
+
if config.functional_environment is not None:
|
462
|
+
# TODO(remo): change aviary fenv code to have this happen automatically.
|
463
|
+
for t in config.functional_environment.tools:
|
464
|
+
t._force_pickle_fn = True
|
465
|
+
pickled_env = cloudpickle.dumps(config.functional_environment)
|
466
|
+
encoded_pickle = base64.b64encode(pickled_env).decode("utf-8")
|
467
|
+
files = []
|
468
|
+
for file_path in Path(config.path).rglob("*") if config.path else []:
|
469
|
+
if any(
|
470
|
+
ignore in file_path.parts for ignore in FILE_UPLOAD_IGNORE_PARTS
|
471
|
+
):
|
472
|
+
continue
|
473
|
+
|
474
|
+
if file_path.is_file():
|
475
|
+
relative_path = (
|
476
|
+
f"{config.module_name}/{file_path.relative_to(config.path)}" # type: ignore[arg-type]
|
477
|
+
)
|
478
|
+
files.append(
|
479
|
+
(
|
480
|
+
"files",
|
481
|
+
(
|
482
|
+
relative_path,
|
483
|
+
file_path.read_bytes(),
|
484
|
+
"application/octet-stream",
|
485
|
+
),
|
486
|
+
),
|
487
|
+
)
|
488
|
+
if (
|
489
|
+
config.functional_environment is not None
|
490
|
+
and config.requirements is not None
|
491
|
+
):
|
492
|
+
requirements_content = "\n".join(config.requirements)
|
493
|
+
files.append(
|
494
|
+
(
|
495
|
+
"files",
|
496
|
+
(
|
497
|
+
f"{config.environment}/requirements.txt",
|
498
|
+
requirements_content.encode(),
|
499
|
+
"text/plain",
|
500
|
+
),
|
501
|
+
),
|
502
|
+
)
|
503
|
+
if config.requirements_path:
|
504
|
+
requirements_path = Path(config.requirements_path)
|
505
|
+
files.append(
|
506
|
+
(
|
507
|
+
"files",
|
508
|
+
(
|
509
|
+
f"{config.module_name}/{requirements_path.name}",
|
510
|
+
requirements_path.read_bytes(),
|
511
|
+
"application/octet-stream",
|
512
|
+
),
|
513
|
+
),
|
514
|
+
)
|
515
|
+
if config.path:
|
516
|
+
self._validate_files(files, config.path)
|
517
|
+
markdown_template_file = None
|
518
|
+
if config.markdown_template_path:
|
519
|
+
self._validate_template_path(config.markdown_template_path)
|
520
|
+
template_path = Path(config.markdown_template_path)
|
521
|
+
markdown_template_file = (
|
522
|
+
"files",
|
523
|
+
(
|
524
|
+
"markdown_template",
|
525
|
+
template_path.read_bytes(),
|
526
|
+
"application/octet-stream",
|
527
|
+
),
|
528
|
+
)
|
529
|
+
logger.debug(f"Sending files: {[f[1][0] for f in files]}")
|
530
|
+
data = {
|
531
|
+
"agent": config.agent,
|
532
|
+
"job_name": config.job_name,
|
533
|
+
"organization": selected_org,
|
534
|
+
"environment": config.environment,
|
535
|
+
"functional_environment_pickle": encoded_pickle,
|
536
|
+
"python_version": config.python_version,
|
537
|
+
"task_description": task_description,
|
538
|
+
"environment_variables": (
|
539
|
+
json.dumps(config.environment_variables)
|
540
|
+
if config.environment_variables
|
541
|
+
else None
|
542
|
+
),
|
543
|
+
"container_config": (
|
544
|
+
config.container_config.model_dump_json()
|
545
|
+
if config.container_config
|
546
|
+
else None
|
547
|
+
),
|
548
|
+
"timeout": config.timeout,
|
549
|
+
"storage_dir": config.storage_location,
|
550
|
+
"frame_paths": (
|
551
|
+
json.dumps(
|
552
|
+
[fp.model_dump() for fp in config.frame_paths],
|
553
|
+
)
|
554
|
+
if config.frame_paths
|
555
|
+
else None
|
556
|
+
),
|
557
|
+
"task_queues_config": (
|
558
|
+
config.task_queues_config.model_dump_json()
|
559
|
+
if config.task_queues_config
|
560
|
+
else None
|
561
|
+
),
|
562
|
+
}
|
563
|
+
response = self.multipart_client.post(
|
564
|
+
"/v0.1/builds",
|
565
|
+
data=data,
|
566
|
+
files=(
|
567
|
+
[*files, markdown_template_file]
|
568
|
+
if markdown_template_file
|
569
|
+
else files
|
570
|
+
),
|
571
|
+
headers={"Accept": "application/json"},
|
572
|
+
params={"internal-deps": config.requires_aviary_internal},
|
573
|
+
)
|
574
|
+
try:
|
575
|
+
response.raise_for_status()
|
576
|
+
build_context = response.json()
|
577
|
+
self.build_id = build_context["build_id"]
|
578
|
+
except HTTPStatusError as e:
|
579
|
+
error_detail = response.json()
|
580
|
+
error_message = error_detail.get("detail", str(e))
|
581
|
+
raise JobFetchError(f"Server validation error: {error_message}") from e
|
582
|
+
except Exception as e:
|
583
|
+
raise JobFetchError(f"Error generating docker image: {e!s}") from e
|
584
|
+
return build_context
|
585
|
+
|
586
|
+
|
587
|
+
def get_installed_packages() -> dict[str, str]:
|
588
|
+
"""Returns a dictionary of installed packages and their versions."""
|
589
|
+
return {
|
590
|
+
dist.metadata["Name"].lower(): dist.version
|
591
|
+
for dist in importlib.metadata.distributions()
|
592
|
+
}
|
593
|
+
|
594
|
+
|
595
|
+
def get_global_imports(global_scope: dict) -> dict[str, str]:
|
596
|
+
"""Retrieve global imports from the global scope, mapping aliases to full module names."""
|
597
|
+
return {
|
598
|
+
name: obj.__name__
|
599
|
+
for name, obj in global_scope.items()
|
600
|
+
if isinstance(obj, ModuleType)
|
601
|
+
}
|
602
|
+
|
603
|
+
|
604
|
+
def get_referenced_globals_from_source(source_code: str) -> set[str]:
|
605
|
+
"""Extract globally referenced symbols from the source code."""
|
606
|
+
parsed = ast.parse(source_code)
|
607
|
+
return {
|
608
|
+
node.id
|
609
|
+
for node in ast.walk(parsed)
|
610
|
+
if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load)
|
611
|
+
}
|
612
|
+
|
613
|
+
|
614
|
+
def get_used_global_imports(
|
615
|
+
func,
|
616
|
+
global_imports: dict[str, str],
|
617
|
+
global_scope: dict,
|
618
|
+
visited=None,
|
619
|
+
) -> set[str]:
|
620
|
+
"""Retrieve global imports used by a function."""
|
621
|
+
if visited is None:
|
622
|
+
visited = set()
|
623
|
+
if func in visited:
|
624
|
+
return set()
|
625
|
+
visited.add(func)
|
626
|
+
used_imports: set[str] = set()
|
627
|
+
source_code = inspect.getsource(func)
|
628
|
+
referenced_globals = get_referenced_globals_from_source(source_code)
|
629
|
+
used_imports.update(
|
630
|
+
global_imports[name] for name in referenced_globals if name in global_imports
|
631
|
+
)
|
632
|
+
parsed = ast.parse(source_code)
|
633
|
+
for node in ast.walk(parsed):
|
634
|
+
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
|
635
|
+
ref_func = global_scope.get(node.func.id)
|
636
|
+
if callable(ref_func):
|
637
|
+
used_imports.update(
|
638
|
+
get_used_global_imports(
|
639
|
+
ref_func,
|
640
|
+
global_imports,
|
641
|
+
global_scope,
|
642
|
+
visited,
|
643
|
+
),
|
644
|
+
)
|
645
|
+
return used_imports
|
646
|
+
|
647
|
+
|
648
|
+
def get_used_modules(env_builder: EnvironmentBuilder, global_scope: dict) -> set[str]:
|
649
|
+
"""Retrieve globally imported modules referenced by the start_fn and tools."""
|
650
|
+
if not isinstance(env_builder, EnvironmentBuilder):
|
651
|
+
raise TypeError("The provided object is not an instance of EnvironmentBuilder.")
|
652
|
+
global_imports = get_global_imports(global_scope)
|
653
|
+
used_imports = get_used_global_imports(
|
654
|
+
env_builder.start_fn,
|
655
|
+
global_imports,
|
656
|
+
global_scope,
|
657
|
+
)
|
658
|
+
for tool in env_builder.tools:
|
659
|
+
used_imports.update(
|
660
|
+
get_used_global_imports(tool._tool_fn, global_imports, global_scope),
|
661
|
+
)
|
662
|
+
return used_imports
|
663
|
+
|
664
|
+
|
665
|
+
def generate_requirements(
|
666
|
+
env_builder: EnvironmentBuilder,
|
667
|
+
global_scope: dict,
|
668
|
+
) -> list[str]:
|
669
|
+
"""Generates a list of modules to install based on loaded modules."""
|
670
|
+
used_modules = get_used_modules(env_builder, global_scope)
|
671
|
+
used_modules.add("cloudpickle")
|
672
|
+
installed_packages = get_installed_packages()
|
673
|
+
pip_modules = {module for module in used_modules if module in installed_packages}
|
674
|
+
return [f"{module}=={installed_packages[module]}" for module in sorted(pip_modules)]
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from .app import (
|
2
|
+
AuthType,
|
3
|
+
CrowDeploymentConfig,
|
4
|
+
DockerContainerConfiguration,
|
5
|
+
FramePath,
|
6
|
+
JobRequest,
|
7
|
+
RuntimeConfig,
|
8
|
+
Stage,
|
9
|
+
Step,
|
10
|
+
)
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"AuthType",
|
14
|
+
"CrowDeploymentConfig",
|
15
|
+
"DockerContainerConfiguration",
|
16
|
+
"FramePath",
|
17
|
+
"JobRequest",
|
18
|
+
"RuntimeConfig",
|
19
|
+
"Stage",
|
20
|
+
"Step",
|
21
|
+
]
|