mcp-yieldshell 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,616 @@
1
+ """Process registry and lifecycle management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import os
7
+ import sys
8
+ import time
9
+ import uuid
10
+ from typing import Any
11
+
12
+ from ..config import Config
13
+ from ..security import redact_text
14
+ from ..types import ProcessInfo, ProcessStatus
15
+ from .ring_buffer import RingBuffer
16
+ from .spawn import kill_process, spawn_process, terminate_process
17
+
18
+
19
+ class ManagedProcess:
20
+ __slots__ = (
21
+ "info",
22
+ "proc",
23
+ "stdout_buf",
24
+ "stderr_buf",
25
+ "drain_stdout",
26
+ "drain_stderr",
27
+ "completion_event",
28
+ "completion_task",
29
+ "timeout_task",
30
+ "_seq_source",
31
+ "_timeout_triggered",
32
+ )
33
+
34
+ def __init__(
35
+ self,
36
+ info: ProcessInfo,
37
+ proc: asyncio.subprocess.Process,
38
+ max_output_bytes: int,
39
+ ) -> None:
40
+ self.info = info
41
+ self.proc = proc
42
+ self._seq_source: list[int] = [1]
43
+ self.stdout_buf = RingBuffer(max_output_bytes, seq_source=self._seq_source)
44
+ self.stderr_buf = RingBuffer(max_output_bytes, seq_source=self._seq_source)
45
+ self.drain_stdout: asyncio.Task[None] | None = None
46
+ self.drain_stderr: asyncio.Task[None] | None = None
47
+ self.completion_event: asyncio.Event = asyncio.Event()
48
+ self.completion_task: asyncio.Task[None] | None = None
49
+ self.timeout_task: asyncio.Task[None] | None = None
50
+ self._timeout_triggered = False
51
+
52
+
53
+ class ProcessManager:
54
+ """Registry and lifecycle manager for managed shell processes."""
55
+
56
+ def __init__(self, config: Config) -> None:
57
+ self._config = config
58
+ self._processes: dict[str, ManagedProcess] = {}
59
+
60
+ def _new_id(self) -> str:
61
+ return f"proc_{uuid.uuid4().hex[:12]}"
62
+
63
+ def _max_output(self, requested: int | None) -> int:
64
+ cap = self._config.max_output_bytes
65
+ if requested is None or requested <= 0:
66
+ return cap
67
+ return min(requested, cap)
68
+
69
+ def _clamp_yield_ms(self, requested: int | None) -> int:
70
+ if requested is None:
71
+ return self._config.default_yield_ms
72
+ return max(0, min(requested, self._config.max_yield_ms))
73
+
74
+ def _clamp_timeout_ms(self, requested: int | None) -> int:
75
+ if requested is None:
76
+ return self._config.default_timeout_ms
77
+ return max(0, requested)
78
+
79
+ async def exec_command(
80
+ self,
81
+ command: str,
82
+ cwd: str | None = None,
83
+ env_overlay: dict[str, str] | None = None,
84
+ shell: str | None = None,
85
+ stdin: str | None = None,
86
+ name: str | None = None,
87
+ yield_ms: int | None = None,
88
+ timeout_ms: int | None = None,
89
+ max_output_bytes: int | None = None,
90
+ ) -> dict[str, Any]:
91
+ """Execute a shell command with auto-yield behavior."""
92
+ from ..security import build_env, resolve_cwd, validate_command
93
+
94
+ # Validate command policy
95
+ cmd_error = validate_command(self._config, command)
96
+ if cmd_error:
97
+ return {"status": "failed_to_start", "error": cmd_error}
98
+
99
+ # Resolve and validate cwd
100
+ resolved_cwd, cwd_error = resolve_cwd(self._config, cwd)
101
+ if cwd_error:
102
+ return {"status": "failed_to_start", "error": cwd_error}
103
+
104
+ # Check process count limit
105
+ running_count = sum(
106
+ 1 for p in self._processes.values() if p.info.status == ProcessStatus.RUNNING
107
+ )
108
+ if running_count >= self._config.max_processes:
109
+ return {
110
+ "status": "failed_to_start",
111
+ "error": f"Maximum process limit ({self._config.max_processes}) reached",
112
+ }
113
+
114
+ # Build environment
115
+ env = build_env(self._config, env_overlay)
116
+ effective_yield = self._clamp_yield_ms(yield_ms)
117
+ effective_timeout = self._clamp_timeout_ms(timeout_ms)
118
+ effective_max_output = self._max_output(max_output_bytes)
119
+
120
+ # Spawn process
121
+ try:
122
+ proc = await spawn_process(command, cwd=resolved_cwd, env=env)
123
+ except Exception as exc:
124
+ return {"status": "failed_to_start", "error": str(exc)}
125
+
126
+ process_id = self._new_id()
127
+ start_time = time.monotonic()
128
+ start_timestamp = time.time()
129
+
130
+ info = ProcessInfo(
131
+ process_id=process_id,
132
+ pid=proc.pid,
133
+ command=command,
134
+ cwd=resolved_cwd,
135
+ name=name,
136
+ status=ProcessStatus.RUNNING,
137
+ started_at=start_timestamp,
138
+ start_monotonic=start_time,
139
+ )
140
+
141
+ mp = ManagedProcess(info, proc, effective_max_output)
142
+
143
+ # Start drain tasks immediately after spawn to prevent blocking on full pipe buffers
144
+ mp.drain_stdout = asyncio.create_task(
145
+ self._drain_stream(proc.stdout, mp.stdout_buf), name=f"drain-stdout-{process_id}"
146
+ )
147
+ mp.drain_stderr = asyncio.create_task(
148
+ self._drain_stream(proc.stderr, mp.stderr_buf), name=f"drain-stderr-{process_id}"
149
+ )
150
+
151
+ # Write initial stdin if provided; keep pipe open for follow-up writes
152
+ if stdin is not None:
153
+ try:
154
+ if proc.stdin is not None:
155
+ proc.stdin.write(stdin.encode("utf-8"))
156
+ await proc.stdin.drain()
157
+ except Exception:
158
+ pass
159
+
160
+ # Start completion tracking
161
+ mp.completion_task = asyncio.create_task(
162
+ self._track_completion(proc, mp), name=f"completion-{process_id}"
163
+ )
164
+
165
+ # Register process
166
+ self._processes[process_id] = mp
167
+
168
+ # Start timeout task if requested
169
+ if effective_timeout > 0:
170
+ mp.timeout_task = asyncio.create_task(
171
+ self._handle_timeout(mp, effective_timeout / 1000.0),
172
+ name=f"timeout-{process_id}",
173
+ )
174
+
175
+ # Wait up to yield_ms for completion
176
+ try:
177
+ await asyncio.wait_for(
178
+ mp.completion_event.wait(), timeout=effective_yield / 1000.0
179
+ )
180
+ except asyncio.TimeoutError:
181
+ pass
182
+
183
+ duration_ms = (time.monotonic() - start_time) * 1000
184
+
185
+ # Prepare output for response
186
+ stdout_data = mp.stdout_buf.read(max_bytes=effective_max_output)
187
+ stderr_data = mp.stderr_buf.read(max_bytes=effective_max_output)
188
+ truncated = stdout_data["truncated"] or stderr_data["truncated"]
189
+ stdout_text = redact_text(self._config, stdout_data["text"])
190
+ stderr_text = redact_text(self._config, stderr_data["text"])
191
+
192
+ if mp.info.status == ProcessStatus.COMPLETED:
193
+ return {
194
+ "status": "completed",
195
+ "exit_code": mp.info.exit_code,
196
+ "signal": mp.info.signal,
197
+ "duration_ms": round(duration_ms, 1),
198
+ "stdout": stdout_text,
199
+ "stderr": stderr_text,
200
+ "truncated": truncated,
201
+ }
202
+
203
+ if mp.info.status == ProcessStatus.TIMED_OUT:
204
+ return {
205
+ "status": "timed_out",
206
+ "process_id": process_id,
207
+ "exit_code": mp.info.exit_code,
208
+ "signal": mp.info.signal,
209
+ "duration_ms": round(duration_ms, 1),
210
+ "stdout": stdout_text,
211
+ "stderr": stderr_text,
212
+ "truncated": truncated,
213
+ }
214
+
215
+ if mp.info.status == ProcessStatus.STOPPED:
216
+ return {
217
+ "status": "stopped",
218
+ "process_id": process_id,
219
+ "exit_code": mp.info.exit_code,
220
+ "signal": mp.info.signal,
221
+ "duration_ms": round(duration_ms, 1),
222
+ "stdout": stdout_text,
223
+ "stderr": stderr_text,
224
+ "truncated": truncated,
225
+ }
226
+
227
+ if mp.info.status == ProcessStatus.FAILED:
228
+ return {
229
+ "status": "failed",
230
+ "process_id": process_id,
231
+ "exit_code": mp.info.exit_code,
232
+ "signal": mp.info.signal,
233
+ "duration_ms": round(duration_ms, 1),
234
+ "stdout": stdout_text,
235
+ "stderr": stderr_text,
236
+ "truncated": truncated,
237
+ }
238
+
239
+ # Still running — background it
240
+ return {
241
+ "status": "backgrounded",
242
+ "process_id": process_id,
243
+ "pid": mp.info.pid,
244
+ "duration_ms": round(duration_ms, 1),
245
+ "stdout": stdout_text,
246
+ "stderr": stderr_text,
247
+ "truncated": truncated,
248
+ "message": "Process is running in the background. Use read/wait/stop with process_id.",
249
+ }
250
+
251
+ async def _drain_stream(
252
+ self, stream: asyncio.StreamReader | None, buf: RingBuffer
253
+ ) -> None:
254
+ """Read from a subprocess stream into a ring buffer."""
255
+ if stream is None:
256
+ return
257
+ while True:
258
+ try:
259
+ chunk = await stream.read(4096)
260
+ if not chunk:
261
+ break
262
+ buf.append(chunk)
263
+ except Exception:
264
+ break
265
+
266
+ async def _track_completion(
267
+ self, proc: asyncio.subprocess.Process, mp: ManagedProcess
268
+ ) -> None:
269
+ """Wait for process to exit and update status."""
270
+ try:
271
+ returncode = await proc.wait()
272
+ # Ensure drain tasks finish before signaling completion so
273
+ # all output is buffered before any response reads happen.
274
+ if mp.drain_stdout is not None and not mp.drain_stdout.done():
275
+ try:
276
+ await mp.drain_stdout
277
+ except Exception:
278
+ pass
279
+ if mp.drain_stderr is not None and not mp.drain_stderr.done():
280
+ try:
281
+ await mp.drain_stderr
282
+ except Exception:
283
+ pass
284
+ mp.info.exit_code = returncode
285
+ mp.info.signal = self._exit_signal(proc)
286
+ mp.info.ended_at = time.time()
287
+ mp.info.duration_ms = (time.monotonic() - mp.info.start_monotonic) * 1000
288
+ if mp.info.status == ProcessStatus.RUNNING:
289
+ mp.info.status = ProcessStatus.COMPLETED
290
+ except Exception:
291
+ if mp.info.status == ProcessStatus.RUNNING:
292
+ mp.info.status = ProcessStatus.FAILED
293
+ finally:
294
+ if mp.timeout_task is not None and not mp.timeout_task.done():
295
+ if not mp._timeout_triggered:
296
+ mp.timeout_task.cancel()
297
+ mp.completion_event.set()
298
+
299
+ def _exit_signal(self, proc: asyncio.subprocess.Process) -> str | None:
300
+ """Determine signal name from process returncode on POSIX."""
301
+ if proc.returncode is None:
302
+ return None
303
+ if sys.platform == "win32":
304
+ return None
305
+ # On POSIX, negative returncode means killed by signal
306
+ rc = proc.returncode
307
+ if rc < 0:
308
+ import signal as sig_module
309
+
310
+ sig_num = -rc
311
+ try:
312
+ return sig_module.Signals(sig_num).name
313
+ except (ValueError, KeyError):
314
+ return f"SIG{sig_num}"
315
+ return None
316
+
317
+ async def _handle_timeout(self, mp: ManagedProcess, timeout_sec: float) -> None:
318
+ """Handle total runtime timeout: graceful terminate then force kill."""
319
+ try:
320
+ await asyncio.sleep(timeout_sec)
321
+ except asyncio.CancelledError:
322
+ return
323
+ mp._timeout_triggered = True
324
+ if mp.info.status != ProcessStatus.RUNNING:
325
+ return
326
+ # Graceful termination
327
+ await terminate_process(mp.proc)
328
+ grace_period = 3.0
329
+ try:
330
+ await asyncio.wait_for(mp.completion_event.wait(), timeout=grace_period)
331
+ except asyncio.TimeoutError:
332
+ # Force kill
333
+ await kill_process(mp.proc)
334
+ try:
335
+ await asyncio.wait_for(mp.completion_event.wait(), timeout=2.0)
336
+ except asyncio.TimeoutError:
337
+ pass
338
+ if mp.info.status in (ProcessStatus.RUNNING, ProcessStatus.COMPLETED):
339
+ mp.info.status = ProcessStatus.TIMED_OUT
340
+ mp.info.ended_at = time.time()
341
+ mp.completion_event.set()
342
+
343
+ async def read_output(
344
+ self,
345
+ process_id: str,
346
+ since_seq: int | None = None,
347
+ max_output_bytes: int | None = None,
348
+ streams: str = "both",
349
+ ) -> dict[str, Any]:
350
+ """Read output from a managed process."""
351
+ mp = self._processes.get(process_id)
352
+ if mp is None:
353
+ return {"process_id": process_id, "error": f"Unknown process_id: {process_id}"}
354
+
355
+ if streams not in ("both", "stdout", "stderr"):
356
+ return {"process_id": process_id, "error": f"Invalid streams: {streams!r}"}
357
+
358
+ effective_max = self._max_output(max_output_bytes)
359
+
360
+ stdout_text = None
361
+ stderr_text = None
362
+ next_seq = 1
363
+ truncated = False
364
+
365
+ if streams in ("both", "stdout"):
366
+ data = mp.stdout_buf.read(since_seq=since_seq, max_bytes=effective_max)
367
+ stdout_text = redact_text(self._config, data["text"])
368
+ next_seq = max(next_seq, data["next_seq"])
369
+ truncated = truncated or data["truncated"]
370
+
371
+ if streams in ("both", "stderr"):
372
+ data = mp.stderr_buf.read(since_seq=since_seq, max_bytes=effective_max)
373
+ stderr_text = redact_text(self._config, data["text"])
374
+ next_seq = max(next_seq, data["next_seq"])
375
+ truncated = truncated or data["truncated"]
376
+
377
+ result: dict[str, Any] = {
378
+ "process_id": process_id,
379
+ "status": mp.info.status.value,
380
+ "exit_code": mp.info.exit_code,
381
+ "signal": mp.info.signal,
382
+ "next_seq": next_seq,
383
+ "truncated": truncated,
384
+ }
385
+ if stdout_text is not None:
386
+ result["stdout"] = stdout_text
387
+ if stderr_text is not None:
388
+ result["stderr"] = stderr_text
389
+ return result
390
+
391
+ async def write_input(
392
+ self, process_id: str, input_data: str, newline: bool = False
393
+ ) -> dict[str, Any]:
394
+ """Write to stdin of a managed process."""
395
+ mp = self._processes.get(process_id)
396
+ if mp is None:
397
+ return {
398
+ "process_id": process_id, "ok": False,
399
+ "error": f"Unknown process_id: {process_id}",
400
+ }
401
+
402
+ if mp.info.status != ProcessStatus.RUNNING:
403
+ return {
404
+ "process_id": process_id,
405
+ "ok": False,
406
+ "error": f"Process is not running (status: {mp.info.status.value})",
407
+ }
408
+
409
+ if mp.proc.stdin is None or mp.proc.stdin.is_closing():
410
+ return {
411
+ "process_id": process_id,
412
+ "ok": False,
413
+ "error": "Process stdin is closed",
414
+ }
415
+
416
+ try:
417
+ data = input_data.encode("utf-8")
418
+ if newline:
419
+ data += b"\n"
420
+ mp.proc.stdin.write(data)
421
+ await mp.proc.stdin.drain()
422
+ return {"process_id": process_id, "ok": True}
423
+ except Exception as exc:
424
+ return {"process_id": process_id, "ok": False, "error": str(exc)}
425
+
426
+ async def wait_process(
427
+ self,
428
+ process_id: str,
429
+ timeout_ms: int = 30000,
430
+ max_output_bytes: int | None = None,
431
+ ) -> dict[str, Any]:
432
+ """Wait for a process to exit without killing it."""
433
+ mp = self._processes.get(process_id)
434
+ if mp is None:
435
+ return {"process_id": process_id, "error": f"Unknown process_id: {process_id}"}
436
+
437
+ if mp.info.status != ProcessStatus.RUNNING:
438
+ # Already completed
439
+ effective_max = self._max_output(max_output_bytes)
440
+ stdout_data = mp.stdout_buf.read(max_bytes=effective_max)
441
+ stderr_data = mp.stderr_buf.read(max_bytes=effective_max)
442
+ truncated = stdout_data["truncated"] or stderr_data["truncated"]
443
+ return {
444
+ "process_id": process_id,
445
+ "status": mp.info.status.value,
446
+ "exit_code": mp.info.exit_code,
447
+ "signal": mp.info.signal,
448
+ "stdout": redact_text(self._config, stdout_data["text"]),
449
+ "stderr": redact_text(self._config, stderr_data["text"]),
450
+ "next_seq": max(stdout_data["next_seq"], stderr_data["next_seq"]),
451
+ "truncated": truncated,
452
+ }
453
+
454
+ # Wait up to timeout
455
+ try:
456
+ await asyncio.wait_for(
457
+ mp.completion_event.wait(), timeout=timeout_ms / 1000.0
458
+ )
459
+ except asyncio.TimeoutError:
460
+ pass
461
+
462
+ effective_max = self._max_output(max_output_bytes)
463
+ stdout_data = mp.stdout_buf.read(max_bytes=effective_max)
464
+ stderr_data = mp.stderr_buf.read(max_bytes=effective_max)
465
+ truncated = stdout_data["truncated"] or stderr_data["truncated"]
466
+
467
+ return {
468
+ "process_id": process_id,
469
+ "status": mp.info.status.value,
470
+ "exit_code": mp.info.exit_code,
471
+ "signal": mp.info.signal,
472
+ "stdout": redact_text(self._config, stdout_data["text"]),
473
+ "stderr": redact_text(self._config, stderr_data["text"]),
474
+ "next_seq": max(stdout_data["next_seq"], stderr_data["next_seq"]),
475
+ "truncated": truncated,
476
+ }
477
+
478
+ async def stop_process(
479
+ self,
480
+ process_id: str,
481
+ signal_name: str = "SIGTERM",
482
+ force_after_ms: int = 3000,
483
+ ) -> dict[str, Any]:
484
+ """Stop a running process with graceful termination then force kill."""
485
+ from .spawn import get_signal
486
+
487
+ mp = self._processes.get(process_id)
488
+ if mp is None:
489
+ return {
490
+ "process_id": process_id, "stopped": False,
491
+ "error": f"Unknown process_id: {process_id}",
492
+ }
493
+
494
+ if mp.info.status != ProcessStatus.RUNNING:
495
+ return {
496
+ "process_id": process_id,
497
+ "stopped": False,
498
+ "error": f"Process is not running (status: {mp.info.status.value})",
499
+ }
500
+
501
+ # Send requested signal
502
+ sig = get_signal(signal_name)
503
+ if sig is not None and sys.platform != "win32" and mp.proc.pid is not None:
504
+ try:
505
+ os.killpg(os.getpgid(mp.proc.pid), sig)
506
+ except (ProcessLookupError, PermissionError):
507
+ try:
508
+ mp.proc.send_signal(sig)
509
+ except Exception:
510
+ pass
511
+ else:
512
+ await terminate_process(mp.proc)
513
+
514
+ # Wait for grace period
515
+ try:
516
+ await asyncio.wait_for(
517
+ mp.completion_event.wait(), timeout=force_after_ms / 1000.0
518
+ )
519
+ except asyncio.TimeoutError:
520
+ # Force kill
521
+ await kill_process(mp.proc)
522
+ try:
523
+ await asyncio.wait_for(mp.completion_event.wait(), timeout=2.0)
524
+ except asyncio.TimeoutError:
525
+ pass
526
+
527
+ # If the process exited due to our signal, mark it as STOPPED.
528
+ # _track_completion may have set COMPLETED, but since we initiated
529
+ # termination, the correct terminal status is STOPPED.
530
+ if mp.info.status == ProcessStatus.RUNNING:
531
+ # Process didn't exit even after force kill
532
+ mp.info.status = ProcessStatus.STOPPED
533
+ mp.info.ended_at = time.time()
534
+ elif mp.info.status == ProcessStatus.COMPLETED:
535
+ mp.info.status = ProcessStatus.STOPPED
536
+
537
+ stopped = mp.info.status == ProcessStatus.STOPPED
538
+
539
+ return {
540
+ "process_id": process_id,
541
+ "stopped": stopped,
542
+ "signal": signal_name,
543
+ "error": None,
544
+ }
545
+
546
+ def list_processes(
547
+ self, include_completed: bool = True, limit: int = 50
548
+ ) -> dict[str, Any]:
549
+ """List managed processes."""
550
+ processes = []
551
+ for mp in list(self._processes.values()):
552
+ if not include_completed and mp.info.status in (
553
+ ProcessStatus.COMPLETED,
554
+ ProcessStatus.STOPPED,
555
+ ProcessStatus.TIMED_OUT,
556
+ ProcessStatus.FAILED,
557
+ ):
558
+ continue
559
+ processes.append(
560
+ {
561
+ "process_id": mp.info.process_id,
562
+ "pid": mp.info.pid,
563
+ "name": mp.info.name,
564
+ "command": mp.info.command,
565
+ "cwd": mp.info.cwd,
566
+ "status": mp.info.status.value,
567
+ "exit_code": mp.info.exit_code,
568
+ "signal": mp.info.signal,
569
+ "started_at": mp.info.started_at,
570
+ "ended_at": mp.info.ended_at,
571
+ "duration_ms": round(
572
+ mp.info.duration_ms
573
+ if mp.info.ended_at is not None
574
+ else (time.monotonic() - mp.info.start_monotonic) * 1000,
575
+ 1,
576
+ ),
577
+ "stdout_bytes": mp.stdout_buf.byte_count,
578
+ "stderr_bytes": mp.stderr_buf.byte_count,
579
+ }
580
+ )
581
+ processes.reverse() # Most recent first
582
+ return {"processes": processes[:limit]}
583
+
584
+ async def cleanup(
585
+ self,
586
+ completed_older_than_ms: int = 3600000,
587
+ stopped_older_than_ms: int = 3600000,
588
+ ) -> dict[str, Any]:
589
+ """Remove completed/stopped processes older than thresholds."""
590
+ now = time.time()
591
+ removed = 0
592
+ to_remove: list[str] = []
593
+
594
+ for pid, mp in self._processes.items():
595
+ if mp.info.status == ProcessStatus.RUNNING:
596
+ continue
597
+
598
+ age_ms = (now - (mp.info.ended_at or mp.info.started_at)) * 1000
599
+
600
+ if mp.info.status == ProcessStatus.COMPLETED and age_ms > completed_older_than_ms:
601
+ to_remove.append(pid)
602
+ elif mp.info.status in (
603
+ ProcessStatus.STOPPED, ProcessStatus.TIMED_OUT,
604
+ ProcessStatus.FAILED,
605
+ ):
606
+ if age_ms > stopped_older_than_ms:
607
+ to_remove.append(pid)
608
+
609
+ for pid in to_remove:
610
+ del self._processes[pid]
611
+ removed += 1
612
+
613
+ return {"removed": removed}
614
+
615
+ def get_process(self, process_id: str) -> ManagedProcess | None:
616
+ return self._processes.get(process_id)