futurehouse-client 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- futurehouse_client/__init__.py +12 -0
- futurehouse_client/clients/__init__.py +12 -0
- futurehouse_client/clients/job_client.py +232 -0
- futurehouse_client/clients/rest_client.py +674 -0
- futurehouse_client/models/__init__.py +21 -0
- futurehouse_client/models/app.py +622 -0
- futurehouse_client/models/client.py +72 -0
- futurehouse_client/models/rest.py +19 -0
- futurehouse_client/utils/__init__.py +0 -0
- futurehouse_client/utils/module_utils.py +149 -0
- futurehouse_client-0.0.1.dist-info/METADATA +151 -0
- futurehouse_client-0.0.1.dist-info/RECORD +14 -0
- futurehouse_client-0.0.1.dist-info/WHEEL +5 -0
- futurehouse_client-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,622 @@
|
|
1
|
+
import os
|
2
|
+
import re
|
3
|
+
from enum import StrEnum, auto
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
|
6
|
+
from uuid import UUID
|
7
|
+
|
8
|
+
from aviary.functional import EnvironmentBuilder
|
9
|
+
from ldp.agent import Agent, AgentConfig
|
10
|
+
from ldp.alg.callbacks import Callback
|
11
|
+
from pydantic import (
|
12
|
+
BaseModel,
|
13
|
+
ConfigDict,
|
14
|
+
Field,
|
15
|
+
field_validator,
|
16
|
+
model_validator,
|
17
|
+
)
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from futurehouse_client.clients import JobNames
|
21
|
+
|
22
|
+
|
23
|
+
MAX_CROW_JOB_RUN_TIMEOUT = 60 * 60 * 24 # 24 hours in sec
|
24
|
+
MIN_CROW_JOB_RUN_TIMEOUT = 0 # sec
|
25
|
+
|
26
|
+
|
27
|
+
class PythonVersion(StrEnum):
|
28
|
+
V3_11 = "3.11"
|
29
|
+
V3_12 = "3.12"
|
30
|
+
|
31
|
+
|
32
|
+
class AuthType(StrEnum):
|
33
|
+
API_KEY = auto()
|
34
|
+
JWT = auto()
|
35
|
+
|
36
|
+
|
37
|
+
class APIKeyPayload(BaseModel):
|
38
|
+
api_key: str = Field(description="A user API key to authenticate with the server.")
|
39
|
+
|
40
|
+
|
41
|
+
class PriorityQueueTypes(StrEnum):
|
42
|
+
LOW = auto()
|
43
|
+
NORMAL = auto()
|
44
|
+
HIGH = auto()
|
45
|
+
ULTRA = auto()
|
46
|
+
|
47
|
+
def rate_percentage(self) -> float:
|
48
|
+
if self == self.LOW:
|
49
|
+
return 0.1
|
50
|
+
if self == self.NORMAL:
|
51
|
+
return 0.5
|
52
|
+
if self == self.HIGH:
|
53
|
+
return 0.75
|
54
|
+
if self == self.ULTRA:
|
55
|
+
return 1.0
|
56
|
+
raise NotImplementedError(f"Unknown priority queue type: {self}")
|
57
|
+
|
58
|
+
|
59
|
+
class RetryConfig(BaseModel):
|
60
|
+
"""Configuration for task retry settings."""
|
61
|
+
|
62
|
+
max_attempts: int = Field(
|
63
|
+
-1, description="Maximum number of retry attempts. -1 for infinite retries."
|
64
|
+
)
|
65
|
+
max_retry_duration_seconds: int = Field(
|
66
|
+
604800, # 7 days in seconds
|
67
|
+
description="Maximum time a task can be retrying for before giving up (in seconds).",
|
68
|
+
)
|
69
|
+
max_backoff_seconds: int = Field(
|
70
|
+
60, # means the rate in a full-queue will be each entry trying once per minute
|
71
|
+
description="Maximum time to wait between retries (in seconds).",
|
72
|
+
)
|
73
|
+
min_backoff_seconds: int = Field(
|
74
|
+
1, description="Minimum time to wait between retries (in seconds)."
|
75
|
+
)
|
76
|
+
max_doublings: int = Field(
|
77
|
+
7,
|
78
|
+
description="Maximum number of times the retry interval can double before becoming constant.",
|
79
|
+
)
|
80
|
+
|
81
|
+
def to_client_dict(self) -> dict[str, Any]:
|
82
|
+
"""Convert retry config to GCP Cloud Tasks client format."""
|
83
|
+
return {
|
84
|
+
"max_attempts": self.max_attempts,
|
85
|
+
"max_retry_duration": {"seconds": self.max_retry_duration_seconds},
|
86
|
+
"max_backoff": {"seconds": self.max_backoff_seconds},
|
87
|
+
"min_backoff": {"seconds": self.min_backoff_seconds},
|
88
|
+
"max_doublings": self.max_doublings,
|
89
|
+
}
|
90
|
+
|
91
|
+
|
92
|
+
class RateLimits(BaseModel):
|
93
|
+
"""Configuration for queue rate limits."""
|
94
|
+
|
95
|
+
max_dispatches_per_second: float = Field(
|
96
|
+
10.0,
|
97
|
+
description=(
|
98
|
+
"Maximum number of tasks that can be dispatched per second."
|
99
|
+
"If this is too high, you can overshoot the rate limit as the "
|
100
|
+
"query to running jobs is not perfectly synchronized."
|
101
|
+
),
|
102
|
+
)
|
103
|
+
max_concurrent_dispatches: int = Field(
|
104
|
+
100,
|
105
|
+
description=(
|
106
|
+
"Maximum number of concurrent tasks that can be dispatched."
|
107
|
+
" This represents how many jobs are actively trying to get "
|
108
|
+
"a spot as a running job at the same time. The rest will "
|
109
|
+
"simply be waiting in the queue. The higher this is, the "
|
110
|
+
" higher gatekeeping server load will be."
|
111
|
+
),
|
112
|
+
)
|
113
|
+
|
114
|
+
MAX_RATIO_FROM_QUEUE_SIZE: ClassVar[float] = 0.1
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def from_max_queue_size(cls, max_queue_size: int) -> "RateLimits":
|
118
|
+
"""Create rate limits from a max_queue_size to avoid overwhelming the gatekeeping server."""
|
119
|
+
return cls(
|
120
|
+
max_concurrent_dispatches=max(
|
121
|
+
1, int(max_queue_size * cls.MAX_RATIO_FROM_QUEUE_SIZE)
|
122
|
+
)
|
123
|
+
)
|
124
|
+
|
125
|
+
def to_client_dict(self) -> dict[str, Any]:
|
126
|
+
"""Convert rate limits to GCP Cloud Tasks client format."""
|
127
|
+
return {
|
128
|
+
"max_dispatches_per_second": self.max_dispatches_per_second,
|
129
|
+
"max_concurrent_dispatches": self.max_concurrent_dispatches,
|
130
|
+
}
|
131
|
+
|
132
|
+
|
133
|
+
class TaskQueue(BaseModel):
|
134
|
+
"""Configuration for a single Task Queue."""
|
135
|
+
|
136
|
+
name: str = Field(..., description="Name of the queue")
|
137
|
+
retry_config: RetryConfig = Field(
|
138
|
+
default_factory=RetryConfig, description="Configuration for task retries"
|
139
|
+
)
|
140
|
+
rate_limits: RateLimits | None = Field(
|
141
|
+
default=None, description="Optional rate limiting configuration"
|
142
|
+
)
|
143
|
+
priority_max_running_fraction: float = Field(
|
144
|
+
default_factory=PriorityQueueTypes.NORMAL.rate_percentage,
|
145
|
+
description=(
|
146
|
+
"Maximum fraction of the total limit that this queue can use, proxy for priority."
|
147
|
+
"Higher limits will essentially be preferred because they can run when "
|
148
|
+
"lower priority queues cannot."
|
149
|
+
),
|
150
|
+
ge=0.0,
|
151
|
+
le=1.0,
|
152
|
+
)
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def from_priority_queue_type_and_max_running_jobs(
|
156
|
+
cls, name: str, queue_type: PriorityQueueTypes, max_running_jobs: int
|
157
|
+
) -> "TaskQueue":
|
158
|
+
"""Create a TaskQueue from a PriorityQueueType."""
|
159
|
+
return cls(
|
160
|
+
name=f"{name}-{queue_type.value}",
|
161
|
+
priority_max_running_fraction=queue_type.rate_percentage(),
|
162
|
+
rate_limits=RateLimits.from_max_queue_size(
|
163
|
+
int(queue_type.rate_percentage() * max_running_jobs)
|
164
|
+
),
|
165
|
+
)
|
166
|
+
|
167
|
+
def to_client_dict(self, project_id: str, location: str) -> dict[str, Any]:
|
168
|
+
"""Convert the queue configuration to GCP Cloud Tasks client format."""
|
169
|
+
parent = f"projects/{project_id}/locations/{location}"
|
170
|
+
queue_path = f"{parent}/queues/{self.name}"
|
171
|
+
|
172
|
+
result = {
|
173
|
+
"name": queue_path,
|
174
|
+
"retry_config": self.retry_config.to_client_dict(),
|
175
|
+
}
|
176
|
+
|
177
|
+
if self.rate_limits:
|
178
|
+
result["rate_limits"] = self.rate_limits.to_client_dict()
|
179
|
+
|
180
|
+
return result
|
181
|
+
|
182
|
+
|
183
|
+
class TaskQueuesConfig(BaseModel):
|
184
|
+
"""Configuration for multiple Task Queues."""
|
185
|
+
|
186
|
+
name: str = Field(..., description="Base name for the queue(s).")
|
187
|
+
max_running_jobs: int = Field(
|
188
|
+
default=30, # low default for now
|
189
|
+
description=(
|
190
|
+
"Maximum concurrency for this crow job, across all queues."
|
191
|
+
" Note: Global max across all crow jobs is 1,000, the backend will always enforce"
|
192
|
+
" the global limit first. This limit should be set keeping in mind any dependent limits"
|
193
|
+
" like LLM throughput."
|
194
|
+
),
|
195
|
+
)
|
196
|
+
queues: list[TaskQueue] | None = Field(
|
197
|
+
default=None,
|
198
|
+
description="List of task queues to be created/managed, will be built automatically if None.",
|
199
|
+
)
|
200
|
+
|
201
|
+
@model_validator(mode="after")
|
202
|
+
def add_priority_queues(self):
|
203
|
+
if self.queues is None:
|
204
|
+
self.queues = [
|
205
|
+
TaskQueue.from_priority_queue_type_and_max_running_jobs(
|
206
|
+
name=self.name,
|
207
|
+
queue_type=queue_type,
|
208
|
+
max_running_jobs=self.max_running_jobs,
|
209
|
+
)
|
210
|
+
for queue_type in PriorityQueueTypes
|
211
|
+
]
|
212
|
+
return self
|
213
|
+
|
214
|
+
def get_queue(self, priority_type: PriorityQueueTypes) -> TaskQueue | None:
|
215
|
+
"""Get a queue by its priority type."""
|
216
|
+
if not self.queues:
|
217
|
+
return None
|
218
|
+
|
219
|
+
for queue in self.queues:
|
220
|
+
if queue.name.endswith(f"-{priority_type.value}"):
|
221
|
+
return queue
|
222
|
+
|
223
|
+
return None
|
224
|
+
|
225
|
+
|
226
|
+
class Stage(StrEnum):
|
227
|
+
DEV = "https://dev.api.platform.futurehouse.org"
|
228
|
+
PROD = "https://api.platform.futurehouse.org"
|
229
|
+
LOCAL = "http://localhost:8080"
|
230
|
+
LOCAL_DOCKER = "http://host.docker.internal:8080"
|
231
|
+
|
232
|
+
@classmethod
|
233
|
+
def from_string(cls, stage: str) -> "Stage":
|
234
|
+
"""Convert a case-insensitive string to Stage enum."""
|
235
|
+
try:
|
236
|
+
return cls[stage.upper()]
|
237
|
+
except KeyError as e:
|
238
|
+
raise ValueError(
|
239
|
+
f"Invalid stage: {stage}. Must be one of: {', '.join(cls.__members__)}",
|
240
|
+
) from e
|
241
|
+
|
242
|
+
|
243
|
+
class Step(StrEnum):
|
244
|
+
BEFORE_TRANSITION = Callback.before_transition.__name__
|
245
|
+
AFTER_AGENT_INIT_STATE = Callback.after_agent_init_state.__name__
|
246
|
+
AFTER_AGENT_GET_ASV = Callback.after_agent_get_asv.__name__
|
247
|
+
AFTER_ENV_RESET = Callback.after_env_reset.__name__
|
248
|
+
AFTER_ENV_STEP = Callback.after_env_step.__name__
|
249
|
+
AFTER_TRANSITION = Callback.after_transition.__name__
|
250
|
+
|
251
|
+
|
252
|
+
class FramePathContentType(StrEnum):
|
253
|
+
TEXT = auto()
|
254
|
+
IMAGE = auto()
|
255
|
+
MARKDOWN = auto()
|
256
|
+
JSON = auto()
|
257
|
+
PDF_LINK = auto()
|
258
|
+
PDB = auto()
|
259
|
+
NOTEBOOK = auto()
|
260
|
+
PQA = auto()
|
261
|
+
|
262
|
+
|
263
|
+
class FramePath(BaseModel):
|
264
|
+
path: str = Field(
|
265
|
+
description="List of JSON path strings (e.g. 'input.data.frame') indicating where to find important frame data. None implies all data is important and the UI will render the full environment frame as is.",
|
266
|
+
)
|
267
|
+
type: FramePathContentType = Field(
|
268
|
+
default=FramePathContentType.JSON,
|
269
|
+
description="Content type of the data at this path",
|
270
|
+
)
|
271
|
+
is_iterable: bool = Field(
|
272
|
+
default=False,
|
273
|
+
description="Content of the JSON path will be iterable, this key tell us if the rendering component should create multiple components for a single key",
|
274
|
+
)
|
275
|
+
|
276
|
+
|
277
|
+
class DockerContainerConfiguration(BaseModel):
|
278
|
+
cpu: str = Field(description="CPU allotment for the container")
|
279
|
+
memory: str = Field(description="Memory allotment for the container")
|
280
|
+
|
281
|
+
MINIMUM_MEMORY: ClassVar[int] = 2
|
282
|
+
MAXIMUM_MEMORY: ClassVar[int] = 32
|
283
|
+
|
284
|
+
@field_validator("cpu")
|
285
|
+
@classmethod
|
286
|
+
# The python library only supports 1, 2, 4, 8 CPUs
|
287
|
+
# https://cloud.google.com/run/docs/reference/rpc/google.cloud.run.v2#resourcerequirements
|
288
|
+
def validate_cpu(cls, v: str) -> str:
|
289
|
+
valid_cpus = {"1", "2", "4", "8"}
|
290
|
+
if v not in valid_cpus:
|
291
|
+
raise ValueError("CPU must be one of: 1, 2, 4, or 8")
|
292
|
+
return v
|
293
|
+
|
294
|
+
@field_validator("memory")
|
295
|
+
@classmethod
|
296
|
+
def validate_memory(cls, v: str) -> str:
|
297
|
+
# https://regex101.com/r/4kWjKw/1
|
298
|
+
match = re.match(r"^(\d+)Gi$", v)
|
299
|
+
|
300
|
+
if not match:
|
301
|
+
raise ValueError("Memory must be in Gi format (e.g., '2Gi')")
|
302
|
+
|
303
|
+
value = int(match.group(1))
|
304
|
+
|
305
|
+
# GCP Cloud Run has min 512Mi and max 32Gi (32768Mi)
|
306
|
+
# https://cloud.google.com/run/docs/configuring/services/memory-limits
|
307
|
+
# due to the above mentioned restriction in the python client, we must
|
308
|
+
# stay between 2Gi and 32Gi
|
309
|
+
if value < cls.MINIMUM_MEMORY:
|
310
|
+
raise ValueError("Memory must be at least 2Gi")
|
311
|
+
if value > cls.MAXIMUM_MEMORY:
|
312
|
+
raise ValueError("Memory must not exceed 32Gi")
|
313
|
+
|
314
|
+
return v
|
315
|
+
|
316
|
+
@model_validator(mode="after")
|
317
|
+
def validate_cpu_memory_ratio(self) -> Self:
|
318
|
+
cpu = int(self.cpu)
|
319
|
+
|
320
|
+
match = re.match(r"^(\d+)Gi$", self.memory)
|
321
|
+
if match is None:
|
322
|
+
raise ValueError("Memory must be in Gi format (e.g., '2Gi')")
|
323
|
+
|
324
|
+
memory_gi = int(match.group(1))
|
325
|
+
memory_mb = memory_gi * 1024
|
326
|
+
|
327
|
+
min_cpu_requirements = {
|
328
|
+
2048: 1, # 2Gi requires 1 CPU
|
329
|
+
4096: 2, # 4Gi requires 2 CPU
|
330
|
+
8192: 4, # 8Gi requires 4 CPU
|
331
|
+
24576: 8, # 24Gi requires 8 CPU
|
332
|
+
}
|
333
|
+
|
334
|
+
for mem_threshold, cpu_required in min_cpu_requirements.items():
|
335
|
+
if memory_mb <= mem_threshold:
|
336
|
+
if cpu < cpu_required:
|
337
|
+
raise ValueError(
|
338
|
+
f"For {self.memory} of memory, minimum required CPU is {cpu_required} CPU. Got {cpu} CPU",
|
339
|
+
)
|
340
|
+
break
|
341
|
+
|
342
|
+
return self
|
343
|
+
|
344
|
+
|
345
|
+
class CrowDeploymentConfig(BaseModel):
|
346
|
+
model_config = ConfigDict(
|
347
|
+
extra="forbid",
|
348
|
+
arbitrary_types_allowed=True, # Allows for agent: Agent | str
|
349
|
+
)
|
350
|
+
|
351
|
+
requirements_path: str | os.PathLike | None = Field(
|
352
|
+
default=None,
|
353
|
+
description="The complete path including filename to the requirements.txt file or pyproject.toml file. If not provided explicitly, it will be inferred from the path parameter.",
|
354
|
+
)
|
355
|
+
|
356
|
+
path: str | os.PathLike | None = Field(
|
357
|
+
default=None,
|
358
|
+
description="The path to your python module. Can be either a string path or Path object. "
|
359
|
+
"This path should be the root directory of your module. "
|
360
|
+
"This path either must include a pyproject.toml with UV tooling, or a requirements.txt for dependency resolution. "
|
361
|
+
"Can be None if we are deploying a functional environment (through the functional_environment parameter).",
|
362
|
+
)
|
363
|
+
|
364
|
+
name: str | None = Field(
|
365
|
+
default=None,
|
366
|
+
description="The name of the crow job. If None, the crow job will be "
|
367
|
+
"named using the included python module or functional environment name.",
|
368
|
+
)
|
369
|
+
|
370
|
+
environment: str = Field(
|
371
|
+
description="Your environment path, should be a module reference if we pass an environment. "
|
372
|
+
"Can be an arbitrary name if we are deploying a functional environment (through the functional_environment parameter). "
|
373
|
+
"example: dummy_env.env.DummyEnv",
|
374
|
+
)
|
375
|
+
|
376
|
+
functional_environment: EnvironmentBuilder | None = Field(
|
377
|
+
default=None,
|
378
|
+
description="An object of type EnvironmentBuilder used to construct an environment. "
|
379
|
+
"Can be None if we are deploying a non functional environment.",
|
380
|
+
)
|
381
|
+
|
382
|
+
requirements: list[str] | None = Field(
|
383
|
+
default=None,
|
384
|
+
description="A list of dependencies required for the deployment, similar to the Python requirements.txt file. "
|
385
|
+
"Each entry in the list specifies a package or module in the format used by pip (e.g., 'package-name==1.0.0'). "
|
386
|
+
"Can be None if we are deploying a non functional environment (functional_environment parameter is None)",
|
387
|
+
)
|
388
|
+
|
389
|
+
environment_variables: dict[str, str] | None = Field(
|
390
|
+
default=None,
|
391
|
+
description="Any key value pair of environment variables your environment needs to function.",
|
392
|
+
)
|
393
|
+
|
394
|
+
container_config: DockerContainerConfiguration | None = Field(
|
395
|
+
default=None,
|
396
|
+
description="The configuration for the cloud run container.",
|
397
|
+
)
|
398
|
+
|
399
|
+
python_version: PythonVersion = Field(
|
400
|
+
default=PythonVersion.V3_12,
|
401
|
+
description="The python version your docker image should build with.",
|
402
|
+
)
|
403
|
+
|
404
|
+
agent: Agent | str = Field(
|
405
|
+
default="ldp.agent.SimpleAgent",
|
406
|
+
description="Your desired agent path, should be a module reference and a fully qualified name. "
|
407
|
+
"example: ldp.agent.SimpleAgent",
|
408
|
+
)
|
409
|
+
|
410
|
+
requires_aviary_internal: bool = Field(
|
411
|
+
default=False,
|
412
|
+
description="Indicates your project requires aviary-internal to function. "
|
413
|
+
"This is only necessary for envs within aviary-internal.",
|
414
|
+
)
|
415
|
+
|
416
|
+
timeout: int | None = Field(
|
417
|
+
default=600,
|
418
|
+
description="The amount of time in seconds your crow will run on a task before it terminates.",
|
419
|
+
ge=MIN_CROW_JOB_RUN_TIMEOUT,
|
420
|
+
le=MAX_CROW_JOB_RUN_TIMEOUT,
|
421
|
+
)
|
422
|
+
|
423
|
+
force: bool = Field(
|
424
|
+
default=False,
|
425
|
+
description="If true, immediately overwrite any existing job with the same name.",
|
426
|
+
)
|
427
|
+
|
428
|
+
storage_location: str = Field(
|
429
|
+
default="storage",
|
430
|
+
description="The location the container will use to mount a locally accessible GCS folder as a volume. "
|
431
|
+
"This location can be used to store and fetch files safely without GCS apis or direct access.",
|
432
|
+
)
|
433
|
+
|
434
|
+
frame_paths: list[FramePath] | None = Field(
|
435
|
+
default=None,
|
436
|
+
description="List of FramePath which indicates where to find important frame data, and how to render it.",
|
437
|
+
)
|
438
|
+
|
439
|
+
markdown_template_path: str | os.PathLike | None = Field(
|
440
|
+
default=None,
|
441
|
+
description="The path to the markdown template file. This file will be dynamically built within the environment frame section of the UI. "
|
442
|
+
"The keys used in the markdown file follow the same requirement as FramePath.path. None implies no markdown template is present and the UI "
|
443
|
+
"will render the environment frame as is.",
|
444
|
+
)
|
445
|
+
|
446
|
+
task_queues_config: TaskQueuesConfig | None = Field(
|
447
|
+
default=None,
|
448
|
+
description="The configuration for the task queue(s) that will be created for this deployment.",
|
449
|
+
)
|
450
|
+
|
451
|
+
@field_validator("markdown_template_path")
|
452
|
+
@classmethod
|
453
|
+
def validate_markdown_path(
|
454
|
+
cls, v: str | os.PathLike | None
|
455
|
+
) -> str | os.PathLike | None:
|
456
|
+
if v is not None:
|
457
|
+
path = Path(v)
|
458
|
+
if path.suffix.lower() not in {".md", ".markdown"}:
|
459
|
+
raise ValueError(
|
460
|
+
f"Markdown template must be a .md or .markdown extension: {path}"
|
461
|
+
)
|
462
|
+
return v
|
463
|
+
|
464
|
+
task_description: str | None = Field(
|
465
|
+
default=None,
|
466
|
+
description="Override for the task description, if not included it will be pulled from your "
|
467
|
+
"environment `from_task` docstring. Necessary if you are deploying using an Environment class"
|
468
|
+
" as a dependency.",
|
469
|
+
)
|
470
|
+
|
471
|
+
@field_validator("path")
|
472
|
+
@classmethod
|
473
|
+
def validate_module_path(cls, value: str | os.PathLike) -> str | os.PathLike:
|
474
|
+
path = Path(value)
|
475
|
+
if not path.exists():
|
476
|
+
raise ValueError(f"Module path {path} does not exist")
|
477
|
+
if not path.is_dir():
|
478
|
+
raise ValueError(f"Module path {path} is not a directory")
|
479
|
+
return value
|
480
|
+
|
481
|
+
@field_validator("requirements_path")
|
482
|
+
@classmethod
|
483
|
+
def validate_requirements_path(
|
484
|
+
cls, value: str | os.PathLike | None
|
485
|
+
) -> str | os.PathLike | None:
|
486
|
+
if value is None:
|
487
|
+
return value
|
488
|
+
|
489
|
+
path = Path(value)
|
490
|
+
if not path.exists():
|
491
|
+
raise ValueError(f"Requirements path {path} does not exist")
|
492
|
+
if not path.is_file():
|
493
|
+
raise ValueError(f"Requirements path {path} is not a file")
|
494
|
+
if path.suffix not in {".txt", ".toml"}:
|
495
|
+
raise ValueError(f"Requirements path {path} must be a .txt or .toml file")
|
496
|
+
return value
|
497
|
+
|
498
|
+
@model_validator(mode="after")
|
499
|
+
def validate_path_and_requirements(self) -> Self:
|
500
|
+
if self.path is None:
|
501
|
+
return self
|
502
|
+
|
503
|
+
path = Path(self.path)
|
504
|
+
requirements_path = (
|
505
|
+
Path(self.requirements_path) if self.requirements_path else None
|
506
|
+
)
|
507
|
+
|
508
|
+
if not (
|
509
|
+
(path / "pyproject.toml").exists()
|
510
|
+
or (path / "requirements.txt").exists()
|
511
|
+
or (requirements_path and requirements_path.exists())
|
512
|
+
):
|
513
|
+
raise ValueError(
|
514
|
+
f"Module path {path} must contain either pyproject.toml or requirements.txt, "
|
515
|
+
f"or a valid requirements_path must be provided"
|
516
|
+
)
|
517
|
+
|
518
|
+
if not self.task_queues_config:
|
519
|
+
self.task_queues_config = TaskQueuesConfig(name=self.job_name)
|
520
|
+
|
521
|
+
return self
|
522
|
+
|
523
|
+
@field_validator("environment")
|
524
|
+
@classmethod
|
525
|
+
def validate_environment_path(cls, value: str) -> str:
|
526
|
+
if not value or not value.strip():
|
527
|
+
raise ValueError("Environment path cannot be empty")
|
528
|
+
if not all(part.isidentifier() for part in value.split(".")):
|
529
|
+
raise ValueError(f"Invalid environment path format: {value}")
|
530
|
+
return value
|
531
|
+
|
532
|
+
@field_validator("agent")
|
533
|
+
@classmethod
|
534
|
+
def validate_agent_path(cls, value: Agent | str) -> Agent | str:
|
535
|
+
if isinstance(value, Agent):
|
536
|
+
return value
|
537
|
+
|
538
|
+
if not value or not value.strip():
|
539
|
+
raise ValueError("Agent path cannot be empty")
|
540
|
+
if not all(part.isidentifier() for part in value.split(".")):
|
541
|
+
raise ValueError(f"Invalid agent path format: {value}")
|
542
|
+
return value
|
543
|
+
|
544
|
+
@property
|
545
|
+
def module_name(self) -> str:
|
546
|
+
if not self.path and not self.functional_environment:
|
547
|
+
raise ValueError(
|
548
|
+
"No module specified, either a path or a functional environment must be provided."
|
549
|
+
)
|
550
|
+
return (
|
551
|
+
Path(self.path).name
|
552
|
+
if self.path
|
553
|
+
else cast(EnvironmentBuilder, self.functional_environment).__name__ # type: ignore[attr-defined]
|
554
|
+
)
|
555
|
+
|
556
|
+
@property
|
557
|
+
def job_name(self) -> str:
|
558
|
+
"""Name to be used for the crow job deployment."""
|
559
|
+
return self.name or self.module_name
|
560
|
+
|
561
|
+
|
562
|
+
class RuntimeConfig(BaseModel):
|
563
|
+
"""Runtime configuration for crow job execution.
|
564
|
+
|
565
|
+
This advanced configuration is only available for supported crows.
|
566
|
+
"""
|
567
|
+
|
568
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
569
|
+
|
570
|
+
upload_id: str | None = Field(
|
571
|
+
default=None,
|
572
|
+
description="GCS folder id for uploaded files associated with this job",
|
573
|
+
)
|
574
|
+
timeout: int | None = Field(
|
575
|
+
default=None, description="Maximum execution time in seconds"
|
576
|
+
)
|
577
|
+
max_steps: int | None = Field(
|
578
|
+
default=None, description="Maximum number of steps to execute"
|
579
|
+
)
|
580
|
+
agent: Agent | AgentConfig | None = Field(
|
581
|
+
default=None,
|
582
|
+
description=(
|
583
|
+
"Agent configuration to use for this job. If None, it will default to the "
|
584
|
+
"agent selected during Crow deployment in the CrowDeploymentConfig object."
|
585
|
+
),
|
586
|
+
)
|
587
|
+
continued_job_id: UUID | None = Field(
|
588
|
+
default=None,
|
589
|
+
description="Optional job identifier for a continued job",
|
590
|
+
)
|
591
|
+
|
592
|
+
@field_validator("agent")
|
593
|
+
@classmethod
|
594
|
+
def validate_agent(
|
595
|
+
cls, value: str | AgentConfig | None
|
596
|
+
) -> str | AgentConfig | None:
|
597
|
+
if value is None:
|
598
|
+
return None
|
599
|
+
|
600
|
+
if isinstance(value, AgentConfig):
|
601
|
+
return value
|
602
|
+
|
603
|
+
if not value or not value.strip():
|
604
|
+
raise ValueError("Agent path cannot be empty")
|
605
|
+
if not all(part.isidentifier() for part in value.split(".")):
|
606
|
+
raise ValueError(f"Invalid agent path format: {value}")
|
607
|
+
return value
|
608
|
+
|
609
|
+
|
610
|
+
class JobRequest(BaseModel):
|
611
|
+
job_id: UUID | None = Field(
|
612
|
+
default=None,
|
613
|
+
description="Optional job identifier",
|
614
|
+
alias="id",
|
615
|
+
)
|
616
|
+
name: "str | JobNames" = Field(
|
617
|
+
description="Name of the crow to execute eg. paperqa-crow"
|
618
|
+
)
|
619
|
+
query: str = Field(description="Query or task to be executed by the crow")
|
620
|
+
runtime_config: RuntimeConfig | None = Field(
|
621
|
+
default=None, description="All optional runtime parameters for the job"
|
622
|
+
)
|
@@ -0,0 +1,72 @@
|
|
1
|
+
from typing import Any, Generic, TypeAlias, TypeVar
|
2
|
+
|
3
|
+
from aviary.message import Message
|
4
|
+
from aviary.tools.base import Tool
|
5
|
+
from ldp.data_structures import Transition
|
6
|
+
from ldp.graph.ops import OpResult
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
8
|
+
|
9
|
+
T = TypeVar("T")
|
10
|
+
|
11
|
+
|
12
|
+
# TODO: revisit this
|
13
|
+
# unsure what crow states will return
|
14
|
+
# need to revisit after we get more crows deployed
|
15
|
+
class BaseState(BaseModel):
|
16
|
+
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
17
|
+
|
18
|
+
|
19
|
+
class BeforeTransitionState(BaseState):
|
20
|
+
current_state: Any = Field()
|
21
|
+
observations: list[Message] = Field()
|
22
|
+
|
23
|
+
|
24
|
+
class InitialState(BaseState):
|
25
|
+
initial_state: Any = Field()
|
26
|
+
|
27
|
+
|
28
|
+
class ASVState(BaseState, Generic[T]):
|
29
|
+
action: OpResult[T] = Field()
|
30
|
+
next_state: Any = Field()
|
31
|
+
value: float = Field()
|
32
|
+
|
33
|
+
@field_serializer("action")
|
34
|
+
def serialize_action(self, action: OpResult[T]) -> dict:
|
35
|
+
return action.to_dict()
|
36
|
+
|
37
|
+
@field_serializer("next_state")
|
38
|
+
def serialize_next_state(self, state: Any) -> str:
|
39
|
+
return str(state)
|
40
|
+
|
41
|
+
|
42
|
+
class EnvResetState(BaseState):
|
43
|
+
observations: list[Message] = Field()
|
44
|
+
tools: list[Tool] = Field()
|
45
|
+
|
46
|
+
|
47
|
+
class EnvStepState(BaseState):
|
48
|
+
observations: list[Message] = Field()
|
49
|
+
reward: float = Field()
|
50
|
+
done: bool = Field()
|
51
|
+
trunc: bool = Field()
|
52
|
+
|
53
|
+
|
54
|
+
class TransitionState(BaseState):
|
55
|
+
transition: Transition = Field()
|
56
|
+
|
57
|
+
@field_serializer("transition")
|
58
|
+
def serialize_transition(self, transition: Transition) -> dict:
|
59
|
+
transition_data = transition.model_dump()
|
60
|
+
return transition_data | {
|
61
|
+
"action": transition.action.to_dict() if transition.action else None,
|
62
|
+
}
|
63
|
+
|
64
|
+
|
65
|
+
StateType: TypeAlias = (
|
66
|
+
BeforeTransitionState
|
67
|
+
| InitialState
|
68
|
+
| ASVState
|
69
|
+
| EnvResetState
|
70
|
+
| EnvStepState
|
71
|
+
| TransitionState
|
72
|
+
)
|