plato-sdk-v2 2.0.50__py3-none-any.whl → 2.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/__init__.py +7 -6
- plato/_generated/__init__.py +1 -1
- plato/_generated/api/v1/env/evaluate_session.py +3 -3
- plato/_generated/api/v1/env/log_state_mutation.py +4 -4
- plato/_generated/api/v1/sandbox/checkpoint_vm.py +3 -3
- plato/_generated/api/v1/sandbox/save_vm_snapshot.py +3 -3
- plato/_generated/api/v1/sandbox/setup_sandbox.py +8 -8
- plato/_generated/api/v1/session/__init__.py +2 -0
- plato/_generated/api/v1/session/get_sessions_for_archival.py +100 -0
- plato/_generated/api/v1/testcases/__init__.py +6 -2
- plato/_generated/api/v1/testcases/get_mutation_groups_for_testcase.py +98 -0
- plato/_generated/api/v1/testcases/{get_next_output_testcase_for_scoring.py → get_next_testcase_for_scoring.py} +23 -10
- plato/_generated/api/v1/testcases/get_testcase_metadata_for_scoring.py +74 -0
- plato/_generated/api/v2/__init__.py +2 -1
- plato/_generated/api/v2/jobs/__init__.py +4 -0
- plato/_generated/api/v2/jobs/checkpoint.py +3 -3
- plato/_generated/api/v2/jobs/disk_snapshot.py +3 -3
- plato/_generated/api/v2/jobs/log_for_job.py +4 -39
- plato/_generated/api/v2/jobs/make.py +4 -4
- plato/_generated/api/v2/jobs/setup_sandbox.py +97 -0
- plato/_generated/api/v2/jobs/snapshot.py +3 -3
- plato/_generated/api/v2/jobs/snapshot_store.py +91 -0
- plato/_generated/api/v2/sessions/__init__.py +4 -0
- plato/_generated/api/v2/sessions/checkpoint.py +3 -3
- plato/_generated/api/v2/sessions/disk_snapshot.py +3 -3
- plato/_generated/api/v2/sessions/evaluate.py +3 -3
- plato/_generated/api/v2/sessions/log_job_mutation.py +4 -39
- plato/_generated/api/v2/sessions/make.py +4 -4
- plato/_generated/api/v2/sessions/setup_sandbox.py +98 -0
- plato/_generated/api/v2/sessions/snapshot.py +3 -3
- plato/_generated/api/v2/sessions/snapshot_store.py +94 -0
- plato/_generated/api/v2/user/__init__.py +7 -0
- plato/_generated/api/v2/user/get_current_user.py +76 -0
- plato/_generated/models/__init__.py +174 -23
- plato/_sims_generator/__init__.py +19 -4
- plato/_sims_generator/instruction.py +203 -0
- plato/_sims_generator/templates/instruction/helpers.py.jinja +161 -0
- plato/_sims_generator/templates/instruction/init.py.jinja +43 -0
- plato/agents/__init__.py +107 -517
- plato/agents/base.py +145 -0
- plato/agents/build.py +61 -0
- plato/agents/config.py +160 -0
- plato/agents/logging.py +401 -0
- plato/agents/runner.py +161 -0
- plato/agents/trajectory.py +266 -0
- plato/chronos/__init__.py +37 -0
- plato/chronos/api/__init__.py +3 -0
- plato/chronos/api/agents/__init__.py +13 -0
- plato/chronos/api/agents/create_agent.py +63 -0
- plato/chronos/api/agents/delete_agent.py +61 -0
- plato/chronos/api/agents/get_agent.py +62 -0
- plato/chronos/api/agents/get_agent_schema.py +72 -0
- plato/chronos/api/agents/get_agent_versions.py +62 -0
- plato/chronos/api/agents/list_agents.py +57 -0
- plato/chronos/api/agents/lookup_agent.py +74 -0
- plato/chronos/api/auth/__init__.py +9 -0
- plato/chronos/api/auth/debug_auth_api_auth_debug_get.py +43 -0
- plato/chronos/api/auth/get_auth_status_api_auth_status_get.py +61 -0
- plato/chronos/api/auth/get_current_user_route_api_auth_me_get.py +60 -0
- plato/chronos/api/callback/__init__.py +11 -0
- plato/chronos/api/callback/push_agent_logs.py +61 -0
- plato/chronos/api/callback/update_agent_status.py +57 -0
- plato/chronos/api/callback/upload_artifacts.py +59 -0
- plato/chronos/api/callback/upload_logs_zip.py +57 -0
- plato/chronos/api/callback/upload_trajectory.py +57 -0
- plato/chronos/api/default/__init__.py +7 -0
- plato/chronos/api/default/health.py +43 -0
- plato/chronos/api/jobs/__init__.py +7 -0
- plato/chronos/api/jobs/launch_job.py +63 -0
- plato/chronos/api/registry/__init__.py +19 -0
- plato/chronos/api/registry/get_agent_schema_api_registry_agents__agent_name__schema_get.py +62 -0
- plato/chronos/api/registry/get_agent_versions_api_registry_agents__agent_name__versions_get.py +52 -0
- plato/chronos/api/registry/get_world_schema_api_registry_worlds__package_name__schema_get.py +68 -0
- plato/chronos/api/registry/get_world_versions_api_registry_worlds__package_name__versions_get.py +52 -0
- plato/chronos/api/registry/list_registry_agents_api_registry_agents_get.py +44 -0
- plato/chronos/api/registry/list_registry_worlds_api_registry_worlds_get.py +44 -0
- plato/chronos/api/runtimes/__init__.py +11 -0
- plato/chronos/api/runtimes/create_runtime.py +63 -0
- plato/chronos/api/runtimes/delete_runtime.py +61 -0
- plato/chronos/api/runtimes/get_runtime.py +62 -0
- plato/chronos/api/runtimes/list_runtimes.py +57 -0
- plato/chronos/api/runtimes/test_runtime.py +67 -0
- plato/chronos/api/secrets/__init__.py +11 -0
- plato/chronos/api/secrets/create_secret.py +63 -0
- plato/chronos/api/secrets/delete_secret.py +61 -0
- plato/chronos/api/secrets/get_secret.py +62 -0
- plato/chronos/api/secrets/list_secrets.py +57 -0
- plato/chronos/api/secrets/update_secret.py +68 -0
- plato/chronos/api/sessions/__init__.py +10 -0
- plato/chronos/api/sessions/get_session.py +62 -0
- plato/chronos/api/sessions/get_session_logs.py +72 -0
- plato/chronos/api/sessions/get_session_logs_download.py +62 -0
- plato/chronos/api/sessions/list_sessions.py +57 -0
- plato/chronos/api/status/__init__.py +8 -0
- plato/chronos/api/status/get_status_api_status_get.py +44 -0
- plato/chronos/api/status/get_version_info_api_version_get.py +44 -0
- plato/chronos/api/templates/__init__.py +11 -0
- plato/chronos/api/templates/create_template.py +63 -0
- plato/chronos/api/templates/delete_template.py +61 -0
- plato/chronos/api/templates/get_template.py +62 -0
- plato/chronos/api/templates/list_templates.py +57 -0
- plato/chronos/api/templates/update_template.py +68 -0
- plato/chronos/api/trajectories/__init__.py +8 -0
- plato/chronos/api/trajectories/get_trajectory.py +62 -0
- plato/chronos/api/trajectories/list_trajectories.py +62 -0
- plato/chronos/api/worlds/__init__.py +10 -0
- plato/chronos/api/worlds/create_world.py +63 -0
- plato/chronos/api/worlds/delete_world.py +61 -0
- plato/chronos/api/worlds/get_world.py +62 -0
- plato/chronos/api/worlds/list_worlds.py +57 -0
- plato/chronos/client.py +171 -0
- plato/chronos/errors.py +141 -0
- plato/chronos/models/__init__.py +647 -0
- plato/chronos/py.typed +0 -0
- plato/sims/cli.py +299 -123
- plato/sims/registry.py +77 -4
- plato/v1/cli/agent.py +88 -84
- plato/v1/cli/main.py +2 -0
- plato/v1/cli/pm.py +441 -119
- plato/v1/cli/sandbox.py +747 -191
- plato/v1/cli/sim.py +11 -0
- plato/v1/cli/verify.py +1269 -0
- plato/v1/cli/world.py +3 -0
- plato/v1/flow_executor.py +21 -17
- plato/v1/models/env.py +11 -11
- plato/v1/sdk.py +2 -2
- plato/v1/sync_env.py +11 -11
- plato/v1/sync_flow_executor.py +21 -17
- plato/v1/sync_sdk.py +4 -2
- plato/v2/__init__.py +2 -0
- plato/v2/async_/environment.py +20 -1
- plato/v2/async_/session.py +54 -3
- plato/v2/sync/environment.py +2 -1
- plato/v2/sync/session.py +52 -2
- plato/worlds/README.md +218 -0
- plato/worlds/__init__.py +54 -18
- plato/worlds/base.py +304 -93
- plato/worlds/config.py +239 -73
- plato/worlds/runner.py +391 -80
- {plato_sdk_v2-2.0.50.dist-info → plato_sdk_v2-2.2.4.dist-info}/METADATA +1 -3
- {plato_sdk_v2-2.0.50.dist-info → plato_sdk_v2-2.2.4.dist-info}/RECORD +143 -68
- {plato_sdk_v2-2.0.50.dist-info → plato_sdk_v2-2.2.4.dist-info}/entry_points.txt +1 -0
- plato/_generated/api/v2/interfaces/__init__.py +0 -27
- plato/_generated/api/v2/interfaces/v2_interface_browser_create.py +0 -68
- plato/_generated/api/v2/interfaces/v2_interface_cdp_url.py +0 -65
- plato/_generated/api/v2/interfaces/v2_interface_click.py +0 -64
- plato/_generated/api/v2/interfaces/v2_interface_close.py +0 -59
- plato/_generated/api/v2/interfaces/v2_interface_computer_create.py +0 -68
- plato/_generated/api/v2/interfaces/v2_interface_cursor.py +0 -64
- plato/_generated/api/v2/interfaces/v2_interface_key.py +0 -68
- plato/_generated/api/v2/interfaces/v2_interface_screenshot.py +0 -65
- plato/_generated/api/v2/interfaces/v2_interface_scroll.py +0 -70
- plato/_generated/api/v2/interfaces/v2_interface_type.py +0 -64
- plato/world/__init__.py +0 -44
- plato/world/base.py +0 -267
- plato/world/config.py +0 -139
- plato/world/types.py +0 -47
- {plato_sdk_v2-2.0.50.dist-info → plato_sdk_v2-2.2.4.dist-info}/WHEEL +0 -0
plato/world/base.py
DELETED
|
@@ -1,267 +0,0 @@
|
|
|
1
|
-
"""World base class - OpenAI Gym-like interface for Plato simulators."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
import logging
|
|
6
|
-
from abc import ABC, abstractmethod
|
|
7
|
-
from datetime import datetime, timedelta
|
|
8
|
-
from typing import TYPE_CHECKING, cast
|
|
9
|
-
|
|
10
|
-
from plato.v2 import AsyncPlato, AsyncSession, EnvFromArtifact, EnvFromResource, EnvFromSimulator
|
|
11
|
-
from plato.world.config import WorldConfig
|
|
12
|
-
from plato.world.types import Observation, StepResult
|
|
13
|
-
|
|
14
|
-
if TYPE_CHECKING:
|
|
15
|
-
from plato.storage import RolloutStorage
|
|
16
|
-
|
|
17
|
-
logger = logging.getLogger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class World(ABC):
|
|
21
|
-
"""OpenAI Gym-like environment backed by Plato simulators.
|
|
22
|
-
|
|
23
|
-
World manages the Plato session lifecycle and provides a gym-like interface
|
|
24
|
-
for agent interaction. Agents interact with simulators via the plato.sims SDK
|
|
25
|
-
using URLs exposed by the World.
|
|
26
|
-
|
|
27
|
-
Subclasses must implement:
|
|
28
|
-
- _on_reset(): Custom reset logic, returns initial Observation
|
|
29
|
-
- _on_step(): Custom step logic (e.g., run consumers), returns StepResult
|
|
30
|
-
- get_prompt(): Generate agent prompt for current state
|
|
31
|
-
|
|
32
|
-
Example:
|
|
33
|
-
class MerchantWorld(World):
|
|
34
|
-
async def _on_reset(self) -> Observation:
|
|
35
|
-
return Observation(step=0, date=self.config.start_date, data={"balance": 1000})
|
|
36
|
-
|
|
37
|
-
async def _on_step(self) -> StepResult:
|
|
38
|
-
# Run consumer simulation, calculate reward
|
|
39
|
-
return StepResult(observation=obs, reward=100.0)
|
|
40
|
-
|
|
41
|
-
def get_prompt(self) -> str:
|
|
42
|
-
return f"Day {self._current_step}: Manage your store..."
|
|
43
|
-
|
|
44
|
-
# Agent receives env vars:
|
|
45
|
-
# STORE_BASE_URL=http://...
|
|
46
|
-
# ACCOUNTING_BASE_URL=http://...
|
|
47
|
-
# And uses: from plato.sims import spree; client = spree.Client.create()
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(self, config: WorldConfig, storage: RolloutStorage | None = None):
|
|
51
|
-
"""Initialize World with configuration.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
config: WorldConfig with environment specs and parameters
|
|
55
|
-
storage: Optional RolloutStorage for persisting data
|
|
56
|
-
"""
|
|
57
|
-
self.config = config
|
|
58
|
-
self.storage = storage
|
|
59
|
-
|
|
60
|
-
# Plato session state
|
|
61
|
-
self._plato: AsyncPlato | None = None
|
|
62
|
-
self._session: AsyncSession | None = None
|
|
63
|
-
self._connect_urls: dict[str, str] = {}
|
|
64
|
-
|
|
65
|
-
# Simulation state
|
|
66
|
-
self._current_step: int = 0
|
|
67
|
-
self._current_date: datetime | None = None
|
|
68
|
-
self._done: bool = False
|
|
69
|
-
|
|
70
|
-
@property
|
|
71
|
-
def current_step(self) -> int:
|
|
72
|
-
"""Get the current step number."""
|
|
73
|
-
return self._current_step
|
|
74
|
-
|
|
75
|
-
@property
|
|
76
|
-
def current_date(self) -> str:
|
|
77
|
-
"""Get the current simulation date as ISO string."""
|
|
78
|
-
if self._current_date is None:
|
|
79
|
-
return self.config.start_date
|
|
80
|
-
return self._current_date.strftime("%Y-%m-%d")
|
|
81
|
-
|
|
82
|
-
@property
|
|
83
|
-
def is_done(self) -> bool:
|
|
84
|
-
"""Check if the simulation is complete."""
|
|
85
|
-
return self._done
|
|
86
|
-
|
|
87
|
-
@property
|
|
88
|
-
def session(self) -> AsyncSession | None:
|
|
89
|
-
"""Get the underlying Plato session."""
|
|
90
|
-
return self._session
|
|
91
|
-
|
|
92
|
-
async def reset(self) -> Observation:
|
|
93
|
-
"""Reset world: create Plato session, return initial observation.
|
|
94
|
-
|
|
95
|
-
Creates a new Plato session with the configured environments and
|
|
96
|
-
waits for them to be ready. Returns the initial observation.
|
|
97
|
-
|
|
98
|
-
Returns:
|
|
99
|
-
Initial Observation from _on_reset()
|
|
100
|
-
"""
|
|
101
|
-
logger.info(f"Resetting world '{self.config.name}'")
|
|
102
|
-
|
|
103
|
-
# Create Plato client and session
|
|
104
|
-
self._plato = AsyncPlato()
|
|
105
|
-
envs = [spec.to_plato_env() for spec in self.config.envs]
|
|
106
|
-
|
|
107
|
-
logger.info(f"Creating Plato session with {len(envs)} environments")
|
|
108
|
-
self._session = await self._plato.sessions.create(
|
|
109
|
-
envs=cast(list[EnvFromSimulator | EnvFromArtifact | EnvFromResource], envs),
|
|
110
|
-
timeout=self.config.session_timeout,
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
# Get connect URLs - these become env vars for agents
|
|
114
|
-
self._connect_urls = await self._session.get_connect_url()
|
|
115
|
-
logger.info(f"Got connect URLs: {list(self._connect_urls.keys())}")
|
|
116
|
-
|
|
117
|
-
# Initialize state
|
|
118
|
-
self._current_step = 0
|
|
119
|
-
self._current_date = datetime.fromisoformat(self.config.start_date)
|
|
120
|
-
self._done = False
|
|
121
|
-
|
|
122
|
-
# Call subclass hook
|
|
123
|
-
obs = await self._on_reset()
|
|
124
|
-
|
|
125
|
-
# Log to storage
|
|
126
|
-
if self.storage:
|
|
127
|
-
self.storage.log_reset(obs)
|
|
128
|
-
|
|
129
|
-
logger.info(f"World reset complete. Initial observation: step={obs.step}, date={obs.date}")
|
|
130
|
-
return obs
|
|
131
|
-
|
|
132
|
-
async def step(self) -> StepResult:
|
|
133
|
-
"""Advance simulation by one step.
|
|
134
|
-
|
|
135
|
-
Increments the step counter, advances the simulation date,
|
|
136
|
-
and calls the subclass _on_step() hook for custom logic.
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
StepResult with observation, reward, done, and info
|
|
140
|
-
"""
|
|
141
|
-
if self._session is None:
|
|
142
|
-
raise RuntimeError("World not initialized. Call reset() first.")
|
|
143
|
-
|
|
144
|
-
self._current_step += 1
|
|
145
|
-
|
|
146
|
-
# Advance simulation date
|
|
147
|
-
await self._advance_date()
|
|
148
|
-
|
|
149
|
-
# Call subclass hook for custom step logic
|
|
150
|
-
result = await self._on_step()
|
|
151
|
-
|
|
152
|
-
# Update done status
|
|
153
|
-
if result.done:
|
|
154
|
-
self._done = True
|
|
155
|
-
|
|
156
|
-
# Log to storage
|
|
157
|
-
if self.storage:
|
|
158
|
-
self.storage.log_step(self._current_step, result)
|
|
159
|
-
|
|
160
|
-
logger.info(f"Step {self._current_step} complete. Reward: {result.reward:.2f}, Done: {result.done}")
|
|
161
|
-
return result
|
|
162
|
-
|
|
163
|
-
async def _advance_date(self) -> None:
|
|
164
|
-
"""Advance the simulation date by one day.
|
|
165
|
-
|
|
166
|
-
Updates the date on all environments in the session.
|
|
167
|
-
"""
|
|
168
|
-
if self._current_date is None:
|
|
169
|
-
self._current_date = datetime.fromisoformat(self.config.start_date)
|
|
170
|
-
|
|
171
|
-
self._current_date += timedelta(days=1)
|
|
172
|
-
|
|
173
|
-
if self._session:
|
|
174
|
-
try:
|
|
175
|
-
await self._session.set_date(self._current_date)
|
|
176
|
-
logger.debug(f"Advanced date to {self.current_date}")
|
|
177
|
-
except Exception as e:
|
|
178
|
-
logger.warning(f"Failed to set date: {e}")
|
|
179
|
-
|
|
180
|
-
async def close(self) -> None:
|
|
181
|
-
"""Close Plato session and cleanup resources."""
|
|
182
|
-
logger.info(f"Closing world '{self.config.name}'")
|
|
183
|
-
|
|
184
|
-
if self._session:
|
|
185
|
-
try:
|
|
186
|
-
await self._session.close()
|
|
187
|
-
except Exception as e:
|
|
188
|
-
logger.warning(f"Error closing session: {e}")
|
|
189
|
-
self._session = None
|
|
190
|
-
|
|
191
|
-
if self._plato:
|
|
192
|
-
try:
|
|
193
|
-
await self._plato.close()
|
|
194
|
-
except Exception as e:
|
|
195
|
-
logger.warning(f"Error closing Plato client: {e}")
|
|
196
|
-
self._plato = None
|
|
197
|
-
|
|
198
|
-
self._connect_urls = {}
|
|
199
|
-
|
|
200
|
-
def get_env_vars(self) -> dict[str, str]:
|
|
201
|
-
"""Get environment variables for agent.
|
|
202
|
-
|
|
203
|
-
Returns dict with simulator URLs and current simulation date:
|
|
204
|
-
{
|
|
205
|
-
"STORE_BASE_URL": "http://...",
|
|
206
|
-
"ACCOUNTING_BASE_URL": "http://...",
|
|
207
|
-
"SIMULATION_DATE": "2024-01-15",
|
|
208
|
-
}
|
|
209
|
-
|
|
210
|
-
Agent can then: from plato.sims import spree; client = spree.Client.create()
|
|
211
|
-
"""
|
|
212
|
-
env_vars = {f"{alias.upper()}_BASE_URL": url for alias, url in self._connect_urls.items()}
|
|
213
|
-
env_vars["SIMULATION_DATE"] = self.current_date
|
|
214
|
-
env_vars["SIMULATION_STEP"] = str(self._current_step)
|
|
215
|
-
return env_vars
|
|
216
|
-
|
|
217
|
-
def get_connect_urls(self) -> dict[str, str]:
|
|
218
|
-
"""Get raw connect URLs by alias.
|
|
219
|
-
|
|
220
|
-
Returns:
|
|
221
|
-
Dict mapping alias to connect URL
|
|
222
|
-
"""
|
|
223
|
-
return self._connect_urls.copy()
|
|
224
|
-
|
|
225
|
-
@abstractmethod
|
|
226
|
-
async def _on_reset(self) -> Observation:
|
|
227
|
-
"""Subclass hook: custom reset logic.
|
|
228
|
-
|
|
229
|
-
Called after Plato session is created and ready.
|
|
230
|
-
Should return the initial observation.
|
|
231
|
-
|
|
232
|
-
Returns:
|
|
233
|
-
Initial Observation
|
|
234
|
-
"""
|
|
235
|
-
pass
|
|
236
|
-
|
|
237
|
-
@abstractmethod
|
|
238
|
-
async def _on_step(self) -> StepResult:
|
|
239
|
-
"""Subclass hook: custom step logic.
|
|
240
|
-
|
|
241
|
-
Called after date is advanced. Should implement any custom
|
|
242
|
-
step logic (e.g., run consumer simulation) and return the result.
|
|
243
|
-
|
|
244
|
-
Returns:
|
|
245
|
-
StepResult with observation, reward, done, info
|
|
246
|
-
"""
|
|
247
|
-
pass
|
|
248
|
-
|
|
249
|
-
@abstractmethod
|
|
250
|
-
def get_prompt(self) -> str:
|
|
251
|
-
"""Generate agent prompt for current state.
|
|
252
|
-
|
|
253
|
-
Should return a prompt string that describes what the agent
|
|
254
|
-
should do in this step.
|
|
255
|
-
|
|
256
|
-
Returns:
|
|
257
|
-
Prompt string for the agent
|
|
258
|
-
"""
|
|
259
|
-
pass
|
|
260
|
-
|
|
261
|
-
async def __aenter__(self) -> World:
|
|
262
|
-
"""Async context manager entry."""
|
|
263
|
-
return self
|
|
264
|
-
|
|
265
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
266
|
-
"""Async context manager exit."""
|
|
267
|
-
await self.close()
|
plato/world/config.py
DELETED
|
@@ -1,139 +0,0 @@
|
|
|
1
|
-
"""World configuration - WorldConfig and EnvSpec."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
import yaml
|
|
9
|
-
from pydantic import BaseModel, Field, model_validator
|
|
10
|
-
|
|
11
|
-
from plato.v2 import Env, EnvFromArtifact, EnvFromSimulator
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class EnvSpec(BaseModel):
|
|
15
|
-
"""Specification for a single environment in the world.
|
|
16
|
-
|
|
17
|
-
Environments can be specified either by:
|
|
18
|
-
- simulator + dataset: Create a fresh simulator instance
|
|
19
|
-
- artifact_id: Load from a saved artifact
|
|
20
|
-
|
|
21
|
-
Attributes:
|
|
22
|
-
alias: Unique identifier for this environment (e.g., "store", "accounting")
|
|
23
|
-
simulator: Simulator type (e.g., "spree", "firefly", "espocrm")
|
|
24
|
-
artifact_id: Optional artifact ID to load from
|
|
25
|
-
dataset: Dataset to use when creating fresh instance (default: "blank")
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
alias: str
|
|
29
|
-
simulator: str | None = None
|
|
30
|
-
artifact_id: str | None = None
|
|
31
|
-
dataset: str | None = None
|
|
32
|
-
|
|
33
|
-
@model_validator(mode="after")
|
|
34
|
-
def validate_env_source(self) -> EnvSpec:
|
|
35
|
-
"""Ensure either simulator or artifact_id is provided."""
|
|
36
|
-
if not self.artifact_id and not self.simulator:
|
|
37
|
-
raise ValueError(f"EnvSpec '{self.alias}' must have either simulator or artifact_id")
|
|
38
|
-
return self
|
|
39
|
-
|
|
40
|
-
def to_plato_env(self) -> EnvFromArtifact | EnvFromSimulator:
|
|
41
|
-
"""Convert to Plato Env object for session creation.
|
|
42
|
-
|
|
43
|
-
Returns:
|
|
44
|
-
Env object suitable for AsyncPlato.sessions.create()
|
|
45
|
-
"""
|
|
46
|
-
if self.artifact_id:
|
|
47
|
-
return Env.artifact(self.artifact_id, alias=self.alias)
|
|
48
|
-
# validator ensures simulator is set if artifact_id is not
|
|
49
|
-
assert self.simulator is not None
|
|
50
|
-
return Env.simulator(self.simulator, alias=self.alias, dataset=self.dataset)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class WorldConfig(BaseModel):
|
|
54
|
-
"""Configuration for a World.
|
|
55
|
-
|
|
56
|
-
All configs can be loaded from YAML files.
|
|
57
|
-
|
|
58
|
-
Example YAML format:
|
|
59
|
-
```yaml
|
|
60
|
-
name: "merchant-sim"
|
|
61
|
-
num_steps: 30
|
|
62
|
-
start_date: "2024-01-01"
|
|
63
|
-
envs:
|
|
64
|
-
- alias: store
|
|
65
|
-
simulator: spree
|
|
66
|
-
- alias: accounting
|
|
67
|
-
artifact_id: "abc-123"
|
|
68
|
-
```
|
|
69
|
-
|
|
70
|
-
Attributes:
|
|
71
|
-
name: Human-readable name for this world
|
|
72
|
-
envs: List of environment specifications
|
|
73
|
-
num_steps: Number of simulation steps (e.g., days)
|
|
74
|
-
start_date: Simulation start date (ISO format)
|
|
75
|
-
seed: Random seed for reproducibility
|
|
76
|
-
"""
|
|
77
|
-
|
|
78
|
-
name: str = "world"
|
|
79
|
-
envs: list[EnvSpec] = Field(default_factory=list)
|
|
80
|
-
num_steps: int = 30
|
|
81
|
-
start_date: str = "2024-01-01"
|
|
82
|
-
seed: int = 42
|
|
83
|
-
session_timeout: int = Field(default=300, description="Timeout for session creation in seconds")
|
|
84
|
-
|
|
85
|
-
@classmethod
|
|
86
|
-
def from_yaml(cls, path: str | Path) -> WorldConfig:
|
|
87
|
-
"""Load configuration from a YAML file.
|
|
88
|
-
|
|
89
|
-
Args:
|
|
90
|
-
path: Path to YAML config file
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
Loaded WorldConfig instance
|
|
94
|
-
|
|
95
|
-
Raises:
|
|
96
|
-
FileNotFoundError: If config file doesn't exist
|
|
97
|
-
"""
|
|
98
|
-
path = Path(path)
|
|
99
|
-
if not path.exists():
|
|
100
|
-
raise FileNotFoundError(f"Config file not found: {path}")
|
|
101
|
-
|
|
102
|
-
with open(path) as f:
|
|
103
|
-
data = yaml.safe_load(f)
|
|
104
|
-
|
|
105
|
-
return cls(**data)
|
|
106
|
-
|
|
107
|
-
@classmethod
|
|
108
|
-
def from_dict(cls, data: dict[str, Any]) -> WorldConfig:
|
|
109
|
-
"""Create configuration from a dictionary.
|
|
110
|
-
|
|
111
|
-
Args:
|
|
112
|
-
data: Configuration dictionary
|
|
113
|
-
|
|
114
|
-
Returns:
|
|
115
|
-
WorldConfig instance
|
|
116
|
-
"""
|
|
117
|
-
return cls(**data)
|
|
118
|
-
|
|
119
|
-
def get_env(self, alias: str) -> EnvSpec | None:
|
|
120
|
-
"""Get environment spec by alias.
|
|
121
|
-
|
|
122
|
-
Args:
|
|
123
|
-
alias: Environment alias to find
|
|
124
|
-
|
|
125
|
-
Returns:
|
|
126
|
-
EnvSpec if found, None otherwise
|
|
127
|
-
"""
|
|
128
|
-
for env in self.envs:
|
|
129
|
-
if env.alias == alias:
|
|
130
|
-
return env
|
|
131
|
-
return None
|
|
132
|
-
|
|
133
|
-
def get_env_aliases(self) -> list[str]:
|
|
134
|
-
"""Get list of all environment aliases.
|
|
135
|
-
|
|
136
|
-
Returns:
|
|
137
|
-
List of alias strings
|
|
138
|
-
"""
|
|
139
|
-
return [env.alias for env in self.envs]
|
plato/world/types.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
"""World types - Observation and StepResult."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
from pydantic import BaseModel, Field
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class Observation(BaseModel):
|
|
11
|
-
"""Observation from the world after a step.
|
|
12
|
-
|
|
13
|
-
Contains the current state information that the agent observes.
|
|
14
|
-
|
|
15
|
-
Attributes:
|
|
16
|
-
step: Current step number (0 for initial reset)
|
|
17
|
-
date: Current simulation date (ISO format)
|
|
18
|
-
data: Custom observation data (world-specific)
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
step: int
|
|
22
|
-
date: str
|
|
23
|
-
data: dict[str, Any] = Field(default_factory=dict)
|
|
24
|
-
|
|
25
|
-
model_config = {"extra": "allow"}
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
class StepResult(BaseModel):
|
|
29
|
-
"""Result of a world step.
|
|
30
|
-
|
|
31
|
-
Follows OpenAI Gym convention with observation, reward, done, info.
|
|
32
|
-
|
|
33
|
-
Attributes:
|
|
34
|
-
observation: Current state/observation after the step
|
|
35
|
-
reward: Numeric reward for this step (e.g., profit, score)
|
|
36
|
-
done: Whether the simulation has ended
|
|
37
|
-
truncated: Whether step was cut short (e.g., timeout)
|
|
38
|
-
info: Additional information about the step
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
observation: Observation
|
|
42
|
-
reward: float = 0.0
|
|
43
|
-
done: bool = False
|
|
44
|
-
truncated: bool = False
|
|
45
|
-
info: dict[str, Any] = Field(default_factory=dict)
|
|
46
|
-
|
|
47
|
-
model_config = {"arbitrary_types_allowed": True}
|
|
File without changes
|