plato-sdk-v2 2.3.3__py3-none-any.whl → 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plato/worlds/base.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
+ import os
6
7
  import subprocess
7
8
  from abc import ABC, abstractmethod
8
9
  from pathlib import Path
@@ -16,15 +17,29 @@ if TYPE_CHECKING:
16
17
  from plato.v2.async_.environment import Environment
17
18
  from plato.v2.async_.session import Session
18
19
 
19
- from plato.agents.logging import init_logging as _init_chronos_logging
20
- from plato.agents.logging import log_event as _log_event
21
- from plato.agents.logging import reset_logging as _reset_chronos_logging
22
- from plato.agents.logging import span as _span
23
- from plato.agents.logging import upload_artifact as _upload_artifact
24
- from plato.agents.logging import upload_checkpoint as _upload_checkpoint
20
+ from plato.agents.artifacts import (
21
+ upload_artifact as _upload_artifact_raw,
22
+ )
23
+ from plato.agents.otel import (
24
+ get_tracer,
25
+ init_tracing,
26
+ shutdown_tracing,
27
+ )
28
+ from plato.agents.runner import run_agent as _run_agent_raw
25
29
 
26
30
  logger = logging.getLogger(__name__)
27
31
 
32
+
33
+ def _get_plato_version() -> str:
34
+ """Get the installed plato SDK version."""
35
+ try:
36
+ from importlib.metadata import version
37
+
38
+ return version("plato")
39
+ except Exception:
40
+ return "unknown"
41
+
42
+
28
43
  # Global registry of worlds
29
44
  _WORLD_REGISTRY: dict[str, type[BaseWorld]] = {}
30
45
 
@@ -111,6 +126,8 @@ class BaseWorld(ABC, Generic[ConfigT]):
111
126
  self._step_count: int = 0
112
127
  self.plato_session = None
113
128
  self._current_step_id: str | None = None
129
+ self._session_id: str | None = None
130
+ self._agent_containers: list[str] = [] # Track spawned agent containers for cleanup
114
131
 
115
132
  @classmethod
116
133
  def get_config_class(cls) -> type[RunConfig]:
@@ -170,7 +187,70 @@ class BaseWorld(ABC, Generic[ConfigT]):
170
187
 
171
188
  async def close(self) -> None:
172
189
  """Cleanup resources. Called after run completes."""
173
- pass
190
+ await self._cleanup_agent_containers()
191
+
192
+ async def _cleanup_agent_containers(self) -> None:
193
+ """Stop any agent containers spawned by this world."""
194
+ import asyncio
195
+
196
+ if not self._agent_containers:
197
+ return
198
+
199
+ self.logger.info(f"Stopping {len(self._agent_containers)} agent container(s)...")
200
+ for container_name in self._agent_containers:
201
+ try:
202
+ proc = await asyncio.create_subprocess_exec(
203
+ "docker",
204
+ "stop",
205
+ container_name,
206
+ stdout=asyncio.subprocess.DEVNULL,
207
+ stderr=asyncio.subprocess.DEVNULL,
208
+ )
209
+ await proc.wait()
210
+ self.logger.debug(f"Stopped container: {container_name}")
211
+ except Exception as e:
212
+ self.logger.warning(f"Failed to stop container {container_name}: {e}")
213
+ self._agent_containers.clear()
214
+ self.logger.info("Agent containers stopped")
215
+
216
+ async def run_agent(
217
+ self,
218
+ image: str,
219
+ config: dict,
220
+ secrets: dict[str, str],
221
+ instruction: str,
222
+ workspace: str | None = None,
223
+ logs_dir: str | None = None,
224
+ pull: bool = True,
225
+ ) -> str:
226
+ """Run an agent in a Docker container, tracking the container for cleanup.
227
+
228
+ This is a wrapper around plato.agents.runner.run_agent that automatically
229
+ tracks spawned containers so they can be cleaned up when the world closes.
230
+
231
+ Args:
232
+ image: Docker image URI
233
+ config: Agent configuration dict
234
+ secrets: Secret values (API keys, etc.)
235
+ instruction: Task instruction for the agent
236
+ workspace: Docker volume name for workspace
237
+ logs_dir: Ignored (kept for backwards compatibility)
238
+ pull: Whether to pull the image first
239
+
240
+ Returns:
241
+ The container name that was created
242
+ """
243
+ container_name = await _run_agent_raw(
244
+ image=image,
245
+ config=config,
246
+ secrets=secrets,
247
+ instruction=instruction,
248
+ workspace=workspace,
249
+ logs_dir=logs_dir,
250
+ pull=pull,
251
+ )
252
+ self._agent_containers.append(container_name)
253
+ return container_name
174
254
 
175
255
  async def _connect_plato_session(self) -> None:
176
256
  """Connect to Plato session from config.
@@ -390,17 +470,39 @@ class BaseWorld(ABC, Generic[ConfigT]):
390
470
  self.logger.warning(f"Failed to create state bundle: {e.stderr}")
391
471
  return None
392
472
 
393
- async def _create_and_upload_checkpoint(self) -> dict[str, Any] | None:
473
+ async def _upload_artifact(
474
+ self,
475
+ data: bytes,
476
+ content_type: str = "application/octet-stream",
477
+ ) -> bool:
478
+ """Upload an artifact directly to S3.
479
+
480
+ Args:
481
+ data: Raw bytes of the artifact
482
+ content_type: MIME type of the content
483
+
484
+ Returns:
485
+ True if successful, False otherwise
486
+ """
487
+ if not self.config.upload_url:
488
+ self.logger.warning("Cannot upload artifact: upload_url not set")
489
+ return False
490
+ return await _upload_artifact_raw(
491
+ upload_url=self.config.upload_url,
492
+ data=data,
493
+ content_type=content_type,
494
+ )
495
+
496
+ async def _create_and_upload_checkpoint(self) -> tuple[dict[str, str], bool]:
394
497
  """Create a full checkpoint including env snapshots and state bundle.
395
498
 
396
499
  This method:
397
500
  1. Commits any pending state changes
398
501
  2. Creates env snapshots using snapshot_store
399
- 3. Creates and uploads state bundle as an artifact
400
- 4. Calls the checkpoint endpoint with all data
502
+ 3. Creates and uploads state bundle to S3
401
503
 
402
504
  Returns:
403
- Checkpoint result dict if successful, None otherwise.
505
+ Tuple of (env_snapshots dict, state_bundle_uploaded bool)
404
506
  """
405
507
  # Commit state changes first
406
508
  self._commit_state(f"Checkpoint at step {self._step_count}")
@@ -410,36 +512,24 @@ class BaseWorld(ABC, Generic[ConfigT]):
410
512
  if env_snapshots is None:
411
513
  env_snapshots = {}
412
514
 
515
+ state_bundle_uploaded = True # Default to True if state not enabled
516
+
413
517
  # Create and upload state bundle
414
- state_artifact_id: str | None = None
415
518
  if self.config.state.enabled:
416
519
  bundle_data = self._create_state_bundle()
417
520
  if bundle_data:
418
- result = await _upload_artifact(
521
+ success = await self._upload_artifact(
419
522
  data=bundle_data,
420
- artifact_type="state",
421
- filename=f"state_step_{self._step_count}.bundle",
422
- extra={
423
- "step_number": self._step_count,
424
- "state_path": self.config.state.path,
425
- },
523
+ content_type="application/octet-stream",
426
524
  )
427
- if result:
428
- state_artifact_id = result.get("artifact_id")
429
- self.logger.info(f"Uploaded state artifact: {state_artifact_id}")
430
-
431
- # Upload checkpoint with all data
432
- checkpoint_result = await _upload_checkpoint(
433
- step_number=self._step_count,
434
- env_snapshots=env_snapshots,
435
- state_artifact_id=state_artifact_id,
436
- extra={
437
- "world_name": self.name,
438
- "world_version": self.get_version(),
439
- },
440
- )
525
+ if success:
526
+ self.logger.info(f"Uploaded state bundle at step {self._step_count}")
527
+ state_bundle_uploaded = True
528
+ else:
529
+ self.logger.warning(f"Failed to upload state bundle at step {self._step_count}")
530
+ state_bundle_uploaded = False
441
531
 
442
- return checkpoint_result
532
+ return env_snapshots, state_bundle_uploaded
443
533
 
444
534
  def get_env(self, alias: str) -> Environment | None:
445
535
  """Get an environment by alias.
@@ -630,81 +720,112 @@ The following services are available for your use:
630
720
  # Initialize state directory (creates git repo if needed)
631
721
  self._init_state_directory()
632
722
 
633
- # Initialize the logging singleton for agents to use
634
- if config.callback_url and config.session_id:
635
- _init_chronos_logging(
636
- callback_url=config.callback_url,
637
- session_id=config.session_id,
638
- )
723
+ # Initialize OTel tracing and session info for artifact uploads
724
+ if config.session_id:
725
+ self._session_id = config.session_id
726
+
727
+ # Set environment variables for agent runners (which run in Docker)
728
+ os.environ["SESSION_ID"] = config.session_id
729
+ if config.otel_url:
730
+ # For agents in Docker, convert localhost to host.docker.internal
731
+ # so they can reach the host machine's Chronos instance
732
+ agent_otel_url = config.otel_url
733
+ if "localhost" in agent_otel_url or "127.0.0.1" in agent_otel_url:
734
+ agent_otel_url = agent_otel_url.replace("localhost", "host.docker.internal")
735
+ agent_otel_url = agent_otel_url.replace("127.0.0.1", "host.docker.internal")
736
+ os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = agent_otel_url
737
+ os.environ["OTEL_EXPORTER_OTLP_PROTOCOL"] = "http/protobuf"
738
+ if config.upload_url:
739
+ os.environ["UPLOAD_URL"] = config.upload_url
740
+
741
+ # Initialize OTel tracing for the world itself (runs on host, not in Docker)
742
+ if config.otel_url:
743
+ logger.debug(f"Initializing OTel tracing with endpoint: {config.otel_url}")
744
+ init_tracing(
745
+ service_name=f"world-{self.name}",
746
+ session_id=config.session_id,
747
+ otlp_endpoint=config.otel_url,
748
+ )
749
+ else:
750
+ logger.debug("No otel_url in config - OTel tracing disabled")
751
+
752
+ # Log version info (goes to OTel after init_tracing)
753
+ plato_version = _get_plato_version()
754
+ world_version = self.get_version()
755
+ self.logger.info(f"World version: {world_version}, Plato SDK version: {plato_version}")
639
756
 
640
757
  # Connect to Plato session if configured (for heartbeats)
641
758
  await self._connect_plato_session()
642
759
 
643
- # Log session start
644
- await _log_event(
645
- span_type="session_start",
646
- content=f"World '{self.name}' started",
647
- source="world",
648
- extra={"world_name": self.name, "world_version": self.get_version()},
649
- )
650
-
651
- try:
652
- # Execute reset with automatic span tracking
653
- async with _span("reset", span_type="reset", source="world") as reset_span:
654
- reset_span.log(f"Resetting world '{self.name}'")
655
- obs = await self.reset()
656
- reset_span.set_extra({"observation": obs.model_dump() if hasattr(obs, "model_dump") else str(obs)})
657
- self.logger.info(f"World reset complete: {obs}")
658
-
659
- while True:
660
- self._step_count += 1
661
-
662
- # Execute step with automatic span tracking
663
- # The span automatically sets itself as the current parent,
664
- # so agent trajectories will nest under this step
665
- async with _span(
666
- f"step_{self._step_count}",
667
- span_type="step",
668
- source="world",
669
- ) as step_span:
670
- self._current_step_id = step_span.event_id
671
- step_span.log(f"Step {self._step_count} started")
672
- result = await self.step()
673
- step_span.set_extra(
674
- {
675
- "done": result.done,
676
- "observation": result.observation.model_dump()
677
- if hasattr(result.observation, "model_dump")
678
- else str(result.observation),
679
- "info": result.info,
680
- }
681
- )
760
+ # Get tracer for spans
761
+ tracer = get_tracer("plato.world")
682
762
 
683
- self.logger.info(f"Step {self._step_count}: done={result.done}")
763
+ # Create root session span that encompasses everything
764
+ # This ensures all child spans share the same trace_id
765
+ with tracer.start_as_current_span("session") as session_span:
766
+ session_span.set_attribute("plato.world.name", self.name)
767
+ session_span.set_attribute("plato.world.version", self.get_version())
768
+ session_span.set_attribute("plato.session.id", config.session_id)
684
769
 
685
- # Create checkpoint if enabled and interval matches
686
- # Note: The checkpoint event is created by the callback endpoint,
687
- # so we don't need a span wrapper here (would create duplicates)
688
- if self.config.checkpoint.enabled and self._step_count % self.config.checkpoint.interval == 0:
689
- self.logger.info(f"Creating checkpoint after step {self._step_count}")
690
- await self._create_and_upload_checkpoint()
770
+ try:
771
+ # Execute reset with OTel span
772
+ with tracer.start_as_current_span("reset") as reset_span:
773
+ obs = await self.reset()
774
+ obs_data = obs.model_dump() if hasattr(obs, "model_dump") else str(obs)
775
+ reset_span.set_attribute("plato.observation", str(obs_data)[:1000])
776
+ self.logger.info(f"World reset complete: {obs}")
691
777
 
692
- if result.done:
693
- break
778
+ while True:
779
+ self._step_count += 1
694
780
 
695
- finally:
696
- await self.close()
697
- await self._disconnect_plato_session()
781
+ # Execute step with OTel span
782
+ with tracer.start_as_current_span(f"step_{self._step_count}") as step_span:
783
+ step_span.set_attribute("plato.step.number", self._step_count)
698
784
 
699
- # Log session end
700
- await _log_event(
701
- span_type="session_end",
702
- content=f"World '{self.name}' completed after {self._step_count} steps",
703
- source="world",
704
- extra={"total_steps": self._step_count},
705
- )
785
+ # Store span context for nested agent spans
786
+ self._current_step_id = format(step_span.get_span_context().span_id, "016x")
706
787
 
707
- # Reset the logging singleton
708
- _reset_chronos_logging()
788
+ result = await self.step()
709
789
 
710
- self.logger.info(f"World '{self.name}' completed after {self._step_count} steps")
790
+ step_span.set_attribute("plato.step.done", result.done)
791
+ obs_data = (
792
+ result.observation.model_dump()
793
+ if hasattr(result.observation, "model_dump")
794
+ else str(result.observation)
795
+ )
796
+ step_span.set_attribute("plato.step.observation", str(obs_data)[:1000])
797
+
798
+ self.logger.info(f"Step {self._step_count}: done={result.done}")
799
+
800
+ # Create checkpoint if enabled and interval matches
801
+ if self.config.checkpoint.enabled and self._step_count % self.config.checkpoint.interval == 0:
802
+ self.logger.info(f"Creating checkpoint after step {self._step_count}")
803
+ with tracer.start_as_current_span("checkpoint") as checkpoint_span:
804
+ checkpoint_span.set_attribute("plato.checkpoint.step", self._step_count)
805
+ env_snapshots, state_bundle_uploaded = await self._create_and_upload_checkpoint()
806
+
807
+ checkpoint_span.set_attribute("plato.checkpoint.success", len(env_snapshots) > 0)
808
+ checkpoint_span.set_attribute(
809
+ "plato.checkpoint.state_bundle_uploaded", state_bundle_uploaded
810
+ )
811
+
812
+ if env_snapshots:
813
+ checkpoint_span.set_attribute(
814
+ "plato.checkpoint.environments", list(env_snapshots.keys())
815
+ )
816
+ checkpoint_span.set_attribute(
817
+ "plato.checkpoint.artifact_ids", list(env_snapshots.values())
818
+ )
819
+
820
+ if result.done:
821
+ break
822
+
823
+ finally:
824
+ await self.close()
825
+ await self._disconnect_plato_session()
826
+
827
+ # Shutdown OTel tracing and clear session info (outside the span)
828
+ shutdown_tracing()
829
+ self._session_id = None
830
+
831
+ self.logger.info(f"World '{self.name}' completed after {self._step_count} steps")
plato/worlds/config.py CHANGED
@@ -126,13 +126,15 @@ class RunConfig(BaseModel):
126
126
 
127
127
  Attributes:
128
128
  session_id: Unique Chronos session identifier
129
- callback_url: Callback URL for status updates
129
+ otel_url: OTel endpoint URL (e.g., https://chronos.plato.so/api/otel)
130
+ upload_url: Presigned S3 URL for uploading artifacts (provided by Chronos)
130
131
  plato_session: Serialized Plato session for connecting to existing VM session
131
132
  checkpoint: Configuration for automatic checkpoints after steps
132
133
  """
133
134
 
134
135
  session_id: str = ""
135
- callback_url: str = ""
136
+ otel_url: str = "" # OTel endpoint URL
137
+ upload_url: str = "" # Presigned S3 URL for uploads
136
138
  all_secrets: dict[str, str] = Field(default_factory=dict) # All secrets (world + agent)
137
139
 
138
140
  # Serialized Plato session for connecting to VM and sending heartbeats
@@ -182,7 +184,7 @@ class RunConfig(BaseModel):
182
184
  envs = []
183
185
 
184
186
  # Skip runtime fields
185
- runtime_fields = {"session_id", "callback_url", "all_secrets", "plato_session", "checkpoint", "state"}
187
+ runtime_fields = {"session_id", "otel_url", "upload_url", "all_secrets", "plato_session", "checkpoint", "state"}
186
188
 
187
189
  for field_name, prop_schema in properties.items():
188
190
  if field_name in runtime_fields: