wafer-core 0.1.37__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
  2. wafer_core/rollouts/_logging/__init__.py +5 -1
  3. wafer_core/rollouts/_logging/logging_config.py +95 -3
  4. wafer_core/rollouts/_logging/sample_handler.py +66 -0
  5. wafer_core/rollouts/_pytui/__init__.py +114 -0
  6. wafer_core/rollouts/_pytui/app.py +809 -0
  7. wafer_core/rollouts/_pytui/console.py +291 -0
  8. wafer_core/rollouts/_pytui/renderer.py +210 -0
  9. wafer_core/rollouts/_pytui/spinner.py +73 -0
  10. wafer_core/rollouts/_pytui/terminal.py +489 -0
  11. wafer_core/rollouts/_pytui/text.py +470 -0
  12. wafer_core/rollouts/_pytui/theme.py +241 -0
  13. wafer_core/rollouts/evaluation.py +142 -177
  14. wafer_core/rollouts/progress_app.py +395 -0
  15. wafer_core/rollouts/tui/DESIGN.md +251 -115
  16. wafer_core/rollouts/tui/monitor.py +64 -20
  17. wafer_core/tools/compile/__init__.py +30 -0
  18. wafer_core/tools/compile/compiler.py +314 -0
  19. wafer_core/tools/compile/modal_compile.py +359 -0
  20. wafer_core/tools/compile/tests/__init__.py +1 -0
  21. wafer_core/tools/compile/tests/test_compiler.py +675 -0
  22. wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
  23. wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
  24. wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
  25. wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
  26. wafer_core/tools/compile/types.py +117 -0
  27. {wafer_core-0.1.37.dist-info → wafer_core-0.1.39.dist-info}/METADATA +1 -1
  28. {wafer_core-0.1.37.dist-info → wafer_core-0.1.39.dist-info}/RECORD +29 -12
  29. wafer_core/rollouts/events.py +0 -240
  30. wafer_core/rollouts/progress_display.py +0 -476
  31. wafer_core/utils/event_streaming.py +0 -63
  32. {wafer_core-0.1.37.dist-info → wafer_core-0.1.39.dist-info}/WHEEL +0 -0
@@ -0,0 +1,395 @@
1
+ """Eval progress display built on pytui's Elm-style App.
2
+
3
+ Tails events.jsonl (written by evaluate() via logging), derives display state,
4
+ renders progress in alternate screen via a subprocess.
5
+
6
+ Usage:
7
+ with progress_display(output_dir=output_dir):
8
+ await evaluate(dataset, config)
9
+
10
+ Or detached mode (two terminals):
11
+ # Terminal 1: Run eval (writes events.jsonl via logging)
12
+ python run_eval.py --config config.py
13
+
14
+ # Terminal 2: Watch progress
15
+ python -m wafer_core.rollouts.progress_app /path/to/events.jsonl
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import os
22
+ import signal
23
+ import subprocess
24
+ import sys
25
+ import time
26
+ from collections.abc import Generator
27
+ from contextlib import contextmanager
28
+ from dataclasses import dataclass, field
29
+ from pathlib import Path
30
+ from typing import Any
31
+
32
+ # ── ANSI colors ──────────────────────────────────────────────────────────────
33
+
34
+ DIM = "\x1b[90m"
35
+ RESET = "\x1b[0m"
36
+ GREEN = "\x1b[32m"
37
+ RED = "\x1b[31m"
38
+ YELLOW = "\x1b[33m"
39
+ CYAN = "\x1b[36m"
40
+ BOLD = "\x1b[1m"
41
+
42
+ # Event types the progress display understands.
43
+ # The "message" field in the JSONL record is the event type discriminator.
44
+ # Records with unrecognised message values are silently skipped.
45
+ KNOWN_EVENT_TYPES = frozenset({
46
+ "eval_start",
47
+ "eval_end",
48
+ "sample_start",
49
+ "sample_end",
50
+ "turn",
51
+ "modal_progress",
52
+ "gepa_iteration",
53
+ "gepa_accepted",
54
+ "gepa_rejected",
55
+ "sample_retry",
56
+ })
57
+
58
+
59
+ # ── State ────────────────────────────────────────────────────────────────────
60
+
61
+
62
+ @dataclass
63
+ class SampleState:
64
+ """State of a single sample."""
65
+
66
+ id: str
67
+ name: str = ""
68
+ turn: int = 0
69
+ phase: str = ""
70
+ score: float | None = None
71
+ status: str = "started" # started, complete, retry
72
+ retry_attempt: int = 0
73
+ start_time: float = field(default_factory=time.time)
74
+ last_update: float = field(default_factory=time.time)
75
+
76
+
77
+ @dataclass
78
+ class RenderState:
79
+ """Display state derived from events."""
80
+
81
+ eval_name: str = ""
82
+ total: int = 0
83
+ completed: int = 0
84
+ samples: dict[str, SampleState] = field(default_factory=dict)
85
+ start_time: float = field(default_factory=time.time)
86
+ gepa_iter: int | None = None
87
+ gepa_total: int | None = None
88
+ gepa_best: float | None = None
89
+ scores: list[float] = field(default_factory=list)
90
+
91
+
92
+ # ── Pure functions ───────────────────────────────────────────────────────────
93
+
94
+
95
+ def _format_time(seconds: float) -> str:
96
+ """Format seconds as H:MM:SS or M:SS."""
97
+ h = int(seconds) // 3600
98
+ m = int(seconds) % 3600 // 60
99
+ s = int(seconds) % 60
100
+ if h:
101
+ return f"{h}:{m:02d}:{s:02d}"
102
+ return f"{m}:{s:02d}"
103
+
104
+
105
+ def _format_bar(progress: float, width: int) -> str:
106
+ """Render a progress bar with unicode blocks."""
107
+ if width <= 0:
108
+ return ""
109
+ filled = progress * width
110
+ full_blocks = int(filled)
111
+ partial_idx = int(8 * filled) % 8
112
+ partial = " ▏▎▍▌▋▊▉"[partial_idx] if partial_idx else ""
113
+ bar = "█" * full_blocks + partial
114
+ return bar.ljust(width)
115
+
116
+
117
+ def derive_state(events: list[dict[str, Any]]) -> RenderState:
118
+ """Derive current display state from event stream. Stateless."""
119
+ state = RenderState()
120
+
121
+ for event in events:
122
+ event_type = event.get("message")
123
+ if event_type is None or event_type not in KNOWN_EVENT_TYPES:
124
+ continue
125
+
126
+ sample_id = event.get("sample_id", "")
127
+
128
+ if event_type == "eval_start":
129
+ state.eval_name = event.get("eval_name", "eval")
130
+ state.total = event.get("total", 0)
131
+ if ts := event.get("timestamp"):
132
+ from datetime import datetime
133
+
134
+ state.start_time = datetime.fromisoformat(ts.replace("Z", "+00:00")).timestamp()
135
+ else:
136
+ state.start_time = time.time()
137
+
138
+ elif event_type == "sample_start":
139
+ state.samples[sample_id] = SampleState(
140
+ id=sample_id,
141
+ name=event.get("sample_name", sample_id),
142
+ last_update=time.time(),
143
+ phase="starting",
144
+ )
145
+
146
+ elif event_type == "turn":
147
+ if sample_id in state.samples:
148
+ if "turn" in event:
149
+ state.samples[sample_id].turn = event["turn"]
150
+ state.samples[sample_id].last_update = time.time()
151
+ status = event.get("status", "running")
152
+ state.samples[sample_id].phase = status
153
+
154
+ elif event_type == "modal_progress":
155
+ if sample_id in state.samples:
156
+ state.samples[sample_id].phase = event.get("phase", "")
157
+ state.samples[sample_id].last_update = time.time()
158
+
159
+ elif event_type == "sample_end":
160
+ if sample_id in state.samples:
161
+ state.samples[sample_id].status = "complete"
162
+ score = event.get("score")
163
+ if score is not None:
164
+ state.samples[sample_id].score = score
165
+ state.scores.append(score)
166
+ state.completed += 1
167
+
168
+ elif event_type == "sample_retry":
169
+ if sample_id in state.samples:
170
+ state.samples[sample_id].status = "retry"
171
+ state.samples[sample_id].retry_attempt = event.get("attempt", 1)
172
+
173
+ elif event_type == "gepa_iteration":
174
+ state.gepa_iter = event.get("iter")
175
+ state.gepa_total = event.get("total")
176
+ state.gepa_best = event.get("best")
177
+
178
+ elif event_type == "eval_end":
179
+ pass # Handled by update() -> Cmd.quit()
180
+
181
+ return state
182
+
183
+
184
+ def _format_sample_row(sample: SampleState, width: int) -> str:
185
+ """Format a single sample row."""
186
+ name = sample.name[:25].ljust(25)
187
+ turn_info = f"T:{sample.turn}"
188
+
189
+ if sample.status == "complete":
190
+ if sample.score is not None:
191
+ icon = f"{GREEN}✓{RESET}" if sample.score > 0.5 else f"{RED}✗{RESET}"
192
+ return f" {name} {turn_info:>4} {icon} score={sample.score:.2f}"
193
+ else:
194
+ return f" {name} {turn_info:>4} {GREEN}✓{RESET} done"
195
+ elif sample.status == "retry":
196
+ return f" {name} {turn_info:>4} {YELLOW}⟳{RESET} retry (attempt {sample.retry_attempt})"
197
+ else:
198
+ phase = sample.phase or "running"
199
+ return f" {name} {turn_info:>4} {CYAN}{phase}...{RESET}"
200
+
201
+
202
+ # ── Render function ──────────────────────────────────────────────────────────
203
+
204
+
205
+ def render(state: RenderState, width: int, height: int) -> list[str]:
206
+ """Render state to list of lines. Pure function."""
207
+ lines: list[str] = []
208
+
209
+ # Header
210
+ if state.gepa_iter is not None:
211
+ header = (
212
+ f"GEPA iter {state.gepa_iter}/{state.gepa_total or '?'} │ best: {state.gepa_best:.0%}"
213
+ if state.gepa_best
214
+ else f"GEPA iter {state.gepa_iter}/{state.gepa_total or '?'}"
215
+ )
216
+ else:
217
+ elapsed = time.time() - state.start_time
218
+ if state.total > 0:
219
+ progress = state.completed / state.total
220
+ bar_width = min(30, width - 50)
221
+ bar = _format_bar(progress, bar_width)
222
+ eta = ""
223
+ if state.completed > 0 and progress < 1.0:
224
+ remaining = (elapsed / progress) - elapsed
225
+ eta = f" <{_format_time(remaining)}"
226
+ header = (
227
+ f"{state.eval_name}: {BOLD}{state.completed}/{state.total}{RESET} "
228
+ f"|{bar}| {100 * progress:3.0f}% "
229
+ f"[{_format_time(elapsed)}{eta}]"
230
+ )
231
+ else:
232
+ header = f"{state.eval_name}: {state.completed} samples [{_format_time(elapsed)}]"
233
+
234
+ lines.append(header[:width])
235
+
236
+ # Active samples, sorted by most recently updated
237
+ active = [s for s in state.samples.values() if s.status != "complete" and s.phase]
238
+ active.sort(key=lambda s: s.last_update, reverse=True)
239
+
240
+ max_samples = max(1, height - 4)
241
+ for sample in active[:max_samples]:
242
+ lines.append(_format_sample_row(sample, width))
243
+
244
+ total_in_flight = sum(1 for s in state.samples.values() if s.status != "complete")
245
+ hidden = total_in_flight - len(active[:max_samples])
246
+ if hidden > 0:
247
+ lines.append(f"{DIM} ... and {hidden} more in flight{RESET}")
248
+
249
+ if state.scores:
250
+ mean_score = sum(state.scores) / len(state.scores)
251
+ lines.append(f"{DIM}score: {mean_score:.1%}{RESET}")
252
+
253
+ return lines
254
+
255
+
256
+ # ── Elm architecture (subprocess entry point) ────────────────────────────────
257
+
258
+
259
+ @dataclass(frozen=True)
260
+ class Model:
261
+ """Immutable model for the pytui App."""
262
+
263
+ events: tuple[dict[str, Any], ...] = ()
264
+ done: bool = False
265
+
266
+
267
+ @dataclass(frozen=True)
268
+ class NewEvent:
269
+ line: str
270
+
271
+
272
+ def _parse_event(line: str) -> dict[str, Any] | None:
273
+ try:
274
+ return json.loads(line)
275
+ except json.JSONDecodeError:
276
+ return None
277
+
278
+
279
+ def update(model: Model, msg: object) -> tuple[Model, Any]:
280
+ from ._pytui import Cmd, KeyPress
281
+
282
+ match msg:
283
+ case KeyPress(key="q" | "\x03"):
284
+ return model, Cmd.quit()
285
+ case NewEvent(line=line):
286
+ event = _parse_event(line)
287
+ if event is None:
288
+ return model, Cmd.none()
289
+ new_events = model.events + (event,)
290
+ if event.get("message") == "eval_end":
291
+ return Model(events=new_events, done=True), Cmd.quit()
292
+ return Model(events=new_events, done=model.done), Cmd.none()
293
+ return model, Cmd.none()
294
+
295
+
296
+ def view(model: Model, width: int, height: int) -> list[str]:
297
+ state = derive_state(list(model.events))
298
+ lines = render(state, width, height)
299
+
300
+ # Pad to fill screen, add footer
301
+ footer = f"{DIM}q: quit{RESET}"
302
+ while len(lines) < height - 1:
303
+ lines.append("")
304
+ lines.append(footer)
305
+ return lines[:height]
306
+
307
+
308
+ # ── Context manager (spawns subprocess) ──────────────────────────────────────
309
+
310
+
311
+ @contextmanager
312
+ def progress_display(
313
+ output_dir: Path | str,
314
+ disable: bool = False,
315
+ ) -> Generator[Path, None, None]:
316
+ """Context manager that spawns a pytui progress display in a subprocess.
317
+
318
+ The subprocess owns the terminal (alternate screen, raw mode via /dev/tty).
319
+ The parent process writes events.jsonl via logging as usual.
320
+
321
+ Args:
322
+ output_dir: Directory containing events.jsonl
323
+ disable: If True, skip progress display (verbose mode)
324
+ """
325
+ if disable:
326
+ yield Path(output_dir)
327
+ return
328
+
329
+ output_dir = Path(output_dir)
330
+ output_dir.mkdir(parents=True, exist_ok=True)
331
+ events_file = output_dir / "events.jsonl"
332
+
333
+ proc = subprocess.Popen(
334
+ [sys.executable, "-m", "wafer_core.rollouts.progress_app", str(events_file)],
335
+ # Don't touch stdout/stderr — child inherits the terminal.
336
+ # pytui opens /dev/tty directly for input, writes to stdout for rendering.
337
+ #
338
+ # start_new_session: put subprocess in its own session so it doesn't
339
+ # receive SIGINT from Ctrl+C. The parent gets SIGINT and terminates
340
+ # the subprocess in the finally block. Without this, pytui's raw mode
341
+ # swallows Ctrl+C (terminal driver doesn't generate SIGINT in raw mode)
342
+ # and the parent never sees it.
343
+ stdin=subprocess.DEVNULL,
344
+ start_new_session=True,
345
+ )
346
+
347
+ try:
348
+ yield output_dir
349
+ finally:
350
+ # Give the subprocess a chance to clean up its terminal state.
351
+ proc.send_signal(signal.SIGTERM)
352
+ try:
353
+ proc.wait(timeout=5.0)
354
+ except subprocess.TimeoutExpired:
355
+ proc.kill()
356
+ proc.wait()
357
+
358
+ # Defensive: restore terminal to sane state in case the subprocess
359
+ # didn't clean up (e.g. got SIGKILL, or crashed before restoring).
360
+ # stty sane resets raw mode, re-enables echo, etc.
361
+ tty_fd = os.open("/dev/tty", os.O_RDWR)
362
+ try:
363
+ subprocess.run(["stty", "sane"], stdin=tty_fd)
364
+ finally:
365
+ os.close(tty_fd)
366
+
367
+
368
+ # ── CLI entry point (subprocess runs this) ────────────────────────────────────
369
+
370
+ if __name__ == "__main__":
371
+ if len(sys.argv) != 2:
372
+ print("Usage: python -m wafer_core.rollouts.progress_app <events.jsonl>", file=sys.stderr)
373
+ sys.exit(1)
374
+
375
+ from ._pytui import App, Cmd, Sub
376
+
377
+ events_file = sys.argv[1]
378
+
379
+ # Model needs the path for subscriptions. We use a non-frozen wrapper
380
+ # to thread the path through — subscriptions() needs it but it's static.
381
+ _events_path = events_file
382
+
383
+ # Patch subscriptions to close over the path
384
+ def _subscriptions(model: Model) -> Sub:
385
+ if model.done:
386
+ return Sub.none()
387
+ return Sub.file_tail(_events_path, lambda line: NewEvent(line=line))
388
+
389
+ App(
390
+ init=(Model(), Cmd.none()),
391
+ update=update,
392
+ view=view,
393
+ subscriptions=_subscriptions,
394
+ alternate_screen=False,
395
+ ).run()