plato-sdk-v2 2.7.6__py3-none-any.whl → 2.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1461 @@
1
+ """Plato SDK v2 - Synchronous Sandbox Client.
2
+
3
+ The SandboxClient provides methods for sandbox development workflows:
4
+ creating sandboxes, managing SSH, syncing files, running flows, etc.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import json
11
+ import logging
12
+ import os
13
+ import shutil
14
+ import signal
15
+ import subprocess
16
+ import tempfile
17
+ import time
18
+ from pathlib import Path
19
+ from urllib.parse import quote
20
+
21
+ import httpx
22
+ import yaml
23
+ from playwright.async_api import async_playwright
24
+ from pydantic import BaseModel
25
+ from rich.console import Console
26
+
27
+ from plato._generated.api.v1.gitea import (
28
+ create_simulator_repository,
29
+ get_accessible_simulators,
30
+ get_gitea_credentials,
31
+ get_simulator_repository,
32
+ )
33
+ from plato._generated.api.v1.sandbox import start_worker
34
+ from plato._generated.api.v2.jobs import get_flows as jobs_get_flows
35
+ from plato._generated.api.v2.jobs import state as jobs_state
36
+ from plato._generated.api.v2.sessions import add_ssh_key as sessions_add_ssh_key
37
+ from plato._generated.api.v2.sessions import close as sessions_close
38
+ from plato._generated.api.v2.sessions import connect_network as sessions_connect_network
39
+ from plato._generated.api.v2.sessions import get_public_url as sessions_get_public_url
40
+ from plato._generated.api.v2.sessions import get_session_details
41
+ from plato._generated.api.v2.sessions import snapshot as sessions_snapshot
42
+ from plato._generated.api.v2.sessions import state as sessions_state
43
+ from plato._generated.models import (
44
+ AddSSHKeyRequest,
45
+ AppApiV2SchemasSessionCreateSnapshotResponse,
46
+ AppSchemasBuildModelsSimConfigDataset,
47
+ CloseSessionResponse,
48
+ CreateCheckpointRequest,
49
+ DatabaseMutationListenerConfig,
50
+ Flow,
51
+ PlatoConfig,
52
+ SessionStateResponse,
53
+ VMManagementRequest,
54
+ )
55
+ from plato.v2.async_.flow_executor import FlowExecutor
56
+ from plato.v2.models import SandboxState
57
+ from plato.v2.sync.client import Plato
58
+ from plato.v2.types import Env, EnvFromArtifact, EnvFromResource, EnvFromSimulator, SimConfigCompute
59
+
60
+ logger = logging.getLogger(__name__)
61
+
62
+
63
+ DEFAULT_BASE_URL = "https://plato.so"
64
+ DEFAULT_TIMEOUT = 600.0
65
+
66
+
67
+ def _get_plato_dir(working_dir: Path | None = None) -> Path:
68
+ """Get the .plato directory path."""
69
+ base = working_dir or Path.cwd()
70
+ return base / ".plato"
71
+
72
+
73
+ def _generate_ssh_key_pair(prefix: str, working_dir: Path | None = None) -> tuple[str, str]:
74
+ """Generate an SSH key pair and save to .plato/ directory.
75
+
76
+ Args:
77
+ prefix: Prefix for key filename.
78
+ working_dir: Working directory for .plato/.
79
+
80
+ Returns:
81
+ Tuple of (public_key_content, private_key_path).
82
+ """
83
+ plato_dir = _get_plato_dir(working_dir)
84
+ plato_dir.mkdir(mode=0o700, exist_ok=True)
85
+
86
+ key_name = f"ssh_key_{prefix}"
87
+ private_key_path = plato_dir / key_name
88
+ public_key_path = plato_dir / f"{key_name}.pub"
89
+
90
+ # Remove existing keys
91
+ if private_key_path.exists():
92
+ private_key_path.unlink()
93
+ if public_key_path.exists():
94
+ public_key_path.unlink()
95
+
96
+ # Generate key pair
97
+ subprocess.run(
98
+ [
99
+ "ssh-keygen",
100
+ "-t",
101
+ "ed25519",
102
+ "-f",
103
+ str(private_key_path),
104
+ "-N",
105
+ "",
106
+ "-q",
107
+ ],
108
+ check=True,
109
+ )
110
+
111
+ # Read public key
112
+ public_key = public_key_path.read_text().strip()
113
+
114
+ return public_key, str(private_key_path)
115
+
116
+
117
+ def _generate_ssh_config(
118
+ job_id: str,
119
+ private_key_path: str,
120
+ working_dir: Path | None = None,
121
+ ssh_host: str = "sandbox",
122
+ ) -> str:
123
+ """Generate SSH config file for easy access via gateway.
124
+
125
+ Args:
126
+ job_id: The job ID for routing.
127
+ private_key_path: Path to private key (absolute or relative).
128
+ working_dir: Working directory for .plato/.
129
+ ssh_host: Host alias in config.
130
+
131
+ Returns:
132
+ Path to the generated SSH config file (relative to working_dir).
133
+
134
+ Note:
135
+ The IdentityFile in the config uses a path relative to working_dir.
136
+ SSH commands must be run from the workspace root for paths to resolve.
137
+ """
138
+ gateway_host = os.getenv("PLATO_GATEWAY_HOST", "gateway.plato.so")
139
+
140
+ # Convert private key path to be relative to working_dir
141
+ # This ensures the config is portable if the workspace moves
142
+ base = working_dir or Path.cwd()
143
+ try:
144
+ relative_key_path = Path(private_key_path).relative_to(base)
145
+ except ValueError:
146
+ # If not relative to working_dir, keep as-is (shouldn't happen normally)
147
+ relative_key_path = Path(private_key_path)
148
+
149
+ # SNI format: {job_id}--{port}.{gateway_host} (matches v1 proxy.py)
150
+ ssh_port = 22
151
+ sni = f"{job_id}--{ssh_port}.{gateway_host}"
152
+
153
+ config_content = f"""# Plato Sandbox SSH Config
154
+ # Generated for job: {job_id}
155
+ # NOTE: Run SSH commands from workspace root for relative paths to resolve
156
+
157
+ Host {ssh_host}
158
+ HostName {job_id}
159
+ User root
160
+ IdentityFile {relative_key_path}
161
+ StrictHostKeyChecking no
162
+ UserKnownHostsFile /dev/null
163
+ LogLevel ERROR
164
+ ProxyCommand openssl s_client -quiet -connect {gateway_host}:443 -servername {sni} 2>/dev/null
165
+ """
166
+
167
+ plato_dir = _get_plato_dir(working_dir)
168
+ plato_dir.mkdir(mode=0o700, exist_ok=True)
169
+
170
+ config_path = plato_dir / "ssh_config"
171
+ config_path.write_text(config_content)
172
+
173
+ return str(config_path)
174
+
175
+
176
+ def _run_ssh_command(
177
+ ssh_config_path: str,
178
+ ssh_host: str,
179
+ command: str,
180
+ cwd: Path | str | None = None,
181
+ ) -> tuple[int, str, str]:
182
+ """Run a command via SSH.
183
+
184
+ Args:
185
+ ssh_config_path: Path to SSH config file (can be relative).
186
+ ssh_host: SSH host alias from config.
187
+ command: Command to execute on remote.
188
+ cwd: Working directory to run SSH from. Required when ssh_config_path
189
+ contains relative paths (e.g., for IdentityFile).
190
+
191
+ Returns:
192
+ Tuple of (returncode, stdout, stderr).
193
+ """
194
+ result = subprocess.run(
195
+ ["ssh", "-F", ssh_config_path, ssh_host, command],
196
+ capture_output=True,
197
+ text=True,
198
+ cwd=cwd,
199
+ )
200
+ return result.returncode, result.stdout, result.stderr
201
+
202
+
203
+ # =============================================================================
204
+ # HEARTBEAT UTILITIES
205
+ # =============================================================================
206
+
207
+
208
+ def _start_heartbeat_process(session_id: str, api_key: str) -> int | None:
209
+ """Start a background process that sends heartbeats.
210
+
211
+ Returns:
212
+ PID of the background process, or None if failed.
213
+ """
214
+ log_file = f"/tmp/plato_heartbeat_{session_id}.log"
215
+ base_url = os.getenv("PLATO_BASE_URL", "https://plato.so")
216
+ # Strip trailing /api if present to avoid double /api/api in URL
217
+ if base_url.endswith("/api"):
218
+ base_url = base_url[:-4]
219
+ base_url = base_url.rstrip("/")
220
+
221
+ heartbeat_script = f'''
222
+ import time
223
+ import os
224
+ import httpx
225
+ from datetime import datetime
226
+
227
+ session_id = "{session_id}"
228
+ api_key = "{api_key}"
229
+ base_url = "{base_url}"
230
+ log_file = "{log_file}"
231
+
232
+ def log(msg):
233
+ timestamp = datetime.now().isoformat()
234
+ with open(log_file, "a") as f:
235
+ f.write(f"[{{timestamp}}] {{msg}}\\n")
236
+ f.flush()
237
+
238
+ log(f"Heartbeat process started for session {{session_id}}")
239
+
240
+ heartbeat_count = 0
241
+ while True:
242
+ heartbeat_count += 1
243
+ try:
244
+ url = f"{{base_url}}/api/v2/sessions/{{session_id}}/heartbeat"
245
+ with httpx.Client(timeout=30) as client:
246
+ response = client.post(
247
+ url,
248
+ headers={{"X-API-Key": api_key}},
249
+ )
250
+ if response.status_code == 200:
251
+ result = response.json()
252
+ success = result.get("success", False)
253
+ log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, success={{success}}")
254
+ else:
255
+ log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, url={{url}}, body={{response.text[:500]}}")
256
+ except Exception as e:
257
+ log(f"Heartbeat #{{heartbeat_count}} EXCEPTION: {{type(e).__name__}}: {{e}}")
258
+ time.sleep(30)
259
+ '''
260
+
261
+ try:
262
+ process = subprocess.Popen(
263
+ ["python3", "-c", heartbeat_script],
264
+ stdout=subprocess.DEVNULL,
265
+ stderr=subprocess.DEVNULL,
266
+ stdin=subprocess.DEVNULL,
267
+ start_new_session=True,
268
+ )
269
+ return process.pid
270
+ except Exception:
271
+ return None
272
+
273
+
274
+ def _stop_heartbeat_process(pid: int) -> bool:
275
+ """Stop the heartbeat process.
276
+
277
+ Returns:
278
+ True if stopped successfully.
279
+ """
280
+ try:
281
+ os.kill(pid, signal.SIGTERM)
282
+ return True
283
+ except ProcessLookupError:
284
+ return True
285
+ except Exception:
286
+ return False
287
+
288
+
289
+ class SyncResult(BaseModel):
290
+ files_synced: int
291
+ bytes_synced: int
292
+
293
+
294
+ # =============================================================================
295
+ # TUNNEL
296
+ # =============================================================================
297
+
298
+ DEFAULT_GATEWAY_HOST = "gateway.plato.so"
299
+ DEFAULT_GATEWAY_PORT = 443
300
+
301
+
302
+ def _get_gateway_config() -> tuple[str, int]:
303
+ """Get gateway host and port from environment or defaults."""
304
+ host = os.environ.get("PLATO_GATEWAY_HOST", DEFAULT_GATEWAY_HOST)
305
+ port = int(os.environ.get("PLATO_GATEWAY_PORT", str(DEFAULT_GATEWAY_PORT)))
306
+ return host, port
307
+
308
+
309
+ def _create_tls_connection(
310
+ gateway_host: str,
311
+ gateway_port: int,
312
+ sni: str,
313
+ verify_ssl: bool = True,
314
+ ):
315
+ """Create a TLS connection to the gateway with the specified SNI."""
316
+ import socket
317
+ import ssl
318
+
319
+ context = ssl.create_default_context()
320
+ if not verify_ssl:
321
+ context.check_hostname = False
322
+ context.verify_mode = ssl.CERT_NONE
323
+
324
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
325
+ sock.settimeout(30)
326
+ ssl_sock = context.wrap_socket(sock, server_hostname=sni)
327
+
328
+ try:
329
+ ssl_sock.connect((gateway_host, gateway_port))
330
+ except Exception as e:
331
+ ssl_sock.close()
332
+ raise ConnectionError(f"Failed to connect to gateway: {e}") from e
333
+
334
+ return ssl_sock
335
+
336
+
337
+ def _forward_data(src, dst, name: str = "") -> None:
338
+ """Forward data between two sockets until one closes."""
339
+ import socket
340
+
341
+ try:
342
+ while True:
343
+ data = src.recv(4096)
344
+ if not data:
345
+ break
346
+ dst.sendall(data)
347
+ except (ConnectionResetError, BrokenPipeError, OSError):
348
+ pass
349
+ finally:
350
+ try:
351
+ dst.shutdown(socket.SHUT_WR)
352
+ except OSError:
353
+ pass
354
+
355
+
356
+ class Tunnel:
357
+ """A TCP tunnel to a remote port on a sandbox VM via the TLS gateway."""
358
+
359
+ def __init__(
360
+ self,
361
+ job_id: str,
362
+ remote_port: int,
363
+ local_port: int | None = None,
364
+ bind_address: str = "127.0.0.1",
365
+ verify_ssl: bool = True,
366
+ ):
367
+ self.job_id = job_id
368
+ self.remote_port = remote_port
369
+ self.local_port = local_port or remote_port
370
+ self.bind_address = bind_address
371
+ self.verify_ssl = verify_ssl
372
+
373
+ self._server = None
374
+ self._thread = None
375
+ self._running = False
376
+
377
+ def start(self) -> int:
378
+ """Start the tunnel. Returns the local port."""
379
+ import socket
380
+ import threading
381
+
382
+ gateway_host, gateway_port = _get_gateway_config()
383
+ sni = f"{self.job_id}--{self.remote_port}.{gateway_host}"
384
+
385
+ # Create local listener
386
+ self._server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
387
+ self._server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
388
+
389
+ try:
390
+ self._server.bind((self.bind_address, self.local_port))
391
+ self._server.listen(5)
392
+ except OSError as e:
393
+ raise ValueError(f"Could not bind to {self.bind_address}:{self.local_port}: {e}") from e
394
+
395
+ self._running = True
396
+
397
+ def handle_client(client_sock, client_addr):
398
+ try:
399
+ gateway_sock = _create_tls_connection(gateway_host, gateway_port, sni, verify_ssl=self.verify_ssl)
400
+ t1 = threading.Thread(
401
+ target=_forward_data,
402
+ args=(client_sock, gateway_sock, "client->gateway"),
403
+ daemon=True,
404
+ )
405
+ t2 = threading.Thread(
406
+ target=_forward_data,
407
+ args=(gateway_sock, client_sock, "gateway->client"),
408
+ daemon=True,
409
+ )
410
+ t1.start()
411
+ t2.start()
412
+ t1.join()
413
+ t2.join()
414
+ except Exception:
415
+ pass
416
+ finally:
417
+ try:
418
+ client_sock.close()
419
+ except OSError:
420
+ pass
421
+
422
+ def accept_loop():
423
+ server = self._server
424
+ assert server is not None, "Server must be initialized before accept_loop"
425
+ while self._running:
426
+ try:
427
+ server.settimeout(1.0)
428
+ client_sock, client_addr = server.accept()
429
+ threading.Thread(
430
+ target=handle_client,
431
+ args=(client_sock, client_addr),
432
+ daemon=True,
433
+ ).start()
434
+ except TimeoutError:
435
+ continue
436
+ except OSError:
437
+ break
438
+
439
+ self._thread = threading.Thread(target=accept_loop, daemon=True)
440
+ self._thread.start()
441
+
442
+ return self.local_port
443
+
444
+ def stop(self) -> None:
445
+ """Stop the tunnel."""
446
+ self._running = False
447
+ if self._server:
448
+ try:
449
+ self._server.close()
450
+ except OSError:
451
+ pass
452
+ if self._thread:
453
+ self._thread.join(timeout=2.0)
454
+
455
+ def __enter__(self):
456
+ self.start()
457
+ return self
458
+
459
+ def __exit__(self, exc_type, exc_val, exc_tb):
460
+ self.stop()
461
+ return False
462
+
463
+
464
+ # =============================================================================
465
+ # SANDBOX CLIENT
466
+ # =============================================================================
467
+
468
+
469
+ class SandboxClient:
470
+ """Synchronous client for sandbox development workflows.
471
+
472
+ Supports two modes:
473
+ 1. Stateless (working_dir=None): Pure operations, no file I/O
474
+ 2. Stateful (working_dir=Path): Persists state to .plato/state.yaml
475
+
476
+ Usage (stateless):
477
+ client = SandboxClient(api_key="...")
478
+ result = client.start(mode="blank", service="myservice")
479
+ client.stop(result.session_id)
480
+ client.close()
481
+
482
+ Usage (stateful - recommended for CLI/scripts):
483
+ client = SandboxClient(api_key="...", working_dir=Path("."))
484
+ client.start(mode="blank", service="myservice") # Saves state
485
+ # Later...
486
+ client = SandboxClient(api_key="...", working_dir=Path(".")) # Loads state
487
+ client.stop() # Uses saved session_id
488
+ """
489
+
490
+ # State file paths
491
+ PLATO_DIR = ".plato"
492
+
493
+ def __init__(
494
+ self,
495
+ working_dir: Path,
496
+ api_key: str | None = None,
497
+ base_url: str | None = None,
498
+ timeout: float = DEFAULT_TIMEOUT,
499
+ console: Console = Console(),
500
+ ):
501
+ self.api_key = api_key or os.environ.get("PLATO_API_KEY")
502
+ if not self.api_key:
503
+ raise ValueError("API key required. Set PLATO_API_KEY or pass api_key=")
504
+
505
+ url = base_url or os.environ.get("PLATO_BASE_URL", DEFAULT_BASE_URL)
506
+ if url.endswith("/api"):
507
+ url = url[:-4]
508
+ self.base_url = url.rstrip("/")
509
+ self.console = console
510
+ self.working_dir = working_dir
511
+
512
+ self._http = httpx.Client(
513
+ base_url=self.base_url,
514
+ timeout=httpx.Timeout(timeout),
515
+ )
516
+
517
+ def _get_plato_dir(self) -> Path:
518
+ return Path(self.working_dir) / self.PLATO_DIR
519
+
520
+ def close(self) -> None:
521
+ """Close the underlying HTTP client."""
522
+ self._http.close()
523
+
524
+ # -------------------------------------------------------------------------
525
+ # START
526
+ # -------------------------------------------------------------------------
527
+
528
+ def start(
529
+ self,
530
+ simulator_name: str | None = None,
531
+ mode: str = "blank",
532
+ # artifact or simulator mode
533
+ artifact_id: str | None = None,
534
+ dataset: str = "base",
535
+ tag: str = "latest",
536
+ # blankl or plato-config mode
537
+ cpus: int = 1,
538
+ memory: int = 2048,
539
+ disk: int = 10240,
540
+ app_port: int | None = None,
541
+ messaging_port: int | None = None,
542
+ # common
543
+ connect_network: bool = True,
544
+ timeout: int = 1800,
545
+ ) -> SandboxState:
546
+ """Start a sandbox environment.
547
+
548
+ Uses Plato SDK v2 internally for session creation.
549
+
550
+ Args:
551
+ mode: Start mode - "blank", "simulator", or "artifact".
552
+ simulator_name: Simulator name.
553
+ artifact_id: Artifact UUID.
554
+ dataset: Dataset name.
555
+ tag: Artifact tag.
556
+ cpus: Number of CPUs.
557
+ memory: Memory in MB.
558
+ disk: Disk in MB.
559
+ app_port: App port.
560
+ messaging_port: Messaging port.
561
+ connect_network: Whether to connect WireGuard network.
562
+
563
+ Returns:
564
+ SandboxState with sandbox info.
565
+ """
566
+
567
+ assert self.api_key is not None
568
+
569
+ # Build environment config using Env factory
570
+ env_config: EnvFromSimulator | EnvFromArtifact | EnvFromResource
571
+
572
+ if mode == "artifact" and artifact_id:
573
+ self.console.print(f"Starting from artifact: {artifact_id}")
574
+ env_config = Env.artifact(artifact_id)
575
+ elif mode == "simulator" and simulator_name:
576
+ self.console.print(f"Starting from simulator: {simulator_name}")
577
+ env_config = Env.simulator(simulator_name, tag=tag, dataset=dataset)
578
+ elif mode == "blank" and simulator_name:
579
+ self.console.print("Starting from blank")
580
+ sim_config = SimConfigCompute(
581
+ cpus=cpus, memory=memory, disk=disk, app_port=app_port, plato_messaging_port=messaging_port
582
+ )
583
+ env_config = Env.resource(simulator_name, sim_config)
584
+ elif mode == "config":
585
+ self.console.print("Starting from config")
586
+ # read plato-config.yml
587
+ plato_config_path = self.working_dir / "plato-config.yml"
588
+ with open(plato_config_path, "rb") as f:
589
+ plato_config = yaml.safe_load(f)
590
+ self.console.print(f"plato-config: {plato_config}")
591
+ plato_config_model = PlatoConfig.model_validate(plato_config)
592
+ dataset_config = plato_config_model.datasets[dataset]
593
+ simulator_name = plato_config_model.service
594
+ if not simulator_name:
595
+ raise ValueError("Service name is required in plato-config.yml")
596
+ if not dataset_config.compute:
597
+ raise ValueError(f"Compute configuration is required for dataset '{dataset}'")
598
+ self.console.print(f"simulator_name: {simulator_name}")
599
+ sim_config = SimConfigCompute(
600
+ cpus=dataset_config.compute.cpus,
601
+ memory=dataset_config.compute.memory,
602
+ disk=dataset_config.compute.disk,
603
+ app_port=dataset_config.compute.app_port,
604
+ plato_messaging_port=dataset_config.compute.plato_messaging_port,
605
+ )
606
+ env_config = Env.resource(simulator_name, sim_config)
607
+ self.console.print(f"env_config: {env_config}")
608
+ else:
609
+ raise ValueError(f"Invalid mode '{mode}' or missing required parameter")
610
+
611
+ # Use Plato SDK to create session (handles create, wait, network)
612
+ self.console.print(f"Creating session and waiting for VM to become ready (timeout={timeout}s)...")
613
+ plato = Plato(api_key=self.api_key, http_client=self._http)
614
+ session = plato.sessions.create(
615
+ envs=[env_config],
616
+ connect_network=connect_network,
617
+ timeout=timeout,
618
+ )
619
+ self.console.print(f"session: {session}")
620
+ session_id = session.session_id
621
+ job_id = session.envs[0].job_id if session.envs else None
622
+ if not job_id:
623
+ raise ValueError("No job ID found")
624
+ self.console.print(f"job_id: {job_id}")
625
+
626
+ # Get public URL with router target formatting (logic inlined)
627
+ public_url = None
628
+ try:
629
+ url_response = sessions_get_public_url.sync(
630
+ client=self._http,
631
+ session_id=session_id,
632
+ x_api_key=self.api_key,
633
+ )
634
+ if url_response and url_response.results:
635
+ for result in url_response.results.values():
636
+ url = result.url if hasattr(result, "url") else str(result)
637
+ if not url:
638
+ raise ValueError(f"No public URL found in result dict for job ID {job_id}")
639
+ if "_plato_router_target=" not in url:
640
+ target_param = f"_plato_router_target={simulator_name}.web.plato.so"
641
+ if "?" in url:
642
+ url = f"{url}&{target_param}"
643
+ else:
644
+ url = f"{url}?{target_param}"
645
+ public_url = url
646
+ except Exception as e:
647
+ raise ValueError(f"Error getting public URL: {e}") from e
648
+
649
+ # Setup SSH
650
+ ssh_config_path = None
651
+ try:
652
+ public_key, private_key_path = _generate_ssh_key_pair(session_id[:8], Path(self.working_dir))
653
+
654
+ add_key_request = AddSSHKeyRequest(public_key=public_key, username="root")
655
+ add_response = sessions_add_ssh_key.sync(
656
+ client=self._http,
657
+ session_id=session_id,
658
+ body=add_key_request,
659
+ x_api_key=self.api_key,
660
+ )
661
+
662
+ if add_response.success:
663
+ ssh_config_path = _generate_ssh_config(job_id, private_key_path, Path(self.working_dir))
664
+ except Exception as e:
665
+ logger.warning(f"SSH setup failed: {e}")
666
+
667
+ # Start heartbeat
668
+ heartbeat_pid = None
669
+ heartbeat_pid = _start_heartbeat_process(session_id, self.api_key)
670
+
671
+ # Parse out the simulator_name from service field in jobs in session details.
672
+ session_details = get_session_details.sync(
673
+ client=self._http,
674
+ session_id=session_id,
675
+ x_api_key=self.api_key,
676
+ )
677
+ if hasattr(session_details, "jobs"):
678
+ # session_details["jobs"] if dict, or session_details.jobs if object
679
+ jobs = (
680
+ session_details["jobs"] if isinstance(session_details, dict) else getattr(session_details, "jobs", None)
681
+ )
682
+ if jobs:
683
+ for j in jobs:
684
+ service = j.get("service") if isinstance(j, dict) else getattr(j, "service", None)
685
+ if service:
686
+ simulator_name = service
687
+ break
688
+
689
+ # Convert absolute paths to relative for state storage
690
+ def _to_relative(abs_path: str | None) -> str | None:
691
+ if not abs_path or not self.working_dir:
692
+ return abs_path
693
+ try:
694
+ return str(Path(abs_path).relative_to(self.working_dir))
695
+ except ValueError:
696
+ return abs_path # Keep absolute if not relative to working_dir
697
+
698
+ # Update internal state
699
+ rel_ssh_config = _to_relative(ssh_config_path)
700
+ ssh_host = "sandbox" if ssh_config_path else None
701
+ sandbox_state = SandboxState(
702
+ session_id=session_id,
703
+ job_id=job_id,
704
+ public_url=public_url,
705
+ mode=mode,
706
+ ssh_config_path=rel_ssh_config,
707
+ ssh_host=ssh_host,
708
+ ssh_command=f"ssh -F {rel_ssh_config} {ssh_host}" if rel_ssh_config else None,
709
+ heartbeat_pid=heartbeat_pid,
710
+ simulator_name=simulator_name,
711
+ dataset=dataset,
712
+ )
713
+ if mode == "artifact":
714
+ sandbox_state.artifact_id = artifact_id
715
+ elif mode == "simulator":
716
+ sandbox_state.tag = tag
717
+ elif mode == "blank":
718
+ sandbox_state.cpus = cpus
719
+ sandbox_state.memory = memory
720
+ sandbox_state.disk = disk
721
+ sandbox_state.app_port = app_port
722
+ sandbox_state.messaging_port = messaging_port
723
+
724
+ # Save state to working_dir/.plato/state.json
725
+ with open(self.working_dir / self.PLATO_DIR / "state.json", "w") as f:
726
+ json.dump(sandbox_state.model_dump(), f)
727
+
728
+ return sandbox_state
729
+
730
+ # CHECKED
731
+ def stop(
732
+ self,
733
+ session_id: str,
734
+ heartbeat_pid: int | None = None,
735
+ ) -> CloseSessionResponse:
736
+ if heartbeat_pid:
737
+ _stop_heartbeat_process(heartbeat_pid)
738
+
739
+ return sessions_close.sync(
740
+ client=self._http,
741
+ session_id=session_id,
742
+ x_api_key=self.api_key,
743
+ )
744
+
745
+ # CHECKED
746
+ def status(self, session_id: str) -> dict:
747
+ return get_session_details.sync(
748
+ client=self._http,
749
+ session_id=session_id,
750
+ x_api_key=self.api_key,
751
+ )
752
+
753
+ # CHECKED
754
+ def snapshot(
755
+ self,
756
+ session_id: str,
757
+ mode: str,
758
+ dataset: str,
759
+ ) -> AppApiV2SchemasSessionCreateSnapshotResponse:
760
+ checkpoint_request = CreateCheckpointRequest()
761
+
762
+ if mode == "config":
763
+ # read plato-config.yml
764
+ plato_config_path = self.working_dir / "plato-config.yml"
765
+ with open(plato_config_path, "rb") as f:
766
+ plato_config = yaml.safe_load(f)
767
+ checkpoint_request.plato_config = plato_config
768
+
769
+ # convert plato-config to pydantic model
770
+ plato_config_model = PlatoConfig.model_validate(plato_config)
771
+ dataset_compute = plato_config_model.datasets[dataset].compute
772
+ if not dataset_compute:
773
+ raise ValueError(f"Compute configuration is required for dataset '{dataset}'")
774
+ checkpoint_request.internal_app_port = dataset_compute.app_port
775
+ checkpoint_request.messaging_port = dataset_compute.plato_messaging_port
776
+ # we dont set target
777
+
778
+ # read flows.yml
779
+ flows_path = self.working_dir / "flows.yml"
780
+ with open(flows_path, "rb") as f:
781
+ flows = yaml.safe_load(f)
782
+ checkpoint_request.flows = flows
783
+
784
+ return sessions_snapshot.sync(
785
+ client=self._http,
786
+ session_id=session_id,
787
+ body=checkpoint_request,
788
+ x_api_key=self.api_key,
789
+ )
790
+
791
+ # CHECKED
792
+ def connect_network(self, session_id: str) -> dict:
793
+ return sessions_connect_network.sync(
794
+ client=self._http,
795
+ session_id=session_id,
796
+ x_api_key=self.api_key,
797
+ )
798
+
799
+ # CHECKED
800
+ def start_worker(
801
+ self,
802
+ job_id: str,
803
+ simulator: str,
804
+ dataset: str,
805
+ wait_timeout: int = 300, # 5 minutes
806
+ ) -> None:
807
+ with open(self.working_dir / "plato-config.yml", "rb") as f:
808
+ plato_config = yaml.safe_load(f)
809
+ plato_config_model = PlatoConfig.model_validate(plato_config)
810
+ dataset_config = plato_config_model.datasets[dataset]
811
+
812
+ # Convert AppApiV2SchemasArtifactSimConfigDataset to AppSchemasBuildModelsSimConfigDataset
813
+ # They have compatible fields but different nested types
814
+ dataset_config_dict = dataset_config.model_dump(exclude_none=True)
815
+
816
+ _ = start_worker.sync(
817
+ client=self._http,
818
+ public_id=job_id,
819
+ body=VMManagementRequest(
820
+ service=simulator,
821
+ dataset=dataset,
822
+ plato_dataset_config=AppSchemasBuildModelsSimConfigDataset.model_validate(dataset_config_dict),
823
+ ),
824
+ x_api_key=self.api_key,
825
+ )
826
+
827
+ if wait_timeout > 0:
828
+ start_time = time.time()
829
+ poll_interval = 10
830
+
831
+ while time.time() - start_time < wait_timeout:
832
+ try:
833
+ state_response = jobs_state.sync(
834
+ client=self._http,
835
+ job_id=job_id,
836
+ x_api_key=self.api_key,
837
+ )
838
+ if state_response:
839
+ state_dict = (
840
+ state_response.model_dump() if hasattr(state_response, "model_dump") else state_response
841
+ )
842
+ if isinstance(state_dict, dict) and "error" not in state_dict.get("state", {}):
843
+ return
844
+ except Exception:
845
+ pass
846
+
847
+ time.sleep(poll_interval)
848
+
849
+ # CHECKED
850
+ def sync(
851
+ self,
852
+ session_id: str,
853
+ simulator: str,
854
+ timeout: int = 120,
855
+ ) -> SyncResult:
856
+ """Sync local files to sandbox using rsync over SSH.
857
+
858
+ Uses the SSH config from .plato/ssh_config for fast, reliable file transfer.
859
+ """
860
+ local_path = self.working_dir
861
+ remote_path = f"/home/plato/worktree/{simulator}"
862
+
863
+ # Load SSH config from state
864
+ state_file = self.working_dir / ".plato" / "state.json"
865
+ if not state_file.exists():
866
+ raise ValueError("No state file found - run 'plato sandbox start' first")
867
+
868
+ with open(state_file) as f:
869
+ state = json.load(f)
870
+
871
+ ssh_config_path = state.get("ssh_config_path")
872
+ ssh_host = state.get("ssh_host", "sandbox")
873
+
874
+ if not ssh_config_path:
875
+ raise ValueError("No SSH config in state - run 'plato sandbox start' first")
876
+
877
+ exclude_patterns = [
878
+ "__pycache__",
879
+ "*.pyc",
880
+ ".git",
881
+ ".venv",
882
+ "venv",
883
+ "node_modules",
884
+ ".sandbox.yaml",
885
+ "*.egg-info",
886
+ ".pytest_cache",
887
+ ".mypy_cache",
888
+ ".DS_Store",
889
+ "*.swp",
890
+ "*.swo",
891
+ ".plato",
892
+ ]
893
+
894
+ # Build rsync command
895
+ rsync_cmd = [
896
+ "rsync",
897
+ "-avz",
898
+ "--delete",
899
+ "-e",
900
+ f"ssh -F {ssh_config_path}",
901
+ ]
902
+
903
+ # Add excludes
904
+ for pattern in exclude_patterns:
905
+ rsync_cmd.extend(["--exclude", pattern])
906
+
907
+ # Source and destination
908
+ rsync_cmd.append(f"{local_path}/")
909
+ rsync_cmd.append(f"{ssh_host}:{remote_path}/")
910
+
911
+ self.console.print(f"[dim]rsync -> {ssh_host}:{remote_path}/[/dim]")
912
+
913
+ # Ensure rsync is installed on the VM and create remote directory
914
+ setup_result = subprocess.run(
915
+ [
916
+ "ssh",
917
+ "-F",
918
+ ssh_config_path,
919
+ ssh_host,
920
+ f"which rsync >/dev/null 2>&1 || (apt-get update -qq && apt-get install -y -qq rsync) && mkdir -p {remote_path}",
921
+ ],
922
+ capture_output=True,
923
+ text=True,
924
+ cwd=self.working_dir,
925
+ )
926
+ if setup_result.returncode != 0:
927
+ raise ValueError(f"Failed to setup remote: {setup_result.stderr}")
928
+
929
+ # Run rsync
930
+ result = subprocess.run(
931
+ rsync_cmd,
932
+ capture_output=True,
933
+ text=True,
934
+ cwd=self.working_dir,
935
+ timeout=timeout,
936
+ )
937
+
938
+ if result.returncode != 0:
939
+ raise ValueError(f"rsync failed: {result.stderr}")
940
+
941
+ # Count synced files from rsync output
942
+ lines = result.stdout.strip().split("\n") if result.stdout else []
943
+ file_count = len(
944
+ [
945
+ line
946
+ for line in lines
947
+ if line
948
+ and not line.startswith("sending")
949
+ and not line.startswith("sent")
950
+ and not line.startswith("total")
951
+ ]
952
+ )
953
+
954
+ # Get bytes from rsync output (e.g., "sent 1,234 bytes")
955
+ bytes_synced = 0
956
+ for line in lines:
957
+ if "sent" in line and "bytes" in line:
958
+ import re
959
+
960
+ match = re.search(r"sent ([\d,]+) bytes", line)
961
+ if match:
962
+ bytes_synced = int(match.group(1).replace(",", ""))
963
+ break
964
+
965
+ return SyncResult(
966
+ files_synced=file_count,
967
+ bytes_synced=bytes_synced,
968
+ )
969
+
970
+ def tunnel(
971
+ self,
972
+ job_id: str,
973
+ remote_port: int,
974
+ local_port: int | None = None,
975
+ bind_address: str = "127.0.0.1",
976
+ ) -> Tunnel:
977
+ return Tunnel(
978
+ job_id=job_id,
979
+ remote_port=remote_port,
980
+ local_port=local_port,
981
+ bind_address=bind_address,
982
+ )
983
+
984
+ def run_audit_ui(
985
+ self,
986
+ job_id: str | None = None,
987
+ dataset: str = "base",
988
+ no_tunnel: bool = False,
989
+ ) -> None:
990
+ import shutil
991
+
992
+ if not shutil.which("streamlit"):
993
+ raise ValueError("streamlit not installed. Run: pip install streamlit psycopg2-binary pymysql")
994
+
995
+ ui_file = Path(__file__).resolve().parent.parent.parent / "cli" / "audit_ui.py"
996
+ if not ui_file.exists():
997
+ raise ValueError(f"UI file not found: {ui_file}")
998
+
999
+ # Get DB listener from plato-config.yml
1000
+ db_listener: DatabaseMutationListenerConfig | None = None
1001
+ for config_path in [self.working_dir / "plato-config.yml", self.working_dir / "plato-config.yaml"]:
1002
+ if config_path.exists():
1003
+ with open(config_path) as f:
1004
+ plato_config = PlatoConfig.model_validate(yaml.safe_load(f))
1005
+ dataset_config = plato_config.datasets.get(dataset)
1006
+ if dataset_config and dataset_config.listeners:
1007
+ for listener in dataset_config.listeners.values():
1008
+ if isinstance(listener, DatabaseMutationListenerConfig):
1009
+ db_listener = listener
1010
+ break
1011
+ break
1012
+ tunnel = None
1013
+
1014
+ if db_listener and job_id and not no_tunnel:
1015
+ self.console.print(f"Starting tunnel to {db_listener.db_type} on port {db_listener.db_port}...")
1016
+ tunnel = self.tunnel(job_id, db_listener.db_port)
1017
+ tunnel.start()
1018
+ time.sleep(1) # Let tunnel stabilize
1019
+ self.console.print(
1020
+ f"[green]Tunnel open:[/green] localhost:{db_listener.db_port} -> VM:{db_listener.db_port}"
1021
+ )
1022
+
1023
+ # Pass db config via environment variables
1024
+ env = os.environ.copy()
1025
+ if db_listener:
1026
+ env["PLATO_DB_HOST"] = "127.0.0.1"
1027
+ env["PLATO_DB_PORT"] = str(db_listener.db_port)
1028
+ env["PLATO_DB_USER"] = db_listener.db_user
1029
+ env["PLATO_DB_PASSWORD"] = db_listener.db_password or ""
1030
+ env["PLATO_DB_NAME"] = db_listener.db_database
1031
+ env["PLATO_DB_TYPE"] = str(db_listener.db_type)
1032
+ self.console.print(
1033
+ f"[dim]DB config: {db_listener.db_user}@127.0.0.1:{db_listener.db_port}/{db_listener.db_database}[/dim]"
1034
+ )
1035
+
1036
+ try:
1037
+ subprocess.run(["streamlit", "run", str(ui_file)], env=env)
1038
+ finally:
1039
+ if tunnel:
1040
+ tunnel.stop()
1041
+ self.console.print("[yellow]Tunnel closed[/yellow]")
1042
+
1043
+ def run_flow(
1044
+ self,
1045
+ url: str,
1046
+ flow_name: str,
1047
+ dataset: str,
1048
+ use_api: bool = False,
1049
+ job_id: str | None = None,
1050
+ ) -> None:
1051
+ flow_obj: Flow | None = None
1052
+ screenshots_dir = self.working_dir / "screenshots"
1053
+
1054
+ if use_api:
1055
+ # Fetch from API
1056
+ if not job_id:
1057
+ raise ValueError("job_id required when use_api=True")
1058
+
1059
+ self.console.print("[cyan]Flow source: API[/cyan]")
1060
+ flows_response = jobs_get_flows.sync(
1061
+ client=self._http,
1062
+ job_id=job_id,
1063
+ x_api_key=self.api_key,
1064
+ )
1065
+
1066
+ if flows_response:
1067
+ for flow_data in flows_response:
1068
+ if isinstance(flow_data, dict):
1069
+ if flow_data.get("name") == flow_name:
1070
+ flow_obj = Flow.model_validate(flow_data)
1071
+ break
1072
+ elif hasattr(flow_data, "name") and flow_data.name == flow_name:
1073
+ flow_obj = (
1074
+ flow_data if isinstance(flow_data, Flow) else Flow.model_validate(flow_data.model_dump())
1075
+ )
1076
+ break
1077
+
1078
+ if not flow_obj:
1079
+ available = [
1080
+ f.get("name") if isinstance(f, dict) else getattr(f, "name", "?") for f in (flows_response or [])
1081
+ ]
1082
+ raise ValueError(f"Flow '{flow_name}' not found in API. Available: {available}")
1083
+ else:
1084
+ # Use local flows
1085
+ config_paths = [
1086
+ self.working_dir / "plato-config.yml",
1087
+ self.working_dir / "plato-config.yaml",
1088
+ ]
1089
+
1090
+ for config_path in config_paths:
1091
+ if config_path.exists():
1092
+ with open(config_path) as f:
1093
+ plato_config = PlatoConfig.model_validate(yaml.safe_load(f))
1094
+
1095
+ dataset_config = plato_config.datasets.get(dataset)
1096
+ if dataset_config and dataset_config.metadata:
1097
+ flows_path = dataset_config.metadata.flows_path
1098
+
1099
+ if flows_path:
1100
+ flow_file = (
1101
+ config_path.parent / flows_path
1102
+ if not Path(flows_path).is_absolute()
1103
+ else Path(flows_path)
1104
+ )
1105
+
1106
+ if flow_file.exists():
1107
+ with open(flow_file) as f:
1108
+ flow_dict = yaml.safe_load(f)
1109
+ flow_obj = next(
1110
+ (
1111
+ Flow.model_validate(fl)
1112
+ for fl in flow_dict.get("flows", [])
1113
+ if fl.get("name") == flow_name
1114
+ ),
1115
+ None,
1116
+ )
1117
+ if flow_obj:
1118
+ screenshots_dir = flow_file.parent / "screenshots"
1119
+ self.console.print(f"[cyan]Flow source: local ({flow_file})[/cyan]")
1120
+ break
1121
+
1122
+ if not flow_obj:
1123
+ raise ValueError(f"Flow '{flow_name}' not found in local config")
1124
+
1125
+ # Assert for type narrowing in nested function (checked above in both branches)
1126
+ assert flow_obj is not None
1127
+ validated_flow: Flow = flow_obj
1128
+
1129
+ # Run the flow with Playwright
1130
+ async def _run_flow():
1131
+ async with async_playwright() as p:
1132
+ browser = await p.chromium.launch(headless=False)
1133
+ try:
1134
+ page = await browser.new_page()
1135
+ await page.goto(url)
1136
+ executor = FlowExecutor(page, validated_flow, screenshots_dir)
1137
+ await executor.execute()
1138
+ finally:
1139
+ await browser.close()
1140
+
1141
+ asyncio.get_event_loop().run_until_complete(_run_flow())
1142
+
1143
+ # CHECKED
1144
+ def state(self, session_id: str) -> SessionStateResponse:
1145
+ response = sessions_state.sync(
1146
+ client=self._http,
1147
+ session_id=session_id,
1148
+ merge_mutations=True,
1149
+ x_api_key=self.api_key,
1150
+ )
1151
+ return response
1152
+
1153
+ # CHECKED
1154
+ def start_services(
1155
+ self,
1156
+ simulator_name: str,
1157
+ ssh_config_path: str,
1158
+ ssh_host: str,
1159
+ dataset: str,
1160
+ ) -> list[dict[str, str]]:
1161
+ # Get Gitea credentials
1162
+ creds = get_gitea_credentials.sync(client=self._http, x_api_key=self.api_key)
1163
+
1164
+ # Get accessible simulators
1165
+ simulators = get_accessible_simulators.sync(client=self._http, x_api_key=self.api_key)
1166
+ simulator = None
1167
+ for sim in simulators:
1168
+ sim_name = sim.get("name") if isinstance(sim, dict) else getattr(sim, "name", None)
1169
+ if sim_name and sim_name.lower() == simulator_name.lower():
1170
+ simulator = sim
1171
+ break
1172
+ if not simulator:
1173
+ raise ValueError(f"Simulator '{simulator_name}' not found in gitea accessible simulators")
1174
+
1175
+ # Get or create repo
1176
+ sim_id = simulator.get("id") if isinstance(simulator, dict) else getattr(simulator, "id", None)
1177
+ has_repo = simulator.get("has_repo") if isinstance(simulator, dict) else getattr(simulator, "has_repo", False)
1178
+ if has_repo:
1179
+ repo = get_simulator_repository.sync(client=self._http, simulator_id=sim_id, x_api_key=self.api_key) # type: ignore
1180
+ else:
1181
+ repo = create_simulator_repository.sync(client=self._http, simulator_id=sim_id, x_api_key=self.api_key) # type: ignore
1182
+
1183
+ clone_url = repo.clone_url
1184
+ if not clone_url:
1185
+ raise ValueError("No clone URL available for gitea repository")
1186
+
1187
+ # Build authenticated URL
1188
+ encoded_username = quote(creds.username, safe="")
1189
+ encoded_password = quote(creds.password, safe="")
1190
+ auth_clone_url = clone_url.replace("https://", f"https://{encoded_username}:{encoded_password}@", 1)
1191
+
1192
+ repo_dir = f"/home/plato/worktree/{simulator_name}"
1193
+ branch_name = f"workspace-{int(time.time())}"
1194
+
1195
+ # Clone, copy, push
1196
+ with tempfile.TemporaryDirectory(prefix="plato-hub-") as temp_dir:
1197
+ temp_repo = Path(temp_dir) / "repo"
1198
+ git_env = os.environ.copy()
1199
+ git_env["GIT_TERMINAL_PROMPT"] = "0"
1200
+ git_env["GIT_ASKPASS"] = ""
1201
+
1202
+ subprocess.run(
1203
+ ["git", "clone", auth_clone_url, str(temp_repo)], capture_output=True, env=git_env, check=True
1204
+ )
1205
+ subprocess.run(
1206
+ ["git", "checkout", "-b", branch_name], cwd=temp_repo, capture_output=True, env=git_env, check=True
1207
+ )
1208
+
1209
+ # Copy files
1210
+ current_dir = Path(self.working_dir)
1211
+
1212
+ def _copy_files(src_dir: Path, dst_dir: Path) -> None:
1213
+ """Copy files, skipping .git/ and .plato-hub.json."""
1214
+ for src_path in src_dir.rglob("*"):
1215
+ rel_path = src_path.relative_to(src_dir)
1216
+ if ".git" in rel_path.parts or rel_path.name == ".plato-hub.json":
1217
+ continue
1218
+ dst_path = dst_dir / rel_path
1219
+ if src_path.is_dir():
1220
+ dst_path.mkdir(parents=True, exist_ok=True)
1221
+ else:
1222
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
1223
+ shutil.copy2(src_path, dst_path)
1224
+
1225
+ _copy_files(current_dir, temp_repo)
1226
+
1227
+ subprocess.run(["git", "add", "."], cwd=temp_repo, capture_output=True, env=git_env)
1228
+ result = subprocess.run(
1229
+ ["git", "status", "--porcelain"], cwd=temp_repo, capture_output=True, text=True, env=git_env
1230
+ )
1231
+
1232
+ if result.stdout.strip():
1233
+ subprocess.run(
1234
+ ["git", "commit", "-m", "Sync from local workspace"],
1235
+ cwd=temp_repo,
1236
+ capture_output=True,
1237
+ env=git_env,
1238
+ )
1239
+
1240
+ subprocess.run(
1241
+ ["git", "remote", "set-url", "origin", auth_clone_url],
1242
+ cwd=temp_repo,
1243
+ capture_output=True,
1244
+ env=git_env,
1245
+ )
1246
+ subprocess.run(
1247
+ ["git", "push", "-u", "origin", branch_name],
1248
+ cwd=temp_repo,
1249
+ capture_output=True,
1250
+ env=git_env,
1251
+ check=True,
1252
+ )
1253
+
1254
+ # Clone on VM - first verify SSH works
1255
+ # Debug: show SSH config being used
1256
+ ssh_config_full_path = (
1257
+ Path(self.working_dir) / ssh_config_path
1258
+ if not Path(ssh_config_path).is_absolute()
1259
+ else Path(ssh_config_path)
1260
+ )
1261
+ if not ssh_config_full_path.exists():
1262
+ raise ValueError(f"SSH config file not found: {ssh_config_full_path}")
1263
+
1264
+ self.console.print(f"[dim]SSH config: {ssh_config_full_path}[/dim]")
1265
+ self.console.print(f"[dim]SSH host: {ssh_host}[/dim]")
1266
+
1267
+ # Run SSH with verbose to see what's happening
1268
+ ssh_cmd = ["ssh", "-v", "-F", ssh_config_path, ssh_host, "echo 'SSH connection OK'"]
1269
+ self.console.print(f"[dim]Running: {' '.join(ssh_cmd)}[/dim]")
1270
+ self.console.print(f"[dim]Working dir: {self.working_dir}[/dim]")
1271
+
1272
+ result = subprocess.run(
1273
+ ssh_cmd,
1274
+ capture_output=True,
1275
+ text=True,
1276
+ cwd=self.working_dir,
1277
+ )
1278
+ ret, stdout, stderr = result.returncode, result.stdout, result.stderr
1279
+
1280
+ if ret != 0:
1281
+ # Show SSH config contents for debugging
1282
+ try:
1283
+ config_content = ssh_config_full_path.read_text()
1284
+ self.console.print(f"[yellow]SSH config contents:[/yellow]\n{config_content}")
1285
+ except Exception:
1286
+ pass
1287
+ # Show SSH verbose output
1288
+ self.console.print(f"[yellow]SSH stderr (verbose):[/yellow]\n{stderr}")
1289
+ error_output = stderr or stdout or "(no output)"
1290
+ raise ValueError(f"SSH connection failed (exit {ret})")
1291
+
1292
+ _run_ssh_command(ssh_config_path, ssh_host, "mkdir -p /home/plato/worktree", cwd=self.working_dir)
1293
+ _run_ssh_command(ssh_config_path, ssh_host, f"rm -rf {repo_dir}", cwd=self.working_dir)
1294
+
1295
+ # Clone repo - mask credentials in error output
1296
+ ret, stdout, stderr = _run_ssh_command(
1297
+ ssh_config_path,
1298
+ ssh_host,
1299
+ f"git clone -b {branch_name} {auth_clone_url} {repo_dir}",
1300
+ cwd=self.working_dir,
1301
+ )
1302
+ if ret != 0:
1303
+ # Mask credentials in error output
1304
+ safe_url = clone_url # Use non-authenticated URL in error
1305
+ error_output = stderr or stdout or "(no output)"
1306
+ error_output = error_output.replace(creds.username, "***").replace(creds.password, "***")
1307
+ raise ValueError(f"Clone failed (exit {ret}) for {safe_url} branch {branch_name}: {error_output}")
1308
+
1309
+ # ECR auth
1310
+ ecr_result = subprocess.run(
1311
+ ["aws", "ecr", "get-login-password", "--region", "us-west-1"], capture_output=True, text=True
1312
+ )
1313
+ if ecr_result.returncode == 0:
1314
+ ecr_token = ecr_result.stdout.strip()
1315
+ ecr_registry = "383806609161.dkr.ecr.us-west-1.amazonaws.com"
1316
+ _run_ssh_command(
1317
+ ssh_config_path,
1318
+ ssh_host,
1319
+ f"echo '{ecr_token}' | docker login --username AWS --password-stdin {ecr_registry}",
1320
+ cwd=self.working_dir,
1321
+ )
1322
+
1323
+ # Start services
1324
+ services_started = []
1325
+ with open(self.working_dir / "plato-config.yml", "rb") as f:
1326
+ plato_config = yaml.safe_load(f)
1327
+ plato_config_model = PlatoConfig.model_validate(plato_config)
1328
+ services_config = plato_config_model.datasets[dataset].services
1329
+ if not services_config:
1330
+ self.console.print("[yellow]No services configured, skipping service startup[/yellow]")
1331
+ return services_started
1332
+ for svc_name, svc_config in services_config.items():
1333
+ # svc_config is a Pydantic model (DockerComposeServiceConfig), use getattr
1334
+ svc_type = getattr(svc_config, "type", "")
1335
+ if svc_type == "docker-compose":
1336
+ compose_file = getattr(svc_config, "file", "docker-compose.yml")
1337
+ compose_cmd = f"cd {repo_dir} && docker compose -f {compose_file} up -d"
1338
+ ret, _, stderr = _run_ssh_command(ssh_config_path, ssh_host, compose_cmd, cwd=self.working_dir)
1339
+ if ret != 0:
1340
+ raise ValueError(f"Failed to start {svc_name}: {stderr}")
1341
+ services_started.append({"name": svc_name, "type": "docker-compose", "file": compose_file})
1342
+ else:
1343
+ raise ValueError(f"Unsupported service type: {svc_type}")
1344
+
1345
+ return services_started
1346
+
1347
+ # # -------------------------------------------------------------------------
1348
+ # # RUN FLOW
1349
+ # # -------------------------------------------------------------------------
1350
+
1351
+ # def clear_audit(
1352
+ # self,
1353
+ # job_id: str,
1354
+ # session_id: str | None = None,
1355
+ # db_listeners: list[tuple[str, dict]] | None = None,
1356
+ # ) -> ClearAuditResult:
1357
+ # """Clear audit_log tables in sandbox databases.
1358
+
1359
+ # Args:
1360
+ # job_id: Job ID for the sandbox.
1361
+ # session_id: Session ID for refreshing state cache.
1362
+ # db_listeners: List of (name, config) tuples for database listeners.
1363
+
1364
+ # Returns:
1365
+ # ClearAuditResult with cleanup status.
1366
+ # """
1367
+ # if not db_listeners:
1368
+ # return ClearAuditResult(success=False, error="No database listeners provided")
1369
+
1370
+ # def _execute_db_cleanup(name: str, db_config: dict, local_port: int) -> dict:
1371
+ # """Execute DB cleanup using sync SQLAlchemy."""
1372
+ # db_type = db_config.get("db_type", "postgresql").lower()
1373
+ # db_user = db_config.get("db_user", "postgres" if db_type == "postgresql" else "root")
1374
+ # db_password = db_config.get("db_password", "")
1375
+ # db_database = db_config.get("db_database", "postgres")
1376
+
1377
+ # user = quote_plus(db_user)
1378
+ # password = quote_plus(db_password)
1379
+ # database = quote_plus(db_database)
1380
+
1381
+ # if db_type == "postgresql":
1382
+ # db_url = f"postgresql+psycopg2://{user}:{password}@127.0.0.1:{local_port}/{database}"
1383
+ # elif db_type in ("mysql", "mariadb"):
1384
+ # db_url = f"mysql+pymysql://{user}:{password}@127.0.0.1:{local_port}/{database}"
1385
+ # else:
1386
+ # return {"listener": name, "success": False, "error": f"Unsupported db_type: {db_type}"}
1387
+
1388
+ # engine = create_engine(db_url, pool_pre_ping=True)
1389
+ # tables_truncated = []
1390
+
1391
+ # with engine.begin() as conn:
1392
+ # if db_type == "postgresql":
1393
+ # result = conn.execute(
1394
+ # text("SELECT schemaname, tablename FROM pg_tables WHERE tablename = 'audit_log'")
1395
+ # )
1396
+ # tables = result.fetchall()
1397
+ # for schema, table in tables:
1398
+ # conn.execute(text(f"TRUNCATE TABLE {schema}.{table} RESTART IDENTITY CASCADE"))
1399
+ # tables_truncated.append(f"{schema}.{table}")
1400
+ # elif db_type in ("mysql", "mariadb"):
1401
+ # result = conn.execute(
1402
+ # text(
1403
+ # "SELECT table_schema, table_name FROM information_schema.tables "
1404
+ # "WHERE table_name = 'audit_log' AND table_schema = DATABASE()"
1405
+ # )
1406
+ # )
1407
+ # tables = result.fetchall()
1408
+ # conn.execute(text("SET FOREIGN_KEY_CHECKS = 0"))
1409
+ # for schema, table in tables:
1410
+ # conn.execute(text(f"DELETE FROM `{table}`"))
1411
+ # tables_truncated.append(table)
1412
+ # conn.execute(text("SET FOREIGN_KEY_CHECKS = 1"))
1413
+
1414
+ # engine.dispose()
1415
+ # return {"listener": name, "success": True, "tables_truncated": tables_truncated}
1416
+
1417
+ # async def clear_audit_via_tunnel(name: str, db_config: dict) -> dict:
1418
+ # """Clear audit_log by connecting via proxy tunnel."""
1419
+ # db_type = db_config.get("db_type", "postgresql").lower()
1420
+ # db_port = db_config.get("db_port", 5432 if db_type == "postgresql" else 3306)
1421
+
1422
+ # local_port = find_free_port()
1423
+ # tunnel = ProxyTunnel(
1424
+ # env_id=job_id,
1425
+ # db_port=db_port,
1426
+ # temp_password="newpass",
1427
+ # host_port=local_port,
1428
+ # )
1429
+
1430
+ # try:
1431
+ # await tunnel.start()
1432
+ # result = await asyncio.to_thread(_execute_db_cleanup, name, db_config, local_port)
1433
+ # return result
1434
+ # except Exception as e:
1435
+ # return {"listener": name, "success": False, "error": str(e)}
1436
+ # finally:
1437
+ # await tunnel.stop()
1438
+
1439
+ # async def run_all():
1440
+ # tasks = [clear_audit_via_tunnel(name, db_config) for name, db_config in db_listeners]
1441
+ # return await asyncio.gather(*tasks)
1442
+
1443
+ # try:
1444
+ # results = asyncio.run(run_all())
1445
+
1446
+ # # Refresh state cache
1447
+ # if session_id:
1448
+ # try:
1449
+ # sessions_state.sync(
1450
+ # client=self._http,
1451
+ # session_id=session_id,
1452
+ # x_api_key=self.api_key,
1453
+ # )
1454
+ # except Exception:
1455
+ # pass
1456
+
1457
+ # all_success = all(r["success"] for r in results)
1458
+ # return ClearAuditResult(success=all_success, results=list(results))
1459
+
1460
+ # except Exception as e:
1461
+ # return ClearAuditResult(success=False, error=str(e))