coding-cli-runtime 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,604 @@
1
+ """Shared async session execution and transcript mirroring helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ import os
8
+ import shutil
9
+ import signal
10
+ import time
11
+ from collections.abc import Awaitable, Callable, Sequence
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+
17
+ def _default_source_exists(source: Any) -> bool:
18
+ return isinstance(source, Path) and source.exists()
19
+
20
+
21
+ def _default_describe_source(source: Any) -> str:
22
+ if isinstance(source, Path):
23
+ return str(source)
24
+ return str(source)
25
+
26
+
27
+ def _default_copy_source(source: Any, destination: Path) -> None:
28
+ if not isinstance(source, Path):
29
+ raise TypeError(f"Unsupported transcript source type: {type(source)!r}")
30
+ shutil.copy2(source, destination)
31
+
32
+
33
+ def _default_source_identity(source: Any) -> Any:
34
+ return source
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class TranscriptMirrorStrategy:
39
+ provider_label: str
40
+ destination: Path
41
+ locate_source: Callable[[], object | None]
42
+ source_identity: Callable[[object], object] = _default_source_identity
43
+ source_exists: Callable[[object], bool] = _default_source_exists
44
+ describe_source: Callable[[object], str] = _default_describe_source
45
+ copy_source: Callable[[object, Path], None] = _default_copy_source
46
+ final_copy_source: Callable[[object, Path], None] | None = None
47
+ poll_interval: float = 1.0
48
+ refresh_source_each_poll: bool = False
49
+ progress_callback: Callable[[SessionProgressEvent], None] | None = None
50
+ activity_byte_threshold: int = 1024
51
+ idle_heartbeat_seconds: float = 15.0
52
+ final_retry_attempts: int = 1
53
+ final_retry_initial_delay_seconds: float = 0.0
54
+ final_retry_backoff_multiplier: float = 1.0
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class SessionRetryDecision:
59
+ retry: bool
60
+ delay_seconds: float = 0.0
61
+ category: str | None = None
62
+
63
+
64
+ @dataclass(frozen=True)
65
+ class InteractiveCliRunResult:
66
+ stdout: bytes
67
+ stderr: bytes
68
+ returncode: int | None
69
+ conversation_path: Path | None
70
+ cli_attempt_used: int
71
+ max_cli_attempts: int
72
+ started_at: float
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class SessionProgressEvent:
77
+ event_type: str
78
+ conversation_path: Path
79
+ source_description: str | None = None
80
+ activity_kind: str | None = None
81
+ mirrored_bytes: int | None = None
82
+ last_activity_at: float | None = None
83
+
84
+
85
+ class SessionExecutionTimeoutError(RuntimeError):
86
+ def __init__(
87
+ self,
88
+ message: str,
89
+ *,
90
+ conversation_path: Path | None = None,
91
+ started_at: float | None = None,
92
+ ) -> None:
93
+ super().__init__(message)
94
+ self.conversation_path = conversation_path
95
+ self.started_at = started_at
96
+
97
+
98
+ def _mark_isolated_process_group(proc: asyncio.subprocess.Process) -> None:
99
+ try:
100
+ proc._llm_eval_isolated_process_group = True # type: ignore[attr-defined]
101
+ except Exception:
102
+ pass
103
+
104
+
105
+ def _has_isolated_process_group(proc: asyncio.subprocess.Process) -> bool:
106
+ return bool(getattr(proc, "_llm_eval_isolated_process_group", False))
107
+
108
+
109
+ def _close_process_transport(
110
+ proc: asyncio.subprocess.Process | None,
111
+ *,
112
+ logger: logging.Logger | None = None,
113
+ label: str = "process",
114
+ ) -> None:
115
+ if proc is None:
116
+ return
117
+ transport = getattr(proc, "_transport", None)
118
+ if transport is None:
119
+ return
120
+ try:
121
+ transport.close()
122
+ except Exception as exc:
123
+ if logger:
124
+ logger.debug("Failed to close %s transport: %s", label, exc)
125
+
126
+
127
+ def _safe_file_size(path: Path) -> int | None:
128
+ try:
129
+ return path.stat().st_size
130
+ except OSError:
131
+ return None
132
+
133
+
134
+ def _emit_progress_event(
135
+ *,
136
+ callback: Callable[[SessionProgressEvent], None] | None,
137
+ event: SessionProgressEvent,
138
+ logger: logging.Logger,
139
+ job_name: str,
140
+ phase_tag: str,
141
+ ) -> None:
142
+ if callback is None:
143
+ return
144
+ try:
145
+ callback(event)
146
+ except Exception as exc:
147
+ logger.debug(
148
+ "[%s] %sprogress callback failed for %s: %s",
149
+ job_name,
150
+ phase_tag,
151
+ event.event_type,
152
+ exc,
153
+ )
154
+
155
+
156
+ async def mirror_session_transcript(
157
+ *,
158
+ job_name: str,
159
+ phase_tag: str,
160
+ logger: logging.Logger,
161
+ proc: asyncio.subprocess.Process,
162
+ strategy: TranscriptMirrorStrategy,
163
+ ) -> Path | None:
164
+ destination = strategy.destination
165
+ destination.parent.mkdir(parents=True, exist_ok=True)
166
+ try:
167
+ destination.unlink(missing_ok=True)
168
+ except FileNotFoundError:
169
+ pass
170
+ destination.touch()
171
+
172
+ source: object | None = None
173
+ source_key: object | None = None
174
+ announced = False
175
+ last_mirrored_bytes: int | None = None
176
+ last_activity_at = time.time()
177
+ last_heartbeat_emit_at = 0.0
178
+ activity_byte_threshold = max(1, int(strategy.activity_byte_threshold))
179
+ idle_heartbeat_seconds = max(0.0, float(strategy.idle_heartbeat_seconds))
180
+
181
+ def emit_progress(
182
+ event_type: str,
183
+ *,
184
+ activity_kind: str,
185
+ source_description: str | None = None,
186
+ mirrored_bytes: int | None = None,
187
+ last_activity_ts: float | None = None,
188
+ ) -> None:
189
+ _emit_progress_event(
190
+ callback=strategy.progress_callback,
191
+ event=SessionProgressEvent(
192
+ event_type=event_type,
193
+ conversation_path=destination,
194
+ source_description=source_description,
195
+ activity_kind=activity_kind,
196
+ mirrored_bytes=mirrored_bytes,
197
+ last_activity_at=last_activity_ts,
198
+ ),
199
+ logger=logger,
200
+ job_name=job_name,
201
+ phase_tag=phase_tag,
202
+ )
203
+
204
+ while True:
205
+ latest_source = None
206
+ latest_source_key = None
207
+ if source is None or strategy.refresh_source_each_poll:
208
+ latest_source = strategy.locate_source()
209
+ if latest_source is not None:
210
+ latest_source_key = strategy.source_identity(latest_source)
211
+ source_changed = latest_source is not None and latest_source_key != source_key
212
+ if latest_source is not None:
213
+ if source is not None:
214
+ if source_changed:
215
+ logger.info(
216
+ "[%s] %sswitching %s conversation source to %s",
217
+ job_name,
218
+ phase_tag,
219
+ strategy.provider_label,
220
+ strategy.describe_source(latest_source),
221
+ )
222
+ source = latest_source
223
+ source_key = latest_source_key
224
+ if source is not None and strategy.source_exists(source):
225
+ source_description = strategy.describe_source(source)
226
+ first_attachment = not announced
227
+ if first_attachment:
228
+ logger.info(
229
+ "[%s] %sstreaming %s conversation from %s",
230
+ job_name,
231
+ phase_tag,
232
+ strategy.provider_label,
233
+ source_description,
234
+ )
235
+ try:
236
+ strategy.copy_source(source, destination)
237
+ mirrored_bytes = _safe_file_size(destination)
238
+ now = time.time()
239
+ if first_attachment or source_changed:
240
+ attach_kind = "attached" if first_attachment else "source_switched"
241
+ emit_progress(
242
+ "transcript_attached",
243
+ activity_kind=attach_kind,
244
+ source_description=source_description,
245
+ mirrored_bytes=mirrored_bytes,
246
+ last_activity_ts=now,
247
+ )
248
+ last_heartbeat_emit_at = now
249
+ if mirrored_bytes is not None and mirrored_bytes != last_mirrored_bytes:
250
+ bytes_delta = abs(mirrored_bytes - (last_mirrored_bytes or 0))
251
+ last_mirrored_bytes = mirrored_bytes
252
+ last_activity_at = now
253
+ if bytes_delta >= activity_byte_threshold:
254
+ emit_progress(
255
+ "provider_call_heartbeat",
256
+ activity_kind="transcript_growth",
257
+ source_description=source_description,
258
+ mirrored_bytes=mirrored_bytes,
259
+ last_activity_ts=last_activity_at,
260
+ )
261
+ last_heartbeat_emit_at = now
262
+ if first_attachment:
263
+ announced = True
264
+ except OSError as exc:
265
+ logger.debug(
266
+ "[%s] %sfailed mirroring %s transcript: %s",
267
+ job_name,
268
+ phase_tag,
269
+ strategy.provider_label,
270
+ exc,
271
+ )
272
+ now = time.time()
273
+ if (
274
+ proc.returncode is None
275
+ and idle_heartbeat_seconds > 0
276
+ and strategy.progress_callback is not None
277
+ and now - last_activity_at >= idle_heartbeat_seconds
278
+ and now - last_heartbeat_emit_at >= idle_heartbeat_seconds
279
+ ):
280
+ emit_progress(
281
+ "provider_call_heartbeat",
282
+ activity_kind="idle",
283
+ source_description=(
284
+ strategy.describe_source(source)
285
+ if source is not None and strategy.source_exists(source)
286
+ else None
287
+ ),
288
+ mirrored_bytes=(
289
+ last_mirrored_bytes
290
+ if last_mirrored_bytes is not None
291
+ else _safe_file_size(destination)
292
+ ),
293
+ last_activity_ts=last_activity_at,
294
+ )
295
+ last_heartbeat_emit_at = now
296
+ if proc.returncode is not None:
297
+ break
298
+ await asyncio.sleep(strategy.poll_interval)
299
+
300
+ retry_attempts = max(1, int(strategy.final_retry_attempts))
301
+ retry_delay = max(0.0, float(strategy.final_retry_initial_delay_seconds))
302
+ retry_multiplier = max(1.0, float(strategy.final_retry_backoff_multiplier))
303
+ for attempt_index in range(retry_attempts):
304
+ if attempt_index > 0 and retry_delay > 0:
305
+ await asyncio.sleep(retry_delay)
306
+ retry_delay *= retry_multiplier
307
+ latest_source = strategy.locate_source()
308
+ latest_source_key = (
309
+ strategy.source_identity(latest_source) if latest_source is not None else None
310
+ )
311
+ final_source_changed = latest_source is not None and latest_source_key != source_key
312
+ if latest_source is not None:
313
+ if source is not None:
314
+ if final_source_changed:
315
+ logger.info(
316
+ "[%s] %sswitching %s conversation source to %s",
317
+ job_name,
318
+ phase_tag,
319
+ strategy.provider_label,
320
+ strategy.describe_source(latest_source),
321
+ )
322
+ source = latest_source
323
+ source_key = latest_source_key
324
+ else:
325
+ source = source or latest_source
326
+ if source is not None and strategy.source_exists(source):
327
+ try:
328
+ final_copy = strategy.final_copy_source or strategy.copy_source
329
+ final_copy(source, destination)
330
+ final_size = _safe_file_size(destination)
331
+ if not announced or final_source_changed:
332
+ emit_progress(
333
+ "transcript_attached",
334
+ activity_kind="source_switched" if final_source_changed else "attached",
335
+ source_description=strategy.describe_source(source),
336
+ mirrored_bytes=final_size,
337
+ last_activity_ts=time.time(),
338
+ )
339
+ return destination
340
+ except OSError as exc:
341
+ logger.warning(
342
+ "[%s] %sfailed final copy for %s transcript: %s",
343
+ job_name,
344
+ phase_tag,
345
+ strategy.provider_label,
346
+ exc,
347
+ )
348
+ else:
349
+ logger.debug(
350
+ "[%s] %s%s transcript not found",
351
+ job_name,
352
+ phase_tag,
353
+ strategy.provider_label,
354
+ )
355
+ return destination if destination.exists() else None
356
+
357
+
358
+ async def _terminate_process(
359
+ proc: asyncio.subprocess.Process,
360
+ *,
361
+ logger: logging.Logger | None = None,
362
+ label: str = "process",
363
+ timeout: float = 5.0,
364
+ ) -> None:
365
+ if proc.returncode is not None:
366
+ _close_process_transport(proc, logger=logger, label=label)
367
+ return
368
+ isolated_group = (
369
+ os.name != "nt"
370
+ and _has_isolated_process_group(proc)
371
+ and isinstance(getattr(proc, "pid", None), int)
372
+ and int(proc.pid) > 0
373
+ )
374
+ try:
375
+ if isolated_group:
376
+ os.killpg(int(proc.pid), signal.SIGTERM)
377
+ else:
378
+ proc.terminate()
379
+ except ProcessLookupError:
380
+ _close_process_transport(proc, logger=logger, label=label)
381
+ return
382
+ try:
383
+ await asyncio.wait_for(proc.wait(), timeout=timeout)
384
+ _close_process_transport(proc, logger=logger, label=label)
385
+ return
386
+ except TimeoutError:
387
+ pass
388
+ try:
389
+ if isolated_group:
390
+ os.killpg(int(proc.pid), signal.SIGKILL)
391
+ else:
392
+ proc.kill()
393
+ except ProcessLookupError:
394
+ _close_process_transport(proc, logger=logger, label=label)
395
+ return
396
+ try:
397
+ await asyncio.wait_for(proc.wait(), timeout=timeout)
398
+ except TimeoutError:
399
+ if logger:
400
+ logger.warning("Timed out while killing %s", label)
401
+ finally:
402
+ _close_process_transport(proc, logger=logger, label=label)
403
+
404
+
405
+ async def _finalize_transcript_task(
406
+ *,
407
+ transcript_task: asyncio.Task[Path | None] | None,
408
+ logger: logging.Logger,
409
+ job_name: str,
410
+ phase_tag: str,
411
+ ) -> Path | None:
412
+ if transcript_task is None:
413
+ return None
414
+ try:
415
+ return await asyncio.wait_for(transcript_task, timeout=5)
416
+ except TimeoutError:
417
+ transcript_task.cancel()
418
+ try:
419
+ await transcript_task
420
+ except asyncio.CancelledError:
421
+ logger.debug(
422
+ "[%s] %sconversation log capture canceled to avoid blocking",
423
+ job_name,
424
+ phase_tag,
425
+ )
426
+ except Exception as exc:
427
+ logger.debug(
428
+ "[%s] %sconversation log capture canceled with error: %s",
429
+ job_name,
430
+ phase_tag,
431
+ exc,
432
+ )
433
+ except Exception as exc:
434
+ logger.debug(
435
+ "[%s] %sconversation log capture failed: %s",
436
+ job_name,
437
+ phase_tag,
438
+ exc,
439
+ )
440
+ return None
441
+
442
+
443
+ async def _cancel_transcript_task(
444
+ *,
445
+ transcript_task: asyncio.Task[Path | None] | None,
446
+ logger: logging.Logger,
447
+ job_name: str,
448
+ phase_tag: str,
449
+ reason: str,
450
+ ) -> None:
451
+ if transcript_task is None or transcript_task.done():
452
+ return
453
+ transcript_task.cancel()
454
+ try:
455
+ await transcript_task
456
+ except asyncio.CancelledError:
457
+ logger.debug(
458
+ "[%s] %sconversation log capture canceled after %s",
459
+ job_name,
460
+ phase_tag,
461
+ reason,
462
+ )
463
+ except Exception as exc:
464
+ logger.debug(
465
+ "[%s] %sconversation log cancel failed: %s",
466
+ job_name,
467
+ phase_tag,
468
+ exc,
469
+ )
470
+
471
+
472
+ async def run_interactive_session(
473
+ *,
474
+ cmd_parts: Sequence[str],
475
+ cwd: Path,
476
+ stdin_text: str | None,
477
+ logger: logging.Logger,
478
+ provider_label: str = "cli",
479
+ job_name: str = "session",
480
+ phase_tag: str = "",
481
+ process_label: str = "process",
482
+ timeout_seconds: float | None = None,
483
+ env: dict[str, str] | None = None,
484
+ create_process: Callable[..., Awaitable[asyncio.subprocess.Process]] | None = None,
485
+ transcript_factory: Callable[[asyncio.subprocess.Process], Awaitable[Path | None]]
486
+ | None = None,
487
+ max_cli_attempts: int = 1,
488
+ retry_decider: Callable[[InteractiveCliRunResult, int, int], SessionRetryDecision | bool]
489
+ | None = None,
490
+ ) -> InteractiveCliRunResult:
491
+ process_factory = create_process or asyncio.create_subprocess_exec
492
+ stdin_payload = stdin_text.encode("utf-8") if stdin_text is not None else None
493
+ last_result: InteractiveCliRunResult | None = None
494
+
495
+ for cli_attempt in range(1, max(1, max_cli_attempts) + 1):
496
+ started_at = time.time()
497
+ creation_kwargs = {
498
+ "stdout": asyncio.subprocess.PIPE,
499
+ "stderr": asyncio.subprocess.PIPE,
500
+ "stdin": (
501
+ asyncio.subprocess.PIPE if stdin_text is not None else asyncio.subprocess.DEVNULL
502
+ ),
503
+ "cwd": str(cwd),
504
+ "env": env,
505
+ }
506
+ if os.name != "nt":
507
+ creation_kwargs["start_new_session"] = True
508
+ proc = await process_factory(
509
+ *cmd_parts,
510
+ **creation_kwargs, # type: ignore[arg-type]
511
+ )
512
+ if os.name != "nt":
513
+ _mark_isolated_process_group(proc)
514
+ transcript_task: asyncio.Task[Path | None] | None = (
515
+ asyncio.create_task(transcript_factory(proc)) # type: ignore[arg-type]
516
+ if transcript_factory is not None
517
+ else None
518
+ )
519
+ try:
520
+ if timeout_seconds and timeout_seconds > 0:
521
+ stdout, stderr = await asyncio.wait_for(
522
+ proc.communicate(input=stdin_payload),
523
+ timeout=timeout_seconds,
524
+ )
525
+ else:
526
+ stdout, stderr = await proc.communicate(input=stdin_payload)
527
+ except asyncio.CancelledError as exc:
528
+ await _terminate_process(proc, logger=logger, label=process_label)
529
+ conversation_path = await _finalize_transcript_task(
530
+ transcript_task=transcript_task,
531
+ logger=logger,
532
+ job_name=job_name,
533
+ phase_tag=phase_tag,
534
+ )
535
+ try:
536
+ exc.conversation_path = conversation_path # type: ignore[attr-defined]
537
+ exc.started_at = started_at # type: ignore[attr-defined]
538
+ except Exception:
539
+ pass
540
+ raise
541
+ except TimeoutError:
542
+ logger.error(
543
+ "[%s] %s%s run timed out after %ss",
544
+ job_name,
545
+ phase_tag,
546
+ provider_label,
547
+ timeout_seconds,
548
+ )
549
+ await _terminate_process(proc, logger=logger, label=process_label)
550
+ conversation_path = await _finalize_transcript_task(
551
+ transcript_task=transcript_task,
552
+ logger=logger,
553
+ job_name=job_name,
554
+ phase_tag=phase_tag,
555
+ )
556
+ raise SessionExecutionTimeoutError(
557
+ f"{provider_label.lower()} run timed out after {timeout_seconds}s",
558
+ conversation_path=conversation_path,
559
+ started_at=started_at,
560
+ ) from None
561
+
562
+ conversation_path = await _finalize_transcript_task(
563
+ transcript_task=transcript_task,
564
+ logger=logger,
565
+ job_name=job_name,
566
+ phase_tag=phase_tag,
567
+ )
568
+ result = InteractiveCliRunResult(
569
+ stdout=stdout,
570
+ stderr=stderr,
571
+ returncode=proc.returncode,
572
+ conversation_path=conversation_path,
573
+ cli_attempt_used=cli_attempt,
574
+ max_cli_attempts=max(1, max_cli_attempts),
575
+ started_at=started_at,
576
+ )
577
+ last_result = result
578
+ if proc.returncode == 0:
579
+ return result
580
+
581
+ if retry_decider is None or cli_attempt >= max(1, max_cli_attempts):
582
+ return result
583
+
584
+ decision = retry_decider(result, cli_attempt, max(1, max_cli_attempts))
585
+ if isinstance(decision, bool):
586
+ decision = SessionRetryDecision(retry=decision)
587
+ if not decision.retry:
588
+ return result
589
+ logger.warning(
590
+ "[%s] %s%s retryable failure (%s, %s/%s); retrying in %.1fs",
591
+ job_name,
592
+ phase_tag,
593
+ provider_label,
594
+ decision.category or "retryable",
595
+ cli_attempt,
596
+ max(1, max_cli_attempts),
597
+ decision.delay_seconds,
598
+ )
599
+ if decision.delay_seconds > 0:
600
+ await asyncio.sleep(decision.delay_seconds)
601
+
602
+ if last_result is None:
603
+ raise RuntimeError(f"{provider_label.lower()} run failed before launch")
604
+ return last_result