wafer-core 0.1.38__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. wafer_core/lib/trace_compare/fusion_analyzer.py +2 -0
  2. wafer_core/rollouts/_logging/__init__.py +5 -1
  3. wafer_core/rollouts/_logging/logging_config.py +95 -3
  4. wafer_core/rollouts/_logging/sample_handler.py +66 -0
  5. wafer_core/rollouts/_pytui/__init__.py +114 -0
  6. wafer_core/rollouts/_pytui/app.py +809 -0
  7. wafer_core/rollouts/_pytui/console.py +291 -0
  8. wafer_core/rollouts/_pytui/renderer.py +210 -0
  9. wafer_core/rollouts/_pytui/spinner.py +73 -0
  10. wafer_core/rollouts/_pytui/terminal.py +489 -0
  11. wafer_core/rollouts/_pytui/text.py +470 -0
  12. wafer_core/rollouts/_pytui/theme.py +241 -0
  13. wafer_core/rollouts/evaluation.py +142 -177
  14. wafer_core/rollouts/progress_app.py +395 -0
  15. wafer_core/rollouts/tui/DESIGN.md +251 -115
  16. wafer_core/rollouts/tui/monitor.py +64 -20
  17. wafer_core/tools/compile/__init__.py +30 -0
  18. wafer_core/tools/compile/compiler.py +314 -0
  19. wafer_core/tools/compile/modal_compile.py +359 -0
  20. wafer_core/tools/compile/tests/__init__.py +1 -0
  21. wafer_core/tools/compile/tests/test_compiler.py +675 -0
  22. wafer_core/tools/compile/tests/test_data/utils.cuh +10 -0
  23. wafer_core/tools/compile/tests/test_data/vector_add.cu +7 -0
  24. wafer_core/tools/compile/tests/test_data/with_header.cu +9 -0
  25. wafer_core/tools/compile/tests/test_modal_integration.py +326 -0
  26. wafer_core/tools/compile/types.py +117 -0
  27. {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/METADATA +1 -1
  28. {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/RECORD +29 -12
  29. wafer_core/rollouts/events.py +0 -240
  30. wafer_core/rollouts/progress_display.py +0 -476
  31. wafer_core/utils/event_streaming.py +0 -63
  32. {wafer_core-0.1.38.dist-info → wafer_core-0.1.39.dist-info}/WHEEL +0 -0
@@ -1,476 +0,0 @@
1
- """Clean progress display using alternate screen and JSONL event stream.
2
-
3
- Replaces MultiProgress with a cleaner approach:
4
- - Alternate screen (no scrollback pollution)
5
- - Stateless renderer (derives state from events.jsonl file)
6
- - Works with existing EventEmitter infrastructure
7
-
8
- Usage:
9
- with progress_display(output_dir=output_dir):
10
- await evaluate(dataset, config)
11
-
12
- Or detached mode (two terminals):
13
- # Terminal 1: Run eval (writes to events.jsonl)
14
- python eval_script.py --output-dir ./results
15
-
16
- # Terminal 2: Watch progress
17
- python -m rollouts.progress_watch ./results/events.jsonl
18
-
19
- Note: Events are written by EventEmitter (rollouts/events.py), not Python logging.
20
- The progress display just reads the events.jsonl file that evaluate() already produces.
21
- """
22
-
23
- from __future__ import annotations
24
-
25
- import atexit
26
- import json
27
- import os
28
- import signal
29
- import sys
30
- import threading
31
- import time
32
- from collections.abc import Generator
33
- from contextlib import contextmanager
34
- from dataclasses import dataclass, field
35
- from pathlib import Path
36
- from types import FrameType
37
- from typing import TextIO
38
-
39
- # ANSI escape codes
40
- HIDE_CURSOR = "\x1b[?25l"
41
- SHOW_CURSOR = "\x1b[?25h"
42
- CLEAR_LINE = "\x1b[2K"
43
- CURSOR_UP = "\x1b[{n}A"
44
- # Synchronized output (reduces flicker in supported terminals)
45
- SYNC_START = "\x1b[?2026h"
46
- SYNC_END = "\x1b[?2026l"
47
-
48
- # Colors
49
- DIM = "\x1b[90m"
50
- RESET = "\x1b[0m"
51
- GREEN = "\x1b[32m"
52
- RED = "\x1b[31m"
53
- YELLOW = "\x1b[33m"
54
- CYAN = "\x1b[36m"
55
- BOLD = "\x1b[1m"
56
-
57
- # Known event types for validation
58
- KNOWN_EVENT_TYPES = frozenset({
59
- "eval_start",
60
- "eval_end",
61
- "sample_start",
62
- "sample_end",
63
- "turn",
64
- "modal_progress",
65
- "gepa_iteration",
66
- "gepa_accepted",
67
- "gepa_rejected",
68
- "sample_retry",
69
- })
70
-
71
-
72
- def _format_time(seconds: float) -> str:
73
- """Format seconds as H:MM:SS or M:SS."""
74
- h = int(seconds) // 3600
75
- m = int(seconds) % 3600 // 60
76
- s = int(seconds) % 60
77
- if h:
78
- return f"{h}:{m:02d}:{s:02d}"
79
- return f"{m}:{s:02d}"
80
-
81
-
82
- def _format_bar(progress: float, width: int) -> str:
83
- """Render a progress bar with unicode blocks."""
84
- if width <= 0:
85
- return ""
86
- filled = progress * width
87
- full_blocks = int(filled)
88
- partial_idx = int(8 * filled) % 8
89
- partial = " ▏▎▍▌▋▊▉"[partial_idx] if partial_idx else ""
90
- bar = "█" * full_blocks + partial
91
- return bar.ljust(width)
92
-
93
-
94
- @dataclass
95
- class SampleState:
96
- """State of a single sample."""
97
-
98
- id: str
99
- name: str = ""
100
- turn: int = 0
101
- phase: str = "" # streaming, compiling, checking, etc.
102
- score: float | None = None
103
- status: str = "started" # started, complete, retry
104
- retry_attempt: int = 0
105
- start_time: float = field(default_factory=time.time)
106
- last_update: float = field(default_factory=time.time) # For sorting by recency
107
-
108
-
109
- @dataclass
110
- class RenderState:
111
- """Display state derived from events."""
112
-
113
- eval_name: str = ""
114
- total: int = 0
115
- completed: int = 0
116
- samples: dict[str, SampleState] = field(default_factory=dict)
117
- start_time: float = field(default_factory=time.time)
118
-
119
- # GEPA state
120
- gepa_iter: int | None = None
121
- gepa_total: int | None = None
122
- gepa_best: float | None = None
123
-
124
- # Histogram data (scores per metric)
125
- scores: list[float] = field(default_factory=list)
126
-
127
-
128
- def derive_state(events: list[dict]) -> RenderState:
129
- """Derive current display state from event stream. Stateless."""
130
- state = RenderState()
131
-
132
- for event in events:
133
- event_type = event.get("type")
134
- if event_type is None:
135
- continue
136
-
137
- # Strict validation: crash on unknown event types
138
- if event_type not in KNOWN_EVENT_TYPES:
139
- raise ValueError(f"Unknown event type: {event_type}. Known types: {KNOWN_EVENT_TYPES}")
140
-
141
- if event_type == "eval_start":
142
- state.eval_name = event.get("name", "eval")
143
- state.total = event.get("total", 0)
144
- # Use event timestamp if available, otherwise fall back to now
145
- if ts := event.get("timestamp"):
146
- from datetime import datetime
147
-
148
- state.start_time = datetime.fromisoformat(ts.replace("Z", "+00:00")).timestamp()
149
- else:
150
- state.start_time = time.time()
151
-
152
- elif event_type == "sample_start":
153
- sample_id = event["id"]
154
- now = time.time()
155
- state.samples[sample_id] = SampleState(
156
- id=sample_id,
157
- name=event.get("name", sample_id),
158
- last_update=now,
159
- phase="starting", # Default phase so samples show immediately
160
- )
161
-
162
- elif event_type == "turn":
163
- sample_id = event["id"]
164
- if sample_id in state.samples:
165
- # Only update turn if explicitly provided (status updates don't include turn)
166
- if "turn" in event:
167
- state.samples[sample_id].turn = event["turn"]
168
- state.samples[sample_id].last_update = time.time()
169
- # Set phase from turn status if provided, or default to "running"
170
- # This ensures samples show up even without modal_progress events
171
- status = event.get("status", "running")
172
- state.samples[sample_id].phase = status
173
-
174
- elif event_type == "modal_progress":
175
- sample_id = event["id"]
176
- if sample_id in state.samples:
177
- state.samples[sample_id].phase = event.get("phase", "")
178
- state.samples[sample_id].last_update = time.time()
179
-
180
- elif event_type == "sample_end":
181
- sample_id = event["id"]
182
- if sample_id in state.samples:
183
- state.samples[sample_id].status = "complete"
184
- score = event.get("score")
185
- if score is not None:
186
- state.samples[sample_id].score = score
187
- state.scores.append(score)
188
- state.completed += 1
189
-
190
- elif event_type == "sample_retry":
191
- sample_id = event["id"]
192
- if sample_id in state.samples:
193
- state.samples[sample_id].status = "retry"
194
- state.samples[sample_id].retry_attempt = event.get("attempt", 1)
195
-
196
- elif event_type == "gepa_iteration":
197
- state.gepa_iter = event.get("iter")
198
- state.gepa_total = event.get("total")
199
- state.gepa_best = event.get("best")
200
-
201
- elif event_type == "eval_end":
202
- pass # Will trigger exit in render loop
203
-
204
- return state
205
-
206
-
207
- def render(state: RenderState, width: int, height: int) -> list[str]:
208
- """Render state to list of lines. Pure function."""
209
- lines = []
210
-
211
- # Header with GEPA or eval progress
212
- if state.gepa_iter is not None:
213
- header = (
214
- f"GEPA iter {state.gepa_iter}/{state.gepa_total or '?'} │ best: {state.gepa_best:.0%}"
215
- if state.gepa_best
216
- else f"GEPA iter {state.gepa_iter}/{state.gepa_total or '?'}"
217
- )
218
- else:
219
- elapsed = time.time() - state.start_time
220
- if state.total > 0:
221
- progress = state.completed / state.total
222
- bar_width = min(30, width - 50)
223
- bar = _format_bar(progress, bar_width)
224
- eta = ""
225
- if state.completed > 0 and progress < 1.0:
226
- remaining = (elapsed / progress) - elapsed
227
- eta = f" <{_format_time(remaining)}"
228
- header = (
229
- f"{state.eval_name}: {BOLD}{state.completed}/{state.total}{RESET} "
230
- f"|{bar}| {100 * progress:3.0f}% "
231
- f"[{_format_time(elapsed)}{eta}]"
232
- )
233
- else:
234
- header = f"{state.eval_name}: {state.completed} samples [{_format_time(elapsed)}]"
235
-
236
- lines.append(header[:width])
237
-
238
- # Sample list - only show active (have a phase), sorted by most recently updated
239
- # This matches uv's behavior: most active items float to the top
240
- active = [s for s in state.samples.values() if s.status != "complete" and s.phase]
241
- active.sort(key=lambda s: s.last_update, reverse=True)
242
-
243
- # Reserve lines: 1 header + 1 for "... and X more" + 1 for score summary
244
- max_samples = max(1, height - 4)
245
-
246
- for sample in active[:max_samples]:
247
- line = _format_sample_row(sample, width)
248
- lines.append(line)
249
-
250
- # Show how many more are in flight (including non-active)
251
- total_in_flight = sum(1 for s in state.samples.values() if s.status != "complete")
252
- hidden = total_in_flight - len(active[:max_samples])
253
- if hidden > 0:
254
- lines.append(f"{DIM} ... and {hidden} more in flight{RESET}")
255
-
256
- # Score summary (compact)
257
- if state.scores:
258
- mean_score = sum(state.scores) / len(state.scores)
259
- lines.append(f"{DIM}score: {mean_score:.1%}{RESET}")
260
-
261
- return lines
262
-
263
-
264
- def _format_sample_row(sample: SampleState, width: int) -> str:
265
- """Format a single sample row."""
266
- name = sample.name[:25].ljust(25)
267
- turn_info = f"T:{sample.turn}"
268
-
269
- if sample.status == "complete":
270
- if sample.score is not None:
271
- icon = f"{GREEN}✓{RESET}" if sample.score > 0.5 else f"{RED}✗{RESET}"
272
- return f" {name} {turn_info:>4} {icon} score={sample.score:.2f}"
273
- else:
274
- return f" {name} {turn_info:>4} {GREEN}✓{RESET} done"
275
- elif sample.status == "retry":
276
- return f" {name} {turn_info:>4} {YELLOW}⟳{RESET} retry (attempt {sample.retry_attempt})"
277
- else:
278
- phase = sample.phase or "running"
279
- return f" {name} {turn_info:>4} {CYAN}{phase}...{RESET}"
280
-
281
-
282
- class ProgressDisplay:
283
- """In-place progress display that reads from JSONL file.
284
-
285
- Renders progress by overwriting lines in place (no alternate screen).
286
- Similar to how MultiProgress and the rollouts chat CLI work.
287
- """
288
-
289
- def __init__(
290
- self,
291
- events_file: Path,
292
- poll_interval: float = 0.2,
293
- output_stream: TextIO | None = None,
294
- ) -> None:
295
- self.events_file = events_file
296
- self.poll_interval = poll_interval
297
- self._output: TextIO = output_stream or sys.stdout
298
- self._stop_event = threading.Event()
299
- self._file_pos = 0
300
- self._events: list[dict] = []
301
- self._lines_rendered = 0
302
- self._old_sigwinch = None
303
-
304
- def start(self) -> None:
305
- """Start rendering (hide cursor)."""
306
- self._output.write(HIDE_CURSOR)
307
- self._output.flush()
308
-
309
- # Install resize handler
310
- self._old_sigwinch = signal.signal(signal.SIGWINCH, self._handle_resize)
311
-
312
- # Register atexit for cleanup
313
- atexit.register(self._cleanup)
314
-
315
- def stop(self) -> None:
316
- """Stop rendering and restore terminal."""
317
- self._stop_event.set()
318
- self._cleanup()
319
-
320
- def _cleanup(self) -> None:
321
- """Restore terminal state."""
322
- # Show cursor
323
- self._output.write(SHOW_CURSOR)
324
- self._output.flush()
325
-
326
- # Restore signal handlers
327
- if self._old_sigwinch is not None:
328
- signal.signal(signal.SIGWINCH, self._old_sigwinch)
329
-
330
- try:
331
- atexit.unregister(self._cleanup)
332
- except Exception:
333
- pass
334
-
335
- def _handle_resize(self, signum: int, frame: FrameType | None) -> None:
336
- """Handle terminal resize."""
337
- self._render()
338
-
339
- def run(self) -> None:
340
- """Main render loop. Polls file, derives state, renders."""
341
- while not self._stop_event.is_set():
342
- self._poll_events()
343
- self._render()
344
-
345
- # Check for eval_end event
346
- for event in self._events:
347
- if event.get("type") == "eval_end":
348
- # Final render then exit
349
- self._render()
350
- return
351
-
352
- time.sleep(self.poll_interval)
353
-
354
- def _poll_events(self) -> None:
355
- """Read new events from file."""
356
- if not self.events_file.exists():
357
- return
358
-
359
- try:
360
- with open(self.events_file) as f:
361
- f.seek(self._file_pos)
362
- for line in f:
363
- line = line.strip()
364
- if line:
365
- try:
366
- event = json.loads(line)
367
- self._events.append(event)
368
- except json.JSONDecodeError:
369
- pass # Skip malformed lines
370
- self._file_pos = f.tell()
371
- except OSError:
372
- pass # File might be locked
373
-
374
- def _render(self) -> None:
375
- """Render current state by overwriting previous lines."""
376
- try:
377
- size = os.get_terminal_size()
378
- width, height = size.columns, size.lines
379
- except OSError:
380
- width, height = 80, 24
381
-
382
- state = derive_state(self._events)
383
- lines = render(state, width, height)
384
-
385
- # Fixed display height - always use same number of lines to avoid jumpiness
386
- display_height = min(12, height - 2) # Cap at 12 lines
387
-
388
- # Pad or truncate to fixed height
389
- while len(lines) < display_height:
390
- lines.append("")
391
- lines = lines[:display_height]
392
-
393
- # Build complete output buffer with synchronized output
394
- buf = [SYNC_START] # Begin synchronized update
395
-
396
- # Move cursor up to overwrite previous render
397
- if self._lines_rendered > 0:
398
- buf.append(CURSOR_UP.format(n=self._lines_rendered))
399
-
400
- # Write all lines (fixed count)
401
- for line in lines:
402
- buf.append(CLEAR_LINE + line + "\n")
403
-
404
- buf.append(SYNC_END) # End synchronized update
405
-
406
- # Single write + flush
407
- self._output.write("".join(buf))
408
- self._output.flush()
409
- self._lines_rendered = len(lines)
410
-
411
-
412
- @contextmanager
413
- def progress_display(
414
- output_dir: Path | str,
415
- disable: bool = False,
416
- suppress_output: bool = True,
417
- ) -> Generator[Path, None, None]:
418
- """Context manager for clean progress display.
419
-
420
- Reads events.jsonl written by EventEmitter (from evaluate()) and
421
- renders a progress display in alternate screen.
422
-
423
- Args:
424
- output_dir: Directory containing events.jsonl (required - evaluate() writes here)
425
- disable: If True, skip progress display (verbose mode)
426
- suppress_output: If True, redirect stdout/stderr to log file to prevent display glitches
427
-
428
- Usage:
429
- # The output_dir must match what you pass to EvalConfig
430
- with progress_display(output_dir=config.output_dir):
431
- await evaluate(dataset, config)
432
- """
433
- if disable:
434
- yield Path(output_dir)
435
- return
436
-
437
- output_dir = Path(output_dir)
438
- output_dir.mkdir(parents=True, exist_ok=True)
439
- events_file = output_dir / "events.jsonl"
440
-
441
- # Save original stdout/stderr for display rendering
442
- original_stdout = sys.stdout
443
- original_stderr = sys.stderr
444
- log_file = None
445
-
446
- if suppress_output:
447
- # Redirect stdout/stderr to a log file to prevent display glitches
448
- # The progress display will write directly to original_stdout
449
- log_file = open(output_dir / "output.log", "w")
450
- sys.stdout = log_file
451
- sys.stderr = log_file
452
-
453
- # Create and start display (uses original_stdout for rendering)
454
- display = ProgressDisplay(events_file, output_stream=original_stdout)
455
- display.start()
456
-
457
- # Run display in background thread
458
- display_thread = threading.Thread(target=display.run, daemon=True)
459
- display_thread.start()
460
-
461
- try:
462
- yield output_dir
463
- except KeyboardInterrupt:
464
- display.stop()
465
- raise
466
- except Exception:
467
- display.stop()
468
- raise
469
- finally:
470
- display.stop()
471
- # Restore stdout/stderr
472
- if suppress_output:
473
- sys.stdout = original_stdout
474
- sys.stderr = original_stderr
475
- if log_file:
476
- log_file.close()
@@ -1,63 +0,0 @@
1
- """Event streaming utilities for frontend live updates.
2
-
3
- Provides utilities for emitting structured events to JSONL files for frontend consumption.
4
- Used across all benchmarks for real-time evaluation monitoring.
5
- """
6
-
7
- import json
8
- from collections.abc import Callable
9
- from datetime import datetime, timezone
10
- from pathlib import Path
11
-
12
-
13
- def create_event_emitter(events_file: Path) -> Callable:
14
- """Create an emit_event callback that writes structured events to JSONL.
15
-
16
- Used for frontend live streaming - writes sample/turn/token events.
17
-
18
- Args:
19
- events_file: Path to events.jsonl file
20
-
21
- Returns:
22
- Async callback function that emits events
23
- """
24
- f = events_file.open("a", buffering=1)
25
-
26
- async def emit_event(event_type: str, data: dict) -> None:
27
- """Emit a structured event to events.jsonl."""
28
- assert event_type is not None and isinstance(event_type, str)
29
- assert data is not None and isinstance(data, dict)
30
-
31
- event = {"type": event_type, "timestamp": datetime.now(timezone.utc).isoformat(), **data}
32
- f.write(json.dumps(event) + "\n")
33
- f.flush()
34
-
35
- return emit_event
36
-
37
-
38
- def create_streaming_on_chunk(emit_event: Callable, original_on_chunk: Callable | None = None) -> Callable:
39
- """Wrap on_chunk to emit token events to events.jsonl while preserving original behavior.
40
-
41
- Args:
42
- emit_event: Event emitter created by create_event_emitter()
43
- original_on_chunk: Optional existing on_chunk handler to wrap
44
-
45
- Returns:
46
- Async on_chunk handler that emits events
47
- """
48
- from wafer_core.rollouts.dtypes import StreamChunk
49
-
50
- async def on_chunk(chunk: StreamChunk) -> None:
51
- # Call original handler if provided
52
- if original_on_chunk is not None:
53
- await original_on_chunk(chunk)
54
-
55
- # Emit event for frontend
56
- if chunk.type == "token":
57
- await emit_event("token", {"content": chunk.data["text"]})
58
- elif chunk.type == "tool_call_complete":
59
- await emit_event("tool_call", {"name": chunk.data["name"], "args": chunk.data["args"]})
60
- elif chunk.type == "tool_result":
61
- await emit_event("tool_result", {"ok": chunk.data["ok"], "content": chunk.data["content"]})
62
-
63
- return on_chunk