plato-sdk-v2 2.3.6__py3-none-any.whl → 2.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
plato/agents/runner.py CHANGED
@@ -1,4 +1,11 @@
1
- """Agent runner - run agents in Docker containers."""
1
+ """Agent runner - run agents in Docker containers.
2
+
3
+ Agents emit their own OTel spans for trajectory events. This runner:
4
+ 1. Runs agents in Docker containers
5
+ 2. Streams stdout/stderr for logging
6
+ 3. Passes OTel environment variables for trace context propagation
7
+ 4. Uploads artifacts to S3 when complete
8
+ """
2
9
 
3
10
  from __future__ import annotations
4
11
 
@@ -8,12 +15,10 @@ import logging
8
15
  import os
9
16
  import platform
10
17
  import tempfile
11
- from pathlib import Path
12
18
 
13
19
  from opentelemetry import trace
14
20
 
15
21
  from plato.agents.artifacts import upload_artifacts
16
- from plato.agents.otel import get_tracer
17
22
 
18
23
  logger = logging.getLogger(__name__)
19
24
 
@@ -37,310 +42,142 @@ async def run_agent(
37
42
  workspace: Host directory to mount as /workspace
38
43
  logs_dir: Host directory for logs (temp dir if None)
39
44
  pull: Whether to pull the image first
45
+
46
+ Note: Agents handle their own OTel tracing. This runner only passes
47
+ the trace context (TRACEPARENT) so agent spans link to the parent step.
40
48
  """
41
49
  logs_dir = logs_dir or tempfile.mkdtemp(prefix="agent_logs_")
42
- agent_name = image.split("/")[-1].split(":")[0]
43
50
 
44
51
  # Get session info from environment variables
45
52
  session_id = os.environ.get("SESSION_ID")
46
53
  otel_url = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
47
54
  upload_url = os.environ.get("UPLOAD_URL")
48
55
 
49
- tracer = get_tracer("plato.agent")
50
-
51
- with tracer.start_as_current_span(agent_name) as agent_span:
52
- agent_span.set_attribute("span.type", "agent")
53
- agent_span.set_attribute("source", "agent")
54
- agent_span.set_attribute("image", image)
55
- agent_span.set_attribute("content", f"Starting agent: {agent_name}")
56
-
57
- # Pull image if requested
58
- if pull:
59
- with tracer.start_as_current_span("docker_pull") as pull_span:
60
- pull_span.set_attribute("span.type", "docker_pull")
61
- pull_span.set_attribute("image", image)
62
- pull_proc = await asyncio.create_subprocess_exec(
63
- "docker",
64
- "pull",
65
- image,
66
- stdout=asyncio.subprocess.PIPE,
67
- stderr=asyncio.subprocess.STDOUT,
56
+ # Pull image if requested
57
+ if pull:
58
+ pull_proc = await asyncio.create_subprocess_exec(
59
+ "docker",
60
+ "pull",
61
+ image,
62
+ stdout=asyncio.subprocess.PIPE,
63
+ stderr=asyncio.subprocess.STDOUT,
64
+ )
65
+ await pull_proc.wait()
66
+
67
+ # Setup
68
+ os.makedirs(os.path.join(logs_dir, "agent"), exist_ok=True)
69
+ config_file = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
70
+ json.dump(config, config_file)
71
+ config_file.close()
72
+
73
+ try:
74
+ # Build docker command
75
+ docker_cmd = ["docker", "run", "--rm", "--privileged"]
76
+
77
+ # Determine if we need host networking
78
+ use_host_network = False
79
+ is_macos = platform.system() == "Darwin"
80
+
81
+ if not is_macos:
82
+ try:
83
+ proc = await asyncio.create_subprocess_exec(
84
+ "iptables",
85
+ "-L",
86
+ "-n",
87
+ stdout=asyncio.subprocess.DEVNULL,
88
+ stderr=asyncio.subprocess.DEVNULL,
68
89
  )
69
- await pull_proc.wait()
70
-
71
- # Setup
72
- os.makedirs(os.path.join(logs_dir, "agent"), exist_ok=True)
73
- config_file = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
74
- json.dump(config, config_file)
75
- config_file.close()
76
-
77
- try:
78
- # Build docker command
79
- docker_cmd = ["docker", "run", "--rm"]
80
-
81
- # Determine if we need host networking
82
- use_host_network = False
83
- is_macos = platform.system() == "Darwin"
84
-
85
- if not is_macos:
86
- try:
87
- proc = await asyncio.create_subprocess_exec(
88
- "iptables",
89
- "-L",
90
- "-n",
91
- stdout=asyncio.subprocess.DEVNULL,
92
- stderr=asyncio.subprocess.DEVNULL,
93
- )
94
- await proc.wait()
95
- has_iptables = proc.returncode == 0
96
- except (FileNotFoundError, PermissionError):
97
- has_iptables = False
98
-
99
- use_host_network = not has_iptables
100
-
101
- if use_host_network:
102
- docker_cmd.extend(["--network=host", "--add-host=localhost:127.0.0.1"])
103
-
90
+ await proc.wait()
91
+ has_iptables = proc.returncode == 0
92
+ except (FileNotFoundError, PermissionError):
93
+ has_iptables = False
94
+
95
+ use_host_network = not has_iptables
96
+
97
+ if use_host_network:
98
+ docker_cmd.extend(["--network=host", "--add-host=localhost:127.0.0.1"])
99
+
100
+ docker_cmd.extend(
101
+ [
102
+ "-v",
103
+ f"{workspace}:/workspace",
104
+ "-v",
105
+ f"{logs_dir}:/logs",
106
+ "-v",
107
+ f"{config_file.name}:/config.json:ro",
108
+ "-v",
109
+ "/var/run/docker.sock:/var/run/docker.sock",
110
+ "-w",
111
+ "/workspace",
112
+ ]
113
+ )
114
+
115
+ # Pass session info to agent
116
+ if otel_url:
117
+ traces_endpoint = f"{otel_url.rstrip('/')}/v1/traces"
118
+ docker_cmd.extend(["-e", f"OTEL_EXPORTER_OTLP_ENDPOINT={otel_url}"])
119
+ docker_cmd.extend(["-e", f"OTEL_EXPORTER_OTLP_TRACES_ENDPOINT={traces_endpoint}"])
120
+ docker_cmd.extend(["-e", "OTEL_EXPORTER_OTLP_PROTOCOL=http/protobuf"])
121
+ if session_id:
122
+ docker_cmd.extend(["-e", f"SESSION_ID={session_id}"])
123
+ if upload_url:
124
+ docker_cmd.extend(["-e", f"UPLOAD_URL={upload_url}"])
125
+
126
+ # Pass trace context to agent for parent linking
127
+ # Agent spans will be children of the current step span
128
+ current_span = trace.get_current_span()
129
+ span_context = current_span.get_span_context()
130
+ if span_context.is_valid:
131
+ trace_id = format(span_context.trace_id, "032x")
132
+ span_id = format(span_context.span_id, "016x")
133
+ # W3C Trace Context format for TRACEPARENT
134
+ traceparent = f"00-{trace_id}-{span_id}-01"
104
135
  docker_cmd.extend(
105
136
  [
106
- "-v",
107
- f"{workspace}:/workspace",
108
- "-v",
109
- f"{logs_dir}:/logs",
110
- "-v",
111
- f"{config_file.name}:/config.json:ro",
112
- "-v",
113
- "/var/run/docker.sock:/var/run/docker.sock",
114
- "-w",
115
- "/workspace",
137
+ "-e",
138
+ f"TRACEPARENT={traceparent}",
139
+ "-e",
140
+ f"OTEL_TRACE_ID={trace_id}",
141
+ "-e",
142
+ f"OTEL_PARENT_SPAN_ID={span_id}",
116
143
  ]
117
144
  )
118
145
 
119
- # Pass session info to agent
120
- if otel_url:
121
- docker_cmd.extend(["-e", f"OTEL_EXPORTER_OTLP_ENDPOINT={otel_url}"])
122
- # Use JSON protocol (not protobuf) for OTLP exports
123
- docker_cmd.extend(["-e", "OTEL_EXPORTER_OTLP_PROTOCOL=http/json"])
124
- if session_id:
125
- docker_cmd.extend(["-e", f"SESSION_ID={session_id}"])
126
- if upload_url:
127
- docker_cmd.extend(["-e", f"UPLOAD_URL={upload_url}"])
128
-
129
- # Pass trace context to agent for parent linking
130
- current_span = trace.get_current_span()
131
- span_context = current_span.get_span_context()
132
- if span_context.is_valid:
133
- trace_id = format(span_context.trace_id, "032x")
134
- span_id = format(span_context.span_id, "016x")
135
- docker_cmd.extend(
136
- [
137
- "-e",
138
- f"OTEL_TRACE_ID={trace_id}",
139
- "-e",
140
- f"OTEL_PARENT_SPAN_ID={span_id}",
141
- ]
142
- )
143
-
144
- for key, value in secrets.items():
145
- docker_cmd.extend(["-e", f"{key.upper()}={value}"])
146
-
147
- docker_cmd.append(image)
148
-
149
- # Pass instruction via CLI arg
150
- docker_cmd.extend(["--instruction", instruction])
151
-
152
- # Run container and stream output
153
- with tracer.start_as_current_span("agent_execution") as exec_span:
154
- exec_span.set_attribute("span.type", "agent_execution")
155
- exec_span.set_attribute("content", f"Running {agent_name}")
156
-
157
- process = await asyncio.create_subprocess_exec(
158
- *docker_cmd,
159
- stdout=asyncio.subprocess.PIPE,
160
- stderr=asyncio.subprocess.STDOUT,
161
- )
162
-
163
- # Stream output line by line
164
- output_lines: list[str] = []
165
- turn_count = 0
166
- assert process.stdout is not None
167
- while True:
168
- line = await process.stdout.readline()
169
- if not line:
170
- break
171
- decoded_line = line.decode().rstrip()
172
- output_lines.append(decoded_line)
173
-
174
- # Try to parse JSON output from agent for structured trajectory spans
175
- try:
176
- data = json.loads(decoded_line)
177
- event_type = data.get("type", "")
178
-
179
- if event_type == "assistant":
180
- # Agent response - create a turn span
181
- turn_count += 1
182
- msg = data.get("message", {})
183
- content_items = msg.get("content", [])
184
-
185
- # Extract text and tool calls with full details
186
- text_parts = []
187
- tool_calls = []
188
- for item in content_items:
189
- if item.get("type") == "text":
190
- text_parts.append(item.get("text", "")[:2000])
191
- elif item.get("type") == "tool_use":
192
- tool_input = item.get("input", {})
193
- # Truncate large inputs
194
- input_str = json.dumps(tool_input) if tool_input else ""
195
- if len(input_str) > 2000:
196
- input_str = input_str[:2000] + "..."
197
- tool_calls.append(
198
- {
199
- "tool": item.get("name"),
200
- "id": item.get("id"),
201
- "input": input_str,
202
- }
203
- )
204
-
205
- with tracer.start_as_current_span(f"turn_{turn_count}") as turn_span:
206
- turn_span.set_attribute("span.type", "agent_turn")
207
- turn_span.set_attribute("source", "agent")
208
- turn_span.set_attribute("turn_number", turn_count)
209
- turn_span.set_attribute("model", msg.get("model", "unknown"))
210
-
211
- if text_parts:
212
- turn_span.set_attribute("content", "\n".join(text_parts)[:4000])
213
- if tool_calls:
214
- turn_span.set_attribute("tool_calls", json.dumps(tool_calls))
215
- # If no text content, show tool calls summary
216
- if not text_parts:
217
- turn_span.set_attribute(
218
- "content", f"Tool calls: {', '.join(t['tool'] for t in tool_calls)}"
219
- )
220
-
221
- # Usage info
222
- usage = msg.get("usage", {})
223
- if usage:
224
- turn_span.set_attribute("input_tokens", usage.get("input_tokens", 0))
225
- turn_span.set_attribute("output_tokens", usage.get("output_tokens", 0))
226
-
227
- elif event_type == "user":
228
- # Tool result
229
- tool_results = data.get("message", {}).get("content", [])
230
- for result in tool_results:
231
- if result.get("type") == "tool_result":
232
- tool_id = result.get("tool_use_id", "")
233
- content = result.get("content", "")
234
- # Handle content that might be a list of content blocks
235
- if isinstance(content, list):
236
- text_parts = []
237
- for item in content:
238
- if isinstance(item, dict) and item.get("type") == "text":
239
- text_parts.append(item.get("text", ""))
240
- elif isinstance(item, str):
241
- text_parts.append(item)
242
- content = "\n".join(text_parts)
243
- if isinstance(content, str):
244
- content = content[:2000] # Truncate large results
245
- with tracer.start_as_current_span("tool_result") as tr_span:
246
- tr_span.set_attribute("span.type", "tool_result")
247
- tr_span.set_attribute("source", "agent")
248
- tr_span.set_attribute("tool_use_id", tool_id)
249
- tr_span.set_attribute("content", f"Tool result for {tool_id}")
250
- tr_span.set_attribute("result", content if content else "")
251
-
252
- elif event_type == "result":
253
- # Final result
254
- result_text = data.get("result", "")[:1000]
255
- is_error = data.get("is_error", False)
256
- duration_ms = data.get("duration_ms", 0)
257
- total_cost = data.get("total_cost_usd", 0)
258
-
259
- with tracer.start_as_current_span("agent_result") as res_span:
260
- res_span.set_attribute("span.type", "agent_result")
261
- res_span.set_attribute("source", "agent")
262
- res_span.set_attribute("content", result_text if result_text else "Agent completed")
263
- res_span.set_attribute("is_error", is_error)
264
- res_span.set_attribute("duration_ms", duration_ms)
265
- res_span.set_attribute("total_cost_usd", total_cost)
266
- res_span.set_attribute("num_turns", data.get("num_turns", turn_count))
267
-
268
- elif event_type == "system" and data.get("subtype") == "init":
269
- # Agent initialization
270
- with tracer.start_as_current_span("agent_init") as init_span:
271
- init_span.set_attribute("span.type", "agent_init")
272
- init_span.set_attribute("source", "agent")
273
- init_span.set_attribute("model", data.get("model", "unknown"))
274
- init_span.set_attribute("tools", json.dumps(data.get("tools", [])))
275
- init_span.set_attribute("content", f"Agent initialized: {data.get('model', 'unknown')}")
276
-
277
- else:
278
- # Other output - just log it without creating a span
279
- logger.debug(f"[agent] {decoded_line}")
280
- continue
281
-
282
- except json.JSONDecodeError:
283
- # Not JSON - just log it
284
- logger.info(f"[agent] {decoded_line}")
285
-
286
- await process.wait()
287
-
288
- exit_code = process.returncode or 0
289
- if exit_code != 0:
290
- error_context = "\n".join(output_lines[-50:]) if output_lines else "No output captured"
291
-
292
- exec_span.set_attribute("error", True)
293
- exec_span.set_attribute("exit_code", exit_code)
294
- exec_span.add_event(
295
- "agent_error",
296
- {
297
- "exit_code": exit_code,
298
- "output": error_context[:4000],
299
- },
300
- )
146
+ for key, value in secrets.items():
147
+ docker_cmd.extend(["-e", f"{key.upper()}={value}"])
301
148
 
302
- agent_span.set_attribute("error", True)
303
- agent_span.set_attribute("exit_code", exit_code)
149
+ docker_cmd.append(image)
304
150
 
305
- raise RuntimeError(f"Agent failed with exit code {exit_code}")
151
+ # Pass instruction via CLI arg
152
+ docker_cmd.extend(["--instruction", instruction])
306
153
 
307
- exec_span.set_attribute("success", True)
154
+ # Run container - agents emit their own OTel spans
155
+ process = await asyncio.create_subprocess_exec(
156
+ *docker_cmd,
157
+ stdout=asyncio.subprocess.PIPE,
158
+ stderr=asyncio.subprocess.STDOUT,
159
+ )
308
160
 
309
- finally:
310
- os.unlink(config_file.name)
161
+ # Capture output for error reporting
162
+ output_lines: list[str] = []
163
+ assert process.stdout is not None
164
+ while True:
165
+ line = await process.stdout.readline()
166
+ if not line:
167
+ break
168
+ decoded_line = line.decode().rstrip()
169
+ output_lines.append(decoded_line)
311
170
 
312
- # Load trajectory and log as event
313
- trajectory_path = Path(logs_dir) / "agent" / "trajectory.json"
314
- if trajectory_path.exists():
315
- try:
316
- with open(trajectory_path) as f:
317
- trajectory = json.load(f)
318
- if isinstance(trajectory, dict) and "schema_version" in trajectory:
319
- # Add agent image to trajectory
320
- agent_data = trajectory.get("agent", {})
321
- extra = agent_data.get("extra") or {}
322
- extra["image"] = image
323
- agent_data["extra"] = extra
324
- trajectory["agent"] = agent_data
171
+ await process.wait()
325
172
 
326
- # Log trajectory as span event
327
- with tracer.start_as_current_span("trajectory") as traj_span:
328
- traj_span.set_attribute("span.type", "trajectory")
329
- traj_span.set_attribute("log_type", "atif")
330
- traj_span.set_attribute("source", "agent")
331
- # Store trajectory in span (truncated for OTel limits)
332
- traj_json = json.dumps(trajectory)
333
- if len(traj_json) > 10000:
334
- traj_span.set_attribute("trajectory_truncated", True)
335
- traj_span.set_attribute("trajectory_size", len(traj_json))
336
- else:
337
- traj_span.set_attribute("trajectory", traj_json)
338
- except Exception as e:
339
- logger.warning(f"Failed to load trajectory: {e}")
173
+ exit_code = process.returncode or 0
174
+ if exit_code != 0:
175
+ error_context = "\n".join(output_lines[-50:]) if output_lines else "No output captured"
176
+ raise RuntimeError(f"Agent failed with exit code {exit_code}\n\nAgent output:\n{error_context}")
340
177
 
341
- # Upload artifacts if we have upload URL configured
342
- if upload_url:
343
- await upload_artifacts(upload_url, logs_dir)
178
+ finally:
179
+ os.unlink(config_file.name)
344
180
 
345
- agent_span.set_attribute("success", True)
346
- agent_span.set_attribute("content", f"Agent {agent_name} completed successfully")
181
+ # Upload artifacts if we have upload URL configured
182
+ if upload_url:
183
+ await upload_artifacts(upload_url, logs_dir)
plato/v1/cli/sandbox.py CHANGED
@@ -131,6 +131,9 @@ def sandbox_start(
131
131
  timeout: int = typer.Option(1800, "--timeout", help="VM lifetime in seconds (default: 30 minutes)"),
132
132
  no_reset: bool = typer.Option(False, "--no-reset", help="Skip initial reset after ready"),
133
133
  json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
134
+ working_dir: Path = typer.Option(
135
+ None, "--working-dir", "-w", help="Working directory for .sandbox.yaml and .plato/"
136
+ ),
134
137
  ):
135
138
  """
136
139
  Start a sandbox environment.
@@ -377,7 +380,7 @@ def sandbox_start(
377
380
  console.print("[cyan] Generating SSH key pair...[/cyan]")
378
381
 
379
382
  base_url = os.getenv("PLATO_BASE_URL", "https://plato.so")
380
- ssh_info = setup_ssh_for_sandbox(base_url, job_id, username=ssh_username)
383
+ ssh_info = setup_ssh_for_sandbox(base_url, job_id, username=ssh_username, working_dir=working_dir)
381
384
  ssh_host = ssh_info["ssh_host"]
382
385
  ssh_config_path = ssh_info["config_path"]
383
386
  ssh_private_key_path = ssh_info["private_key_path"]
@@ -489,7 +492,7 @@ def sandbox_start(
489
492
  # Add heartbeat PID
490
493
  if heartbeat_pid:
491
494
  state["heartbeat_pid"] = heartbeat_pid
492
- save_sandbox_state(state)
495
+ save_sandbox_state(state, working_dir)
493
496
 
494
497
  # Close the plato client (heartbeat process keeps session alive)
495
498
  plato.close()
plato/v1/cli/ssh.py CHANGED
@@ -9,21 +9,21 @@ from cryptography.hazmat.primitives import serialization
9
9
  from cryptography.hazmat.primitives.asymmetric import ed25519
10
10
 
11
11
 
12
- def get_plato_dir() -> Path:
12
+ def get_plato_dir(working_dir: Path | str | None = None) -> Path:
13
13
  """Get the directory for plato config/SSH files.
14
14
 
15
- Uses /workspace/.plato if /workspace exists (container environment),
16
- otherwise uses ~/.plato (local development).
15
+ Args:
16
+ working_dir: If provided, returns working_dir/.plato (for container/agent use).
17
+ If None, returns ~/.plato (local development).
17
18
  """
18
- workspace = Path("/workspace")
19
- if workspace.exists() and workspace.is_dir():
20
- return workspace / ".plato"
19
+ if working_dir is not None:
20
+ return Path(working_dir) / ".plato"
21
21
  return Path.home() / ".plato"
22
22
 
23
23
 
24
- def get_next_sandbox_number() -> int:
24
+ def get_next_sandbox_number(working_dir: Path | str | None = None) -> int:
25
25
  """Find next available sandbox number by checking existing config files."""
26
- plato_dir = get_plato_dir()
26
+ plato_dir = get_plato_dir(working_dir)
27
27
  if not plato_dir.exists():
28
28
  return 1
29
29
 
@@ -41,13 +41,13 @@ def get_next_sandbox_number() -> int:
41
41
  return max_num + 1
42
42
 
43
43
 
44
- def generate_ssh_key_pair(sandbox_num: int) -> tuple[str, str]:
44
+ def generate_ssh_key_pair(sandbox_num: int, working_dir: Path | str | None = None) -> tuple[str, str]:
45
45
  """
46
46
  Generate a new ed25519 SSH key pair for a specific sandbox.
47
47
 
48
48
  Returns (public_key_str, private_key_path).
49
49
  """
50
- plato_dir = get_plato_dir()
50
+ plato_dir = get_plato_dir(working_dir)
51
51
  plato_dir.mkdir(mode=0o700, exist_ok=True)
52
52
 
53
53
  private_key_path = plato_dir / f"ssh_{sandbox_num}_key"
@@ -136,6 +136,7 @@ def create_ssh_config(
136
136
  username: str,
137
137
  private_key_path: str,
138
138
  sandbox_num: int,
139
+ working_dir: Path | str | None = None,
139
140
  ) -> str:
140
141
  """
141
142
  Create a temporary SSH config file for a specific sandbox.
@@ -172,7 +173,7 @@ def create_ssh_config(
172
173
  TCPKeepAlive yes
173
174
  """
174
175
 
175
- plato_dir = get_plato_dir()
176
+ plato_dir = get_plato_dir(working_dir)
176
177
  plato_dir.mkdir(mode=0o700, exist_ok=True)
177
178
 
178
179
  config_path = plato_dir / f"ssh_{sandbox_num}.conf"
@@ -182,7 +183,12 @@ def create_ssh_config(
182
183
  return str(config_path)
183
184
 
184
185
 
185
- def setup_ssh_for_sandbox(base_url: str, job_public_id: str, username: str = "plato") -> dict:
186
+ def setup_ssh_for_sandbox(
187
+ base_url: str,
188
+ job_public_id: str,
189
+ username: str = "plato",
190
+ working_dir: Path | str | None = None,
191
+ ) -> dict:
186
192
  """
187
193
  Set up SSH access for a sandbox - generates keys and creates config.
188
194
 
@@ -190,14 +196,14 @@ def setup_ssh_for_sandbox(base_url: str, job_public_id: str, username: str = "pl
190
196
 
191
197
  Returns dict with: ssh_host, config_path, public_key, private_key_path
192
198
  """
193
- sandbox_num = get_next_sandbox_number()
199
+ sandbox_num = get_next_sandbox_number(working_dir)
194
200
  ssh_host = f"sandbox-{sandbox_num}"
195
201
 
196
202
  # Choose random port between 2200 and 2299
197
203
  local_port = random.randint(2200, 2299)
198
204
 
199
205
  # Generate SSH key pair
200
- public_key, private_key_path = generate_ssh_key_pair(sandbox_num)
206
+ public_key, private_key_path = generate_ssh_key_pair(sandbox_num, working_dir)
201
207
 
202
208
  # Create SSH config file
203
209
  config_path = create_ssh_config(
@@ -208,6 +214,7 @@ def setup_ssh_for_sandbox(base_url: str, job_public_id: str, username: str = "pl
208
214
  username=username,
209
215
  private_key_path=private_key_path,
210
216
  sandbox_num=sandbox_num,
217
+ working_dir=working_dir,
211
218
  )
212
219
 
213
220
  return {
plato/v1/cli/utils.py CHANGED
@@ -15,32 +15,52 @@ console = Console()
15
15
  SANDBOX_FILE = ".sandbox.yaml"
16
16
 
17
17
 
18
- def get_sandbox_state() -> dict | None:
19
- """Read sandbox state from .sandbox.yaml in current directory."""
20
- sandbox_file = Path.cwd() / SANDBOX_FILE
18
+ def get_sandbox_state(working_dir: Path | str | None = None) -> dict | None:
19
+ """Read sandbox state from .sandbox.yaml.
20
+
21
+ Args:
22
+ working_dir: Directory containing .sandbox.yaml. If None, uses cwd.
23
+ """
24
+ base_dir = Path(working_dir) if working_dir else Path.cwd()
25
+ sandbox_file = base_dir / SANDBOX_FILE
21
26
  if not sandbox_file.exists():
22
27
  return None
23
28
  with open(sandbox_file) as f:
24
29
  return yaml.safe_load(f)
25
30
 
26
31
 
27
- def save_sandbox_state(state: dict) -> None:
28
- """Save sandbox state to .sandbox.yaml in current directory."""
29
- sandbox_file = Path.cwd() / SANDBOX_FILE
32
+ def save_sandbox_state(state: dict, working_dir: Path | str | None = None) -> None:
33
+ """Save sandbox state to .sandbox.yaml.
34
+
35
+ Args:
36
+ state: State dict to save.
37
+ working_dir: Directory to save .sandbox.yaml in. If None, uses cwd.
38
+ """
39
+ base_dir = Path(working_dir) if working_dir else Path.cwd()
40
+ sandbox_file = base_dir / SANDBOX_FILE
30
41
  with open(sandbox_file, "w") as f:
31
42
  yaml.dump(state, f, default_flow_style=False)
32
43
 
33
44
 
34
- def remove_sandbox_state() -> None:
35
- """Remove .sandbox.yaml from current directory."""
36
- sandbox_file = Path.cwd() / SANDBOX_FILE
45
+ def remove_sandbox_state(working_dir: Path | str | None = None) -> None:
46
+ """Remove .sandbox.yaml.
47
+
48
+ Args:
49
+ working_dir: Directory containing .sandbox.yaml. If None, uses cwd.
50
+ """
51
+ base_dir = Path(working_dir) if working_dir else Path.cwd()
52
+ sandbox_file = base_dir / SANDBOX_FILE
37
53
  if sandbox_file.exists():
38
54
  sandbox_file.unlink()
39
55
 
40
56
 
41
- def require_sandbox_state() -> dict:
42
- """Get sandbox state or exit with error."""
43
- state = get_sandbox_state()
57
+ def require_sandbox_state(working_dir: Path | str | None = None) -> dict:
58
+ """Get sandbox state or exit with error.
59
+
60
+ Args:
61
+ working_dir: Directory containing .sandbox.yaml. If None, uses cwd.
62
+ """
63
+ state = get_sandbox_state(working_dir)
44
64
  if not state:
45
65
  console.print("[red]No sandbox found in current directory[/red]")
46
66
  console.print("\n[yellow]Start a sandbox with:[/yellow]")