futurehouse-client 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,622 @@
1
+ import os
2
+ import re
3
+ from enum import StrEnum, auto
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any, ClassVar, Self, cast
6
+ from uuid import UUID
7
+
8
+ from aviary.functional import EnvironmentBuilder
9
+ from ldp.agent import Agent, AgentConfig
10
+ from ldp.alg.callbacks import Callback
11
+ from pydantic import (
12
+ BaseModel,
13
+ ConfigDict,
14
+ Field,
15
+ field_validator,
16
+ model_validator,
17
+ )
18
+
19
+ if TYPE_CHECKING:
20
+ from futurehouse_client.clients import JobNames
21
+
22
+
23
+ MAX_CROW_JOB_RUN_TIMEOUT = 60 * 60 * 24 # 24 hours in sec
24
+ MIN_CROW_JOB_RUN_TIMEOUT = 0 # sec
25
+
26
+
27
+ class PythonVersion(StrEnum):
28
+ V3_11 = "3.11"
29
+ V3_12 = "3.12"
30
+
31
+
32
+ class AuthType(StrEnum):
33
+ API_KEY = auto()
34
+ JWT = auto()
35
+
36
+
37
+ class APIKeyPayload(BaseModel):
38
+ api_key: str = Field(description="A user API key to authenticate with the server.")
39
+
40
+
41
+ class PriorityQueueTypes(StrEnum):
42
+ LOW = auto()
43
+ NORMAL = auto()
44
+ HIGH = auto()
45
+ ULTRA = auto()
46
+
47
+ def rate_percentage(self) -> float:
48
+ if self == self.LOW:
49
+ return 0.1
50
+ if self == self.NORMAL:
51
+ return 0.5
52
+ if self == self.HIGH:
53
+ return 0.75
54
+ if self == self.ULTRA:
55
+ return 1.0
56
+ raise NotImplementedError(f"Unknown priority queue type: {self}")
57
+
58
+
59
+ class RetryConfig(BaseModel):
60
+ """Configuration for task retry settings."""
61
+
62
+ max_attempts: int = Field(
63
+ -1, description="Maximum number of retry attempts. -1 for infinite retries."
64
+ )
65
+ max_retry_duration_seconds: int = Field(
66
+ 604800, # 7 days in seconds
67
+ description="Maximum time a task can be retrying for before giving up (in seconds).",
68
+ )
69
+ max_backoff_seconds: int = Field(
70
+ 60, # means the rate in a full-queue will be each entry trying once per minute
71
+ description="Maximum time to wait between retries (in seconds).",
72
+ )
73
+ min_backoff_seconds: int = Field(
74
+ 1, description="Minimum time to wait between retries (in seconds)."
75
+ )
76
+ max_doublings: int = Field(
77
+ 7,
78
+ description="Maximum number of times the retry interval can double before becoming constant.",
79
+ )
80
+
81
+ def to_client_dict(self) -> dict[str, Any]:
82
+ """Convert retry config to GCP Cloud Tasks client format."""
83
+ return {
84
+ "max_attempts": self.max_attempts,
85
+ "max_retry_duration": {"seconds": self.max_retry_duration_seconds},
86
+ "max_backoff": {"seconds": self.max_backoff_seconds},
87
+ "min_backoff": {"seconds": self.min_backoff_seconds},
88
+ "max_doublings": self.max_doublings,
89
+ }
90
+
91
+
92
+ class RateLimits(BaseModel):
93
+ """Configuration for queue rate limits."""
94
+
95
+ max_dispatches_per_second: float = Field(
96
+ 10.0,
97
+ description=(
98
+ "Maximum number of tasks that can be dispatched per second."
99
+ "If this is too high, you can overshoot the rate limit as the "
100
+ "query to running jobs is not perfectly synchronized."
101
+ ),
102
+ )
103
+ max_concurrent_dispatches: int = Field(
104
+ 100,
105
+ description=(
106
+ "Maximum number of concurrent tasks that can be dispatched."
107
+ " This represents how many jobs are actively trying to get "
108
+ "a spot as a running job at the same time. The rest will "
109
+ "simply be waiting in the queue. The higher this is, the "
110
+ " higher gatekeeping server load will be."
111
+ ),
112
+ )
113
+
114
+ MAX_RATIO_FROM_QUEUE_SIZE: ClassVar[float] = 0.1
115
+
116
+ @classmethod
117
+ def from_max_queue_size(cls, max_queue_size: int) -> "RateLimits":
118
+ """Create rate limits from a max_queue_size to avoid overwhelming the gatekeeping server."""
119
+ return cls(
120
+ max_concurrent_dispatches=max(
121
+ 1, int(max_queue_size * cls.MAX_RATIO_FROM_QUEUE_SIZE)
122
+ )
123
+ )
124
+
125
+ def to_client_dict(self) -> dict[str, Any]:
126
+ """Convert rate limits to GCP Cloud Tasks client format."""
127
+ return {
128
+ "max_dispatches_per_second": self.max_dispatches_per_second,
129
+ "max_concurrent_dispatches": self.max_concurrent_dispatches,
130
+ }
131
+
132
+
133
+ class TaskQueue(BaseModel):
134
+ """Configuration for a single Task Queue."""
135
+
136
+ name: str = Field(..., description="Name of the queue")
137
+ retry_config: RetryConfig = Field(
138
+ default_factory=RetryConfig, description="Configuration for task retries"
139
+ )
140
+ rate_limits: RateLimits | None = Field(
141
+ default=None, description="Optional rate limiting configuration"
142
+ )
143
+ priority_max_running_fraction: float = Field(
144
+ default_factory=PriorityQueueTypes.NORMAL.rate_percentage,
145
+ description=(
146
+ "Maximum fraction of the total limit that this queue can use, proxy for priority."
147
+ "Higher limits will essentially be preferred because they can run when "
148
+ "lower priority queues cannot."
149
+ ),
150
+ ge=0.0,
151
+ le=1.0,
152
+ )
153
+
154
+ @classmethod
155
+ def from_priority_queue_type_and_max_running_jobs(
156
+ cls, name: str, queue_type: PriorityQueueTypes, max_running_jobs: int
157
+ ) -> "TaskQueue":
158
+ """Create a TaskQueue from a PriorityQueueType."""
159
+ return cls(
160
+ name=f"{name}-{queue_type.value}",
161
+ priority_max_running_fraction=queue_type.rate_percentage(),
162
+ rate_limits=RateLimits.from_max_queue_size(
163
+ int(queue_type.rate_percentage() * max_running_jobs)
164
+ ),
165
+ )
166
+
167
+ def to_client_dict(self, project_id: str, location: str) -> dict[str, Any]:
168
+ """Convert the queue configuration to GCP Cloud Tasks client format."""
169
+ parent = f"projects/{project_id}/locations/{location}"
170
+ queue_path = f"{parent}/queues/{self.name}"
171
+
172
+ result = {
173
+ "name": queue_path,
174
+ "retry_config": self.retry_config.to_client_dict(),
175
+ }
176
+
177
+ if self.rate_limits:
178
+ result["rate_limits"] = self.rate_limits.to_client_dict()
179
+
180
+ return result
181
+
182
+
183
+ class TaskQueuesConfig(BaseModel):
184
+ """Configuration for multiple Task Queues."""
185
+
186
+ name: str = Field(..., description="Base name for the queue(s).")
187
+ max_running_jobs: int = Field(
188
+ default=30, # low default for now
189
+ description=(
190
+ "Maximum concurrency for this crow job, across all queues."
191
+ " Note: Global max across all crow jobs is 1,000, the backend will always enforce"
192
+ " the global limit first. This limit should be set keeping in mind any dependent limits"
193
+ " like LLM throughput."
194
+ ),
195
+ )
196
+ queues: list[TaskQueue] | None = Field(
197
+ default=None,
198
+ description="List of task queues to be created/managed, will be built automatically if None.",
199
+ )
200
+
201
+ @model_validator(mode="after")
202
+ def add_priority_queues(self):
203
+ if self.queues is None:
204
+ self.queues = [
205
+ TaskQueue.from_priority_queue_type_and_max_running_jobs(
206
+ name=self.name,
207
+ queue_type=queue_type,
208
+ max_running_jobs=self.max_running_jobs,
209
+ )
210
+ for queue_type in PriorityQueueTypes
211
+ ]
212
+ return self
213
+
214
+ def get_queue(self, priority_type: PriorityQueueTypes) -> TaskQueue | None:
215
+ """Get a queue by its priority type."""
216
+ if not self.queues:
217
+ return None
218
+
219
+ for queue in self.queues:
220
+ if queue.name.endswith(f"-{priority_type.value}"):
221
+ return queue
222
+
223
+ return None
224
+
225
+
226
+ class Stage(StrEnum):
227
+ DEV = "https://dev.api.platform.futurehouse.org"
228
+ PROD = "https://api.platform.futurehouse.org"
229
+ LOCAL = "http://localhost:8080"
230
+ LOCAL_DOCKER = "http://host.docker.internal:8080"
231
+
232
+ @classmethod
233
+ def from_string(cls, stage: str) -> "Stage":
234
+ """Convert a case-insensitive string to Stage enum."""
235
+ try:
236
+ return cls[stage.upper()]
237
+ except KeyError as e:
238
+ raise ValueError(
239
+ f"Invalid stage: {stage}. Must be one of: {', '.join(cls.__members__)}",
240
+ ) from e
241
+
242
+
243
+ class Step(StrEnum):
244
+ BEFORE_TRANSITION = Callback.before_transition.__name__
245
+ AFTER_AGENT_INIT_STATE = Callback.after_agent_init_state.__name__
246
+ AFTER_AGENT_GET_ASV = Callback.after_agent_get_asv.__name__
247
+ AFTER_ENV_RESET = Callback.after_env_reset.__name__
248
+ AFTER_ENV_STEP = Callback.after_env_step.__name__
249
+ AFTER_TRANSITION = Callback.after_transition.__name__
250
+
251
+
252
+ class FramePathContentType(StrEnum):
253
+ TEXT = auto()
254
+ IMAGE = auto()
255
+ MARKDOWN = auto()
256
+ JSON = auto()
257
+ PDF_LINK = auto()
258
+ PDB = auto()
259
+ NOTEBOOK = auto()
260
+ PQA = auto()
261
+
262
+
263
+ class FramePath(BaseModel):
264
+ path: str = Field(
265
+ description="List of JSON path strings (e.g. 'input.data.frame') indicating where to find important frame data. None implies all data is important and the UI will render the full environment frame as is.",
266
+ )
267
+ type: FramePathContentType = Field(
268
+ default=FramePathContentType.JSON,
269
+ description="Content type of the data at this path",
270
+ )
271
+ is_iterable: bool = Field(
272
+ default=False,
273
+ description="Content of the JSON path will be iterable, this key tell us if the rendering component should create multiple components for a single key",
274
+ )
275
+
276
+
277
+ class DockerContainerConfiguration(BaseModel):
278
+ cpu: str = Field(description="CPU allotment for the container")
279
+ memory: str = Field(description="Memory allotment for the container")
280
+
281
+ MINIMUM_MEMORY: ClassVar[int] = 2
282
+ MAXIMUM_MEMORY: ClassVar[int] = 32
283
+
284
+ @field_validator("cpu")
285
+ @classmethod
286
+ # The python library only supports 1, 2, 4, 8 CPUs
287
+ # https://cloud.google.com/run/docs/reference/rpc/google.cloud.run.v2#resourcerequirements
288
+ def validate_cpu(cls, v: str) -> str:
289
+ valid_cpus = {"1", "2", "4", "8"}
290
+ if v not in valid_cpus:
291
+ raise ValueError("CPU must be one of: 1, 2, 4, or 8")
292
+ return v
293
+
294
+ @field_validator("memory")
295
+ @classmethod
296
+ def validate_memory(cls, v: str) -> str:
297
+ # https://regex101.com/r/4kWjKw/1
298
+ match = re.match(r"^(\d+)Gi$", v)
299
+
300
+ if not match:
301
+ raise ValueError("Memory must be in Gi format (e.g., '2Gi')")
302
+
303
+ value = int(match.group(1))
304
+
305
+ # GCP Cloud Run has min 512Mi and max 32Gi (32768Mi)
306
+ # https://cloud.google.com/run/docs/configuring/services/memory-limits
307
+ # due to the above mentioned restriction in the python client, we must
308
+ # stay between 2Gi and 32Gi
309
+ if value < cls.MINIMUM_MEMORY:
310
+ raise ValueError("Memory must be at least 2Gi")
311
+ if value > cls.MAXIMUM_MEMORY:
312
+ raise ValueError("Memory must not exceed 32Gi")
313
+
314
+ return v
315
+
316
+ @model_validator(mode="after")
317
+ def validate_cpu_memory_ratio(self) -> Self:
318
+ cpu = int(self.cpu)
319
+
320
+ match = re.match(r"^(\d+)Gi$", self.memory)
321
+ if match is None:
322
+ raise ValueError("Memory must be in Gi format (e.g., '2Gi')")
323
+
324
+ memory_gi = int(match.group(1))
325
+ memory_mb = memory_gi * 1024
326
+
327
+ min_cpu_requirements = {
328
+ 2048: 1, # 2Gi requires 1 CPU
329
+ 4096: 2, # 4Gi requires 2 CPU
330
+ 8192: 4, # 8Gi requires 4 CPU
331
+ 24576: 8, # 24Gi requires 8 CPU
332
+ }
333
+
334
+ for mem_threshold, cpu_required in min_cpu_requirements.items():
335
+ if memory_mb <= mem_threshold:
336
+ if cpu < cpu_required:
337
+ raise ValueError(
338
+ f"For {self.memory} of memory, minimum required CPU is {cpu_required} CPU. Got {cpu} CPU",
339
+ )
340
+ break
341
+
342
+ return self
343
+
344
+
345
+ class CrowDeploymentConfig(BaseModel):
346
+ model_config = ConfigDict(
347
+ extra="forbid",
348
+ arbitrary_types_allowed=True, # Allows for agent: Agent | str
349
+ )
350
+
351
+ requirements_path: str | os.PathLike | None = Field(
352
+ default=None,
353
+ description="The complete path including filename to the requirements.txt file or pyproject.toml file. If not provided explicitly, it will be inferred from the path parameter.",
354
+ )
355
+
356
+ path: str | os.PathLike | None = Field(
357
+ default=None,
358
+ description="The path to your python module. Can be either a string path or Path object. "
359
+ "This path should be the root directory of your module. "
360
+ "This path either must include a pyproject.toml with UV tooling, or a requirements.txt for dependency resolution. "
361
+ "Can be None if we are deploying a functional environment (through the functional_environment parameter).",
362
+ )
363
+
364
+ name: str | None = Field(
365
+ default=None,
366
+ description="The name of the crow job. If None, the crow job will be "
367
+ "named using the included python module or functional environment name.",
368
+ )
369
+
370
+ environment: str = Field(
371
+ description="Your environment path, should be a module reference if we pass an environment. "
372
+ "Can be an arbitrary name if we are deploying a functional environment (through the functional_environment parameter). "
373
+ "example: dummy_env.env.DummyEnv",
374
+ )
375
+
376
+ functional_environment: EnvironmentBuilder | None = Field(
377
+ default=None,
378
+ description="An object of type EnvironmentBuilder used to construct an environment. "
379
+ "Can be None if we are deploying a non functional environment.",
380
+ )
381
+
382
+ requirements: list[str] | None = Field(
383
+ default=None,
384
+ description="A list of dependencies required for the deployment, similar to the Python requirements.txt file. "
385
+ "Each entry in the list specifies a package or module in the format used by pip (e.g., 'package-name==1.0.0'). "
386
+ "Can be None if we are deploying a non functional environment (functional_environment parameter is None)",
387
+ )
388
+
389
+ environment_variables: dict[str, str] | None = Field(
390
+ default=None,
391
+ description="Any key value pair of environment variables your environment needs to function.",
392
+ )
393
+
394
+ container_config: DockerContainerConfiguration | None = Field(
395
+ default=None,
396
+ description="The configuration for the cloud run container.",
397
+ )
398
+
399
+ python_version: PythonVersion = Field(
400
+ default=PythonVersion.V3_12,
401
+ description="The python version your docker image should build with.",
402
+ )
403
+
404
+ agent: Agent | str = Field(
405
+ default="ldp.agent.SimpleAgent",
406
+ description="Your desired agent path, should be a module reference and a fully qualified name. "
407
+ "example: ldp.agent.SimpleAgent",
408
+ )
409
+
410
+ requires_aviary_internal: bool = Field(
411
+ default=False,
412
+ description="Indicates your project requires aviary-internal to function. "
413
+ "This is only necessary for envs within aviary-internal.",
414
+ )
415
+
416
+ timeout: int | None = Field(
417
+ default=600,
418
+ description="The amount of time in seconds your crow will run on a task before it terminates.",
419
+ ge=MIN_CROW_JOB_RUN_TIMEOUT,
420
+ le=MAX_CROW_JOB_RUN_TIMEOUT,
421
+ )
422
+
423
+ force: bool = Field(
424
+ default=False,
425
+ description="If true, immediately overwrite any existing job with the same name.",
426
+ )
427
+
428
+ storage_location: str = Field(
429
+ default="storage",
430
+ description="The location the container will use to mount a locally accessible GCS folder as a volume. "
431
+ "This location can be used to store and fetch files safely without GCS apis or direct access.",
432
+ )
433
+
434
+ frame_paths: list[FramePath] | None = Field(
435
+ default=None,
436
+ description="List of FramePath which indicates where to find important frame data, and how to render it.",
437
+ )
438
+
439
+ markdown_template_path: str | os.PathLike | None = Field(
440
+ default=None,
441
+ description="The path to the markdown template file. This file will be dynamically built within the environment frame section of the UI. "
442
+ "The keys used in the markdown file follow the same requirement as FramePath.path. None implies no markdown template is present and the UI "
443
+ "will render the environment frame as is.",
444
+ )
445
+
446
+ task_queues_config: TaskQueuesConfig | None = Field(
447
+ default=None,
448
+ description="The configuration for the task queue(s) that will be created for this deployment.",
449
+ )
450
+
451
+ @field_validator("markdown_template_path")
452
+ @classmethod
453
+ def validate_markdown_path(
454
+ cls, v: str | os.PathLike | None
455
+ ) -> str | os.PathLike | None:
456
+ if v is not None:
457
+ path = Path(v)
458
+ if path.suffix.lower() not in {".md", ".markdown"}:
459
+ raise ValueError(
460
+ f"Markdown template must be a .md or .markdown extension: {path}"
461
+ )
462
+ return v
463
+
464
+ task_description: str | None = Field(
465
+ default=None,
466
+ description="Override for the task description, if not included it will be pulled from your "
467
+ "environment `from_task` docstring. Necessary if you are deploying using an Environment class"
468
+ " as a dependency.",
469
+ )
470
+
471
+ @field_validator("path")
472
+ @classmethod
473
+ def validate_module_path(cls, value: str | os.PathLike) -> str | os.PathLike:
474
+ path = Path(value)
475
+ if not path.exists():
476
+ raise ValueError(f"Module path {path} does not exist")
477
+ if not path.is_dir():
478
+ raise ValueError(f"Module path {path} is not a directory")
479
+ return value
480
+
481
+ @field_validator("requirements_path")
482
+ @classmethod
483
+ def validate_requirements_path(
484
+ cls, value: str | os.PathLike | None
485
+ ) -> str | os.PathLike | None:
486
+ if value is None:
487
+ return value
488
+
489
+ path = Path(value)
490
+ if not path.exists():
491
+ raise ValueError(f"Requirements path {path} does not exist")
492
+ if not path.is_file():
493
+ raise ValueError(f"Requirements path {path} is not a file")
494
+ if path.suffix not in {".txt", ".toml"}:
495
+ raise ValueError(f"Requirements path {path} must be a .txt or .toml file")
496
+ return value
497
+
498
+ @model_validator(mode="after")
499
+ def validate_path_and_requirements(self) -> Self:
500
+ if self.path is None:
501
+ return self
502
+
503
+ path = Path(self.path)
504
+ requirements_path = (
505
+ Path(self.requirements_path) if self.requirements_path else None
506
+ )
507
+
508
+ if not (
509
+ (path / "pyproject.toml").exists()
510
+ or (path / "requirements.txt").exists()
511
+ or (requirements_path and requirements_path.exists())
512
+ ):
513
+ raise ValueError(
514
+ f"Module path {path} must contain either pyproject.toml or requirements.txt, "
515
+ f"or a valid requirements_path must be provided"
516
+ )
517
+
518
+ if not self.task_queues_config:
519
+ self.task_queues_config = TaskQueuesConfig(name=self.job_name)
520
+
521
+ return self
522
+
523
+ @field_validator("environment")
524
+ @classmethod
525
+ def validate_environment_path(cls, value: str) -> str:
526
+ if not value or not value.strip():
527
+ raise ValueError("Environment path cannot be empty")
528
+ if not all(part.isidentifier() for part in value.split(".")):
529
+ raise ValueError(f"Invalid environment path format: {value}")
530
+ return value
531
+
532
+ @field_validator("agent")
533
+ @classmethod
534
+ def validate_agent_path(cls, value: Agent | str) -> Agent | str:
535
+ if isinstance(value, Agent):
536
+ return value
537
+
538
+ if not value or not value.strip():
539
+ raise ValueError("Agent path cannot be empty")
540
+ if not all(part.isidentifier() for part in value.split(".")):
541
+ raise ValueError(f"Invalid agent path format: {value}")
542
+ return value
543
+
544
+ @property
545
+ def module_name(self) -> str:
546
+ if not self.path and not self.functional_environment:
547
+ raise ValueError(
548
+ "No module specified, either a path or a functional environment must be provided."
549
+ )
550
+ return (
551
+ Path(self.path).name
552
+ if self.path
553
+ else cast(EnvironmentBuilder, self.functional_environment).__name__ # type: ignore[attr-defined]
554
+ )
555
+
556
+ @property
557
+ def job_name(self) -> str:
558
+ """Name to be used for the crow job deployment."""
559
+ return self.name or self.module_name
560
+
561
+
562
+ class RuntimeConfig(BaseModel):
563
+ """Runtime configuration for crow job execution.
564
+
565
+ This advanced configuration is only available for supported crows.
566
+ """
567
+
568
+ model_config = ConfigDict(arbitrary_types_allowed=True)
569
+
570
+ upload_id: str | None = Field(
571
+ default=None,
572
+ description="GCS folder id for uploaded files associated with this job",
573
+ )
574
+ timeout: int | None = Field(
575
+ default=None, description="Maximum execution time in seconds"
576
+ )
577
+ max_steps: int | None = Field(
578
+ default=None, description="Maximum number of steps to execute"
579
+ )
580
+ agent: Agent | AgentConfig | None = Field(
581
+ default=None,
582
+ description=(
583
+ "Agent configuration to use for this job. If None, it will default to the "
584
+ "agent selected during Crow deployment in the CrowDeploymentConfig object."
585
+ ),
586
+ )
587
+ continued_job_id: UUID | None = Field(
588
+ default=None,
589
+ description="Optional job identifier for a continued job",
590
+ )
591
+
592
+ @field_validator("agent")
593
+ @classmethod
594
+ def validate_agent(
595
+ cls, value: str | AgentConfig | None
596
+ ) -> str | AgentConfig | None:
597
+ if value is None:
598
+ return None
599
+
600
+ if isinstance(value, AgentConfig):
601
+ return value
602
+
603
+ if not value or not value.strip():
604
+ raise ValueError("Agent path cannot be empty")
605
+ if not all(part.isidentifier() for part in value.split(".")):
606
+ raise ValueError(f"Invalid agent path format: {value}")
607
+ return value
608
+
609
+
610
+ class JobRequest(BaseModel):
611
+ job_id: UUID | None = Field(
612
+ default=None,
613
+ description="Optional job identifier",
614
+ alias="id",
615
+ )
616
+ name: "str | JobNames" = Field(
617
+ description="Name of the crow to execute eg. paperqa-crow"
618
+ )
619
+ query: str = Field(description="Query or task to be executed by the crow")
620
+ runtime_config: RuntimeConfig | None = Field(
621
+ default=None, description="All optional runtime parameters for the job"
622
+ )
@@ -0,0 +1,72 @@
1
+ from typing import Any, Generic, TypeAlias, TypeVar
2
+
3
+ from aviary.message import Message
4
+ from aviary.tools.base import Tool
5
+ from ldp.data_structures import Transition
6
+ from ldp.graph.ops import OpResult
7
+ from pydantic import BaseModel, ConfigDict, Field, field_serializer
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ # TODO: revisit this
13
+ # unsure what crow states will return
14
+ # need to revisit after we get more crows deployed
15
+ class BaseState(BaseModel):
16
+ model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
17
+
18
+
19
+ class BeforeTransitionState(BaseState):
20
+ current_state: Any = Field()
21
+ observations: list[Message] = Field()
22
+
23
+
24
+ class InitialState(BaseState):
25
+ initial_state: Any = Field()
26
+
27
+
28
+ class ASVState(BaseState, Generic[T]):
29
+ action: OpResult[T] = Field()
30
+ next_state: Any = Field()
31
+ value: float = Field()
32
+
33
+ @field_serializer("action")
34
+ def serialize_action(self, action: OpResult[T]) -> dict:
35
+ return action.to_dict()
36
+
37
+ @field_serializer("next_state")
38
+ def serialize_next_state(self, state: Any) -> str:
39
+ return str(state)
40
+
41
+
42
+ class EnvResetState(BaseState):
43
+ observations: list[Message] = Field()
44
+ tools: list[Tool] = Field()
45
+
46
+
47
+ class EnvStepState(BaseState):
48
+ observations: list[Message] = Field()
49
+ reward: float = Field()
50
+ done: bool = Field()
51
+ trunc: bool = Field()
52
+
53
+
54
+ class TransitionState(BaseState):
55
+ transition: Transition = Field()
56
+
57
+ @field_serializer("transition")
58
+ def serialize_transition(self, transition: Transition) -> dict:
59
+ transition_data = transition.model_dump()
60
+ return transition_data | {
61
+ "action": transition.action.to_dict() if transition.action else None,
62
+ }
63
+
64
+
65
+ StateType: TypeAlias = (
66
+ BeforeTransitionState
67
+ | InitialState
68
+ | ASVState
69
+ | EnvResetState
70
+ | EnvStepState
71
+ | TransitionState
72
+ )