plato-sdk-v2 2.7.6__py3-none-any.whl → 2.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1462 @@
1
+ """Plato SDK v2 - Synchronous Sandbox Client.
2
+
3
+ The SandboxClient provides methods for sandbox development workflows:
4
+ creating sandboxes, managing SSH, syncing files, running flows, etc.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import json
11
+ import logging
12
+ import os
13
+ import shutil
14
+ import signal
15
+ import subprocess
16
+ import tempfile
17
+ import time
18
+ from pathlib import Path
19
+ from urllib.parse import quote
20
+
21
+ import httpx
22
+ import yaml
23
+ from pydantic import BaseModel
24
+ from rich.console import Console
25
+
26
+ from plato._generated.api.v1.gitea import (
27
+ create_simulator_repository,
28
+ get_accessible_simulators,
29
+ get_gitea_credentials,
30
+ get_simulator_repository,
31
+ )
32
+ from plato._generated.api.v1.sandbox import start_worker
33
+ from plato._generated.api.v2.jobs import get_flows as jobs_get_flows
34
+ from plato._generated.api.v2.jobs import state as jobs_state
35
+ from plato._generated.api.v2.sessions import add_ssh_key as sessions_add_ssh_key
36
+ from plato._generated.api.v2.sessions import close as sessions_close
37
+ from plato._generated.api.v2.sessions import connect_network as sessions_connect_network
38
+ from plato._generated.api.v2.sessions import get_public_url as sessions_get_public_url
39
+ from plato._generated.api.v2.sessions import get_session_details
40
+ from plato._generated.api.v2.sessions import snapshot as sessions_snapshot
41
+ from plato._generated.api.v2.sessions import state as sessions_state
42
+ from plato._generated.models import (
43
+ AddSSHKeyRequest,
44
+ AppApiV2SchemasSessionCreateSnapshotResponse,
45
+ AppSchemasBuildModelsSimConfigDataset,
46
+ CloseSessionResponse,
47
+ CreateCheckpointRequest,
48
+ DatabaseMutationListenerConfig,
49
+ Flow,
50
+ PlatoConfig,
51
+ SessionStateResponse,
52
+ VMManagementRequest,
53
+ )
54
+ from plato.v2.async_.flow_executor import FlowExecutor
55
+ from plato.v2.models import SandboxState
56
+ from plato.v2.sync.client import Plato
57
+ from plato.v2.types import Env, EnvFromArtifact, EnvFromResource, EnvFromSimulator, SimConfigCompute
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ DEFAULT_BASE_URL = "https://plato.so"
63
+ DEFAULT_TIMEOUT = 600.0
64
+
65
+
66
+ def _get_plato_dir(working_dir: Path | None = None) -> Path:
67
+ """Get the .plato directory path."""
68
+ base = working_dir or Path.cwd()
69
+ return base / ".plato"
70
+
71
+
72
+ def _generate_ssh_key_pair(prefix: str, working_dir: Path | None = None) -> tuple[str, str]:
73
+ """Generate an SSH key pair and save to .plato/ directory.
74
+
75
+ Args:
76
+ prefix: Prefix for key filename.
77
+ working_dir: Working directory for .plato/.
78
+
79
+ Returns:
80
+ Tuple of (public_key_content, private_key_path).
81
+ """
82
+ plato_dir = _get_plato_dir(working_dir)
83
+ plato_dir.mkdir(mode=0o700, exist_ok=True)
84
+
85
+ key_name = f"ssh_key_{prefix}"
86
+ private_key_path = plato_dir / key_name
87
+ public_key_path = plato_dir / f"{key_name}.pub"
88
+
89
+ # Remove existing keys
90
+ if private_key_path.exists():
91
+ private_key_path.unlink()
92
+ if public_key_path.exists():
93
+ public_key_path.unlink()
94
+
95
+ # Generate key pair
96
+ subprocess.run(
97
+ [
98
+ "ssh-keygen",
99
+ "-t",
100
+ "ed25519",
101
+ "-f",
102
+ str(private_key_path),
103
+ "-N",
104
+ "",
105
+ "-q",
106
+ ],
107
+ check=True,
108
+ )
109
+
110
+ # Read public key
111
+ public_key = public_key_path.read_text().strip()
112
+
113
+ return public_key, str(private_key_path)
114
+
115
+
116
+ def _generate_ssh_config(
117
+ job_id: str,
118
+ private_key_path: str,
119
+ working_dir: Path | None = None,
120
+ ssh_host: str = "sandbox",
121
+ ) -> str:
122
+ """Generate SSH config file for easy access via gateway.
123
+
124
+ Args:
125
+ job_id: The job ID for routing.
126
+ private_key_path: Path to private key (absolute or relative).
127
+ working_dir: Working directory for .plato/.
128
+ ssh_host: Host alias in config.
129
+
130
+ Returns:
131
+ Path to the generated SSH config file (relative to working_dir).
132
+
133
+ Note:
134
+ The IdentityFile in the config uses a path relative to working_dir.
135
+ SSH commands must be run from the workspace root for paths to resolve.
136
+ """
137
+ gateway_host = os.getenv("PLATO_GATEWAY_HOST", "gateway.plato.so")
138
+
139
+ # Convert private key path to be relative to working_dir
140
+ # This ensures the config is portable if the workspace moves
141
+ base = working_dir or Path.cwd()
142
+ try:
143
+ relative_key_path = Path(private_key_path).relative_to(base)
144
+ except ValueError:
145
+ # If not relative to working_dir, keep as-is (shouldn't happen normally)
146
+ relative_key_path = Path(private_key_path)
147
+
148
+ # SNI format: {job_id}--{port}.{gateway_host} (matches v1 proxy.py)
149
+ ssh_port = 22
150
+ sni = f"{job_id}--{ssh_port}.{gateway_host}"
151
+
152
+ config_content = f"""# Plato Sandbox SSH Config
153
+ # Generated for job: {job_id}
154
+ # NOTE: Run SSH commands from workspace root for relative paths to resolve
155
+
156
+ Host {ssh_host}
157
+ HostName {job_id}
158
+ User root
159
+ IdentityFile {relative_key_path}
160
+ StrictHostKeyChecking no
161
+ UserKnownHostsFile /dev/null
162
+ LogLevel ERROR
163
+ ProxyCommand openssl s_client -quiet -connect {gateway_host}:443 -servername {sni} 2>/dev/null
164
+ """
165
+
166
+ plato_dir = _get_plato_dir(working_dir)
167
+ plato_dir.mkdir(mode=0o700, exist_ok=True)
168
+
169
+ config_path = plato_dir / "ssh_config"
170
+ config_path.write_text(config_content)
171
+
172
+ return str(config_path)
173
+
174
+
175
+ def _run_ssh_command(
176
+ ssh_config_path: str,
177
+ ssh_host: str,
178
+ command: str,
179
+ cwd: Path | str | None = None,
180
+ ) -> tuple[int, str, str]:
181
+ """Run a command via SSH.
182
+
183
+ Args:
184
+ ssh_config_path: Path to SSH config file (can be relative).
185
+ ssh_host: SSH host alias from config.
186
+ command: Command to execute on remote.
187
+ cwd: Working directory to run SSH from. Required when ssh_config_path
188
+ contains relative paths (e.g., for IdentityFile).
189
+
190
+ Returns:
191
+ Tuple of (returncode, stdout, stderr).
192
+ """
193
+ result = subprocess.run(
194
+ ["ssh", "-F", ssh_config_path, ssh_host, command],
195
+ capture_output=True,
196
+ text=True,
197
+ cwd=cwd,
198
+ )
199
+ return result.returncode, result.stdout, result.stderr
200
+
201
+
202
+ # =============================================================================
203
+ # HEARTBEAT UTILITIES
204
+ # =============================================================================
205
+
206
+
207
+ def _start_heartbeat_process(session_id: str, api_key: str) -> int | None:
208
+ """Start a background process that sends heartbeats.
209
+
210
+ Returns:
211
+ PID of the background process, or None if failed.
212
+ """
213
+ log_file = f"/tmp/plato_heartbeat_{session_id}.log"
214
+ base_url = os.getenv("PLATO_BASE_URL", "https://plato.so")
215
+ # Strip trailing /api if present to avoid double /api/api in URL
216
+ if base_url.endswith("/api"):
217
+ base_url = base_url[:-4]
218
+ base_url = base_url.rstrip("/")
219
+
220
+ heartbeat_script = f'''
221
+ import time
222
+ import os
223
+ import httpx
224
+ from datetime import datetime
225
+
226
+ session_id = "{session_id}"
227
+ api_key = "{api_key}"
228
+ base_url = "{base_url}"
229
+ log_file = "{log_file}"
230
+
231
+ def log(msg):
232
+ timestamp = datetime.now().isoformat()
233
+ with open(log_file, "a") as f:
234
+ f.write(f"[{{timestamp}}] {{msg}}\\n")
235
+ f.flush()
236
+
237
+ log(f"Heartbeat process started for session {{session_id}}")
238
+
239
+ heartbeat_count = 0
240
+ while True:
241
+ heartbeat_count += 1
242
+ try:
243
+ url = f"{{base_url}}/api/v2/sessions/{{session_id}}/heartbeat"
244
+ with httpx.Client(timeout=30) as client:
245
+ response = client.post(
246
+ url,
247
+ headers={{"X-API-Key": api_key}},
248
+ )
249
+ if response.status_code == 200:
250
+ result = response.json()
251
+ success = result.get("success", False)
252
+ log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, success={{success}}")
253
+ else:
254
+ log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, url={{url}}, body={{response.text[:500]}}")
255
+ except Exception as e:
256
+ log(f"Heartbeat #{{heartbeat_count}} EXCEPTION: {{type(e).__name__}}: {{e}}")
257
+ time.sleep(30)
258
+ '''
259
+
260
+ try:
261
+ process = subprocess.Popen(
262
+ ["python3", "-c", heartbeat_script],
263
+ stdout=subprocess.DEVNULL,
264
+ stderr=subprocess.DEVNULL,
265
+ stdin=subprocess.DEVNULL,
266
+ start_new_session=True,
267
+ )
268
+ return process.pid
269
+ except Exception:
270
+ return None
271
+
272
+
273
+ def _stop_heartbeat_process(pid: int) -> bool:
274
+ """Stop the heartbeat process.
275
+
276
+ Returns:
277
+ True if stopped successfully.
278
+ """
279
+ try:
280
+ os.kill(pid, signal.SIGTERM)
281
+ return True
282
+ except ProcessLookupError:
283
+ return True
284
+ except Exception:
285
+ return False
286
+
287
+
288
+ class SyncResult(BaseModel):
289
+ files_synced: int
290
+ bytes_synced: int
291
+
292
+
293
+ # =============================================================================
294
+ # TUNNEL
295
+ # =============================================================================
296
+
297
+ DEFAULT_GATEWAY_HOST = "gateway.plato.so"
298
+ DEFAULT_GATEWAY_PORT = 443
299
+
300
+
301
+ def _get_gateway_config() -> tuple[str, int]:
302
+ """Get gateway host and port from environment or defaults."""
303
+ host = os.environ.get("PLATO_GATEWAY_HOST", DEFAULT_GATEWAY_HOST)
304
+ port = int(os.environ.get("PLATO_GATEWAY_PORT", str(DEFAULT_GATEWAY_PORT)))
305
+ return host, port
306
+
307
+
308
+ def _create_tls_connection(
309
+ gateway_host: str,
310
+ gateway_port: int,
311
+ sni: str,
312
+ verify_ssl: bool = True,
313
+ ):
314
+ """Create a TLS connection to the gateway with the specified SNI."""
315
+ import socket
316
+ import ssl
317
+
318
+ context = ssl.create_default_context()
319
+ if not verify_ssl:
320
+ context.check_hostname = False
321
+ context.verify_mode = ssl.CERT_NONE
322
+
323
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
324
+ sock.settimeout(30)
325
+ ssl_sock = context.wrap_socket(sock, server_hostname=sni)
326
+
327
+ try:
328
+ ssl_sock.connect((gateway_host, gateway_port))
329
+ except Exception as e:
330
+ ssl_sock.close()
331
+ raise ConnectionError(f"Failed to connect to gateway: {e}") from e
332
+
333
+ return ssl_sock
334
+
335
+
336
+ def _forward_data(src, dst, name: str = "") -> None:
337
+ """Forward data between two sockets until one closes."""
338
+ import socket
339
+
340
+ try:
341
+ while True:
342
+ data = src.recv(4096)
343
+ if not data:
344
+ break
345
+ dst.sendall(data)
346
+ except (ConnectionResetError, BrokenPipeError, OSError):
347
+ pass
348
+ finally:
349
+ try:
350
+ dst.shutdown(socket.SHUT_WR)
351
+ except OSError:
352
+ pass
353
+
354
+
355
+ class Tunnel:
356
+ """A TCP tunnel to a remote port on a sandbox VM via the TLS gateway."""
357
+
358
+ def __init__(
359
+ self,
360
+ job_id: str,
361
+ remote_port: int,
362
+ local_port: int | None = None,
363
+ bind_address: str = "127.0.0.1",
364
+ verify_ssl: bool = True,
365
+ ):
366
+ self.job_id = job_id
367
+ self.remote_port = remote_port
368
+ self.local_port = local_port or remote_port
369
+ self.bind_address = bind_address
370
+ self.verify_ssl = verify_ssl
371
+
372
+ self._server = None
373
+ self._thread = None
374
+ self._running = False
375
+
376
+ def start(self) -> int:
377
+ """Start the tunnel. Returns the local port."""
378
+ import socket
379
+ import threading
380
+
381
+ gateway_host, gateway_port = _get_gateway_config()
382
+ sni = f"{self.job_id}--{self.remote_port}.{gateway_host}"
383
+
384
+ # Create local listener
385
+ self._server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
386
+ self._server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
387
+
388
+ try:
389
+ self._server.bind((self.bind_address, self.local_port))
390
+ self._server.listen(5)
391
+ except OSError as e:
392
+ raise ValueError(f"Could not bind to {self.bind_address}:{self.local_port}: {e}") from e
393
+
394
+ self._running = True
395
+
396
+ def handle_client(client_sock, client_addr):
397
+ try:
398
+ gateway_sock = _create_tls_connection(gateway_host, gateway_port, sni, verify_ssl=self.verify_ssl)
399
+ t1 = threading.Thread(
400
+ target=_forward_data,
401
+ args=(client_sock, gateway_sock, "client->gateway"),
402
+ daemon=True,
403
+ )
404
+ t2 = threading.Thread(
405
+ target=_forward_data,
406
+ args=(gateway_sock, client_sock, "gateway->client"),
407
+ daemon=True,
408
+ )
409
+ t1.start()
410
+ t2.start()
411
+ t1.join()
412
+ t2.join()
413
+ except Exception:
414
+ pass
415
+ finally:
416
+ try:
417
+ client_sock.close()
418
+ except OSError:
419
+ pass
420
+
421
+ def accept_loop():
422
+ server = self._server
423
+ assert server is not None, "Server must be initialized before accept_loop"
424
+ while self._running:
425
+ try:
426
+ server.settimeout(1.0)
427
+ client_sock, client_addr = server.accept()
428
+ threading.Thread(
429
+ target=handle_client,
430
+ args=(client_sock, client_addr),
431
+ daemon=True,
432
+ ).start()
433
+ except TimeoutError:
434
+ continue
435
+ except OSError:
436
+ break
437
+
438
+ self._thread = threading.Thread(target=accept_loop, daemon=True)
439
+ self._thread.start()
440
+
441
+ return self.local_port
442
+
443
+ def stop(self) -> None:
444
+ """Stop the tunnel."""
445
+ self._running = False
446
+ if self._server:
447
+ try:
448
+ self._server.close()
449
+ except OSError:
450
+ pass
451
+ if self._thread:
452
+ self._thread.join(timeout=2.0)
453
+
454
+ def __enter__(self):
455
+ self.start()
456
+ return self
457
+
458
+ def __exit__(self, exc_type, exc_val, exc_tb):
459
+ self.stop()
460
+ return False
461
+
462
+
463
+ # =============================================================================
464
+ # SANDBOX CLIENT
465
+ # =============================================================================
466
+
467
+
468
+ class SandboxClient:
469
+ """Synchronous client for sandbox development workflows.
470
+
471
+ Supports two modes:
472
+ 1. Stateless (working_dir=None): Pure operations, no file I/O
473
+ 2. Stateful (working_dir=Path): Persists state to .plato/state.yaml
474
+
475
+ Usage (stateless):
476
+ client = SandboxClient(api_key="...")
477
+ result = client.start(mode="blank", service="myservice")
478
+ client.stop(result.session_id)
479
+ client.close()
480
+
481
+ Usage (stateful - recommended for CLI/scripts):
482
+ client = SandboxClient(api_key="...", working_dir=Path("."))
483
+ client.start(mode="blank", service="myservice") # Saves state
484
+ # Later...
485
+ client = SandboxClient(api_key="...", working_dir=Path(".")) # Loads state
486
+ client.stop() # Uses saved session_id
487
+ """
488
+
489
+ # State file paths
490
+ PLATO_DIR = ".plato"
491
+
492
+ def __init__(
493
+ self,
494
+ working_dir: Path,
495
+ api_key: str | None = None,
496
+ base_url: str | None = None,
497
+ timeout: float = DEFAULT_TIMEOUT,
498
+ console: Console = Console(),
499
+ ):
500
+ self.api_key = api_key or os.environ.get("PLATO_API_KEY")
501
+ if not self.api_key:
502
+ raise ValueError("API key required. Set PLATO_API_KEY or pass api_key=")
503
+
504
+ url = base_url or os.environ.get("PLATO_BASE_URL", DEFAULT_BASE_URL)
505
+ if url.endswith("/api"):
506
+ url = url[:-4]
507
+ self.base_url = url.rstrip("/")
508
+ self.console = console
509
+ self.working_dir = working_dir
510
+
511
+ self._http = httpx.Client(
512
+ base_url=self.base_url,
513
+ timeout=httpx.Timeout(timeout),
514
+ )
515
+
516
+ def _get_plato_dir(self) -> Path:
517
+ return Path(self.working_dir) / self.PLATO_DIR
518
+
519
+ def close(self) -> None:
520
+ """Close the underlying HTTP client."""
521
+ self._http.close()
522
+
523
+ # -------------------------------------------------------------------------
524
+ # START
525
+ # -------------------------------------------------------------------------
526
+
527
+ def start(
528
+ self,
529
+ simulator_name: str | None = None,
530
+ mode: str = "blank",
531
+ # artifact or simulator mode
532
+ artifact_id: str | None = None,
533
+ dataset: str = "base",
534
+ tag: str = "latest",
535
+ # blankl or plato-config mode
536
+ cpus: int = 1,
537
+ memory: int = 2048,
538
+ disk: int = 10240,
539
+ app_port: int | None = None,
540
+ messaging_port: int | None = None,
541
+ # common
542
+ connect_network: bool = True,
543
+ timeout: int = 1800,
544
+ ) -> SandboxState:
545
+ """Start a sandbox environment.
546
+
547
+ Uses Plato SDK v2 internally for session creation.
548
+
549
+ Args:
550
+ mode: Start mode - "blank", "simulator", or "artifact".
551
+ simulator_name: Simulator name.
552
+ artifact_id: Artifact UUID.
553
+ dataset: Dataset name.
554
+ tag: Artifact tag.
555
+ cpus: Number of CPUs.
556
+ memory: Memory in MB.
557
+ disk: Disk in MB.
558
+ app_port: App port.
559
+ messaging_port: Messaging port.
560
+ connect_network: Whether to connect WireGuard network.
561
+
562
+ Returns:
563
+ SandboxState with sandbox info.
564
+ """
565
+
566
+ assert self.api_key is not None
567
+
568
+ # Build environment config using Env factory
569
+ env_config: EnvFromSimulator | EnvFromArtifact | EnvFromResource
570
+
571
+ if mode == "artifact" and artifact_id:
572
+ self.console.print(f"Starting from artifact: {artifact_id}")
573
+ env_config = Env.artifact(artifact_id)
574
+ elif mode == "simulator" and simulator_name:
575
+ self.console.print(f"Starting from simulator: {simulator_name}")
576
+ env_config = Env.simulator(simulator_name, tag=tag, dataset=dataset)
577
+ elif mode == "blank" and simulator_name:
578
+ self.console.print("Starting from blank")
579
+ sim_config = SimConfigCompute(
580
+ cpus=cpus, memory=memory, disk=disk, app_port=app_port, plato_messaging_port=messaging_port
581
+ )
582
+ env_config = Env.resource(simulator_name, sim_config)
583
+ elif mode == "config":
584
+ self.console.print("Starting from config")
585
+ # read plato-config.yml
586
+ plato_config_path = self.working_dir / "plato-config.yml"
587
+ with open(plato_config_path, "rb") as f:
588
+ plato_config = yaml.safe_load(f)
589
+ self.console.print(f"plato-config: {plato_config}")
590
+ plato_config_model = PlatoConfig.model_validate(plato_config)
591
+ dataset_config = plato_config_model.datasets[dataset]
592
+ simulator_name = plato_config_model.service
593
+ if not simulator_name:
594
+ raise ValueError("Service name is required in plato-config.yml")
595
+ if not dataset_config.compute:
596
+ raise ValueError(f"Compute configuration is required for dataset '{dataset}'")
597
+ self.console.print(f"simulator_name: {simulator_name}")
598
+ sim_config = SimConfigCompute(
599
+ cpus=dataset_config.compute.cpus,
600
+ memory=dataset_config.compute.memory,
601
+ disk=dataset_config.compute.disk,
602
+ app_port=dataset_config.compute.app_port,
603
+ plato_messaging_port=dataset_config.compute.plato_messaging_port,
604
+ )
605
+ env_config = Env.resource(simulator_name, sim_config)
606
+ self.console.print(f"env_config: {env_config}")
607
+ else:
608
+ raise ValueError(f"Invalid mode '{mode}' or missing required parameter")
609
+
610
+ # Use Plato SDK to create session (handles create, wait, network)
611
+ self.console.print(f"Creating session and waiting for VM to become ready (timeout={timeout}s)...")
612
+ plato = Plato(api_key=self.api_key, http_client=self._http)
613
+ session = plato.sessions.create(
614
+ envs=[env_config],
615
+ connect_network=connect_network,
616
+ timeout=timeout,
617
+ )
618
+ self.console.print(f"session: {session}")
619
+ session_id = session.session_id
620
+ job_id = session.envs[0].job_id if session.envs else None
621
+ if not job_id:
622
+ raise ValueError("No job ID found")
623
+ self.console.print(f"job_id: {job_id}")
624
+
625
+ # Get public URL with router target formatting (logic inlined)
626
+ public_url = None
627
+ try:
628
+ url_response = sessions_get_public_url.sync(
629
+ client=self._http,
630
+ session_id=session_id,
631
+ x_api_key=self.api_key,
632
+ )
633
+ if url_response and url_response.results:
634
+ for result in url_response.results.values():
635
+ url = result.url if hasattr(result, "url") else str(result)
636
+ if not url:
637
+ raise ValueError(f"No public URL found in result dict for job ID {job_id}")
638
+ if "_plato_router_target=" not in url:
639
+ target_param = f"_plato_router_target={simulator_name}.web.plato.so"
640
+ if "?" in url:
641
+ url = f"{url}&{target_param}"
642
+ else:
643
+ url = f"{url}?{target_param}"
644
+ public_url = url
645
+ except Exception as e:
646
+ raise ValueError(f"Error getting public URL: {e}") from e
647
+
648
+ # Setup SSH
649
+ ssh_config_path = None
650
+ try:
651
+ public_key, private_key_path = _generate_ssh_key_pair(session_id[:8], Path(self.working_dir))
652
+
653
+ add_key_request = AddSSHKeyRequest(public_key=public_key, username="root")
654
+ add_response = sessions_add_ssh_key.sync(
655
+ client=self._http,
656
+ session_id=session_id,
657
+ body=add_key_request,
658
+ x_api_key=self.api_key,
659
+ )
660
+
661
+ if add_response.success:
662
+ ssh_config_path = _generate_ssh_config(job_id, private_key_path, Path(self.working_dir))
663
+ except Exception as e:
664
+ logger.warning(f"SSH setup failed: {e}")
665
+
666
+ # Start heartbeat
667
+ heartbeat_pid = None
668
+ heartbeat_pid = _start_heartbeat_process(session_id, self.api_key)
669
+
670
+ # Parse out the simulator_name from service field in jobs in session details.
671
+ session_details = get_session_details.sync(
672
+ client=self._http,
673
+ session_id=session_id,
674
+ x_api_key=self.api_key,
675
+ )
676
+ if hasattr(session_details, "jobs"):
677
+ # session_details["jobs"] if dict, or session_details.jobs if object
678
+ jobs = (
679
+ session_details["jobs"] if isinstance(session_details, dict) else getattr(session_details, "jobs", None)
680
+ )
681
+ if jobs:
682
+ for j in jobs:
683
+ service = j.get("service") if isinstance(j, dict) else getattr(j, "service", None)
684
+ if service:
685
+ simulator_name = service
686
+ break
687
+
688
+ # Convert absolute paths to relative for state storage
689
+ def _to_relative(abs_path: str | None) -> str | None:
690
+ if not abs_path or not self.working_dir:
691
+ return abs_path
692
+ try:
693
+ return str(Path(abs_path).relative_to(self.working_dir))
694
+ except ValueError:
695
+ return abs_path # Keep absolute if not relative to working_dir
696
+
697
+ # Update internal state
698
+ rel_ssh_config = _to_relative(ssh_config_path)
699
+ ssh_host = "sandbox" if ssh_config_path else None
700
+ sandbox_state = SandboxState(
701
+ session_id=session_id,
702
+ job_id=job_id,
703
+ public_url=public_url,
704
+ mode=mode,
705
+ ssh_config_path=rel_ssh_config,
706
+ ssh_host=ssh_host,
707
+ ssh_command=f"ssh -F {rel_ssh_config} {ssh_host}" if rel_ssh_config else None,
708
+ heartbeat_pid=heartbeat_pid,
709
+ simulator_name=simulator_name,
710
+ dataset=dataset,
711
+ )
712
+ if mode == "artifact":
713
+ sandbox_state.artifact_id = artifact_id
714
+ elif mode == "simulator":
715
+ sandbox_state.tag = tag
716
+ elif mode == "blank":
717
+ sandbox_state.cpus = cpus
718
+ sandbox_state.memory = memory
719
+ sandbox_state.disk = disk
720
+ sandbox_state.app_port = app_port
721
+ sandbox_state.messaging_port = messaging_port
722
+
723
+ # Save state to working_dir/.plato/state.json
724
+ with open(self.working_dir / self.PLATO_DIR / "state.json", "w") as f:
725
+ json.dump(sandbox_state.model_dump(), f)
726
+
727
+ return sandbox_state
728
+
729
+ # CHECKED
730
+ def stop(
731
+ self,
732
+ session_id: str,
733
+ heartbeat_pid: int | None = None,
734
+ ) -> CloseSessionResponse:
735
+ if heartbeat_pid:
736
+ _stop_heartbeat_process(heartbeat_pid)
737
+
738
+ return sessions_close.sync(
739
+ client=self._http,
740
+ session_id=session_id,
741
+ x_api_key=self.api_key,
742
+ )
743
+
744
+ # CHECKED
745
+ def status(self, session_id: str) -> dict:
746
+ return get_session_details.sync(
747
+ client=self._http,
748
+ session_id=session_id,
749
+ x_api_key=self.api_key,
750
+ )
751
+
752
+ # CHECKED
753
+ def snapshot(
754
+ self,
755
+ session_id: str,
756
+ mode: str,
757
+ dataset: str,
758
+ ) -> AppApiV2SchemasSessionCreateSnapshotResponse:
759
+ checkpoint_request = CreateCheckpointRequest()
760
+
761
+ if mode == "config":
762
+ # read plato-config.yml
763
+ plato_config_path = self.working_dir / "plato-config.yml"
764
+ with open(plato_config_path, "rb") as f:
765
+ plato_config = yaml.safe_load(f)
766
+ checkpoint_request.plato_config = plato_config
767
+
768
+ # convert plato-config to pydantic model
769
+ plato_config_model = PlatoConfig.model_validate(plato_config)
770
+ dataset_compute = plato_config_model.datasets[dataset].compute
771
+ if not dataset_compute:
772
+ raise ValueError(f"Compute configuration is required for dataset '{dataset}'")
773
+ checkpoint_request.internal_app_port = dataset_compute.app_port
774
+ checkpoint_request.messaging_port = dataset_compute.plato_messaging_port
775
+ # we dont set target
776
+
777
+ # read flows.yml
778
+ flows_path = self.working_dir / "flows.yml"
779
+ with open(flows_path, "rb") as f:
780
+ flows = yaml.safe_load(f)
781
+ checkpoint_request.flows = flows
782
+
783
+ return sessions_snapshot.sync(
784
+ client=self._http,
785
+ session_id=session_id,
786
+ body=checkpoint_request,
787
+ x_api_key=self.api_key,
788
+ )
789
+
790
+ # CHECKED
791
+ def connect_network(self, session_id: str) -> dict:
792
+ return sessions_connect_network.sync(
793
+ client=self._http,
794
+ session_id=session_id,
795
+ x_api_key=self.api_key,
796
+ )
797
+
798
+ # CHECKED
799
+ def start_worker(
800
+ self,
801
+ job_id: str,
802
+ simulator: str,
803
+ dataset: str,
804
+ wait_timeout: int = 300, # 5 minutes
805
+ ) -> None:
806
+ with open(self.working_dir / "plato-config.yml", "rb") as f:
807
+ plato_config = yaml.safe_load(f)
808
+ plato_config_model = PlatoConfig.model_validate(plato_config)
809
+ dataset_config = plato_config_model.datasets[dataset]
810
+
811
+ # Convert AppApiV2SchemasArtifactSimConfigDataset to AppSchemasBuildModelsSimConfigDataset
812
+ # They have compatible fields but different nested types
813
+ dataset_config_dict = dataset_config.model_dump(exclude_none=True)
814
+
815
+ _ = start_worker.sync(
816
+ client=self._http,
817
+ public_id=job_id,
818
+ body=VMManagementRequest(
819
+ service=simulator,
820
+ dataset=dataset,
821
+ plato_dataset_config=AppSchemasBuildModelsSimConfigDataset.model_validate(dataset_config_dict),
822
+ ),
823
+ x_api_key=self.api_key,
824
+ )
825
+
826
+ if wait_timeout > 0:
827
+ start_time = time.time()
828
+ poll_interval = 10
829
+
830
+ while time.time() - start_time < wait_timeout:
831
+ try:
832
+ state_response = jobs_state.sync(
833
+ client=self._http,
834
+ job_id=job_id,
835
+ x_api_key=self.api_key,
836
+ )
837
+ if state_response:
838
+ state_dict = (
839
+ state_response.model_dump() if hasattr(state_response, "model_dump") else state_response
840
+ )
841
+ if isinstance(state_dict, dict) and "error" not in state_dict.get("state", {}):
842
+ return
843
+ except Exception:
844
+ pass
845
+
846
+ time.sleep(poll_interval)
847
+
848
+ # CHECKED
849
+ def sync(
850
+ self,
851
+ session_id: str,
852
+ simulator: str,
853
+ timeout: int = 120,
854
+ ) -> SyncResult:
855
+ """Sync local files to sandbox using rsync over SSH.
856
+
857
+ Uses the SSH config from .plato/ssh_config for fast, reliable file transfer.
858
+ """
859
+ local_path = self.working_dir
860
+ remote_path = f"/home/plato/worktree/{simulator}"
861
+
862
+ # Load SSH config from state
863
+ state_file = self.working_dir / ".plato" / "state.json"
864
+ if not state_file.exists():
865
+ raise ValueError("No state file found - run 'plato sandbox start' first")
866
+
867
+ with open(state_file) as f:
868
+ state = json.load(f)
869
+
870
+ ssh_config_path = state.get("ssh_config_path")
871
+ ssh_host = state.get("ssh_host", "sandbox")
872
+
873
+ if not ssh_config_path:
874
+ raise ValueError("No SSH config in state - run 'plato sandbox start' first")
875
+
876
+ exclude_patterns = [
877
+ "__pycache__",
878
+ "*.pyc",
879
+ ".git",
880
+ ".venv",
881
+ "venv",
882
+ "node_modules",
883
+ ".sandbox.yaml",
884
+ "*.egg-info",
885
+ ".pytest_cache",
886
+ ".mypy_cache",
887
+ ".DS_Store",
888
+ "*.swp",
889
+ "*.swo",
890
+ ".plato",
891
+ ]
892
+
893
+ # Build rsync command
894
+ rsync_cmd = [
895
+ "rsync",
896
+ "-avz",
897
+ "--delete",
898
+ "-e",
899
+ f"ssh -F {ssh_config_path}",
900
+ ]
901
+
902
+ # Add excludes
903
+ for pattern in exclude_patterns:
904
+ rsync_cmd.extend(["--exclude", pattern])
905
+
906
+ # Source and destination
907
+ rsync_cmd.append(f"{local_path}/")
908
+ rsync_cmd.append(f"{ssh_host}:{remote_path}/")
909
+
910
+ self.console.print(f"[dim]rsync -> {ssh_host}:{remote_path}/[/dim]")
911
+
912
+ # Ensure rsync is installed on the VM and create remote directory
913
+ setup_result = subprocess.run(
914
+ [
915
+ "ssh",
916
+ "-F",
917
+ ssh_config_path,
918
+ ssh_host,
919
+ f"which rsync >/dev/null 2>&1 || (apt-get update -qq && apt-get install -y -qq rsync) && mkdir -p {remote_path}",
920
+ ],
921
+ capture_output=True,
922
+ text=True,
923
+ cwd=self.working_dir,
924
+ )
925
+ if setup_result.returncode != 0:
926
+ raise ValueError(f"Failed to setup remote: {setup_result.stderr}")
927
+
928
+ # Run rsync
929
+ result = subprocess.run(
930
+ rsync_cmd,
931
+ capture_output=True,
932
+ text=True,
933
+ cwd=self.working_dir,
934
+ timeout=timeout,
935
+ )
936
+
937
+ if result.returncode != 0:
938
+ raise ValueError(f"rsync failed: {result.stderr}")
939
+
940
+ # Count synced files from rsync output
941
+ lines = result.stdout.strip().split("\n") if result.stdout else []
942
+ file_count = len(
943
+ [
944
+ line
945
+ for line in lines
946
+ if line
947
+ and not line.startswith("sending")
948
+ and not line.startswith("sent")
949
+ and not line.startswith("total")
950
+ ]
951
+ )
952
+
953
+ # Get bytes from rsync output (e.g., "sent 1,234 bytes")
954
+ bytes_synced = 0
955
+ for line in lines:
956
+ if "sent" in line and "bytes" in line:
957
+ import re
958
+
959
+ match = re.search(r"sent ([\d,]+) bytes", line)
960
+ if match:
961
+ bytes_synced = int(match.group(1).replace(",", ""))
962
+ break
963
+
964
+ return SyncResult(
965
+ files_synced=file_count,
966
+ bytes_synced=bytes_synced,
967
+ )
968
+
969
+ def tunnel(
970
+ self,
971
+ job_id: str,
972
+ remote_port: int,
973
+ local_port: int | None = None,
974
+ bind_address: str = "127.0.0.1",
975
+ ) -> Tunnel:
976
+ return Tunnel(
977
+ job_id=job_id,
978
+ remote_port=remote_port,
979
+ local_port=local_port,
980
+ bind_address=bind_address,
981
+ )
982
+
983
+ def run_audit_ui(
984
+ self,
985
+ job_id: str | None = None,
986
+ dataset: str = "base",
987
+ no_tunnel: bool = False,
988
+ ) -> None:
989
+ import shutil
990
+
991
+ if not shutil.which("streamlit"):
992
+ raise ValueError("streamlit not installed. Run: pip install streamlit psycopg2-binary pymysql")
993
+
994
+ ui_file = Path(__file__).resolve().parent.parent.parent / "cli" / "audit_ui.py"
995
+ if not ui_file.exists():
996
+ raise ValueError(f"UI file not found: {ui_file}")
997
+
998
+ # Get DB listener from plato-config.yml
999
+ db_listener: DatabaseMutationListenerConfig | None = None
1000
+ for config_path in [self.working_dir / "plato-config.yml", self.working_dir / "plato-config.yaml"]:
1001
+ if config_path.exists():
1002
+ with open(config_path) as f:
1003
+ plato_config = PlatoConfig.model_validate(yaml.safe_load(f))
1004
+ dataset_config = plato_config.datasets.get(dataset)
1005
+ if dataset_config and dataset_config.listeners:
1006
+ for listener in dataset_config.listeners.values():
1007
+ if isinstance(listener, DatabaseMutationListenerConfig):
1008
+ db_listener = listener
1009
+ break
1010
+ break
1011
+ tunnel = None
1012
+
1013
+ if db_listener and job_id and not no_tunnel:
1014
+ self.console.print(f"Starting tunnel to {db_listener.db_type} on port {db_listener.db_port}...")
1015
+ tunnel = self.tunnel(job_id, db_listener.db_port)
1016
+ tunnel.start()
1017
+ time.sleep(1) # Let tunnel stabilize
1018
+ self.console.print(
1019
+ f"[green]Tunnel open:[/green] localhost:{db_listener.db_port} -> VM:{db_listener.db_port}"
1020
+ )
1021
+
1022
+ # Pass db config via environment variables
1023
+ env = os.environ.copy()
1024
+ if db_listener:
1025
+ env["PLATO_DB_HOST"] = "127.0.0.1"
1026
+ env["PLATO_DB_PORT"] = str(db_listener.db_port)
1027
+ env["PLATO_DB_USER"] = db_listener.db_user
1028
+ env["PLATO_DB_PASSWORD"] = db_listener.db_password or ""
1029
+ env["PLATO_DB_NAME"] = db_listener.db_database
1030
+ env["PLATO_DB_TYPE"] = str(db_listener.db_type)
1031
+ self.console.print(
1032
+ f"[dim]DB config: {db_listener.db_user}@127.0.0.1:{db_listener.db_port}/{db_listener.db_database}[/dim]"
1033
+ )
1034
+
1035
+ try:
1036
+ subprocess.run(["streamlit", "run", str(ui_file)], env=env)
1037
+ finally:
1038
+ if tunnel:
1039
+ tunnel.stop()
1040
+ self.console.print("[yellow]Tunnel closed[/yellow]")
1041
+
1042
+ def run_flow(
1043
+ self,
1044
+ url: str,
1045
+ flow_name: str,
1046
+ dataset: str,
1047
+ use_api: bool = False,
1048
+ job_id: str | None = None,
1049
+ ) -> None:
1050
+ flow_obj: Flow | None = None
1051
+ screenshots_dir = self.working_dir / "screenshots"
1052
+
1053
+ if use_api:
1054
+ # Fetch from API
1055
+ if not job_id:
1056
+ raise ValueError("job_id required when use_api=True")
1057
+
1058
+ self.console.print("[cyan]Flow source: API[/cyan]")
1059
+ flows_response = jobs_get_flows.sync(
1060
+ client=self._http,
1061
+ job_id=job_id,
1062
+ x_api_key=self.api_key,
1063
+ )
1064
+
1065
+ if flows_response:
1066
+ for flow_data in flows_response:
1067
+ if isinstance(flow_data, dict):
1068
+ if flow_data.get("name") == flow_name:
1069
+ flow_obj = Flow.model_validate(flow_data)
1070
+ break
1071
+ elif hasattr(flow_data, "name") and flow_data.name == flow_name:
1072
+ flow_obj = (
1073
+ flow_data if isinstance(flow_data, Flow) else Flow.model_validate(flow_data.model_dump())
1074
+ )
1075
+ break
1076
+
1077
+ if not flow_obj:
1078
+ available = [
1079
+ f.get("name") if isinstance(f, dict) else getattr(f, "name", "?") for f in (flows_response or [])
1080
+ ]
1081
+ raise ValueError(f"Flow '{flow_name}' not found in API. Available: {available}")
1082
+ else:
1083
+ # Use local flows
1084
+ config_paths = [
1085
+ self.working_dir / "plato-config.yml",
1086
+ self.working_dir / "plato-config.yaml",
1087
+ ]
1088
+
1089
+ for config_path in config_paths:
1090
+ if config_path.exists():
1091
+ with open(config_path) as f:
1092
+ plato_config = PlatoConfig.model_validate(yaml.safe_load(f))
1093
+
1094
+ dataset_config = plato_config.datasets.get(dataset)
1095
+ if dataset_config and dataset_config.metadata:
1096
+ flows_path = dataset_config.metadata.flows_path
1097
+
1098
+ if flows_path:
1099
+ flow_file = (
1100
+ config_path.parent / flows_path
1101
+ if not Path(flows_path).is_absolute()
1102
+ else Path(flows_path)
1103
+ )
1104
+
1105
+ if flow_file.exists():
1106
+ with open(flow_file) as f:
1107
+ flow_dict = yaml.safe_load(f)
1108
+ flow_obj = next(
1109
+ (
1110
+ Flow.model_validate(fl)
1111
+ for fl in flow_dict.get("flows", [])
1112
+ if fl.get("name") == flow_name
1113
+ ),
1114
+ None,
1115
+ )
1116
+ if flow_obj:
1117
+ screenshots_dir = flow_file.parent / "screenshots"
1118
+ self.console.print(f"[cyan]Flow source: local ({flow_file})[/cyan]")
1119
+ break
1120
+
1121
+ if not flow_obj:
1122
+ raise ValueError(f"Flow '{flow_name}' not found in local config")
1123
+
1124
+ # Assert for type narrowing in nested function (checked above in both branches)
1125
+ assert flow_obj is not None
1126
+ validated_flow: Flow = flow_obj
1127
+
1128
+ # Run the flow with Playwright
1129
+ async def _run_flow():
1130
+ from playwright.async_api import async_playwright
1131
+
1132
+ async with async_playwright() as p:
1133
+ browser = await p.chromium.launch(headless=False)
1134
+ try:
1135
+ page = await browser.new_page()
1136
+ await page.goto(url)
1137
+ executor = FlowExecutor(page, validated_flow, screenshots_dir)
1138
+ await executor.execute()
1139
+ finally:
1140
+ await browser.close()
1141
+
1142
+ asyncio.get_event_loop().run_until_complete(_run_flow())
1143
+
1144
+ # CHECKED
1145
+ def state(self, session_id: str) -> SessionStateResponse:
1146
+ response = sessions_state.sync(
1147
+ client=self._http,
1148
+ session_id=session_id,
1149
+ merge_mutations=True,
1150
+ x_api_key=self.api_key,
1151
+ )
1152
+ return response
1153
+
1154
+ # CHECKED
1155
+ def start_services(
1156
+ self,
1157
+ simulator_name: str,
1158
+ ssh_config_path: str,
1159
+ ssh_host: str,
1160
+ dataset: str,
1161
+ ) -> list[dict[str, str]]:
1162
+ # Get Gitea credentials
1163
+ creds = get_gitea_credentials.sync(client=self._http, x_api_key=self.api_key)
1164
+
1165
+ # Get accessible simulators
1166
+ simulators = get_accessible_simulators.sync(client=self._http, x_api_key=self.api_key)
1167
+ simulator = None
1168
+ for sim in simulators:
1169
+ sim_name = sim.get("name") if isinstance(sim, dict) else getattr(sim, "name", None)
1170
+ if sim_name and sim_name.lower() == simulator_name.lower():
1171
+ simulator = sim
1172
+ break
1173
+ if not simulator:
1174
+ raise ValueError(f"Simulator '{simulator_name}' not found in gitea accessible simulators")
1175
+
1176
+ # Get or create repo
1177
+ sim_id = simulator.get("id") if isinstance(simulator, dict) else getattr(simulator, "id", None)
1178
+ has_repo = simulator.get("has_repo") if isinstance(simulator, dict) else getattr(simulator, "has_repo", False)
1179
+ if has_repo:
1180
+ repo = get_simulator_repository.sync(client=self._http, simulator_id=sim_id, x_api_key=self.api_key) # type: ignore
1181
+ else:
1182
+ repo = create_simulator_repository.sync(client=self._http, simulator_id=sim_id, x_api_key=self.api_key) # type: ignore
1183
+
1184
+ clone_url = repo.clone_url
1185
+ if not clone_url:
1186
+ raise ValueError("No clone URL available for gitea repository")
1187
+
1188
+ # Build authenticated URL
1189
+ encoded_username = quote(creds.username, safe="")
1190
+ encoded_password = quote(creds.password, safe="")
1191
+ auth_clone_url = clone_url.replace("https://", f"https://{encoded_username}:{encoded_password}@", 1)
1192
+
1193
+ repo_dir = f"/home/plato/worktree/{simulator_name}"
1194
+ branch_name = f"workspace-{int(time.time())}"
1195
+
1196
+ # Clone, copy, push
1197
+ with tempfile.TemporaryDirectory(prefix="plato-hub-") as temp_dir:
1198
+ temp_repo = Path(temp_dir) / "repo"
1199
+ git_env = os.environ.copy()
1200
+ git_env["GIT_TERMINAL_PROMPT"] = "0"
1201
+ git_env["GIT_ASKPASS"] = ""
1202
+
1203
+ subprocess.run(
1204
+ ["git", "clone", auth_clone_url, str(temp_repo)], capture_output=True, env=git_env, check=True
1205
+ )
1206
+ subprocess.run(
1207
+ ["git", "checkout", "-b", branch_name], cwd=temp_repo, capture_output=True, env=git_env, check=True
1208
+ )
1209
+
1210
+ # Copy files
1211
+ current_dir = Path(self.working_dir)
1212
+
1213
+ def _copy_files(src_dir: Path, dst_dir: Path) -> None:
1214
+ """Copy files, skipping .git/ and .plato-hub.json."""
1215
+ for src_path in src_dir.rglob("*"):
1216
+ rel_path = src_path.relative_to(src_dir)
1217
+ if ".git" in rel_path.parts or rel_path.name == ".plato-hub.json":
1218
+ continue
1219
+ dst_path = dst_dir / rel_path
1220
+ if src_path.is_dir():
1221
+ dst_path.mkdir(parents=True, exist_ok=True)
1222
+ else:
1223
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
1224
+ shutil.copy2(src_path, dst_path)
1225
+
1226
+ _copy_files(current_dir, temp_repo)
1227
+
1228
+ subprocess.run(["git", "add", "."], cwd=temp_repo, capture_output=True, env=git_env)
1229
+ result = subprocess.run(
1230
+ ["git", "status", "--porcelain"], cwd=temp_repo, capture_output=True, text=True, env=git_env
1231
+ )
1232
+
1233
+ if result.stdout.strip():
1234
+ subprocess.run(
1235
+ ["git", "commit", "-m", "Sync from local workspace"],
1236
+ cwd=temp_repo,
1237
+ capture_output=True,
1238
+ env=git_env,
1239
+ )
1240
+
1241
+ subprocess.run(
1242
+ ["git", "remote", "set-url", "origin", auth_clone_url],
1243
+ cwd=temp_repo,
1244
+ capture_output=True,
1245
+ env=git_env,
1246
+ )
1247
+ subprocess.run(
1248
+ ["git", "push", "-u", "origin", branch_name],
1249
+ cwd=temp_repo,
1250
+ capture_output=True,
1251
+ env=git_env,
1252
+ check=True,
1253
+ )
1254
+
1255
+ # Clone on VM - first verify SSH works
1256
+ # Debug: show SSH config being used
1257
+ ssh_config_full_path = (
1258
+ Path(self.working_dir) / ssh_config_path
1259
+ if not Path(ssh_config_path).is_absolute()
1260
+ else Path(ssh_config_path)
1261
+ )
1262
+ if not ssh_config_full_path.exists():
1263
+ raise ValueError(f"SSH config file not found: {ssh_config_full_path}")
1264
+
1265
+ self.console.print(f"[dim]SSH config: {ssh_config_full_path}[/dim]")
1266
+ self.console.print(f"[dim]SSH host: {ssh_host}[/dim]")
1267
+
1268
+ # Run SSH with verbose to see what's happening
1269
+ ssh_cmd = ["ssh", "-v", "-F", ssh_config_path, ssh_host, "echo 'SSH connection OK'"]
1270
+ self.console.print(f"[dim]Running: {' '.join(ssh_cmd)}[/dim]")
1271
+ self.console.print(f"[dim]Working dir: {self.working_dir}[/dim]")
1272
+
1273
+ result = subprocess.run(
1274
+ ssh_cmd,
1275
+ capture_output=True,
1276
+ text=True,
1277
+ cwd=self.working_dir,
1278
+ )
1279
+ ret, stdout, stderr = result.returncode, result.stdout, result.stderr
1280
+
1281
+ if ret != 0:
1282
+ # Show SSH config contents for debugging
1283
+ try:
1284
+ config_content = ssh_config_full_path.read_text()
1285
+ self.console.print(f"[yellow]SSH config contents:[/yellow]\n{config_content}")
1286
+ except Exception:
1287
+ pass
1288
+ # Show SSH verbose output
1289
+ self.console.print(f"[yellow]SSH stderr (verbose):[/yellow]\n{stderr}")
1290
+ error_output = stderr or stdout or "(no output)"
1291
+ raise ValueError(f"SSH connection failed (exit {ret})")
1292
+
1293
+ _run_ssh_command(ssh_config_path, ssh_host, "mkdir -p /home/plato/worktree", cwd=self.working_dir)
1294
+ _run_ssh_command(ssh_config_path, ssh_host, f"rm -rf {repo_dir}", cwd=self.working_dir)
1295
+
1296
+ # Clone repo - mask credentials in error output
1297
+ ret, stdout, stderr = _run_ssh_command(
1298
+ ssh_config_path,
1299
+ ssh_host,
1300
+ f"git clone -b {branch_name} {auth_clone_url} {repo_dir}",
1301
+ cwd=self.working_dir,
1302
+ )
1303
+ if ret != 0:
1304
+ # Mask credentials in error output
1305
+ safe_url = clone_url # Use non-authenticated URL in error
1306
+ error_output = stderr or stdout or "(no output)"
1307
+ error_output = error_output.replace(creds.username, "***").replace(creds.password, "***")
1308
+ raise ValueError(f"Clone failed (exit {ret}) for {safe_url} branch {branch_name}: {error_output}")
1309
+
1310
+ # ECR auth
1311
+ ecr_result = subprocess.run(
1312
+ ["aws", "ecr", "get-login-password", "--region", "us-west-1"], capture_output=True, text=True
1313
+ )
1314
+ if ecr_result.returncode == 0:
1315
+ ecr_token = ecr_result.stdout.strip()
1316
+ ecr_registry = "383806609161.dkr.ecr.us-west-1.amazonaws.com"
1317
+ _run_ssh_command(
1318
+ ssh_config_path,
1319
+ ssh_host,
1320
+ f"echo '{ecr_token}' | docker login --username AWS --password-stdin {ecr_registry}",
1321
+ cwd=self.working_dir,
1322
+ )
1323
+
1324
+ # Start services
1325
+ services_started = []
1326
+ with open(self.working_dir / "plato-config.yml", "rb") as f:
1327
+ plato_config = yaml.safe_load(f)
1328
+ plato_config_model = PlatoConfig.model_validate(plato_config)
1329
+ services_config = plato_config_model.datasets[dataset].services
1330
+ if not services_config:
1331
+ self.console.print("[yellow]No services configured, skipping service startup[/yellow]")
1332
+ return services_started
1333
+ for svc_name, svc_config in services_config.items():
1334
+ # svc_config is a Pydantic model (DockerComposeServiceConfig), use getattr
1335
+ svc_type = getattr(svc_config, "type", "")
1336
+ if svc_type == "docker-compose":
1337
+ compose_file = getattr(svc_config, "file", "docker-compose.yml")
1338
+ compose_cmd = f"cd {repo_dir} && docker compose -f {compose_file} up -d"
1339
+ ret, _, stderr = _run_ssh_command(ssh_config_path, ssh_host, compose_cmd, cwd=self.working_dir)
1340
+ if ret != 0:
1341
+ raise ValueError(f"Failed to start {svc_name}: {stderr}")
1342
+ services_started.append({"name": svc_name, "type": "docker-compose", "file": compose_file})
1343
+ else:
1344
+ raise ValueError(f"Unsupported service type: {svc_type}")
1345
+
1346
+ return services_started
1347
+
1348
+ # # -------------------------------------------------------------------------
1349
+ # # RUN FLOW
1350
+ # # -------------------------------------------------------------------------
1351
+
1352
+ # def clear_audit(
1353
+ # self,
1354
+ # job_id: str,
1355
+ # session_id: str | None = None,
1356
+ # db_listeners: list[tuple[str, dict]] | None = None,
1357
+ # ) -> ClearAuditResult:
1358
+ # """Clear audit_log tables in sandbox databases.
1359
+
1360
+ # Args:
1361
+ # job_id: Job ID for the sandbox.
1362
+ # session_id: Session ID for refreshing state cache.
1363
+ # db_listeners: List of (name, config) tuples for database listeners.
1364
+
1365
+ # Returns:
1366
+ # ClearAuditResult with cleanup status.
1367
+ # """
1368
+ # if not db_listeners:
1369
+ # return ClearAuditResult(success=False, error="No database listeners provided")
1370
+
1371
+ # def _execute_db_cleanup(name: str, db_config: dict, local_port: int) -> dict:
1372
+ # """Execute DB cleanup using sync SQLAlchemy."""
1373
+ # db_type = db_config.get("db_type", "postgresql").lower()
1374
+ # db_user = db_config.get("db_user", "postgres" if db_type == "postgresql" else "root")
1375
+ # db_password = db_config.get("db_password", "")
1376
+ # db_database = db_config.get("db_database", "postgres")
1377
+
1378
+ # user = quote_plus(db_user)
1379
+ # password = quote_plus(db_password)
1380
+ # database = quote_plus(db_database)
1381
+
1382
+ # if db_type == "postgresql":
1383
+ # db_url = f"postgresql+psycopg2://{user}:{password}@127.0.0.1:{local_port}/{database}"
1384
+ # elif db_type in ("mysql", "mariadb"):
1385
+ # db_url = f"mysql+pymysql://{user}:{password}@127.0.0.1:{local_port}/{database}"
1386
+ # else:
1387
+ # return {"listener": name, "success": False, "error": f"Unsupported db_type: {db_type}"}
1388
+
1389
+ # engine = create_engine(db_url, pool_pre_ping=True)
1390
+ # tables_truncated = []
1391
+
1392
+ # with engine.begin() as conn:
1393
+ # if db_type == "postgresql":
1394
+ # result = conn.execute(
1395
+ # text("SELECT schemaname, tablename FROM pg_tables WHERE tablename = 'audit_log'")
1396
+ # )
1397
+ # tables = result.fetchall()
1398
+ # for schema, table in tables:
1399
+ # conn.execute(text(f"TRUNCATE TABLE {schema}.{table} RESTART IDENTITY CASCADE"))
1400
+ # tables_truncated.append(f"{schema}.{table}")
1401
+ # elif db_type in ("mysql", "mariadb"):
1402
+ # result = conn.execute(
1403
+ # text(
1404
+ # "SELECT table_schema, table_name FROM information_schema.tables "
1405
+ # "WHERE table_name = 'audit_log' AND table_schema = DATABASE()"
1406
+ # )
1407
+ # )
1408
+ # tables = result.fetchall()
1409
+ # conn.execute(text("SET FOREIGN_KEY_CHECKS = 0"))
1410
+ # for schema, table in tables:
1411
+ # conn.execute(text(f"DELETE FROM `{table}`"))
1412
+ # tables_truncated.append(table)
1413
+ # conn.execute(text("SET FOREIGN_KEY_CHECKS = 1"))
1414
+
1415
+ # engine.dispose()
1416
+ # return {"listener": name, "success": True, "tables_truncated": tables_truncated}
1417
+
1418
+ # async def clear_audit_via_tunnel(name: str, db_config: dict) -> dict:
1419
+ # """Clear audit_log by connecting via proxy tunnel."""
1420
+ # db_type = db_config.get("db_type", "postgresql").lower()
1421
+ # db_port = db_config.get("db_port", 5432 if db_type == "postgresql" else 3306)
1422
+
1423
+ # local_port = find_free_port()
1424
+ # tunnel = ProxyTunnel(
1425
+ # env_id=job_id,
1426
+ # db_port=db_port,
1427
+ # temp_password="newpass",
1428
+ # host_port=local_port,
1429
+ # )
1430
+
1431
+ # try:
1432
+ # await tunnel.start()
1433
+ # result = await asyncio.to_thread(_execute_db_cleanup, name, db_config, local_port)
1434
+ # return result
1435
+ # except Exception as e:
1436
+ # return {"listener": name, "success": False, "error": str(e)}
1437
+ # finally:
1438
+ # await tunnel.stop()
1439
+
1440
+ # async def run_all():
1441
+ # tasks = [clear_audit_via_tunnel(name, db_config) for name, db_config in db_listeners]
1442
+ # return await asyncio.gather(*tasks)
1443
+
1444
+ # try:
1445
+ # results = asyncio.run(run_all())
1446
+
1447
+ # # Refresh state cache
1448
+ # if session_id:
1449
+ # try:
1450
+ # sessions_state.sync(
1451
+ # client=self._http,
1452
+ # session_id=session_id,
1453
+ # x_api_key=self.api_key,
1454
+ # )
1455
+ # except Exception:
1456
+ # pass
1457
+
1458
+ # all_success = all(r["success"] for r in results)
1459
+ # return ClearAuditResult(success=all_success, results=list(results))
1460
+
1461
+ # except Exception as e:
1462
+ # return ClearAuditResult(success=False, error=str(e))