plato-sdk-v2 2.7.6__py3-none-any.whl → 2.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- plato/cli/__init__.py +5 -0
- plato/cli/agent.py +1209 -0
- plato/cli/audit_ui.py +316 -0
- plato/cli/chronos.py +817 -0
- plato/cli/main.py +193 -0
- plato/cli/pm.py +1204 -0
- plato/cli/proxy.py +222 -0
- plato/cli/sandbox.py +808 -0
- plato/cli/utils.py +200 -0
- plato/cli/verify.py +690 -0
- plato/cli/world.py +250 -0
- plato/v1/cli/pm.py +4 -1
- plato/v2/__init__.py +2 -0
- plato/v2/models.py +42 -0
- plato/v2/sync/__init__.py +6 -0
- plato/v2/sync/client.py +6 -3
- plato/v2/sync/sandbox.py +1462 -0
- {plato_sdk_v2-2.7.6.dist-info → plato_sdk_v2-2.7.8.dist-info}/METADATA +1 -1
- {plato_sdk_v2-2.7.6.dist-info → plato_sdk_v2-2.7.8.dist-info}/RECORD +21 -9
- {plato_sdk_v2-2.7.6.dist-info → plato_sdk_v2-2.7.8.dist-info}/entry_points.txt +1 -1
- {plato_sdk_v2-2.7.6.dist-info → plato_sdk_v2-2.7.8.dist-info}/WHEEL +0 -0
plato/v2/sync/sandbox.py
ADDED
|
@@ -0,0 +1,1462 @@
|
|
|
1
|
+
"""Plato SDK v2 - Synchronous Sandbox Client.
|
|
2
|
+
|
|
3
|
+
The SandboxClient provides methods for sandbox development workflows:
|
|
4
|
+
creating sandboxes, managing SSH, syncing files, running flows, etc.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import shutil
|
|
14
|
+
import signal
|
|
15
|
+
import subprocess
|
|
16
|
+
import tempfile
|
|
17
|
+
import time
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from urllib.parse import quote
|
|
20
|
+
|
|
21
|
+
import httpx
|
|
22
|
+
import yaml
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
from rich.console import Console
|
|
25
|
+
|
|
26
|
+
from plato._generated.api.v1.gitea import (
|
|
27
|
+
create_simulator_repository,
|
|
28
|
+
get_accessible_simulators,
|
|
29
|
+
get_gitea_credentials,
|
|
30
|
+
get_simulator_repository,
|
|
31
|
+
)
|
|
32
|
+
from plato._generated.api.v1.sandbox import start_worker
|
|
33
|
+
from plato._generated.api.v2.jobs import get_flows as jobs_get_flows
|
|
34
|
+
from plato._generated.api.v2.jobs import state as jobs_state
|
|
35
|
+
from plato._generated.api.v2.sessions import add_ssh_key as sessions_add_ssh_key
|
|
36
|
+
from plato._generated.api.v2.sessions import close as sessions_close
|
|
37
|
+
from plato._generated.api.v2.sessions import connect_network as sessions_connect_network
|
|
38
|
+
from plato._generated.api.v2.sessions import get_public_url as sessions_get_public_url
|
|
39
|
+
from plato._generated.api.v2.sessions import get_session_details
|
|
40
|
+
from plato._generated.api.v2.sessions import snapshot as sessions_snapshot
|
|
41
|
+
from plato._generated.api.v2.sessions import state as sessions_state
|
|
42
|
+
from plato._generated.models import (
|
|
43
|
+
AddSSHKeyRequest,
|
|
44
|
+
AppApiV2SchemasSessionCreateSnapshotResponse,
|
|
45
|
+
AppSchemasBuildModelsSimConfigDataset,
|
|
46
|
+
CloseSessionResponse,
|
|
47
|
+
CreateCheckpointRequest,
|
|
48
|
+
DatabaseMutationListenerConfig,
|
|
49
|
+
Flow,
|
|
50
|
+
PlatoConfig,
|
|
51
|
+
SessionStateResponse,
|
|
52
|
+
VMManagementRequest,
|
|
53
|
+
)
|
|
54
|
+
from plato.v2.async_.flow_executor import FlowExecutor
|
|
55
|
+
from plato.v2.models import SandboxState
|
|
56
|
+
from plato.v2.sync.client import Plato
|
|
57
|
+
from plato.v2.types import Env, EnvFromArtifact, EnvFromResource, EnvFromSimulator, SimConfigCompute
|
|
58
|
+
|
|
59
|
+
logger = logging.getLogger(__name__)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
DEFAULT_BASE_URL = "https://plato.so"
|
|
63
|
+
DEFAULT_TIMEOUT = 600.0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _get_plato_dir(working_dir: Path | None = None) -> Path:
|
|
67
|
+
"""Get the .plato directory path."""
|
|
68
|
+
base = working_dir or Path.cwd()
|
|
69
|
+
return base / ".plato"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _generate_ssh_key_pair(prefix: str, working_dir: Path | None = None) -> tuple[str, str]:
|
|
73
|
+
"""Generate an SSH key pair and save to .plato/ directory.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
prefix: Prefix for key filename.
|
|
77
|
+
working_dir: Working directory for .plato/.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Tuple of (public_key_content, private_key_path).
|
|
81
|
+
"""
|
|
82
|
+
plato_dir = _get_plato_dir(working_dir)
|
|
83
|
+
plato_dir.mkdir(mode=0o700, exist_ok=True)
|
|
84
|
+
|
|
85
|
+
key_name = f"ssh_key_{prefix}"
|
|
86
|
+
private_key_path = plato_dir / key_name
|
|
87
|
+
public_key_path = plato_dir / f"{key_name}.pub"
|
|
88
|
+
|
|
89
|
+
# Remove existing keys
|
|
90
|
+
if private_key_path.exists():
|
|
91
|
+
private_key_path.unlink()
|
|
92
|
+
if public_key_path.exists():
|
|
93
|
+
public_key_path.unlink()
|
|
94
|
+
|
|
95
|
+
# Generate key pair
|
|
96
|
+
subprocess.run(
|
|
97
|
+
[
|
|
98
|
+
"ssh-keygen",
|
|
99
|
+
"-t",
|
|
100
|
+
"ed25519",
|
|
101
|
+
"-f",
|
|
102
|
+
str(private_key_path),
|
|
103
|
+
"-N",
|
|
104
|
+
"",
|
|
105
|
+
"-q",
|
|
106
|
+
],
|
|
107
|
+
check=True,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Read public key
|
|
111
|
+
public_key = public_key_path.read_text().strip()
|
|
112
|
+
|
|
113
|
+
return public_key, str(private_key_path)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _generate_ssh_config(
|
|
117
|
+
job_id: str,
|
|
118
|
+
private_key_path: str,
|
|
119
|
+
working_dir: Path | None = None,
|
|
120
|
+
ssh_host: str = "sandbox",
|
|
121
|
+
) -> str:
|
|
122
|
+
"""Generate SSH config file for easy access via gateway.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
job_id: The job ID for routing.
|
|
126
|
+
private_key_path: Path to private key (absolute or relative).
|
|
127
|
+
working_dir: Working directory for .plato/.
|
|
128
|
+
ssh_host: Host alias in config.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Path to the generated SSH config file (relative to working_dir).
|
|
132
|
+
|
|
133
|
+
Note:
|
|
134
|
+
The IdentityFile in the config uses a path relative to working_dir.
|
|
135
|
+
SSH commands must be run from the workspace root for paths to resolve.
|
|
136
|
+
"""
|
|
137
|
+
gateway_host = os.getenv("PLATO_GATEWAY_HOST", "gateway.plato.so")
|
|
138
|
+
|
|
139
|
+
# Convert private key path to be relative to working_dir
|
|
140
|
+
# This ensures the config is portable if the workspace moves
|
|
141
|
+
base = working_dir or Path.cwd()
|
|
142
|
+
try:
|
|
143
|
+
relative_key_path = Path(private_key_path).relative_to(base)
|
|
144
|
+
except ValueError:
|
|
145
|
+
# If not relative to working_dir, keep as-is (shouldn't happen normally)
|
|
146
|
+
relative_key_path = Path(private_key_path)
|
|
147
|
+
|
|
148
|
+
# SNI format: {job_id}--{port}.{gateway_host} (matches v1 proxy.py)
|
|
149
|
+
ssh_port = 22
|
|
150
|
+
sni = f"{job_id}--{ssh_port}.{gateway_host}"
|
|
151
|
+
|
|
152
|
+
config_content = f"""# Plato Sandbox SSH Config
|
|
153
|
+
# Generated for job: {job_id}
|
|
154
|
+
# NOTE: Run SSH commands from workspace root for relative paths to resolve
|
|
155
|
+
|
|
156
|
+
Host {ssh_host}
|
|
157
|
+
HostName {job_id}
|
|
158
|
+
User root
|
|
159
|
+
IdentityFile {relative_key_path}
|
|
160
|
+
StrictHostKeyChecking no
|
|
161
|
+
UserKnownHostsFile /dev/null
|
|
162
|
+
LogLevel ERROR
|
|
163
|
+
ProxyCommand openssl s_client -quiet -connect {gateway_host}:443 -servername {sni} 2>/dev/null
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
plato_dir = _get_plato_dir(working_dir)
|
|
167
|
+
plato_dir.mkdir(mode=0o700, exist_ok=True)
|
|
168
|
+
|
|
169
|
+
config_path = plato_dir / "ssh_config"
|
|
170
|
+
config_path.write_text(config_content)
|
|
171
|
+
|
|
172
|
+
return str(config_path)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _run_ssh_command(
|
|
176
|
+
ssh_config_path: str,
|
|
177
|
+
ssh_host: str,
|
|
178
|
+
command: str,
|
|
179
|
+
cwd: Path | str | None = None,
|
|
180
|
+
) -> tuple[int, str, str]:
|
|
181
|
+
"""Run a command via SSH.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
ssh_config_path: Path to SSH config file (can be relative).
|
|
185
|
+
ssh_host: SSH host alias from config.
|
|
186
|
+
command: Command to execute on remote.
|
|
187
|
+
cwd: Working directory to run SSH from. Required when ssh_config_path
|
|
188
|
+
contains relative paths (e.g., for IdentityFile).
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
Tuple of (returncode, stdout, stderr).
|
|
192
|
+
"""
|
|
193
|
+
result = subprocess.run(
|
|
194
|
+
["ssh", "-F", ssh_config_path, ssh_host, command],
|
|
195
|
+
capture_output=True,
|
|
196
|
+
text=True,
|
|
197
|
+
cwd=cwd,
|
|
198
|
+
)
|
|
199
|
+
return result.returncode, result.stdout, result.stderr
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# =============================================================================
|
|
203
|
+
# HEARTBEAT UTILITIES
|
|
204
|
+
# =============================================================================
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _start_heartbeat_process(session_id: str, api_key: str) -> int | None:
|
|
208
|
+
"""Start a background process that sends heartbeats.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
PID of the background process, or None if failed.
|
|
212
|
+
"""
|
|
213
|
+
log_file = f"/tmp/plato_heartbeat_{session_id}.log"
|
|
214
|
+
base_url = os.getenv("PLATO_BASE_URL", "https://plato.so")
|
|
215
|
+
# Strip trailing /api if present to avoid double /api/api in URL
|
|
216
|
+
if base_url.endswith("/api"):
|
|
217
|
+
base_url = base_url[:-4]
|
|
218
|
+
base_url = base_url.rstrip("/")
|
|
219
|
+
|
|
220
|
+
heartbeat_script = f'''
|
|
221
|
+
import time
|
|
222
|
+
import os
|
|
223
|
+
import httpx
|
|
224
|
+
from datetime import datetime
|
|
225
|
+
|
|
226
|
+
session_id = "{session_id}"
|
|
227
|
+
api_key = "{api_key}"
|
|
228
|
+
base_url = "{base_url}"
|
|
229
|
+
log_file = "{log_file}"
|
|
230
|
+
|
|
231
|
+
def log(msg):
|
|
232
|
+
timestamp = datetime.now().isoformat()
|
|
233
|
+
with open(log_file, "a") as f:
|
|
234
|
+
f.write(f"[{{timestamp}}] {{msg}}\\n")
|
|
235
|
+
f.flush()
|
|
236
|
+
|
|
237
|
+
log(f"Heartbeat process started for session {{session_id}}")
|
|
238
|
+
|
|
239
|
+
heartbeat_count = 0
|
|
240
|
+
while True:
|
|
241
|
+
heartbeat_count += 1
|
|
242
|
+
try:
|
|
243
|
+
url = f"{{base_url}}/api/v2/sessions/{{session_id}}/heartbeat"
|
|
244
|
+
with httpx.Client(timeout=30) as client:
|
|
245
|
+
response = client.post(
|
|
246
|
+
url,
|
|
247
|
+
headers={{"X-API-Key": api_key}},
|
|
248
|
+
)
|
|
249
|
+
if response.status_code == 200:
|
|
250
|
+
result = response.json()
|
|
251
|
+
success = result.get("success", False)
|
|
252
|
+
log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, success={{success}}")
|
|
253
|
+
else:
|
|
254
|
+
log(f"Heartbeat #{{heartbeat_count}}: status={{response.status_code}}, url={{url}}, body={{response.text[:500]}}")
|
|
255
|
+
except Exception as e:
|
|
256
|
+
log(f"Heartbeat #{{heartbeat_count}} EXCEPTION: {{type(e).__name__}}: {{e}}")
|
|
257
|
+
time.sleep(30)
|
|
258
|
+
'''
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
process = subprocess.Popen(
|
|
262
|
+
["python3", "-c", heartbeat_script],
|
|
263
|
+
stdout=subprocess.DEVNULL,
|
|
264
|
+
stderr=subprocess.DEVNULL,
|
|
265
|
+
stdin=subprocess.DEVNULL,
|
|
266
|
+
start_new_session=True,
|
|
267
|
+
)
|
|
268
|
+
return process.pid
|
|
269
|
+
except Exception:
|
|
270
|
+
return None
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _stop_heartbeat_process(pid: int) -> bool:
|
|
274
|
+
"""Stop the heartbeat process.
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
True if stopped successfully.
|
|
278
|
+
"""
|
|
279
|
+
try:
|
|
280
|
+
os.kill(pid, signal.SIGTERM)
|
|
281
|
+
return True
|
|
282
|
+
except ProcessLookupError:
|
|
283
|
+
return True
|
|
284
|
+
except Exception:
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class SyncResult(BaseModel):
|
|
289
|
+
files_synced: int
|
|
290
|
+
bytes_synced: int
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# =============================================================================
|
|
294
|
+
# TUNNEL
|
|
295
|
+
# =============================================================================
|
|
296
|
+
|
|
297
|
+
DEFAULT_GATEWAY_HOST = "gateway.plato.so"
|
|
298
|
+
DEFAULT_GATEWAY_PORT = 443
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _get_gateway_config() -> tuple[str, int]:
|
|
302
|
+
"""Get gateway host and port from environment or defaults."""
|
|
303
|
+
host = os.environ.get("PLATO_GATEWAY_HOST", DEFAULT_GATEWAY_HOST)
|
|
304
|
+
port = int(os.environ.get("PLATO_GATEWAY_PORT", str(DEFAULT_GATEWAY_PORT)))
|
|
305
|
+
return host, port
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _create_tls_connection(
|
|
309
|
+
gateway_host: str,
|
|
310
|
+
gateway_port: int,
|
|
311
|
+
sni: str,
|
|
312
|
+
verify_ssl: bool = True,
|
|
313
|
+
):
|
|
314
|
+
"""Create a TLS connection to the gateway with the specified SNI."""
|
|
315
|
+
import socket
|
|
316
|
+
import ssl
|
|
317
|
+
|
|
318
|
+
context = ssl.create_default_context()
|
|
319
|
+
if not verify_ssl:
|
|
320
|
+
context.check_hostname = False
|
|
321
|
+
context.verify_mode = ssl.CERT_NONE
|
|
322
|
+
|
|
323
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
324
|
+
sock.settimeout(30)
|
|
325
|
+
ssl_sock = context.wrap_socket(sock, server_hostname=sni)
|
|
326
|
+
|
|
327
|
+
try:
|
|
328
|
+
ssl_sock.connect((gateway_host, gateway_port))
|
|
329
|
+
except Exception as e:
|
|
330
|
+
ssl_sock.close()
|
|
331
|
+
raise ConnectionError(f"Failed to connect to gateway: {e}") from e
|
|
332
|
+
|
|
333
|
+
return ssl_sock
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _forward_data(src, dst, name: str = "") -> None:
|
|
337
|
+
"""Forward data between two sockets until one closes."""
|
|
338
|
+
import socket
|
|
339
|
+
|
|
340
|
+
try:
|
|
341
|
+
while True:
|
|
342
|
+
data = src.recv(4096)
|
|
343
|
+
if not data:
|
|
344
|
+
break
|
|
345
|
+
dst.sendall(data)
|
|
346
|
+
except (ConnectionResetError, BrokenPipeError, OSError):
|
|
347
|
+
pass
|
|
348
|
+
finally:
|
|
349
|
+
try:
|
|
350
|
+
dst.shutdown(socket.SHUT_WR)
|
|
351
|
+
except OSError:
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class Tunnel:
|
|
356
|
+
"""A TCP tunnel to a remote port on a sandbox VM via the TLS gateway."""
|
|
357
|
+
|
|
358
|
+
def __init__(
|
|
359
|
+
self,
|
|
360
|
+
job_id: str,
|
|
361
|
+
remote_port: int,
|
|
362
|
+
local_port: int | None = None,
|
|
363
|
+
bind_address: str = "127.0.0.1",
|
|
364
|
+
verify_ssl: bool = True,
|
|
365
|
+
):
|
|
366
|
+
self.job_id = job_id
|
|
367
|
+
self.remote_port = remote_port
|
|
368
|
+
self.local_port = local_port or remote_port
|
|
369
|
+
self.bind_address = bind_address
|
|
370
|
+
self.verify_ssl = verify_ssl
|
|
371
|
+
|
|
372
|
+
self._server = None
|
|
373
|
+
self._thread = None
|
|
374
|
+
self._running = False
|
|
375
|
+
|
|
376
|
+
def start(self) -> int:
|
|
377
|
+
"""Start the tunnel. Returns the local port."""
|
|
378
|
+
import socket
|
|
379
|
+
import threading
|
|
380
|
+
|
|
381
|
+
gateway_host, gateway_port = _get_gateway_config()
|
|
382
|
+
sni = f"{self.job_id}--{self.remote_port}.{gateway_host}"
|
|
383
|
+
|
|
384
|
+
# Create local listener
|
|
385
|
+
self._server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
386
|
+
self._server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
self._server.bind((self.bind_address, self.local_port))
|
|
390
|
+
self._server.listen(5)
|
|
391
|
+
except OSError as e:
|
|
392
|
+
raise ValueError(f"Could not bind to {self.bind_address}:{self.local_port}: {e}") from e
|
|
393
|
+
|
|
394
|
+
self._running = True
|
|
395
|
+
|
|
396
|
+
def handle_client(client_sock, client_addr):
|
|
397
|
+
try:
|
|
398
|
+
gateway_sock = _create_tls_connection(gateway_host, gateway_port, sni, verify_ssl=self.verify_ssl)
|
|
399
|
+
t1 = threading.Thread(
|
|
400
|
+
target=_forward_data,
|
|
401
|
+
args=(client_sock, gateway_sock, "client->gateway"),
|
|
402
|
+
daemon=True,
|
|
403
|
+
)
|
|
404
|
+
t2 = threading.Thread(
|
|
405
|
+
target=_forward_data,
|
|
406
|
+
args=(gateway_sock, client_sock, "gateway->client"),
|
|
407
|
+
daemon=True,
|
|
408
|
+
)
|
|
409
|
+
t1.start()
|
|
410
|
+
t2.start()
|
|
411
|
+
t1.join()
|
|
412
|
+
t2.join()
|
|
413
|
+
except Exception:
|
|
414
|
+
pass
|
|
415
|
+
finally:
|
|
416
|
+
try:
|
|
417
|
+
client_sock.close()
|
|
418
|
+
except OSError:
|
|
419
|
+
pass
|
|
420
|
+
|
|
421
|
+
def accept_loop():
|
|
422
|
+
server = self._server
|
|
423
|
+
assert server is not None, "Server must be initialized before accept_loop"
|
|
424
|
+
while self._running:
|
|
425
|
+
try:
|
|
426
|
+
server.settimeout(1.0)
|
|
427
|
+
client_sock, client_addr = server.accept()
|
|
428
|
+
threading.Thread(
|
|
429
|
+
target=handle_client,
|
|
430
|
+
args=(client_sock, client_addr),
|
|
431
|
+
daemon=True,
|
|
432
|
+
).start()
|
|
433
|
+
except TimeoutError:
|
|
434
|
+
continue
|
|
435
|
+
except OSError:
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
self._thread = threading.Thread(target=accept_loop, daemon=True)
|
|
439
|
+
self._thread.start()
|
|
440
|
+
|
|
441
|
+
return self.local_port
|
|
442
|
+
|
|
443
|
+
def stop(self) -> None:
|
|
444
|
+
"""Stop the tunnel."""
|
|
445
|
+
self._running = False
|
|
446
|
+
if self._server:
|
|
447
|
+
try:
|
|
448
|
+
self._server.close()
|
|
449
|
+
except OSError:
|
|
450
|
+
pass
|
|
451
|
+
if self._thread:
|
|
452
|
+
self._thread.join(timeout=2.0)
|
|
453
|
+
|
|
454
|
+
def __enter__(self):
|
|
455
|
+
self.start()
|
|
456
|
+
return self
|
|
457
|
+
|
|
458
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
459
|
+
self.stop()
|
|
460
|
+
return False
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
# =============================================================================
|
|
464
|
+
# SANDBOX CLIENT
|
|
465
|
+
# =============================================================================
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class SandboxClient:
|
|
469
|
+
"""Synchronous client for sandbox development workflows.
|
|
470
|
+
|
|
471
|
+
Supports two modes:
|
|
472
|
+
1. Stateless (working_dir=None): Pure operations, no file I/O
|
|
473
|
+
2. Stateful (working_dir=Path): Persists state to .plato/state.yaml
|
|
474
|
+
|
|
475
|
+
Usage (stateless):
|
|
476
|
+
client = SandboxClient(api_key="...")
|
|
477
|
+
result = client.start(mode="blank", service="myservice")
|
|
478
|
+
client.stop(result.session_id)
|
|
479
|
+
client.close()
|
|
480
|
+
|
|
481
|
+
Usage (stateful - recommended for CLI/scripts):
|
|
482
|
+
client = SandboxClient(api_key="...", working_dir=Path("."))
|
|
483
|
+
client.start(mode="blank", service="myservice") # Saves state
|
|
484
|
+
# Later...
|
|
485
|
+
client = SandboxClient(api_key="...", working_dir=Path(".")) # Loads state
|
|
486
|
+
client.stop() # Uses saved session_id
|
|
487
|
+
"""
|
|
488
|
+
|
|
489
|
+
# State file paths
|
|
490
|
+
PLATO_DIR = ".plato"
|
|
491
|
+
|
|
492
|
+
def __init__(
|
|
493
|
+
self,
|
|
494
|
+
working_dir: Path,
|
|
495
|
+
api_key: str | None = None,
|
|
496
|
+
base_url: str | None = None,
|
|
497
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
498
|
+
console: Console = Console(),
|
|
499
|
+
):
|
|
500
|
+
self.api_key = api_key or os.environ.get("PLATO_API_KEY")
|
|
501
|
+
if not self.api_key:
|
|
502
|
+
raise ValueError("API key required. Set PLATO_API_KEY or pass api_key=")
|
|
503
|
+
|
|
504
|
+
url = base_url or os.environ.get("PLATO_BASE_URL", DEFAULT_BASE_URL)
|
|
505
|
+
if url.endswith("/api"):
|
|
506
|
+
url = url[:-4]
|
|
507
|
+
self.base_url = url.rstrip("/")
|
|
508
|
+
self.console = console
|
|
509
|
+
self.working_dir = working_dir
|
|
510
|
+
|
|
511
|
+
self._http = httpx.Client(
|
|
512
|
+
base_url=self.base_url,
|
|
513
|
+
timeout=httpx.Timeout(timeout),
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
def _get_plato_dir(self) -> Path:
|
|
517
|
+
return Path(self.working_dir) / self.PLATO_DIR
|
|
518
|
+
|
|
519
|
+
def close(self) -> None:
|
|
520
|
+
"""Close the underlying HTTP client."""
|
|
521
|
+
self._http.close()
|
|
522
|
+
|
|
523
|
+
# -------------------------------------------------------------------------
|
|
524
|
+
# START
|
|
525
|
+
# -------------------------------------------------------------------------
|
|
526
|
+
|
|
527
|
+
def start(
|
|
528
|
+
self,
|
|
529
|
+
simulator_name: str | None = None,
|
|
530
|
+
mode: str = "blank",
|
|
531
|
+
# artifact or simulator mode
|
|
532
|
+
artifact_id: str | None = None,
|
|
533
|
+
dataset: str = "base",
|
|
534
|
+
tag: str = "latest",
|
|
535
|
+
# blankl or plato-config mode
|
|
536
|
+
cpus: int = 1,
|
|
537
|
+
memory: int = 2048,
|
|
538
|
+
disk: int = 10240,
|
|
539
|
+
app_port: int | None = None,
|
|
540
|
+
messaging_port: int | None = None,
|
|
541
|
+
# common
|
|
542
|
+
connect_network: bool = True,
|
|
543
|
+
timeout: int = 1800,
|
|
544
|
+
) -> SandboxState:
|
|
545
|
+
"""Start a sandbox environment.
|
|
546
|
+
|
|
547
|
+
Uses Plato SDK v2 internally for session creation.
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
mode: Start mode - "blank", "simulator", or "artifact".
|
|
551
|
+
simulator_name: Simulator name.
|
|
552
|
+
artifact_id: Artifact UUID.
|
|
553
|
+
dataset: Dataset name.
|
|
554
|
+
tag: Artifact tag.
|
|
555
|
+
cpus: Number of CPUs.
|
|
556
|
+
memory: Memory in MB.
|
|
557
|
+
disk: Disk in MB.
|
|
558
|
+
app_port: App port.
|
|
559
|
+
messaging_port: Messaging port.
|
|
560
|
+
connect_network: Whether to connect WireGuard network.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
SandboxState with sandbox info.
|
|
564
|
+
"""
|
|
565
|
+
|
|
566
|
+
assert self.api_key is not None
|
|
567
|
+
|
|
568
|
+
# Build environment config using Env factory
|
|
569
|
+
env_config: EnvFromSimulator | EnvFromArtifact | EnvFromResource
|
|
570
|
+
|
|
571
|
+
if mode == "artifact" and artifact_id:
|
|
572
|
+
self.console.print(f"Starting from artifact: {artifact_id}")
|
|
573
|
+
env_config = Env.artifact(artifact_id)
|
|
574
|
+
elif mode == "simulator" and simulator_name:
|
|
575
|
+
self.console.print(f"Starting from simulator: {simulator_name}")
|
|
576
|
+
env_config = Env.simulator(simulator_name, tag=tag, dataset=dataset)
|
|
577
|
+
elif mode == "blank" and simulator_name:
|
|
578
|
+
self.console.print("Starting from blank")
|
|
579
|
+
sim_config = SimConfigCompute(
|
|
580
|
+
cpus=cpus, memory=memory, disk=disk, app_port=app_port, plato_messaging_port=messaging_port
|
|
581
|
+
)
|
|
582
|
+
env_config = Env.resource(simulator_name, sim_config)
|
|
583
|
+
elif mode == "config":
|
|
584
|
+
self.console.print("Starting from config")
|
|
585
|
+
# read plato-config.yml
|
|
586
|
+
plato_config_path = self.working_dir / "plato-config.yml"
|
|
587
|
+
with open(plato_config_path, "rb") as f:
|
|
588
|
+
plato_config = yaml.safe_load(f)
|
|
589
|
+
self.console.print(f"plato-config: {plato_config}")
|
|
590
|
+
plato_config_model = PlatoConfig.model_validate(plato_config)
|
|
591
|
+
dataset_config = plato_config_model.datasets[dataset]
|
|
592
|
+
simulator_name = plato_config_model.service
|
|
593
|
+
if not simulator_name:
|
|
594
|
+
raise ValueError("Service name is required in plato-config.yml")
|
|
595
|
+
if not dataset_config.compute:
|
|
596
|
+
raise ValueError(f"Compute configuration is required for dataset '{dataset}'")
|
|
597
|
+
self.console.print(f"simulator_name: {simulator_name}")
|
|
598
|
+
sim_config = SimConfigCompute(
|
|
599
|
+
cpus=dataset_config.compute.cpus,
|
|
600
|
+
memory=dataset_config.compute.memory,
|
|
601
|
+
disk=dataset_config.compute.disk,
|
|
602
|
+
app_port=dataset_config.compute.app_port,
|
|
603
|
+
plato_messaging_port=dataset_config.compute.plato_messaging_port,
|
|
604
|
+
)
|
|
605
|
+
env_config = Env.resource(simulator_name, sim_config)
|
|
606
|
+
self.console.print(f"env_config: {env_config}")
|
|
607
|
+
else:
|
|
608
|
+
raise ValueError(f"Invalid mode '{mode}' or missing required parameter")
|
|
609
|
+
|
|
610
|
+
# Use Plato SDK to create session (handles create, wait, network)
|
|
611
|
+
self.console.print(f"Creating session and waiting for VM to become ready (timeout={timeout}s)...")
|
|
612
|
+
plato = Plato(api_key=self.api_key, http_client=self._http)
|
|
613
|
+
session = plato.sessions.create(
|
|
614
|
+
envs=[env_config],
|
|
615
|
+
connect_network=connect_network,
|
|
616
|
+
timeout=timeout,
|
|
617
|
+
)
|
|
618
|
+
self.console.print(f"session: {session}")
|
|
619
|
+
session_id = session.session_id
|
|
620
|
+
job_id = session.envs[0].job_id if session.envs else None
|
|
621
|
+
if not job_id:
|
|
622
|
+
raise ValueError("No job ID found")
|
|
623
|
+
self.console.print(f"job_id: {job_id}")
|
|
624
|
+
|
|
625
|
+
# Get public URL with router target formatting (logic inlined)
|
|
626
|
+
public_url = None
|
|
627
|
+
try:
|
|
628
|
+
url_response = sessions_get_public_url.sync(
|
|
629
|
+
client=self._http,
|
|
630
|
+
session_id=session_id,
|
|
631
|
+
x_api_key=self.api_key,
|
|
632
|
+
)
|
|
633
|
+
if url_response and url_response.results:
|
|
634
|
+
for result in url_response.results.values():
|
|
635
|
+
url = result.url if hasattr(result, "url") else str(result)
|
|
636
|
+
if not url:
|
|
637
|
+
raise ValueError(f"No public URL found in result dict for job ID {job_id}")
|
|
638
|
+
if "_plato_router_target=" not in url:
|
|
639
|
+
target_param = f"_plato_router_target={simulator_name}.web.plato.so"
|
|
640
|
+
if "?" in url:
|
|
641
|
+
url = f"{url}&{target_param}"
|
|
642
|
+
else:
|
|
643
|
+
url = f"{url}?{target_param}"
|
|
644
|
+
public_url = url
|
|
645
|
+
except Exception as e:
|
|
646
|
+
raise ValueError(f"Error getting public URL: {e}") from e
|
|
647
|
+
|
|
648
|
+
# Setup SSH
|
|
649
|
+
ssh_config_path = None
|
|
650
|
+
try:
|
|
651
|
+
public_key, private_key_path = _generate_ssh_key_pair(session_id[:8], Path(self.working_dir))
|
|
652
|
+
|
|
653
|
+
add_key_request = AddSSHKeyRequest(public_key=public_key, username="root")
|
|
654
|
+
add_response = sessions_add_ssh_key.sync(
|
|
655
|
+
client=self._http,
|
|
656
|
+
session_id=session_id,
|
|
657
|
+
body=add_key_request,
|
|
658
|
+
x_api_key=self.api_key,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
if add_response.success:
|
|
662
|
+
ssh_config_path = _generate_ssh_config(job_id, private_key_path, Path(self.working_dir))
|
|
663
|
+
except Exception as e:
|
|
664
|
+
logger.warning(f"SSH setup failed: {e}")
|
|
665
|
+
|
|
666
|
+
# Start heartbeat
|
|
667
|
+
heartbeat_pid = None
|
|
668
|
+
heartbeat_pid = _start_heartbeat_process(session_id, self.api_key)
|
|
669
|
+
|
|
670
|
+
# Parse out the simulator_name from service field in jobs in session details.
|
|
671
|
+
session_details = get_session_details.sync(
|
|
672
|
+
client=self._http,
|
|
673
|
+
session_id=session_id,
|
|
674
|
+
x_api_key=self.api_key,
|
|
675
|
+
)
|
|
676
|
+
if hasattr(session_details, "jobs"):
|
|
677
|
+
# session_details["jobs"] if dict, or session_details.jobs if object
|
|
678
|
+
jobs = (
|
|
679
|
+
session_details["jobs"] if isinstance(session_details, dict) else getattr(session_details, "jobs", None)
|
|
680
|
+
)
|
|
681
|
+
if jobs:
|
|
682
|
+
for j in jobs:
|
|
683
|
+
service = j.get("service") if isinstance(j, dict) else getattr(j, "service", None)
|
|
684
|
+
if service:
|
|
685
|
+
simulator_name = service
|
|
686
|
+
break
|
|
687
|
+
|
|
688
|
+
# Convert absolute paths to relative for state storage
|
|
689
|
+
def _to_relative(abs_path: str | None) -> str | None:
|
|
690
|
+
if not abs_path or not self.working_dir:
|
|
691
|
+
return abs_path
|
|
692
|
+
try:
|
|
693
|
+
return str(Path(abs_path).relative_to(self.working_dir))
|
|
694
|
+
except ValueError:
|
|
695
|
+
return abs_path # Keep absolute if not relative to working_dir
|
|
696
|
+
|
|
697
|
+
# Update internal state
|
|
698
|
+
rel_ssh_config = _to_relative(ssh_config_path)
|
|
699
|
+
ssh_host = "sandbox" if ssh_config_path else None
|
|
700
|
+
sandbox_state = SandboxState(
|
|
701
|
+
session_id=session_id,
|
|
702
|
+
job_id=job_id,
|
|
703
|
+
public_url=public_url,
|
|
704
|
+
mode=mode,
|
|
705
|
+
ssh_config_path=rel_ssh_config,
|
|
706
|
+
ssh_host=ssh_host,
|
|
707
|
+
ssh_command=f"ssh -F {rel_ssh_config} {ssh_host}" if rel_ssh_config else None,
|
|
708
|
+
heartbeat_pid=heartbeat_pid,
|
|
709
|
+
simulator_name=simulator_name,
|
|
710
|
+
dataset=dataset,
|
|
711
|
+
)
|
|
712
|
+
if mode == "artifact":
|
|
713
|
+
sandbox_state.artifact_id = artifact_id
|
|
714
|
+
elif mode == "simulator":
|
|
715
|
+
sandbox_state.tag = tag
|
|
716
|
+
elif mode == "blank":
|
|
717
|
+
sandbox_state.cpus = cpus
|
|
718
|
+
sandbox_state.memory = memory
|
|
719
|
+
sandbox_state.disk = disk
|
|
720
|
+
sandbox_state.app_port = app_port
|
|
721
|
+
sandbox_state.messaging_port = messaging_port
|
|
722
|
+
|
|
723
|
+
# Save state to working_dir/.plato/state.json
|
|
724
|
+
with open(self.working_dir / self.PLATO_DIR / "state.json", "w") as f:
|
|
725
|
+
json.dump(sandbox_state.model_dump(), f)
|
|
726
|
+
|
|
727
|
+
return sandbox_state
|
|
728
|
+
|
|
729
|
+
# CHECKED
|
|
730
|
+
def stop(
|
|
731
|
+
self,
|
|
732
|
+
session_id: str,
|
|
733
|
+
heartbeat_pid: int | None = None,
|
|
734
|
+
) -> CloseSessionResponse:
|
|
735
|
+
if heartbeat_pid:
|
|
736
|
+
_stop_heartbeat_process(heartbeat_pid)
|
|
737
|
+
|
|
738
|
+
return sessions_close.sync(
|
|
739
|
+
client=self._http,
|
|
740
|
+
session_id=session_id,
|
|
741
|
+
x_api_key=self.api_key,
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
# CHECKED
|
|
745
|
+
def status(self, session_id: str) -> dict:
|
|
746
|
+
return get_session_details.sync(
|
|
747
|
+
client=self._http,
|
|
748
|
+
session_id=session_id,
|
|
749
|
+
x_api_key=self.api_key,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
# CHECKED
|
|
753
|
+
def snapshot(
|
|
754
|
+
self,
|
|
755
|
+
session_id: str,
|
|
756
|
+
mode: str,
|
|
757
|
+
dataset: str,
|
|
758
|
+
) -> AppApiV2SchemasSessionCreateSnapshotResponse:
|
|
759
|
+
checkpoint_request = CreateCheckpointRequest()
|
|
760
|
+
|
|
761
|
+
if mode == "config":
|
|
762
|
+
# read plato-config.yml
|
|
763
|
+
plato_config_path = self.working_dir / "plato-config.yml"
|
|
764
|
+
with open(plato_config_path, "rb") as f:
|
|
765
|
+
plato_config = yaml.safe_load(f)
|
|
766
|
+
checkpoint_request.plato_config = plato_config
|
|
767
|
+
|
|
768
|
+
# convert plato-config to pydantic model
|
|
769
|
+
plato_config_model = PlatoConfig.model_validate(plato_config)
|
|
770
|
+
dataset_compute = plato_config_model.datasets[dataset].compute
|
|
771
|
+
if not dataset_compute:
|
|
772
|
+
raise ValueError(f"Compute configuration is required for dataset '{dataset}'")
|
|
773
|
+
checkpoint_request.internal_app_port = dataset_compute.app_port
|
|
774
|
+
checkpoint_request.messaging_port = dataset_compute.plato_messaging_port
|
|
775
|
+
# we dont set target
|
|
776
|
+
|
|
777
|
+
# read flows.yml
|
|
778
|
+
flows_path = self.working_dir / "flows.yml"
|
|
779
|
+
with open(flows_path, "rb") as f:
|
|
780
|
+
flows = yaml.safe_load(f)
|
|
781
|
+
checkpoint_request.flows = flows
|
|
782
|
+
|
|
783
|
+
return sessions_snapshot.sync(
|
|
784
|
+
client=self._http,
|
|
785
|
+
session_id=session_id,
|
|
786
|
+
body=checkpoint_request,
|
|
787
|
+
x_api_key=self.api_key,
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
# CHECKED
|
|
791
|
+
def connect_network(self, session_id: str) -> dict:
|
|
792
|
+
return sessions_connect_network.sync(
|
|
793
|
+
client=self._http,
|
|
794
|
+
session_id=session_id,
|
|
795
|
+
x_api_key=self.api_key,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
# CHECKED
|
|
799
|
+
def start_worker(
|
|
800
|
+
self,
|
|
801
|
+
job_id: str,
|
|
802
|
+
simulator: str,
|
|
803
|
+
dataset: str,
|
|
804
|
+
wait_timeout: int = 300, # 5 minutes
|
|
805
|
+
) -> None:
|
|
806
|
+
with open(self.working_dir / "plato-config.yml", "rb") as f:
|
|
807
|
+
plato_config = yaml.safe_load(f)
|
|
808
|
+
plato_config_model = PlatoConfig.model_validate(plato_config)
|
|
809
|
+
dataset_config = plato_config_model.datasets[dataset]
|
|
810
|
+
|
|
811
|
+
# Convert AppApiV2SchemasArtifactSimConfigDataset to AppSchemasBuildModelsSimConfigDataset
|
|
812
|
+
# They have compatible fields but different nested types
|
|
813
|
+
dataset_config_dict = dataset_config.model_dump(exclude_none=True)
|
|
814
|
+
|
|
815
|
+
_ = start_worker.sync(
|
|
816
|
+
client=self._http,
|
|
817
|
+
public_id=job_id,
|
|
818
|
+
body=VMManagementRequest(
|
|
819
|
+
service=simulator,
|
|
820
|
+
dataset=dataset,
|
|
821
|
+
plato_dataset_config=AppSchemasBuildModelsSimConfigDataset.model_validate(dataset_config_dict),
|
|
822
|
+
),
|
|
823
|
+
x_api_key=self.api_key,
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
if wait_timeout > 0:
|
|
827
|
+
start_time = time.time()
|
|
828
|
+
poll_interval = 10
|
|
829
|
+
|
|
830
|
+
while time.time() - start_time < wait_timeout:
|
|
831
|
+
try:
|
|
832
|
+
state_response = jobs_state.sync(
|
|
833
|
+
client=self._http,
|
|
834
|
+
job_id=job_id,
|
|
835
|
+
x_api_key=self.api_key,
|
|
836
|
+
)
|
|
837
|
+
if state_response:
|
|
838
|
+
state_dict = (
|
|
839
|
+
state_response.model_dump() if hasattr(state_response, "model_dump") else state_response
|
|
840
|
+
)
|
|
841
|
+
if isinstance(state_dict, dict) and "error" not in state_dict.get("state", {}):
|
|
842
|
+
return
|
|
843
|
+
except Exception:
|
|
844
|
+
pass
|
|
845
|
+
|
|
846
|
+
time.sleep(poll_interval)
|
|
847
|
+
|
|
848
|
+
# CHECKED
|
|
849
|
+
def sync(
|
|
850
|
+
self,
|
|
851
|
+
session_id: str,
|
|
852
|
+
simulator: str,
|
|
853
|
+
timeout: int = 120,
|
|
854
|
+
) -> SyncResult:
|
|
855
|
+
"""Sync local files to sandbox using rsync over SSH.
|
|
856
|
+
|
|
857
|
+
Uses the SSH config from .plato/ssh_config for fast, reliable file transfer.
|
|
858
|
+
"""
|
|
859
|
+
local_path = self.working_dir
|
|
860
|
+
remote_path = f"/home/plato/worktree/{simulator}"
|
|
861
|
+
|
|
862
|
+
# Load SSH config from state
|
|
863
|
+
state_file = self.working_dir / ".plato" / "state.json"
|
|
864
|
+
if not state_file.exists():
|
|
865
|
+
raise ValueError("No state file found - run 'plato sandbox start' first")
|
|
866
|
+
|
|
867
|
+
with open(state_file) as f:
|
|
868
|
+
state = json.load(f)
|
|
869
|
+
|
|
870
|
+
ssh_config_path = state.get("ssh_config_path")
|
|
871
|
+
ssh_host = state.get("ssh_host", "sandbox")
|
|
872
|
+
|
|
873
|
+
if not ssh_config_path:
|
|
874
|
+
raise ValueError("No SSH config in state - run 'plato sandbox start' first")
|
|
875
|
+
|
|
876
|
+
exclude_patterns = [
|
|
877
|
+
"__pycache__",
|
|
878
|
+
"*.pyc",
|
|
879
|
+
".git",
|
|
880
|
+
".venv",
|
|
881
|
+
"venv",
|
|
882
|
+
"node_modules",
|
|
883
|
+
".sandbox.yaml",
|
|
884
|
+
"*.egg-info",
|
|
885
|
+
".pytest_cache",
|
|
886
|
+
".mypy_cache",
|
|
887
|
+
".DS_Store",
|
|
888
|
+
"*.swp",
|
|
889
|
+
"*.swo",
|
|
890
|
+
".plato",
|
|
891
|
+
]
|
|
892
|
+
|
|
893
|
+
# Build rsync command
|
|
894
|
+
rsync_cmd = [
|
|
895
|
+
"rsync",
|
|
896
|
+
"-avz",
|
|
897
|
+
"--delete",
|
|
898
|
+
"-e",
|
|
899
|
+
f"ssh -F {ssh_config_path}",
|
|
900
|
+
]
|
|
901
|
+
|
|
902
|
+
# Add excludes
|
|
903
|
+
for pattern in exclude_patterns:
|
|
904
|
+
rsync_cmd.extend(["--exclude", pattern])
|
|
905
|
+
|
|
906
|
+
# Source and destination
|
|
907
|
+
rsync_cmd.append(f"{local_path}/")
|
|
908
|
+
rsync_cmd.append(f"{ssh_host}:{remote_path}/")
|
|
909
|
+
|
|
910
|
+
self.console.print(f"[dim]rsync -> {ssh_host}:{remote_path}/[/dim]")
|
|
911
|
+
|
|
912
|
+
# Ensure rsync is installed on the VM and create remote directory
|
|
913
|
+
setup_result = subprocess.run(
|
|
914
|
+
[
|
|
915
|
+
"ssh",
|
|
916
|
+
"-F",
|
|
917
|
+
ssh_config_path,
|
|
918
|
+
ssh_host,
|
|
919
|
+
f"which rsync >/dev/null 2>&1 || (apt-get update -qq && apt-get install -y -qq rsync) && mkdir -p {remote_path}",
|
|
920
|
+
],
|
|
921
|
+
capture_output=True,
|
|
922
|
+
text=True,
|
|
923
|
+
cwd=self.working_dir,
|
|
924
|
+
)
|
|
925
|
+
if setup_result.returncode != 0:
|
|
926
|
+
raise ValueError(f"Failed to setup remote: {setup_result.stderr}")
|
|
927
|
+
|
|
928
|
+
# Run rsync
|
|
929
|
+
result = subprocess.run(
|
|
930
|
+
rsync_cmd,
|
|
931
|
+
capture_output=True,
|
|
932
|
+
text=True,
|
|
933
|
+
cwd=self.working_dir,
|
|
934
|
+
timeout=timeout,
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
if result.returncode != 0:
|
|
938
|
+
raise ValueError(f"rsync failed: {result.stderr}")
|
|
939
|
+
|
|
940
|
+
# Count synced files from rsync output
|
|
941
|
+
lines = result.stdout.strip().split("\n") if result.stdout else []
|
|
942
|
+
file_count = len(
|
|
943
|
+
[
|
|
944
|
+
line
|
|
945
|
+
for line in lines
|
|
946
|
+
if line
|
|
947
|
+
and not line.startswith("sending")
|
|
948
|
+
and not line.startswith("sent")
|
|
949
|
+
and not line.startswith("total")
|
|
950
|
+
]
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
# Get bytes from rsync output (e.g., "sent 1,234 bytes")
|
|
954
|
+
bytes_synced = 0
|
|
955
|
+
for line in lines:
|
|
956
|
+
if "sent" in line and "bytes" in line:
|
|
957
|
+
import re
|
|
958
|
+
|
|
959
|
+
match = re.search(r"sent ([\d,]+) bytes", line)
|
|
960
|
+
if match:
|
|
961
|
+
bytes_synced = int(match.group(1).replace(",", ""))
|
|
962
|
+
break
|
|
963
|
+
|
|
964
|
+
return SyncResult(
|
|
965
|
+
files_synced=file_count,
|
|
966
|
+
bytes_synced=bytes_synced,
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
def tunnel(
|
|
970
|
+
self,
|
|
971
|
+
job_id: str,
|
|
972
|
+
remote_port: int,
|
|
973
|
+
local_port: int | None = None,
|
|
974
|
+
bind_address: str = "127.0.0.1",
|
|
975
|
+
) -> Tunnel:
|
|
976
|
+
return Tunnel(
|
|
977
|
+
job_id=job_id,
|
|
978
|
+
remote_port=remote_port,
|
|
979
|
+
local_port=local_port,
|
|
980
|
+
bind_address=bind_address,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
def run_audit_ui(
|
|
984
|
+
self,
|
|
985
|
+
job_id: str | None = None,
|
|
986
|
+
dataset: str = "base",
|
|
987
|
+
no_tunnel: bool = False,
|
|
988
|
+
) -> None:
|
|
989
|
+
import shutil
|
|
990
|
+
|
|
991
|
+
if not shutil.which("streamlit"):
|
|
992
|
+
raise ValueError("streamlit not installed. Run: pip install streamlit psycopg2-binary pymysql")
|
|
993
|
+
|
|
994
|
+
ui_file = Path(__file__).resolve().parent.parent.parent / "cli" / "audit_ui.py"
|
|
995
|
+
if not ui_file.exists():
|
|
996
|
+
raise ValueError(f"UI file not found: {ui_file}")
|
|
997
|
+
|
|
998
|
+
# Get DB listener from plato-config.yml
|
|
999
|
+
db_listener: DatabaseMutationListenerConfig | None = None
|
|
1000
|
+
for config_path in [self.working_dir / "plato-config.yml", self.working_dir / "plato-config.yaml"]:
|
|
1001
|
+
if config_path.exists():
|
|
1002
|
+
with open(config_path) as f:
|
|
1003
|
+
plato_config = PlatoConfig.model_validate(yaml.safe_load(f))
|
|
1004
|
+
dataset_config = plato_config.datasets.get(dataset)
|
|
1005
|
+
if dataset_config and dataset_config.listeners:
|
|
1006
|
+
for listener in dataset_config.listeners.values():
|
|
1007
|
+
if isinstance(listener, DatabaseMutationListenerConfig):
|
|
1008
|
+
db_listener = listener
|
|
1009
|
+
break
|
|
1010
|
+
break
|
|
1011
|
+
tunnel = None
|
|
1012
|
+
|
|
1013
|
+
if db_listener and job_id and not no_tunnel:
|
|
1014
|
+
self.console.print(f"Starting tunnel to {db_listener.db_type} on port {db_listener.db_port}...")
|
|
1015
|
+
tunnel = self.tunnel(job_id, db_listener.db_port)
|
|
1016
|
+
tunnel.start()
|
|
1017
|
+
time.sleep(1) # Let tunnel stabilize
|
|
1018
|
+
self.console.print(
|
|
1019
|
+
f"[green]Tunnel open:[/green] localhost:{db_listener.db_port} -> VM:{db_listener.db_port}"
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
# Pass db config via environment variables
|
|
1023
|
+
env = os.environ.copy()
|
|
1024
|
+
if db_listener:
|
|
1025
|
+
env["PLATO_DB_HOST"] = "127.0.0.1"
|
|
1026
|
+
env["PLATO_DB_PORT"] = str(db_listener.db_port)
|
|
1027
|
+
env["PLATO_DB_USER"] = db_listener.db_user
|
|
1028
|
+
env["PLATO_DB_PASSWORD"] = db_listener.db_password or ""
|
|
1029
|
+
env["PLATO_DB_NAME"] = db_listener.db_database
|
|
1030
|
+
env["PLATO_DB_TYPE"] = str(db_listener.db_type)
|
|
1031
|
+
self.console.print(
|
|
1032
|
+
f"[dim]DB config: {db_listener.db_user}@127.0.0.1:{db_listener.db_port}/{db_listener.db_database}[/dim]"
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
try:
|
|
1036
|
+
subprocess.run(["streamlit", "run", str(ui_file)], env=env)
|
|
1037
|
+
finally:
|
|
1038
|
+
if tunnel:
|
|
1039
|
+
tunnel.stop()
|
|
1040
|
+
self.console.print("[yellow]Tunnel closed[/yellow]")
|
|
1041
|
+
|
|
1042
|
+
def run_flow(
|
|
1043
|
+
self,
|
|
1044
|
+
url: str,
|
|
1045
|
+
flow_name: str,
|
|
1046
|
+
dataset: str,
|
|
1047
|
+
use_api: bool = False,
|
|
1048
|
+
job_id: str | None = None,
|
|
1049
|
+
) -> None:
|
|
1050
|
+
flow_obj: Flow | None = None
|
|
1051
|
+
screenshots_dir = self.working_dir / "screenshots"
|
|
1052
|
+
|
|
1053
|
+
if use_api:
|
|
1054
|
+
# Fetch from API
|
|
1055
|
+
if not job_id:
|
|
1056
|
+
raise ValueError("job_id required when use_api=True")
|
|
1057
|
+
|
|
1058
|
+
self.console.print("[cyan]Flow source: API[/cyan]")
|
|
1059
|
+
flows_response = jobs_get_flows.sync(
|
|
1060
|
+
client=self._http,
|
|
1061
|
+
job_id=job_id,
|
|
1062
|
+
x_api_key=self.api_key,
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
if flows_response:
|
|
1066
|
+
for flow_data in flows_response:
|
|
1067
|
+
if isinstance(flow_data, dict):
|
|
1068
|
+
if flow_data.get("name") == flow_name:
|
|
1069
|
+
flow_obj = Flow.model_validate(flow_data)
|
|
1070
|
+
break
|
|
1071
|
+
elif hasattr(flow_data, "name") and flow_data.name == flow_name:
|
|
1072
|
+
flow_obj = (
|
|
1073
|
+
flow_data if isinstance(flow_data, Flow) else Flow.model_validate(flow_data.model_dump())
|
|
1074
|
+
)
|
|
1075
|
+
break
|
|
1076
|
+
|
|
1077
|
+
if not flow_obj:
|
|
1078
|
+
available = [
|
|
1079
|
+
f.get("name") if isinstance(f, dict) else getattr(f, "name", "?") for f in (flows_response or [])
|
|
1080
|
+
]
|
|
1081
|
+
raise ValueError(f"Flow '{flow_name}' not found in API. Available: {available}")
|
|
1082
|
+
else:
|
|
1083
|
+
# Use local flows
|
|
1084
|
+
config_paths = [
|
|
1085
|
+
self.working_dir / "plato-config.yml",
|
|
1086
|
+
self.working_dir / "plato-config.yaml",
|
|
1087
|
+
]
|
|
1088
|
+
|
|
1089
|
+
for config_path in config_paths:
|
|
1090
|
+
if config_path.exists():
|
|
1091
|
+
with open(config_path) as f:
|
|
1092
|
+
plato_config = PlatoConfig.model_validate(yaml.safe_load(f))
|
|
1093
|
+
|
|
1094
|
+
dataset_config = plato_config.datasets.get(dataset)
|
|
1095
|
+
if dataset_config and dataset_config.metadata:
|
|
1096
|
+
flows_path = dataset_config.metadata.flows_path
|
|
1097
|
+
|
|
1098
|
+
if flows_path:
|
|
1099
|
+
flow_file = (
|
|
1100
|
+
config_path.parent / flows_path
|
|
1101
|
+
if not Path(flows_path).is_absolute()
|
|
1102
|
+
else Path(flows_path)
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
if flow_file.exists():
|
|
1106
|
+
with open(flow_file) as f:
|
|
1107
|
+
flow_dict = yaml.safe_load(f)
|
|
1108
|
+
flow_obj = next(
|
|
1109
|
+
(
|
|
1110
|
+
Flow.model_validate(fl)
|
|
1111
|
+
for fl in flow_dict.get("flows", [])
|
|
1112
|
+
if fl.get("name") == flow_name
|
|
1113
|
+
),
|
|
1114
|
+
None,
|
|
1115
|
+
)
|
|
1116
|
+
if flow_obj:
|
|
1117
|
+
screenshots_dir = flow_file.parent / "screenshots"
|
|
1118
|
+
self.console.print(f"[cyan]Flow source: local ({flow_file})[/cyan]")
|
|
1119
|
+
break
|
|
1120
|
+
|
|
1121
|
+
if not flow_obj:
|
|
1122
|
+
raise ValueError(f"Flow '{flow_name}' not found in local config")
|
|
1123
|
+
|
|
1124
|
+
# Assert for type narrowing in nested function (checked above in both branches)
|
|
1125
|
+
assert flow_obj is not None
|
|
1126
|
+
validated_flow: Flow = flow_obj
|
|
1127
|
+
|
|
1128
|
+
# Run the flow with Playwright
|
|
1129
|
+
async def _run_flow():
|
|
1130
|
+
from playwright.async_api import async_playwright
|
|
1131
|
+
|
|
1132
|
+
async with async_playwright() as p:
|
|
1133
|
+
browser = await p.chromium.launch(headless=False)
|
|
1134
|
+
try:
|
|
1135
|
+
page = await browser.new_page()
|
|
1136
|
+
await page.goto(url)
|
|
1137
|
+
executor = FlowExecutor(page, validated_flow, screenshots_dir)
|
|
1138
|
+
await executor.execute()
|
|
1139
|
+
finally:
|
|
1140
|
+
await browser.close()
|
|
1141
|
+
|
|
1142
|
+
asyncio.get_event_loop().run_until_complete(_run_flow())
|
|
1143
|
+
|
|
1144
|
+
# CHECKED
|
|
1145
|
+
def state(self, session_id: str) -> SessionStateResponse:
|
|
1146
|
+
response = sessions_state.sync(
|
|
1147
|
+
client=self._http,
|
|
1148
|
+
session_id=session_id,
|
|
1149
|
+
merge_mutations=True,
|
|
1150
|
+
x_api_key=self.api_key,
|
|
1151
|
+
)
|
|
1152
|
+
return response
|
|
1153
|
+
|
|
1154
|
+
# CHECKED
|
|
1155
|
+
def start_services(
|
|
1156
|
+
self,
|
|
1157
|
+
simulator_name: str,
|
|
1158
|
+
ssh_config_path: str,
|
|
1159
|
+
ssh_host: str,
|
|
1160
|
+
dataset: str,
|
|
1161
|
+
) -> list[dict[str, str]]:
|
|
1162
|
+
# Get Gitea credentials
|
|
1163
|
+
creds = get_gitea_credentials.sync(client=self._http, x_api_key=self.api_key)
|
|
1164
|
+
|
|
1165
|
+
# Get accessible simulators
|
|
1166
|
+
simulators = get_accessible_simulators.sync(client=self._http, x_api_key=self.api_key)
|
|
1167
|
+
simulator = None
|
|
1168
|
+
for sim in simulators:
|
|
1169
|
+
sim_name = sim.get("name") if isinstance(sim, dict) else getattr(sim, "name", None)
|
|
1170
|
+
if sim_name and sim_name.lower() == simulator_name.lower():
|
|
1171
|
+
simulator = sim
|
|
1172
|
+
break
|
|
1173
|
+
if not simulator:
|
|
1174
|
+
raise ValueError(f"Simulator '{simulator_name}' not found in gitea accessible simulators")
|
|
1175
|
+
|
|
1176
|
+
# Get or create repo
|
|
1177
|
+
sim_id = simulator.get("id") if isinstance(simulator, dict) else getattr(simulator, "id", None)
|
|
1178
|
+
has_repo = simulator.get("has_repo") if isinstance(simulator, dict) else getattr(simulator, "has_repo", False)
|
|
1179
|
+
if has_repo:
|
|
1180
|
+
repo = get_simulator_repository.sync(client=self._http, simulator_id=sim_id, x_api_key=self.api_key) # type: ignore
|
|
1181
|
+
else:
|
|
1182
|
+
repo = create_simulator_repository.sync(client=self._http, simulator_id=sim_id, x_api_key=self.api_key) # type: ignore
|
|
1183
|
+
|
|
1184
|
+
clone_url = repo.clone_url
|
|
1185
|
+
if not clone_url:
|
|
1186
|
+
raise ValueError("No clone URL available for gitea repository")
|
|
1187
|
+
|
|
1188
|
+
# Build authenticated URL
|
|
1189
|
+
encoded_username = quote(creds.username, safe="")
|
|
1190
|
+
encoded_password = quote(creds.password, safe="")
|
|
1191
|
+
auth_clone_url = clone_url.replace("https://", f"https://{encoded_username}:{encoded_password}@", 1)
|
|
1192
|
+
|
|
1193
|
+
repo_dir = f"/home/plato/worktree/{simulator_name}"
|
|
1194
|
+
branch_name = f"workspace-{int(time.time())}"
|
|
1195
|
+
|
|
1196
|
+
# Clone, copy, push
|
|
1197
|
+
with tempfile.TemporaryDirectory(prefix="plato-hub-") as temp_dir:
|
|
1198
|
+
temp_repo = Path(temp_dir) / "repo"
|
|
1199
|
+
git_env = os.environ.copy()
|
|
1200
|
+
git_env["GIT_TERMINAL_PROMPT"] = "0"
|
|
1201
|
+
git_env["GIT_ASKPASS"] = ""
|
|
1202
|
+
|
|
1203
|
+
subprocess.run(
|
|
1204
|
+
["git", "clone", auth_clone_url, str(temp_repo)], capture_output=True, env=git_env, check=True
|
|
1205
|
+
)
|
|
1206
|
+
subprocess.run(
|
|
1207
|
+
["git", "checkout", "-b", branch_name], cwd=temp_repo, capture_output=True, env=git_env, check=True
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
# Copy files
|
|
1211
|
+
current_dir = Path(self.working_dir)
|
|
1212
|
+
|
|
1213
|
+
def _copy_files(src_dir: Path, dst_dir: Path) -> None:
|
|
1214
|
+
"""Copy files, skipping .git/ and .plato-hub.json."""
|
|
1215
|
+
for src_path in src_dir.rglob("*"):
|
|
1216
|
+
rel_path = src_path.relative_to(src_dir)
|
|
1217
|
+
if ".git" in rel_path.parts or rel_path.name == ".plato-hub.json":
|
|
1218
|
+
continue
|
|
1219
|
+
dst_path = dst_dir / rel_path
|
|
1220
|
+
if src_path.is_dir():
|
|
1221
|
+
dst_path.mkdir(parents=True, exist_ok=True)
|
|
1222
|
+
else:
|
|
1223
|
+
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
|
1224
|
+
shutil.copy2(src_path, dst_path)
|
|
1225
|
+
|
|
1226
|
+
_copy_files(current_dir, temp_repo)
|
|
1227
|
+
|
|
1228
|
+
subprocess.run(["git", "add", "."], cwd=temp_repo, capture_output=True, env=git_env)
|
|
1229
|
+
result = subprocess.run(
|
|
1230
|
+
["git", "status", "--porcelain"], cwd=temp_repo, capture_output=True, text=True, env=git_env
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
if result.stdout.strip():
|
|
1234
|
+
subprocess.run(
|
|
1235
|
+
["git", "commit", "-m", "Sync from local workspace"],
|
|
1236
|
+
cwd=temp_repo,
|
|
1237
|
+
capture_output=True,
|
|
1238
|
+
env=git_env,
|
|
1239
|
+
)
|
|
1240
|
+
|
|
1241
|
+
subprocess.run(
|
|
1242
|
+
["git", "remote", "set-url", "origin", auth_clone_url],
|
|
1243
|
+
cwd=temp_repo,
|
|
1244
|
+
capture_output=True,
|
|
1245
|
+
env=git_env,
|
|
1246
|
+
)
|
|
1247
|
+
subprocess.run(
|
|
1248
|
+
["git", "push", "-u", "origin", branch_name],
|
|
1249
|
+
cwd=temp_repo,
|
|
1250
|
+
capture_output=True,
|
|
1251
|
+
env=git_env,
|
|
1252
|
+
check=True,
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
# Clone on VM - first verify SSH works
|
|
1256
|
+
# Debug: show SSH config being used
|
|
1257
|
+
ssh_config_full_path = (
|
|
1258
|
+
Path(self.working_dir) / ssh_config_path
|
|
1259
|
+
if not Path(ssh_config_path).is_absolute()
|
|
1260
|
+
else Path(ssh_config_path)
|
|
1261
|
+
)
|
|
1262
|
+
if not ssh_config_full_path.exists():
|
|
1263
|
+
raise ValueError(f"SSH config file not found: {ssh_config_full_path}")
|
|
1264
|
+
|
|
1265
|
+
self.console.print(f"[dim]SSH config: {ssh_config_full_path}[/dim]")
|
|
1266
|
+
self.console.print(f"[dim]SSH host: {ssh_host}[/dim]")
|
|
1267
|
+
|
|
1268
|
+
# Run SSH with verbose to see what's happening
|
|
1269
|
+
ssh_cmd = ["ssh", "-v", "-F", ssh_config_path, ssh_host, "echo 'SSH connection OK'"]
|
|
1270
|
+
self.console.print(f"[dim]Running: {' '.join(ssh_cmd)}[/dim]")
|
|
1271
|
+
self.console.print(f"[dim]Working dir: {self.working_dir}[/dim]")
|
|
1272
|
+
|
|
1273
|
+
result = subprocess.run(
|
|
1274
|
+
ssh_cmd,
|
|
1275
|
+
capture_output=True,
|
|
1276
|
+
text=True,
|
|
1277
|
+
cwd=self.working_dir,
|
|
1278
|
+
)
|
|
1279
|
+
ret, stdout, stderr = result.returncode, result.stdout, result.stderr
|
|
1280
|
+
|
|
1281
|
+
if ret != 0:
|
|
1282
|
+
# Show SSH config contents for debugging
|
|
1283
|
+
try:
|
|
1284
|
+
config_content = ssh_config_full_path.read_text()
|
|
1285
|
+
self.console.print(f"[yellow]SSH config contents:[/yellow]\n{config_content}")
|
|
1286
|
+
except Exception:
|
|
1287
|
+
pass
|
|
1288
|
+
# Show SSH verbose output
|
|
1289
|
+
self.console.print(f"[yellow]SSH stderr (verbose):[/yellow]\n{stderr}")
|
|
1290
|
+
error_output = stderr or stdout or "(no output)"
|
|
1291
|
+
raise ValueError(f"SSH connection failed (exit {ret})")
|
|
1292
|
+
|
|
1293
|
+
_run_ssh_command(ssh_config_path, ssh_host, "mkdir -p /home/plato/worktree", cwd=self.working_dir)
|
|
1294
|
+
_run_ssh_command(ssh_config_path, ssh_host, f"rm -rf {repo_dir}", cwd=self.working_dir)
|
|
1295
|
+
|
|
1296
|
+
# Clone repo - mask credentials in error output
|
|
1297
|
+
ret, stdout, stderr = _run_ssh_command(
|
|
1298
|
+
ssh_config_path,
|
|
1299
|
+
ssh_host,
|
|
1300
|
+
f"git clone -b {branch_name} {auth_clone_url} {repo_dir}",
|
|
1301
|
+
cwd=self.working_dir,
|
|
1302
|
+
)
|
|
1303
|
+
if ret != 0:
|
|
1304
|
+
# Mask credentials in error output
|
|
1305
|
+
safe_url = clone_url # Use non-authenticated URL in error
|
|
1306
|
+
error_output = stderr or stdout or "(no output)"
|
|
1307
|
+
error_output = error_output.replace(creds.username, "***").replace(creds.password, "***")
|
|
1308
|
+
raise ValueError(f"Clone failed (exit {ret}) for {safe_url} branch {branch_name}: {error_output}")
|
|
1309
|
+
|
|
1310
|
+
# ECR auth
|
|
1311
|
+
ecr_result = subprocess.run(
|
|
1312
|
+
["aws", "ecr", "get-login-password", "--region", "us-west-1"], capture_output=True, text=True
|
|
1313
|
+
)
|
|
1314
|
+
if ecr_result.returncode == 0:
|
|
1315
|
+
ecr_token = ecr_result.stdout.strip()
|
|
1316
|
+
ecr_registry = "383806609161.dkr.ecr.us-west-1.amazonaws.com"
|
|
1317
|
+
_run_ssh_command(
|
|
1318
|
+
ssh_config_path,
|
|
1319
|
+
ssh_host,
|
|
1320
|
+
f"echo '{ecr_token}' | docker login --username AWS --password-stdin {ecr_registry}",
|
|
1321
|
+
cwd=self.working_dir,
|
|
1322
|
+
)
|
|
1323
|
+
|
|
1324
|
+
# Start services
|
|
1325
|
+
services_started = []
|
|
1326
|
+
with open(self.working_dir / "plato-config.yml", "rb") as f:
|
|
1327
|
+
plato_config = yaml.safe_load(f)
|
|
1328
|
+
plato_config_model = PlatoConfig.model_validate(plato_config)
|
|
1329
|
+
services_config = plato_config_model.datasets[dataset].services
|
|
1330
|
+
if not services_config:
|
|
1331
|
+
self.console.print("[yellow]No services configured, skipping service startup[/yellow]")
|
|
1332
|
+
return services_started
|
|
1333
|
+
for svc_name, svc_config in services_config.items():
|
|
1334
|
+
# svc_config is a Pydantic model (DockerComposeServiceConfig), use getattr
|
|
1335
|
+
svc_type = getattr(svc_config, "type", "")
|
|
1336
|
+
if svc_type == "docker-compose":
|
|
1337
|
+
compose_file = getattr(svc_config, "file", "docker-compose.yml")
|
|
1338
|
+
compose_cmd = f"cd {repo_dir} && docker compose -f {compose_file} up -d"
|
|
1339
|
+
ret, _, stderr = _run_ssh_command(ssh_config_path, ssh_host, compose_cmd, cwd=self.working_dir)
|
|
1340
|
+
if ret != 0:
|
|
1341
|
+
raise ValueError(f"Failed to start {svc_name}: {stderr}")
|
|
1342
|
+
services_started.append({"name": svc_name, "type": "docker-compose", "file": compose_file})
|
|
1343
|
+
else:
|
|
1344
|
+
raise ValueError(f"Unsupported service type: {svc_type}")
|
|
1345
|
+
|
|
1346
|
+
return services_started
|
|
1347
|
+
|
|
1348
|
+
# # -------------------------------------------------------------------------
|
|
1349
|
+
# # RUN FLOW
|
|
1350
|
+
# # -------------------------------------------------------------------------
|
|
1351
|
+
|
|
1352
|
+
# def clear_audit(
|
|
1353
|
+
# self,
|
|
1354
|
+
# job_id: str,
|
|
1355
|
+
# session_id: str | None = None,
|
|
1356
|
+
# db_listeners: list[tuple[str, dict]] | None = None,
|
|
1357
|
+
# ) -> ClearAuditResult:
|
|
1358
|
+
# """Clear audit_log tables in sandbox databases.
|
|
1359
|
+
|
|
1360
|
+
# Args:
|
|
1361
|
+
# job_id: Job ID for the sandbox.
|
|
1362
|
+
# session_id: Session ID for refreshing state cache.
|
|
1363
|
+
# db_listeners: List of (name, config) tuples for database listeners.
|
|
1364
|
+
|
|
1365
|
+
# Returns:
|
|
1366
|
+
# ClearAuditResult with cleanup status.
|
|
1367
|
+
# """
|
|
1368
|
+
# if not db_listeners:
|
|
1369
|
+
# return ClearAuditResult(success=False, error="No database listeners provided")
|
|
1370
|
+
|
|
1371
|
+
# def _execute_db_cleanup(name: str, db_config: dict, local_port: int) -> dict:
|
|
1372
|
+
# """Execute DB cleanup using sync SQLAlchemy."""
|
|
1373
|
+
# db_type = db_config.get("db_type", "postgresql").lower()
|
|
1374
|
+
# db_user = db_config.get("db_user", "postgres" if db_type == "postgresql" else "root")
|
|
1375
|
+
# db_password = db_config.get("db_password", "")
|
|
1376
|
+
# db_database = db_config.get("db_database", "postgres")
|
|
1377
|
+
|
|
1378
|
+
# user = quote_plus(db_user)
|
|
1379
|
+
# password = quote_plus(db_password)
|
|
1380
|
+
# database = quote_plus(db_database)
|
|
1381
|
+
|
|
1382
|
+
# if db_type == "postgresql":
|
|
1383
|
+
# db_url = f"postgresql+psycopg2://{user}:{password}@127.0.0.1:{local_port}/{database}"
|
|
1384
|
+
# elif db_type in ("mysql", "mariadb"):
|
|
1385
|
+
# db_url = f"mysql+pymysql://{user}:{password}@127.0.0.1:{local_port}/{database}"
|
|
1386
|
+
# else:
|
|
1387
|
+
# return {"listener": name, "success": False, "error": f"Unsupported db_type: {db_type}"}
|
|
1388
|
+
|
|
1389
|
+
# engine = create_engine(db_url, pool_pre_ping=True)
|
|
1390
|
+
# tables_truncated = []
|
|
1391
|
+
|
|
1392
|
+
# with engine.begin() as conn:
|
|
1393
|
+
# if db_type == "postgresql":
|
|
1394
|
+
# result = conn.execute(
|
|
1395
|
+
# text("SELECT schemaname, tablename FROM pg_tables WHERE tablename = 'audit_log'")
|
|
1396
|
+
# )
|
|
1397
|
+
# tables = result.fetchall()
|
|
1398
|
+
# for schema, table in tables:
|
|
1399
|
+
# conn.execute(text(f"TRUNCATE TABLE {schema}.{table} RESTART IDENTITY CASCADE"))
|
|
1400
|
+
# tables_truncated.append(f"{schema}.{table}")
|
|
1401
|
+
# elif db_type in ("mysql", "mariadb"):
|
|
1402
|
+
# result = conn.execute(
|
|
1403
|
+
# text(
|
|
1404
|
+
# "SELECT table_schema, table_name FROM information_schema.tables "
|
|
1405
|
+
# "WHERE table_name = 'audit_log' AND table_schema = DATABASE()"
|
|
1406
|
+
# )
|
|
1407
|
+
# )
|
|
1408
|
+
# tables = result.fetchall()
|
|
1409
|
+
# conn.execute(text("SET FOREIGN_KEY_CHECKS = 0"))
|
|
1410
|
+
# for schema, table in tables:
|
|
1411
|
+
# conn.execute(text(f"DELETE FROM `{table}`"))
|
|
1412
|
+
# tables_truncated.append(table)
|
|
1413
|
+
# conn.execute(text("SET FOREIGN_KEY_CHECKS = 1"))
|
|
1414
|
+
|
|
1415
|
+
# engine.dispose()
|
|
1416
|
+
# return {"listener": name, "success": True, "tables_truncated": tables_truncated}
|
|
1417
|
+
|
|
1418
|
+
# async def clear_audit_via_tunnel(name: str, db_config: dict) -> dict:
|
|
1419
|
+
# """Clear audit_log by connecting via proxy tunnel."""
|
|
1420
|
+
# db_type = db_config.get("db_type", "postgresql").lower()
|
|
1421
|
+
# db_port = db_config.get("db_port", 5432 if db_type == "postgresql" else 3306)
|
|
1422
|
+
|
|
1423
|
+
# local_port = find_free_port()
|
|
1424
|
+
# tunnel = ProxyTunnel(
|
|
1425
|
+
# env_id=job_id,
|
|
1426
|
+
# db_port=db_port,
|
|
1427
|
+
# temp_password="newpass",
|
|
1428
|
+
# host_port=local_port,
|
|
1429
|
+
# )
|
|
1430
|
+
|
|
1431
|
+
# try:
|
|
1432
|
+
# await tunnel.start()
|
|
1433
|
+
# result = await asyncio.to_thread(_execute_db_cleanup, name, db_config, local_port)
|
|
1434
|
+
# return result
|
|
1435
|
+
# except Exception as e:
|
|
1436
|
+
# return {"listener": name, "success": False, "error": str(e)}
|
|
1437
|
+
# finally:
|
|
1438
|
+
# await tunnel.stop()
|
|
1439
|
+
|
|
1440
|
+
# async def run_all():
|
|
1441
|
+
# tasks = [clear_audit_via_tunnel(name, db_config) for name, db_config in db_listeners]
|
|
1442
|
+
# return await asyncio.gather(*tasks)
|
|
1443
|
+
|
|
1444
|
+
# try:
|
|
1445
|
+
# results = asyncio.run(run_all())
|
|
1446
|
+
|
|
1447
|
+
# # Refresh state cache
|
|
1448
|
+
# if session_id:
|
|
1449
|
+
# try:
|
|
1450
|
+
# sessions_state.sync(
|
|
1451
|
+
# client=self._http,
|
|
1452
|
+
# session_id=session_id,
|
|
1453
|
+
# x_api_key=self.api_key,
|
|
1454
|
+
# )
|
|
1455
|
+
# except Exception:
|
|
1456
|
+
# pass
|
|
1457
|
+
|
|
1458
|
+
# all_success = all(r["success"] for r in results)
|
|
1459
|
+
# return ClearAuditResult(success=all_success, results=list(results))
|
|
1460
|
+
|
|
1461
|
+
# except Exception as e:
|
|
1462
|
+
# return ClearAuditResult(success=False, error=str(e))
|