plato-sdk-v2 2.8.7__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,11 +39,17 @@ class Environment:
39
39
  job_id: str,
40
40
  alias: str,
41
41
  artifact_id: str | None = None,
42
+ simulator: str | None = None,
43
+ status: str | None = None,
44
+ public_url: str | None = None,
42
45
  ):
43
46
  self._session = session
44
47
  self.job_id = job_id
45
48
  self.alias = alias
46
49
  self.artifact_id = artifact_id
50
+ self.simulator = simulator
51
+ self.status = status
52
+ self.public_url = public_url
47
53
 
48
54
  @property
49
55
  def _http(self):
plato/v2/sync/sandbox.py CHANGED
@@ -53,7 +53,6 @@ from plato._generated.models import (
53
53
  )
54
54
  from plato.v2.async_.flow_executor import FlowExecutor
55
55
  from plato.v2.models import SandboxState
56
- from plato.v2.sync.client import Plato
57
56
  from plato.v2.types import Env, EnvFromArtifact, EnvFromResource, EnvFromSimulator, SimConfigCompute
58
57
 
59
58
  logger = logging.getLogger(__name__)
@@ -207,6 +206,8 @@ def _run_ssh_command(
207
206
  def _start_heartbeat_process(session_id: str, api_key: str) -> int | None:
208
207
  """Start a background process that sends heartbeats.
209
208
 
209
+ Uses only stdlib (urllib) to work on any machine without dependencies.
210
+
210
211
  Returns:
211
212
  PID of the background process, or None if failed.
212
213
  """
@@ -217,10 +218,12 @@ def _start_heartbeat_process(session_id: str, api_key: str) -> int | None:
217
218
  base_url = base_url[:-4]
218
219
  base_url = base_url.rstrip("/")
219
220
 
221
+ # Use only stdlib - no external dependencies
220
222
  heartbeat_script = f'''
221
223
  import time
222
- import os
223
- import httpx
224
+ import json
225
+ import urllib.request
226
+ import urllib.error
224
227
  from datetime import datetime
225
228
 
226
229
  session_id = "{session_id}"
@@ -235,23 +238,27 @@ def log(msg):
235
238
  f.flush()
236
239
 
237
240
  log(f"Heartbeat process started for session {{session_id}}")
241
+ log(f"URL: {{base_url}}/api/v2/sessions/{{session_id}}/heartbeat")
238
242
 
239
243
  heartbeat_count = 0
240
244
  while True:
241
245
  heartbeat_count += 1
242
246
  try:
243
247
  url = f"{{base_url}}/api/v2/sessions/{{session_id}}/heartbeat"
244
- with httpx.Client(timeout=30) as client:
245
- response = client.post(
246
- url,
247
- headers={{"X-API-Key": api_key}},
248
- )
249
- if response.status_code == 200:
250
- result = response.json()
251
- success = result.get("success", False)
252
- log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, success={{success}}")
253
- else:
254
- log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, url={{url}}, body={{response.text[:500]}}")
248
+ req = urllib.request.Request(
249
+ url,
250
+ method="POST",
251
+ headers={{"X-API-Key": api_key, "Content-Type": "application/json"}},
252
+ data=b"{{}}",
253
+ )
254
+ with urllib.request.urlopen(req, timeout=30) as resp:
255
+ status = resp.status
256
+ body = resp.read().decode("utf-8")
257
+ result = json.loads(body)
258
+ success = result.get("success", False)
259
+ log(f"Heartbeat #{{heartbeat_count}}: status={{status}}, success={{success}}")
260
+ except urllib.error.HTTPError as e:
261
+ log(f"Heartbeat #{{heartbeat_count}}: HTTP {{e.code}} - {{e.reason}}")
255
262
  except Exception as e:
256
263
  log(f"Heartbeat #{{heartbeat_count}} EXCEPTION: {{type(e).__name__}}: {{e}}")
257
264
  time.sleep(30)
@@ -290,6 +297,77 @@ class SyncResult(BaseModel):
290
297
  bytes_synced: int
291
298
 
292
299
 
300
+ class SSHConfigInfo(BaseModel):
301
+ """SSH config information for connecting to a job."""
302
+
303
+ config_content: str
304
+ private_key_path: str
305
+ job_id: str
306
+ gateway_host: str
307
+
308
+
309
+ def _generate_temp_ssh_key_pair() -> tuple[str, str]:
310
+ """Generate a temporary SSH key pair.
311
+
312
+ Returns:
313
+ Tuple of (public_key_content, private_key_path).
314
+ """
315
+ # Create temp directory for keys
316
+ temp_dir = tempfile.mkdtemp(prefix="plato_ssh_")
317
+ private_key_path = os.path.join(temp_dir, "id_ed25519")
318
+
319
+ # Generate key pair
320
+ subprocess.run(
321
+ [
322
+ "ssh-keygen",
323
+ "-t",
324
+ "ed25519",
325
+ "-f",
326
+ private_key_path,
327
+ "-N",
328
+ "",
329
+ "-q",
330
+ ],
331
+ check=True,
332
+ )
333
+
334
+ # Read public key
335
+ public_key = Path(f"{private_key_path}.pub").read_text().strip()
336
+
337
+ return public_key, private_key_path
338
+
339
+
340
+ def _generate_ssh_config_content(job_id: str, private_key_path: str) -> str:
341
+ """Generate SSH config content for a job.
342
+
343
+ Args:
344
+ job_id: The job ID for routing.
345
+ private_key_path: Path to private key.
346
+
347
+ Returns:
348
+ SSH config content as a string.
349
+ """
350
+ gateway_host = os.getenv("PLATO_GATEWAY_HOST", "gateway.plato.so")
351
+
352
+ # SNI format: {job_id}--{port}.{gateway_host}
353
+ ssh_port = 22
354
+ sni = f"{job_id}--{ssh_port}.{gateway_host}"
355
+
356
+ config_content = f"""# Plato SSH Config for job: {job_id}
357
+ # Generated dynamically for -J/--job-id option
358
+
359
+ Host sandbox
360
+ HostName {job_id}
361
+ User root
362
+ IdentityFile {private_key_path}
363
+ StrictHostKeyChecking no
364
+ UserKnownHostsFile /dev/null
365
+ LogLevel ERROR
366
+ ProxyCommand openssl s_client -quiet -connect {gateway_host}:443 -servername {sni} 2>/dev/null
367
+ """
368
+ return config_content
369
+
370
+
293
371
  # =============================================================================
294
372
  # TUNNEL
295
373
  # =============================================================================
@@ -569,24 +647,26 @@ class SandboxClient:
569
647
  env_config: EnvFromSimulator | EnvFromArtifact | EnvFromResource
570
648
 
571
649
  if mode == "artifact" and artifact_id:
572
- self.console.print(f"Starting from artifact: {artifact_id}")
650
+ self.console.print(f"[cyan]Mode:[/cyan] artifact ({artifact_id})")
573
651
  env_config = Env.artifact(artifact_id)
574
652
  elif mode == "simulator" and simulator_name:
575
- self.console.print(f"Starting from simulator: {simulator_name}")
653
+ self.console.print(f"[cyan]Mode:[/cyan] simulator ({simulator_name}:{tag})")
576
654
  env_config = Env.simulator(simulator_name, tag=tag, dataset=dataset)
577
- elif mode == "blank" and simulator_name:
578
- self.console.print("Starting from blank")
655
+ elif mode == "blank":
656
+ # Use provided simulator_name or default to "sandbox"
657
+ sim_name = simulator_name or "sandbox"
658
+ self.console.print(f"[cyan]Mode:[/cyan] blank VM ({sim_name})")
659
+ self.console.print(f"[dim] cpus={cpus}, memory={memory}MB, disk={disk}MB[/dim]")
579
660
  sim_config = SimConfigCompute(
580
661
  cpus=cpus, memory=memory, disk=disk, app_port=app_port, plato_messaging_port=messaging_port
581
662
  )
582
- env_config = Env.resource(simulator_name, sim_config)
663
+ env_config = Env.resource(sim_name, sim_config)
583
664
  elif mode == "config":
584
- self.console.print("Starting from config")
665
+ self.console.print("[cyan]Mode:[/cyan] config (plato-config.yml)")
585
666
  # read plato-config.yml
586
667
  plato_config_path = self.working_dir / "plato-config.yml"
587
668
  with open(plato_config_path, "rb") as f:
588
669
  plato_config = yaml.safe_load(f)
589
- self.console.print(f"plato-config: {plato_config}")
590
670
  plato_config_model = PlatoConfig.model_validate(plato_config)
591
671
  dataset_config = plato_config_model.datasets[dataset]
592
672
  simulator_name = plato_config_model.service
@@ -609,24 +689,67 @@ class SandboxClient:
609
689
  plato_messaging_port=config_messaging_port,
610
690
  )
611
691
  env_config = Env.resource(simulator_name, sim_config)
612
- self.console.print(f"env_config: {env_config}")
613
692
  else:
614
693
  raise ValueError(f"Invalid mode '{mode}' or missing required parameter")
615
694
 
616
- # Use Plato SDK to create session (handles create, wait, network)
617
- self.console.print(f"Creating session and waiting for VM to become ready (timeout={timeout}s)...")
618
- plato = Plato(api_key=self.api_key, http_client=self._http)
619
- session = plato.sessions.create(
620
- envs=[env_config],
621
- connect_network=connect_network,
695
+ # Track total time
696
+ total_start = time.time()
697
+
698
+ # The sessions.create handles: create -> wait_for_ready -> connect_network
699
+ # But we want to show progress, so we'll do it step by step
700
+ from plato._generated.api.v2.sessions import make as sessions_make
701
+ from plato._generated.api.v2.sessions import wait_for_ready as sessions_wait_for_ready
702
+ from plato._generated.models import CreateSessionFromEnvs, Envs, RunSessionSource
703
+
704
+ # Step 1: Create session
705
+ self.console.print("[yellow]Creating session...[/yellow]")
706
+ step_start = time.time()
707
+ request_body = CreateSessionFromEnvs(
708
+ envs=[Envs(root=env_config)],
622
709
  timeout=timeout,
710
+ source=RunSessionSource.SDK,
711
+ )
712
+ response = sessions_make.sync(
713
+ client=self._http,
714
+ body=request_body,
715
+ x_api_key=self.api_key,
623
716
  )
624
- self.console.print(f"session: {session}")
625
- session_id = session.session_id
626
- job_id = session.envs[0].job_id if session.envs else None
717
+ session_id = response.session_id
718
+ elapsed = time.time() - step_start
719
+ self.console.print(f"[green]Session created:[/green] {session_id} [dim]({elapsed:.1f}s)[/dim]")
720
+
721
+ # Check if any envs failed to create
722
+ if response.envs:
723
+ for env_result in response.envs:
724
+ if not env_result.success:
725
+ raise RuntimeError(f"Failed to create environment: {env_result.error}")
726
+ if env_result.job_id:
727
+ self.console.print(f"[dim] Job: {env_result.job_id}[/dim]")
728
+ else:
729
+ raise RuntimeError("No environments created in session")
730
+
731
+ # Step 2: Wait for VM
732
+ self.console.print("[yellow]Waiting for VM to start...[/yellow]")
733
+ step_start = time.time()
734
+ ready_response = sessions_wait_for_ready.sync(
735
+ client=self._http,
736
+ session_id=session_id,
737
+ timeout=timeout,
738
+ x_api_key=self.api_key,
739
+ )
740
+ if not ready_response.ready:
741
+ errors = []
742
+ if ready_response.results:
743
+ for jid, result in ready_response.results.items():
744
+ if not result.ready:
745
+ errors.append(f"{jid}: {result.error or 'Unknown error'}")
746
+ raise RuntimeError(f"VM failed to start: {', '.join(errors) if errors else 'timeout'}")
747
+
748
+ job_id = response.envs[0].job_id if response.envs else None
627
749
  if not job_id:
628
750
  raise ValueError("No job ID found")
629
- self.console.print(f"job_id: {job_id}")
751
+ elapsed = time.time() - step_start
752
+ self.console.print(f"[green]VM ready:[/green] {job_id} [dim]({elapsed:.1f}s)[/dim]")
630
753
 
631
754
  # For artifact mode, we need to get simulator_name from session details BEFORE generating public URL
632
755
  # Note: get_session_details returns a dict, not a Pydantic model
@@ -670,10 +793,14 @@ class SandboxClient:
670
793
  else:
671
794
  url = f"{url}?{target_param}"
672
795
  public_url = url
796
+ elapsed = time.time() - step_start
797
+ self.console.print(f"[green]Public URL:[/green] {public_url} [dim]({elapsed:.1f}s)[/dim]")
673
798
  except Exception as e:
674
- raise ValueError(f"Error getting public URL: {e}") from e
799
+ self.console.print(f"[dim]Public URL not available: {e}[/dim]")
675
800
 
676
801
  # Setup SSH
802
+ self.console.print("[yellow]Setting up SSH...[/yellow]")
803
+ step_start = time.time()
677
804
  ssh_config_path = None
678
805
  try:
679
806
  public_key, private_key_path = _generate_ssh_key_pair(session_id[:8], Path(self.working_dir))
@@ -688,12 +815,24 @@ class SandboxClient:
688
815
 
689
816
  if add_response.success:
690
817
  ssh_config_path = _generate_ssh_config(job_id, private_key_path, Path(self.working_dir))
818
+ elapsed = time.time() - step_start
819
+ self.console.print(
820
+ f"[green]SSH configured:[/green] ssh -F .plato/ssh_config sandbox [dim]({elapsed:.1f}s)[/dim]"
821
+ )
822
+ else:
823
+ self.console.print("[dim]SSH key upload failed[/dim]")
691
824
  except Exception as e:
692
- logger.warning(f"SSH setup failed: {e}")
825
+ self.console.print(f"[dim]SSH setup failed: {e}[/dim]")
693
826
 
694
827
  # Start heartbeat
695
- heartbeat_pid = None
828
+ self.console.print("[yellow]Starting heartbeat...[/yellow]")
829
+ step_start = time.time()
696
830
  heartbeat_pid = _start_heartbeat_process(session_id, self.api_key)
831
+ elapsed = time.time() - step_start
832
+ if heartbeat_pid:
833
+ self.console.print(f"[green]Heartbeat started[/green] (pid={heartbeat_pid}) [dim]({elapsed:.1f}s)[/dim]")
834
+ else:
835
+ self.console.print("[dim]Heartbeat failed to start[/dim]")
697
836
 
698
837
  # Convert absolute paths to relative for state storage
699
838
  def _to_relative(abs_path: str | None) -> str | None:
@@ -718,6 +857,7 @@ class SandboxClient:
718
857
  heartbeat_pid=heartbeat_pid,
719
858
  simulator_name=simulator_name,
720
859
  dataset=dataset,
860
+ network_connected=connect_network,
721
861
  )
722
862
  if mode == "artifact":
723
863
  sandbox_state.artifact_id = artifact_id
@@ -740,6 +880,10 @@ class SandboxClient:
740
880
  with open(self.working_dir / self.PLATO_DIR / "state.json", "w") as f:
741
881
  json.dump(sandbox_state.model_dump(), f)
742
882
 
883
+ total_elapsed = time.time() - total_start
884
+ self.console.print("")
885
+ self.console.print(f"[bold green]Sandbox ready![/bold green] [dim](total: {total_elapsed:.1f}s)[/dim]")
886
+
743
887
  return sandbox_state
744
888
 
745
889
  # CHECKED
@@ -784,7 +928,7 @@ class SandboxClient:
784
928
  plato_config_model = PlatoConfig.model_validate(plato_config)
785
929
  dataset_config = plato_config_model.datasets[dataset]
786
930
  # Convert dataset config back to dict for YAML serialization
787
- dataset_dict = dataset_config.model_dump(exclude_none=True, by_alias=True)
931
+ dataset_dict = dataset_config.model_dump(exclude_none=True, by_alias=True, mode="json")
788
932
  checkpoint_request.plato_config = yaml.dump(dataset_dict, default_flow_style=False)
789
933
 
790
934
  dataset_compute = dataset_config.compute
@@ -839,7 +983,7 @@ class SandboxClient:
839
983
 
840
984
  # Convert AppApiV2SchemasArtifactSimConfigDataset to AppSchemasBuildModelsSimConfigDataset
841
985
  # They have compatible fields but different nested types
842
- dataset_config_dict = dataset_config.model_dump(exclude_none=True)
986
+ dataset_config_dict = dataset_config.model_dump(exclude_none=True, mode="json")
843
987
 
844
988
  _ = start_worker.sync(
845
989
  client=self._http,
@@ -1011,6 +1155,60 @@ class SandboxClient:
1011
1155
  bind_address=bind_address,
1012
1156
  )
1013
1157
 
1158
+ def get_ssh_config_for_job(self, job_id: str) -> SSHConfigInfo:
1159
+ """Get SSH config for connecting to a specific job.
1160
+
1161
+ Generates a temporary SSH key pair, adds the public key to the VM,
1162
+ and returns an SSH config that routes through the Plato gateway.
1163
+
1164
+ Args:
1165
+ job_id: The job public ID to connect to.
1166
+
1167
+ Returns:
1168
+ SSHConfigInfo with the config content, private key path, and metadata.
1169
+
1170
+ Note:
1171
+ The caller is responsible for cleaning up the temporary key files
1172
+ after the SSH session ends. The private key is stored in a temp
1173
+ directory that should be removed when done.
1174
+ """
1175
+ from plato._generated.api.v2.jobs import execute as jobs_execute
1176
+ from plato._generated.models import ExecuteCommandRequest
1177
+
1178
+ gateway_host = os.getenv("PLATO_GATEWAY_HOST", "gateway.plato.so")
1179
+
1180
+ # Generate temp SSH key pair
1181
+ public_key, private_key_path = _generate_temp_ssh_key_pair()
1182
+
1183
+ # Add public key to the VM via execute
1184
+ add_key_cmd = f'mkdir -p /root/.ssh && chmod 700 /root/.ssh && echo "{public_key}" >> /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys'
1185
+ try:
1186
+ jobs_execute.sync(
1187
+ client=self._http,
1188
+ job_id=job_id,
1189
+ body=ExecuteCommandRequest(command=add_key_cmd, timeout=30),
1190
+ x_api_key=self.api_key,
1191
+ )
1192
+ except Exception as e:
1193
+ # Clean up temp key on failure
1194
+ try:
1195
+ import shutil
1196
+
1197
+ shutil.rmtree(os.path.dirname(private_key_path))
1198
+ except Exception:
1199
+ pass
1200
+ raise RuntimeError(f"Failed to add SSH key to job {job_id}: {e}") from e
1201
+
1202
+ # Generate SSH config
1203
+ config_content = _generate_ssh_config_content(job_id, private_key_path)
1204
+
1205
+ return SSHConfigInfo(
1206
+ config_content=config_content,
1207
+ private_key_path=private_key_path,
1208
+ job_id=job_id,
1209
+ gateway_host=gateway_host,
1210
+ )
1211
+
1014
1212
  def run_audit_ui(
1015
1213
  self,
1016
1214
  job_id: str | None = None,
plato/v2/sync/session.py CHANGED
@@ -438,6 +438,8 @@ class Session:
438
438
  job_id=ctx.job_id,
439
439
  alias=ctx.alias,
440
440
  artifact_id=ctx.artifact_id,
441
+ simulator=ctx.simulator,
442
+ status="running", # Environments are running after from_envs completes
441
443
  )
442
444
  for ctx in env_contexts
443
445
  ]
@@ -640,6 +642,7 @@ class Session:
640
642
  override_service: str | None = None,
641
643
  override_version: str | None = None,
642
644
  override_dataset: str | None = None,
645
+ target: str | None = None,
643
646
  ) -> CreateDiskSnapshotResponse:
644
647
  """Create a disk-only snapshot of all environments in the session.
645
648
 
@@ -647,10 +650,13 @@ class Session:
647
650
  will do a fresh boot with the preserved disk state. This is faster to
648
651
  create and smaller to store than full snapshots.
649
652
 
653
+ Uses snapshot-store backend for chunk-based deduplication and efficient storage.
654
+
650
655
  Args:
651
656
  override_service: Override simulator/service name in artifact metadata.
652
657
  override_version: Override version/git_hash in artifact metadata.
653
658
  override_dataset: Override dataset name in artifact metadata.
659
+ target: Target domain for routing (e.g., "sims.plato.so").
654
660
 
655
661
  Returns:
656
662
  CreateDiskSnapshotResponse with artifact_id per job_id.
@@ -664,6 +670,7 @@ class Session:
664
670
  override_service=override_service,
665
671
  override_version=override_version,
666
672
  override_dataset=override_dataset,
673
+ target=target,
667
674
  ),
668
675
  x_api_key=self._api_key,
669
676
  )
@@ -793,6 +800,8 @@ class Session:
793
800
  job_id=job_id,
794
801
  alias=env.alias,
795
802
  artifact_id=response.env.artifact_id,
803
+ simulator=getattr(env, "simulator", None),
804
+ status="running", # Newly added environments are running
796
805
  )
797
806
 
798
807
  logger.info(f"Added job {job_id} (alias={env.alias}) to session {self.session_id}")
plato/v2/types.py CHANGED
@@ -30,17 +30,27 @@ class Env:
30
30
  def simulator(
31
31
  simulator: str,
32
32
  *,
33
- tag: str = "latest",
33
+ tag: str | None = None,
34
+ version: str | None = None,
34
35
  dataset: str | None = None,
35
36
  alias: str | None = None,
37
+ restore_memory: bool = True,
36
38
  ) -> EnvFromSimulator:
37
- """Create env from simulator with tag.
39
+ """Create env from simulator with tag or version lookup.
40
+
41
+ Supports string formats:
42
+ - "simname" -> uses 'latest' tag
43
+ - "simname:tag" -> uses specified tag
44
+ - "simname:version@dataset" -> uses version (git_hash) lookup
38
45
 
39
46
  Args:
40
- simulator: Simulator name, or "env:tag" format (e.g., "espocrm:staging")
41
- tag: Artifact tag (default: "latest"). Ignored if simulator contains ":"
47
+ simulator: Simulator name with optional tag/version format
48
+ tag: Artifact tag. If neither tag nor version provided, defaults to 'latest'.
49
+ version: Artifact version (git_hash). If provided, looks up by version instead of tag.
42
50
  dataset: Dataset name (e.g., "base", "blank"). If not specified, uses default.
43
51
  alias: Custom name for this environment
52
+ restore_memory: If True (default), resume from memory snapshot.
53
+ If False, do a fresh boot with disk state only (for disk snapshots).
44
54
 
45
55
  Returns:
46
56
  EnvFromSimulator
@@ -48,21 +58,33 @@ class Env:
48
58
  Examples:
49
59
  >>> Env.simulator("espocrm") # -> uses "latest" tag
50
60
  >>> Env.simulator("espocrm:staging") # -> uses "staging" tag
51
- >>> Env.simulator("espocrm", tag="staging") # -> uses "staging" tag
52
- >>> Env.simulator("gitea", dataset="blank") # -> uses "blank" dataset
61
+ >>> Env.simulator("espocrm:v1@base") # -> uses version "v1", dataset "base"
62
+ >>> Env.simulator("espocrm:v1@base", restore_memory=False) # disk snapshot
53
63
  """
54
- # Support "env:tag" format
55
- if ":" in simulator:
64
+ sim_name = simulator
65
+
66
+ # Parse "simname:version@dataset" format (version lookup)
67
+ if "@" in simulator:
68
+ parts = simulator.split("@", 1)
69
+ dataset = dataset or parts[1]
70
+ if ":" in parts[0]:
71
+ sim_name, version = parts[0].split(":", 1)
72
+ else:
73
+ sim_name = parts[0]
74
+ # Parse "simname:tag" format (tag lookup)
75
+ elif ":" in simulator:
56
76
  sim_name, tag = simulator.split(":", 1)
57
- else:
58
- sim_name = simulator
59
77
 
60
- # Build kwargs, only including dataset if specified
78
+ # Build kwargs
61
79
  kwargs: dict[str, Any] = {
62
80
  "simulator": sim_name,
63
- "tag": tag,
64
81
  "alias": alias,
82
+ "restore_memory": restore_memory,
65
83
  }
84
+ if tag is not None:
85
+ kwargs["tag"] = tag
86
+ if version is not None:
87
+ kwargs["version"] = version
66
88
  if dataset is not None:
67
89
  kwargs["dataset"] = dataset
68
90
 
@@ -94,27 +116,36 @@ class Env:
94
116
  @staticmethod
95
117
  def resource(
96
118
  simulator: str,
97
- sim_config: SimConfigCompute,
119
+ sim_config: SimConfigCompute | None = None,
98
120
  *,
99
121
  alias: str | None = None,
122
+ docker_image_url: str | None = None,
123
+ upload_rootfs: bool | None = True,
100
124
  ) -> EnvFromResource:
101
125
  """Create env from resource specification (blank VM).
102
126
 
103
127
  Args:
104
128
  simulator: Simulator/service name
105
- sim_config: Resource configuration (CPUs, memory, disk)
129
+ sim_config: Resource configuration (CPUs, memory, disk). If None, uses defaults.
106
130
  alias: Custom name for this environment
131
+ docker_image_url: Custom Docker image URL (ECR). If not set, uses default.
132
+ upload_rootfs: Upload rootfs to S3/snapshot-store if not cached. Set False for one-off VMs.
107
133
 
108
134
  Returns:
109
135
  EnvFromResource
110
136
 
111
- Example:
137
+ Examples:
112
138
  >>> Env.resource("redis", SimConfigCompute(cpus=4, memory=8192, disk=20000))
139
+ >>> Env.resource("agent", docker_image_url="123.dkr.ecr.us-west-1.amazonaws.com/my-image:v1")
113
140
  """
141
+ if sim_config is None:
142
+ sim_config = SimConfigCompute()
114
143
  return EnvFromResource(
115
144
  simulator=simulator,
116
145
  sim_config=sim_config,
117
146
  alias=alias,
147
+ docker_image_url=docker_image_url,
148
+ upload_rootfs=upload_rootfs,
118
149
  )
119
150
 
120
151
 
plato/worlds/__init__.py CHANGED
@@ -46,23 +46,19 @@ from plato._generated.models import (
46
46
  from plato.worlds.base import (
47
47
  BaseWorld,
48
48
  ConfigT,
49
- Observation,
50
- StepResult,
51
49
  get_registered_worlds,
52
50
  get_world,
53
51
  register_world,
54
52
  )
55
53
  from plato.worlds.config import (
56
- Agent,
57
54
  AgentConfig,
58
55
  CheckpointConfig,
59
- Env,
60
56
  EnvConfig,
61
- EnvList,
62
57
  RunConfig,
63
- Secret,
64
58
  StateConfig,
65
59
  )
60
+ from plato.worlds.markers import Agent, Env, EnvList, Secret
61
+ from plato.worlds.models import Observation, StepResult
66
62
  from plato.worlds.runner import run_world
67
63
 
68
64
  __all__ = [