plato-sdk-v2 2.3.3__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/agents/__init__.py +24 -16
- plato/agents/artifacts.py +108 -0
- plato/agents/config.py +16 -13
- plato/agents/otel.py +261 -0
- plato/agents/runner.py +223 -149
- plato/chronos/models/__init__.py +9 -1
- plato/v1/cli/agent.py +7 -7
- plato/v1/cli/chronos.py +767 -0
- plato/v1/cli/main.py +2 -0
- plato/v1/cli/pm.py +3 -3
- plato/v1/cli/sandbox.py +58 -6
- plato/v1/cli/ssh.py +21 -14
- plato/v1/cli/templates/world-runner.Dockerfile +27 -0
- plato/v1/cli/utils.py +32 -12
- plato/worlds/README.md +2 -1
- plato/worlds/base.py +222 -101
- plato/worlds/config.py +5 -3
- plato/worlds/runner.py +1 -391
- {plato_sdk_v2-2.3.3.dist-info → plato_sdk_v2-2.4.2.dist-info}/METADATA +4 -3
- {plato_sdk_v2-2.3.3.dist-info → plato_sdk_v2-2.4.2.dist-info}/RECORD +22 -25
- plato/agents/logging.py +0 -515
- plato/chronos/api/callback/__init__.py +0 -11
- plato/chronos/api/callback/push_agent_logs.py +0 -61
- plato/chronos/api/callback/update_agent_status.py +0 -57
- plato/chronos/api/callback/upload_artifacts.py +0 -59
- plato/chronos/api/callback/upload_logs_zip.py +0 -57
- plato/chronos/api/callback/upload_trajectory.py +0 -57
- {plato_sdk_v2-2.3.3.dist-info → plato_sdk_v2-2.4.2.dist-info}/WHEEL +0 -0
- {plato_sdk_v2-2.3.3.dist-info → plato_sdk_v2-2.4.2.dist-info}/entry_points.txt +0 -0
plato/worlds/base.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
+
import os
|
|
6
7
|
import subprocess
|
|
7
8
|
from abc import ABC, abstractmethod
|
|
8
9
|
from pathlib import Path
|
|
@@ -16,15 +17,29 @@ if TYPE_CHECKING:
|
|
|
16
17
|
from plato.v2.async_.environment import Environment
|
|
17
18
|
from plato.v2.async_.session import Session
|
|
18
19
|
|
|
19
|
-
from plato.agents.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from plato.agents.
|
|
23
|
-
|
|
24
|
-
|
|
20
|
+
from plato.agents.artifacts import (
|
|
21
|
+
upload_artifact as _upload_artifact_raw,
|
|
22
|
+
)
|
|
23
|
+
from plato.agents.otel import (
|
|
24
|
+
get_tracer,
|
|
25
|
+
init_tracing,
|
|
26
|
+
shutdown_tracing,
|
|
27
|
+
)
|
|
28
|
+
from plato.agents.runner import run_agent as _run_agent_raw
|
|
25
29
|
|
|
26
30
|
logger = logging.getLogger(__name__)
|
|
27
31
|
|
|
32
|
+
|
|
33
|
+
def _get_plato_version() -> str:
|
|
34
|
+
"""Get the installed plato SDK version."""
|
|
35
|
+
try:
|
|
36
|
+
from importlib.metadata import version
|
|
37
|
+
|
|
38
|
+
return version("plato")
|
|
39
|
+
except Exception:
|
|
40
|
+
return "unknown"
|
|
41
|
+
|
|
42
|
+
|
|
28
43
|
# Global registry of worlds
|
|
29
44
|
_WORLD_REGISTRY: dict[str, type[BaseWorld]] = {}
|
|
30
45
|
|
|
@@ -111,6 +126,8 @@ class BaseWorld(ABC, Generic[ConfigT]):
|
|
|
111
126
|
self._step_count: int = 0
|
|
112
127
|
self.plato_session = None
|
|
113
128
|
self._current_step_id: str | None = None
|
|
129
|
+
self._session_id: str | None = None
|
|
130
|
+
self._agent_containers: list[str] = [] # Track spawned agent containers for cleanup
|
|
114
131
|
|
|
115
132
|
@classmethod
|
|
116
133
|
def get_config_class(cls) -> type[RunConfig]:
|
|
@@ -170,7 +187,70 @@ class BaseWorld(ABC, Generic[ConfigT]):
|
|
|
170
187
|
|
|
171
188
|
async def close(self) -> None:
|
|
172
189
|
"""Cleanup resources. Called after run completes."""
|
|
173
|
-
|
|
190
|
+
await self._cleanup_agent_containers()
|
|
191
|
+
|
|
192
|
+
async def _cleanup_agent_containers(self) -> None:
|
|
193
|
+
"""Stop any agent containers spawned by this world."""
|
|
194
|
+
import asyncio
|
|
195
|
+
|
|
196
|
+
if not self._agent_containers:
|
|
197
|
+
return
|
|
198
|
+
|
|
199
|
+
self.logger.info(f"Stopping {len(self._agent_containers)} agent container(s)...")
|
|
200
|
+
for container_name in self._agent_containers:
|
|
201
|
+
try:
|
|
202
|
+
proc = await asyncio.create_subprocess_exec(
|
|
203
|
+
"docker",
|
|
204
|
+
"stop",
|
|
205
|
+
container_name,
|
|
206
|
+
stdout=asyncio.subprocess.DEVNULL,
|
|
207
|
+
stderr=asyncio.subprocess.DEVNULL,
|
|
208
|
+
)
|
|
209
|
+
await proc.wait()
|
|
210
|
+
self.logger.debug(f"Stopped container: {container_name}")
|
|
211
|
+
except Exception as e:
|
|
212
|
+
self.logger.warning(f"Failed to stop container {container_name}: {e}")
|
|
213
|
+
self._agent_containers.clear()
|
|
214
|
+
self.logger.info("Agent containers stopped")
|
|
215
|
+
|
|
216
|
+
async def run_agent(
|
|
217
|
+
self,
|
|
218
|
+
image: str,
|
|
219
|
+
config: dict,
|
|
220
|
+
secrets: dict[str, str],
|
|
221
|
+
instruction: str,
|
|
222
|
+
workspace: str | None = None,
|
|
223
|
+
logs_dir: str | None = None,
|
|
224
|
+
pull: bool = True,
|
|
225
|
+
) -> str:
|
|
226
|
+
"""Run an agent in a Docker container, tracking the container for cleanup.
|
|
227
|
+
|
|
228
|
+
This is a wrapper around plato.agents.runner.run_agent that automatically
|
|
229
|
+
tracks spawned containers so they can be cleaned up when the world closes.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
image: Docker image URI
|
|
233
|
+
config: Agent configuration dict
|
|
234
|
+
secrets: Secret values (API keys, etc.)
|
|
235
|
+
instruction: Task instruction for the agent
|
|
236
|
+
workspace: Docker volume name for workspace
|
|
237
|
+
logs_dir: Ignored (kept for backwards compatibility)
|
|
238
|
+
pull: Whether to pull the image first
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
The container name that was created
|
|
242
|
+
"""
|
|
243
|
+
container_name = await _run_agent_raw(
|
|
244
|
+
image=image,
|
|
245
|
+
config=config,
|
|
246
|
+
secrets=secrets,
|
|
247
|
+
instruction=instruction,
|
|
248
|
+
workspace=workspace,
|
|
249
|
+
logs_dir=logs_dir,
|
|
250
|
+
pull=pull,
|
|
251
|
+
)
|
|
252
|
+
self._agent_containers.append(container_name)
|
|
253
|
+
return container_name
|
|
174
254
|
|
|
175
255
|
async def _connect_plato_session(self) -> None:
|
|
176
256
|
"""Connect to Plato session from config.
|
|
@@ -390,17 +470,39 @@ class BaseWorld(ABC, Generic[ConfigT]):
|
|
|
390
470
|
self.logger.warning(f"Failed to create state bundle: {e.stderr}")
|
|
391
471
|
return None
|
|
392
472
|
|
|
393
|
-
async def
|
|
473
|
+
async def _upload_artifact(
|
|
474
|
+
self,
|
|
475
|
+
data: bytes,
|
|
476
|
+
content_type: str = "application/octet-stream",
|
|
477
|
+
) -> bool:
|
|
478
|
+
"""Upload an artifact directly to S3.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
data: Raw bytes of the artifact
|
|
482
|
+
content_type: MIME type of the content
|
|
483
|
+
|
|
484
|
+
Returns:
|
|
485
|
+
True if successful, False otherwise
|
|
486
|
+
"""
|
|
487
|
+
if not self.config.upload_url:
|
|
488
|
+
self.logger.warning("Cannot upload artifact: upload_url not set")
|
|
489
|
+
return False
|
|
490
|
+
return await _upload_artifact_raw(
|
|
491
|
+
upload_url=self.config.upload_url,
|
|
492
|
+
data=data,
|
|
493
|
+
content_type=content_type,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
async def _create_and_upload_checkpoint(self) -> tuple[dict[str, str], bool]:
|
|
394
497
|
"""Create a full checkpoint including env snapshots and state bundle.
|
|
395
498
|
|
|
396
499
|
This method:
|
|
397
500
|
1. Commits any pending state changes
|
|
398
501
|
2. Creates env snapshots using snapshot_store
|
|
399
|
-
3. Creates and uploads state bundle
|
|
400
|
-
4. Calls the checkpoint endpoint with all data
|
|
502
|
+
3. Creates and uploads state bundle to S3
|
|
401
503
|
|
|
402
504
|
Returns:
|
|
403
|
-
|
|
505
|
+
Tuple of (env_snapshots dict, state_bundle_uploaded bool)
|
|
404
506
|
"""
|
|
405
507
|
# Commit state changes first
|
|
406
508
|
self._commit_state(f"Checkpoint at step {self._step_count}")
|
|
@@ -410,36 +512,24 @@ class BaseWorld(ABC, Generic[ConfigT]):
|
|
|
410
512
|
if env_snapshots is None:
|
|
411
513
|
env_snapshots = {}
|
|
412
514
|
|
|
515
|
+
state_bundle_uploaded = True # Default to True if state not enabled
|
|
516
|
+
|
|
413
517
|
# Create and upload state bundle
|
|
414
|
-
state_artifact_id: str | None = None
|
|
415
518
|
if self.config.state.enabled:
|
|
416
519
|
bundle_data = self._create_state_bundle()
|
|
417
520
|
if bundle_data:
|
|
418
|
-
|
|
521
|
+
success = await self._upload_artifact(
|
|
419
522
|
data=bundle_data,
|
|
420
|
-
|
|
421
|
-
filename=f"state_step_{self._step_count}.bundle",
|
|
422
|
-
extra={
|
|
423
|
-
"step_number": self._step_count,
|
|
424
|
-
"state_path": self.config.state.path,
|
|
425
|
-
},
|
|
523
|
+
content_type="application/octet-stream",
|
|
426
524
|
)
|
|
427
|
-
if
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
step_number=self._step_count,
|
|
434
|
-
env_snapshots=env_snapshots,
|
|
435
|
-
state_artifact_id=state_artifact_id,
|
|
436
|
-
extra={
|
|
437
|
-
"world_name": self.name,
|
|
438
|
-
"world_version": self.get_version(),
|
|
439
|
-
},
|
|
440
|
-
)
|
|
525
|
+
if success:
|
|
526
|
+
self.logger.info(f"Uploaded state bundle at step {self._step_count}")
|
|
527
|
+
state_bundle_uploaded = True
|
|
528
|
+
else:
|
|
529
|
+
self.logger.warning(f"Failed to upload state bundle at step {self._step_count}")
|
|
530
|
+
state_bundle_uploaded = False
|
|
441
531
|
|
|
442
|
-
return
|
|
532
|
+
return env_snapshots, state_bundle_uploaded
|
|
443
533
|
|
|
444
534
|
def get_env(self, alias: str) -> Environment | None:
|
|
445
535
|
"""Get an environment by alias.
|
|
@@ -630,81 +720,112 @@ The following services are available for your use:
|
|
|
630
720
|
# Initialize state directory (creates git repo if needed)
|
|
631
721
|
self._init_state_directory()
|
|
632
722
|
|
|
633
|
-
# Initialize
|
|
634
|
-
if config.
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
723
|
+
# Initialize OTel tracing and session info for artifact uploads
|
|
724
|
+
if config.session_id:
|
|
725
|
+
self._session_id = config.session_id
|
|
726
|
+
|
|
727
|
+
# Set environment variables for agent runners (which run in Docker)
|
|
728
|
+
os.environ["SESSION_ID"] = config.session_id
|
|
729
|
+
if config.otel_url:
|
|
730
|
+
# For agents in Docker, convert localhost to host.docker.internal
|
|
731
|
+
# so they can reach the host machine's Chronos instance
|
|
732
|
+
agent_otel_url = config.otel_url
|
|
733
|
+
if "localhost" in agent_otel_url or "127.0.0.1" in agent_otel_url:
|
|
734
|
+
agent_otel_url = agent_otel_url.replace("localhost", "host.docker.internal")
|
|
735
|
+
agent_otel_url = agent_otel_url.replace("127.0.0.1", "host.docker.internal")
|
|
736
|
+
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = agent_otel_url
|
|
737
|
+
os.environ["OTEL_EXPORTER_OTLP_PROTOCOL"] = "http/protobuf"
|
|
738
|
+
if config.upload_url:
|
|
739
|
+
os.environ["UPLOAD_URL"] = config.upload_url
|
|
740
|
+
|
|
741
|
+
# Initialize OTel tracing for the world itself (runs on host, not in Docker)
|
|
742
|
+
if config.otel_url:
|
|
743
|
+
logger.debug(f"Initializing OTel tracing with endpoint: {config.otel_url}")
|
|
744
|
+
init_tracing(
|
|
745
|
+
service_name=f"world-{self.name}",
|
|
746
|
+
session_id=config.session_id,
|
|
747
|
+
otlp_endpoint=config.otel_url,
|
|
748
|
+
)
|
|
749
|
+
else:
|
|
750
|
+
logger.debug("No otel_url in config - OTel tracing disabled")
|
|
751
|
+
|
|
752
|
+
# Log version info (goes to OTel after init_tracing)
|
|
753
|
+
plato_version = _get_plato_version()
|
|
754
|
+
world_version = self.get_version()
|
|
755
|
+
self.logger.info(f"World version: {world_version}, Plato SDK version: {plato_version}")
|
|
639
756
|
|
|
640
757
|
# Connect to Plato session if configured (for heartbeats)
|
|
641
758
|
await self._connect_plato_session()
|
|
642
759
|
|
|
643
|
-
#
|
|
644
|
-
|
|
645
|
-
span_type="session_start",
|
|
646
|
-
content=f"World '{self.name}' started",
|
|
647
|
-
source="world",
|
|
648
|
-
extra={"world_name": self.name, "world_version": self.get_version()},
|
|
649
|
-
)
|
|
650
|
-
|
|
651
|
-
try:
|
|
652
|
-
# Execute reset with automatic span tracking
|
|
653
|
-
async with _span("reset", span_type="reset", source="world") as reset_span:
|
|
654
|
-
reset_span.log(f"Resetting world '{self.name}'")
|
|
655
|
-
obs = await self.reset()
|
|
656
|
-
reset_span.set_extra({"observation": obs.model_dump() if hasattr(obs, "model_dump") else str(obs)})
|
|
657
|
-
self.logger.info(f"World reset complete: {obs}")
|
|
658
|
-
|
|
659
|
-
while True:
|
|
660
|
-
self._step_count += 1
|
|
661
|
-
|
|
662
|
-
# Execute step with automatic span tracking
|
|
663
|
-
# The span automatically sets itself as the current parent,
|
|
664
|
-
# so agent trajectories will nest under this step
|
|
665
|
-
async with _span(
|
|
666
|
-
f"step_{self._step_count}",
|
|
667
|
-
span_type="step",
|
|
668
|
-
source="world",
|
|
669
|
-
) as step_span:
|
|
670
|
-
self._current_step_id = step_span.event_id
|
|
671
|
-
step_span.log(f"Step {self._step_count} started")
|
|
672
|
-
result = await self.step()
|
|
673
|
-
step_span.set_extra(
|
|
674
|
-
{
|
|
675
|
-
"done": result.done,
|
|
676
|
-
"observation": result.observation.model_dump()
|
|
677
|
-
if hasattr(result.observation, "model_dump")
|
|
678
|
-
else str(result.observation),
|
|
679
|
-
"info": result.info,
|
|
680
|
-
}
|
|
681
|
-
)
|
|
760
|
+
# Get tracer for spans
|
|
761
|
+
tracer = get_tracer("plato.world")
|
|
682
762
|
|
|
683
|
-
|
|
763
|
+
# Create root session span that encompasses everything
|
|
764
|
+
# This ensures all child spans share the same trace_id
|
|
765
|
+
with tracer.start_as_current_span("session") as session_span:
|
|
766
|
+
session_span.set_attribute("plato.world.name", self.name)
|
|
767
|
+
session_span.set_attribute("plato.world.version", self.get_version())
|
|
768
|
+
session_span.set_attribute("plato.session.id", config.session_id)
|
|
684
769
|
|
|
685
|
-
|
|
686
|
-
#
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
770
|
+
try:
|
|
771
|
+
# Execute reset with OTel span
|
|
772
|
+
with tracer.start_as_current_span("reset") as reset_span:
|
|
773
|
+
obs = await self.reset()
|
|
774
|
+
obs_data = obs.model_dump() if hasattr(obs, "model_dump") else str(obs)
|
|
775
|
+
reset_span.set_attribute("plato.observation", str(obs_data)[:1000])
|
|
776
|
+
self.logger.info(f"World reset complete: {obs}")
|
|
691
777
|
|
|
692
|
-
|
|
693
|
-
|
|
778
|
+
while True:
|
|
779
|
+
self._step_count += 1
|
|
694
780
|
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
781
|
+
# Execute step with OTel span
|
|
782
|
+
with tracer.start_as_current_span(f"step_{self._step_count}") as step_span:
|
|
783
|
+
step_span.set_attribute("plato.step.number", self._step_count)
|
|
698
784
|
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
span_type="session_end",
|
|
702
|
-
content=f"World '{self.name}' completed after {self._step_count} steps",
|
|
703
|
-
source="world",
|
|
704
|
-
extra={"total_steps": self._step_count},
|
|
705
|
-
)
|
|
785
|
+
# Store span context for nested agent spans
|
|
786
|
+
self._current_step_id = format(step_span.get_span_context().span_id, "016x")
|
|
706
787
|
|
|
707
|
-
|
|
708
|
-
_reset_chronos_logging()
|
|
788
|
+
result = await self.step()
|
|
709
789
|
|
|
710
|
-
|
|
790
|
+
step_span.set_attribute("plato.step.done", result.done)
|
|
791
|
+
obs_data = (
|
|
792
|
+
result.observation.model_dump()
|
|
793
|
+
if hasattr(result.observation, "model_dump")
|
|
794
|
+
else str(result.observation)
|
|
795
|
+
)
|
|
796
|
+
step_span.set_attribute("plato.step.observation", str(obs_data)[:1000])
|
|
797
|
+
|
|
798
|
+
self.logger.info(f"Step {self._step_count}: done={result.done}")
|
|
799
|
+
|
|
800
|
+
# Create checkpoint if enabled and interval matches
|
|
801
|
+
if self.config.checkpoint.enabled and self._step_count % self.config.checkpoint.interval == 0:
|
|
802
|
+
self.logger.info(f"Creating checkpoint after step {self._step_count}")
|
|
803
|
+
with tracer.start_as_current_span("checkpoint") as checkpoint_span:
|
|
804
|
+
checkpoint_span.set_attribute("plato.checkpoint.step", self._step_count)
|
|
805
|
+
env_snapshots, state_bundle_uploaded = await self._create_and_upload_checkpoint()
|
|
806
|
+
|
|
807
|
+
checkpoint_span.set_attribute("plato.checkpoint.success", len(env_snapshots) > 0)
|
|
808
|
+
checkpoint_span.set_attribute(
|
|
809
|
+
"plato.checkpoint.state_bundle_uploaded", state_bundle_uploaded
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
if env_snapshots:
|
|
813
|
+
checkpoint_span.set_attribute(
|
|
814
|
+
"plato.checkpoint.environments", list(env_snapshots.keys())
|
|
815
|
+
)
|
|
816
|
+
checkpoint_span.set_attribute(
|
|
817
|
+
"plato.checkpoint.artifact_ids", list(env_snapshots.values())
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
if result.done:
|
|
821
|
+
break
|
|
822
|
+
|
|
823
|
+
finally:
|
|
824
|
+
await self.close()
|
|
825
|
+
await self._disconnect_plato_session()
|
|
826
|
+
|
|
827
|
+
# Shutdown OTel tracing and clear session info (outside the span)
|
|
828
|
+
shutdown_tracing()
|
|
829
|
+
self._session_id = None
|
|
830
|
+
|
|
831
|
+
self.logger.info(f"World '{self.name}' completed after {self._step_count} steps")
|
plato/worlds/config.py
CHANGED
|
@@ -126,13 +126,15 @@ class RunConfig(BaseModel):
|
|
|
126
126
|
|
|
127
127
|
Attributes:
|
|
128
128
|
session_id: Unique Chronos session identifier
|
|
129
|
-
|
|
129
|
+
otel_url: OTel endpoint URL (e.g., https://chronos.plato.so/api/otel)
|
|
130
|
+
upload_url: Presigned S3 URL for uploading artifacts (provided by Chronos)
|
|
130
131
|
plato_session: Serialized Plato session for connecting to existing VM session
|
|
131
132
|
checkpoint: Configuration for automatic checkpoints after steps
|
|
132
133
|
"""
|
|
133
134
|
|
|
134
135
|
session_id: str = ""
|
|
135
|
-
|
|
136
|
+
otel_url: str = "" # OTel endpoint URL
|
|
137
|
+
upload_url: str = "" # Presigned S3 URL for uploads
|
|
136
138
|
all_secrets: dict[str, str] = Field(default_factory=dict) # All secrets (world + agent)
|
|
137
139
|
|
|
138
140
|
# Serialized Plato session for connecting to VM and sending heartbeats
|
|
@@ -182,7 +184,7 @@ class RunConfig(BaseModel):
|
|
|
182
184
|
envs = []
|
|
183
185
|
|
|
184
186
|
# Skip runtime fields
|
|
185
|
-
runtime_fields = {"session_id", "
|
|
187
|
+
runtime_fields = {"session_id", "otel_url", "upload_url", "all_secrets", "plato_session", "checkpoint", "state"}
|
|
186
188
|
|
|
187
189
|
for field_name, prop_schema in properties.items():
|
|
188
190
|
if field_name in runtime_fields:
|