wafer-core 0.1.43__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,7 @@ import sys
25
25
  import time
26
26
  from collections.abc import Generator
27
27
  from contextlib import contextmanager
28
- from dataclasses import dataclass, field
28
+ from dataclasses import dataclass, field, replace
29
29
  from pathlib import Path
30
30
  from typing import Any
31
31
 
@@ -38,6 +38,8 @@ RED = "\x1b[31m"
38
38
  YELLOW = "\x1b[33m"
39
39
  CYAN = "\x1b[36m"
40
40
  BOLD = "\x1b[1m"
41
+ BG_SELECTED = "\x1b[48;5;236m" # Dark gray background for selected row
42
+ BLUE = "\x1b[34m"
41
43
 
42
44
  # Event types the progress display understands.
43
45
  # The "message" field in the JSONL record is the event type discriminator.
@@ -55,6 +57,19 @@ KNOWN_EVENT_TYPES = frozenset({
55
57
  "sample_retry",
56
58
  })
57
59
 
60
+ # Event types in per-sample JSONL files
61
+ SAMPLE_EVENT_TYPES = frozenset({
62
+ "text_delta",
63
+ "thinking_delta",
64
+ "assistant_message",
65
+ "tool_execution",
66
+ "llm_call",
67
+ "modal_progress",
68
+ "turn",
69
+ "sample_start",
70
+ "sample_end",
71
+ })
72
+
58
73
 
59
74
  # ── State ────────────────────────────────────────────────────────────────────
60
75
 
@@ -181,7 +196,7 @@ def derive_state(events: list[dict[str, Any]]) -> RenderState:
181
196
  return state
182
197
 
183
198
 
184
- def _format_sample_row(sample: SampleState, width: int) -> str:
199
+ def _format_sample_row(sample: SampleState, width: int, selected: bool = False) -> str:
185
200
  """Format a single sample row."""
186
201
  name = sample.name[:25].ljust(25)
187
202
  turn_info = f"T:{sample.turn}"
@@ -189,24 +204,89 @@ def _format_sample_row(sample: SampleState, width: int) -> str:
189
204
  if sample.status == "complete":
190
205
  if sample.score is not None:
191
206
  icon = f"{GREEN}✓{RESET}" if sample.score > 0.5 else f"{RED}✗{RESET}"
192
- return f" {name} {turn_info:>4} {icon} score={sample.score:.2f}"
207
+ row = f" {name} {turn_info:>4} {icon} score={sample.score:.2f}"
193
208
  else:
194
- return f" {name} {turn_info:>4} {GREEN}✓{RESET} done"
209
+ row = f" {name} {turn_info:>4} {GREEN}✓{RESET} done"
195
210
  elif sample.status == "retry":
196
- return f" {name} {turn_info:>4} {YELLOW}⟳{RESET} retry (attempt {sample.retry_attempt})"
211
+ row = f" {name} {turn_info:>4} {YELLOW}⟳{RESET} retry (attempt {sample.retry_attempt})"
197
212
  else:
198
213
  phase = sample.phase or "running"
199
- return f" {name} {turn_info:>4} {CYAN}{phase}...{RESET}"
214
+ row = f" {name} {turn_info:>4} {CYAN}{phase}...{RESET}"
200
215
 
216
+ if selected:
217
+ # Add selection highlight
218
+ row = f"{BG_SELECTED}{row}{RESET}"
201
219
 
202
- # ── Render function ──────────────────────────────────────────────────────────
220
+ return row
203
221
 
204
222
 
205
- def render(state: RenderState, width: int, height: int) -> list[str]:
206
- """Render state to list of lines. Pure function."""
207
- lines: list[str] = []
223
+ # ── Viewport formatting ──────────────────────────────────────────────────────
208
224
 
209
- # Header
225
+
226
+ def format_sample_event(
227
+ event: dict[str, Any], last_turn: int | None = None
228
+ ) -> tuple[str | None, int | None]:
229
+ """Format a per-sample JSONL event into a readable line.
230
+
231
+ Returns (formatted_line, new_turn_number).
232
+ formatted_line is None if the event should not be displayed.
233
+ new_turn_number is the turn number if this event changes it, else None.
234
+ """
235
+ msg = event.get("message", "")
236
+
237
+ if msg == "text_delta":
238
+ # Streaming text — return raw text for accumulation
239
+ return event.get("text", ""), None
240
+
241
+ elif msg == "thinking_delta":
242
+ # Extended thinking — show dimmed
243
+ text = event.get("text", "")
244
+ return (f"{DIM}{text}{RESET}" if text else None), None
245
+
246
+ elif msg == "assistant_message":
247
+ # Final assistant message — skip if we've been streaming
248
+ # The caller handles dedup
249
+ return None, None
250
+
251
+ elif msg == "tool_execution":
252
+ tool = event.get("tool_name", "tool")
253
+ duration = event.get("duration_ms", 0)
254
+ status = event.get("result_summary", "")
255
+ is_error = event.get("is_error", False)
256
+ color = RED if is_error else CYAN
257
+ return f"{color}→ {tool}{RESET} ({duration / 1000:.1f}s) {status}", None
258
+
259
+ elif msg == "llm_call":
260
+ provider = event.get("provider", "")
261
+ model = event.get("model", "")
262
+ tokens_in = event.get("tokens_in", 0)
263
+ tokens_out = event.get("tokens_out", 0)
264
+ duration = event.get("duration_ms", 0)
265
+ model_short = model.split("/")[-1] if model else provider
266
+ return (
267
+ f"{DIM}◆ {model_short} {tokens_in}→{tokens_out} tok ({duration / 1000:.1f}s){RESET}",
268
+ None,
269
+ )
270
+
271
+ elif msg == "modal_progress":
272
+ phase = event.get("phase", "")
273
+ return (f"{YELLOW}⟳ {phase}...{RESET}" if phase else None), None
274
+
275
+ elif msg == "turn":
276
+ turn = event.get("turn")
277
+ # Only show turn separator when turn number changes
278
+ if turn is not None and turn != last_turn:
279
+ return f"{DIM}── turn {turn} ──{RESET}", turn
280
+ return None, turn if turn is not None else last_turn
281
+
282
+ return None, None
283
+
284
+
285
+ # ── Render function ──────────────────────────────────────────────────────────
286
+
287
+
288
+ def render_header(state: RenderState, width: int) -> str:
289
+ """Render the progress header line."""
210
290
  if state.gepa_iter is not None:
211
291
  header = (
212
292
  f"GEPA iter {state.gepa_iter}/{state.gepa_total or '?'} │ best: {state.gepa_best:.0%}"
@@ -231,7 +311,14 @@ def render(state: RenderState, width: int, height: int) -> list[str]:
231
311
  else:
232
312
  header = f"{state.eval_name}: {state.completed} samples [{_format_time(elapsed)}]"
233
313
 
234
- lines.append(header[:width])
314
+ return header[:width]
315
+
316
+
317
+ def render(state: RenderState, width: int, height: int) -> list[str]:
318
+ """Render state to list of lines. Pure function."""
319
+ lines: list[str] = []
320
+
321
+ lines.append(render_header(state, width))
235
322
 
236
323
  # Active samples, sorted by most recently updated
237
324
  active = [s for s in state.samples.values() if s.status != "complete" and s.phase]
@@ -263,9 +350,32 @@ class Model:
263
350
  events: tuple[dict[str, Any], ...] = ()
264
351
  done: bool = False
265
352
 
353
+ # Viewport state
354
+ focus: str = "list" # "list" or "viewport"
355
+ selected_idx: int = 0 # Index into active samples list
356
+ viewport_lines: tuple[str, ...] = () # Formatted lines for viewport
357
+ viewport_scroll: int = 0 # Scroll position in viewport
358
+ viewport_auto_follow: bool = True # Auto-scroll to bottom on new content
359
+ tailing_sample_id: str | None = None # Which sample file we're tailing
360
+
361
+ # For streaming text accumulation
362
+ streaming_text: str = ""
363
+
364
+ # For deduping turn separators
365
+ last_turn: int | None = None
366
+
266
367
 
267
368
  @dataclass(frozen=True)
268
369
  class NewEvent:
370
+ """New line from events.jsonl."""
371
+
372
+ line: str
373
+
374
+
375
+ @dataclass(frozen=True)
376
+ class SampleEvent:
377
+ """New line from per-sample JSONL file."""
378
+
269
379
  line: str
270
380
 
271
381
 
@@ -276,32 +386,249 @@ def _parse_event(line: str) -> dict[str, Any] | None:
276
386
  return None
277
387
 
278
388
 
389
+ def _get_active_samples(events: tuple[dict[str, Any], ...]) -> list[SampleState]:
390
+ """Get list of active (non-complete) samples from events."""
391
+ state = derive_state(list(events))
392
+ active = [s for s in state.samples.values() if s.status != "complete" and s.phase]
393
+ active.sort(key=lambda s: s.last_update, reverse=True)
394
+ return active
395
+
396
+
397
+ def _get_selected_sample_id(events: tuple[dict[str, Any], ...], selected_idx: int) -> str | None:
398
+ """Get the sample_id of the currently selected sample."""
399
+ active = _get_active_samples(events)
400
+ if 0 <= selected_idx < len(active):
401
+ return active[selected_idx].id
402
+ return None
403
+
404
+
279
405
  def update(model: Model, msg: object) -> tuple[Model, Any]:
280
406
  from ._pytui import Cmd, KeyPress
281
407
 
282
408
  match msg:
283
409
  case KeyPress(key="q" | "\x03"):
284
410
  return model, Cmd.quit()
411
+
412
+ # Focus switching
413
+ case KeyPress(key="\t"): # Tab
414
+ new_focus = "viewport" if model.focus == "list" else "list"
415
+ return replace(model, focus=new_focus), Cmd.none()
416
+
417
+ # Navigation in list focus
418
+ case KeyPress(key="j" | "\x1b[B") if model.focus == "list": # Down
419
+ active = _get_active_samples(model.events)
420
+ new_idx = min(model.selected_idx + 1, len(active) - 1) if active else 0
421
+ new_sample_id = _get_selected_sample_id(model.events, new_idx)
422
+ # Reset viewport when selection changes
423
+ if new_sample_id != model.tailing_sample_id:
424
+ return replace(
425
+ model,
426
+ selected_idx=new_idx,
427
+ tailing_sample_id=new_sample_id,
428
+ viewport_lines=(),
429
+ viewport_scroll=0,
430
+ viewport_auto_follow=True,
431
+ streaming_text="",
432
+ last_turn=None,
433
+ ), Cmd.none()
434
+ return replace(model, selected_idx=new_idx), Cmd.none()
435
+
436
+ case KeyPress(key="k" | "\x1b[A") if model.focus == "list": # Up
437
+ new_idx = max(model.selected_idx - 1, 0)
438
+ new_sample_id = _get_selected_sample_id(model.events, new_idx)
439
+ if new_sample_id != model.tailing_sample_id:
440
+ return replace(
441
+ model,
442
+ selected_idx=new_idx,
443
+ tailing_sample_id=new_sample_id,
444
+ viewport_lines=(),
445
+ viewport_scroll=0,
446
+ viewport_auto_follow=True,
447
+ streaming_text="",
448
+ last_turn=None,
449
+ ), Cmd.none()
450
+ return replace(model, selected_idx=new_idx), Cmd.none()
451
+
452
+ # Navigation in viewport focus
453
+ case KeyPress(key="j" | "\x1b[B") if model.focus == "viewport": # Down
454
+ new_scroll = model.viewport_scroll + 1
455
+ return replace(
456
+ model, viewport_scroll=new_scroll, viewport_auto_follow=False
457
+ ), Cmd.none()
458
+
459
+ case KeyPress(key="k" | "\x1b[A") if model.focus == "viewport": # Up
460
+ new_scroll = max(0, model.viewport_scroll - 1)
461
+ return replace(
462
+ model, viewport_scroll=new_scroll, viewport_auto_follow=False
463
+ ), Cmd.none()
464
+
465
+ case KeyPress(key="G") if model.focus == "viewport": # Jump to bottom
466
+ new_scroll = max(0, len(model.viewport_lines) - 1)
467
+ return replace(model, viewport_scroll=new_scroll, viewport_auto_follow=True), Cmd.none()
468
+
469
+ case KeyPress(key="g"): # gg = top (simplified, just g for now)
470
+ if model.focus == "viewport":
471
+ return replace(model, viewport_scroll=0, viewport_auto_follow=False), Cmd.none()
472
+
473
+ # Events from events.jsonl
285
474
  case NewEvent(line=line):
286
475
  event = _parse_event(line)
287
476
  if event is None:
288
477
  return model, Cmd.none()
289
478
  new_events = model.events + (event,)
290
479
  if event.get("message") == "eval_end":
291
- return Model(events=new_events, done=True), Cmd.quit()
292
- return Model(events=new_events, done=model.done), Cmd.none()
480
+ return replace(model, events=new_events, done=True), Cmd.quit()
481
+
482
+ # Auto-select first sample if none selected
483
+ new_model = replace(model, events=new_events)
484
+ if model.tailing_sample_id is None:
485
+ sample_id = _get_selected_sample_id(new_events, model.selected_idx)
486
+ if sample_id:
487
+ new_model = replace(new_model, tailing_sample_id=sample_id)
488
+
489
+ return new_model, Cmd.none()
490
+
491
+ # Events from per-sample JSONL file
492
+ case SampleEvent(line=line):
493
+ event = _parse_event(line)
494
+ if event is None:
495
+ return model, Cmd.none()
496
+
497
+ formatted, new_turn = format_sample_event(event, model.last_turn)
498
+ msg_type = event.get("message", "")
499
+
500
+ # Handle streaming text accumulation
501
+ if msg_type == "text_delta":
502
+ new_streaming = model.streaming_text + (formatted or "")
503
+ return replace(model, streaming_text=new_streaming), Cmd.none()
504
+
505
+ # When we get a non-delta event, flush accumulated streaming text
506
+ new_lines = list(model.viewport_lines)
507
+ if model.streaming_text:
508
+ # Wrap streaming text to reasonable width and add
509
+ for line in model.streaming_text.split("\n"):
510
+ if line:
511
+ new_lines.append(line)
512
+ new_streaming = ""
513
+ else:
514
+ new_streaming = model.streaming_text
515
+
516
+ # Add the formatted event line
517
+ if formatted and msg_type != "text_delta":
518
+ new_lines.append(formatted)
519
+
520
+ # Auto-scroll if enabled
521
+ new_scroll = model.viewport_scroll
522
+ if model.viewport_auto_follow:
523
+ new_scroll = max(0, len(new_lines) - 1)
524
+
525
+ return replace(
526
+ model,
527
+ viewport_lines=tuple(new_lines),
528
+ viewport_scroll=new_scroll,
529
+ streaming_text=new_streaming,
530
+ last_turn=new_turn if new_turn is not None else model.last_turn,
531
+ ), Cmd.none()
532
+
293
533
  return model, Cmd.none()
294
534
 
295
535
 
296
536
  def view(model: Model, width: int, height: int) -> list[str]:
537
+ """Render the split view: sample list on top, viewport on bottom."""
297
538
  state = derive_state(list(model.events))
298
- lines = render(state, width, height)
539
+ lines: list[str] = []
299
540
 
300
- # Pad to fill screen, add footer
301
- footer = f"{DIM}q: quit{RESET}"
541
+ # Header
542
+ lines.append(render_header(state, width))
543
+
544
+ # Calculate layout
545
+ # Sample list gets ~1/3 of space (min 4 lines), viewport gets rest
546
+ list_height = max(4, min(height // 3, 10))
547
+ viewport_height = height - list_height - 3 # -3 for header, separator, footer
548
+
549
+ # Active samples with selection highlight
550
+ active = _get_active_samples(model.events)
551
+
552
+ # Render sample list
553
+ for i, sample in enumerate(active[: list_height - 1]):
554
+ selected = i == model.selected_idx
555
+ row = _format_sample_row(sample, width, selected=selected)
556
+ lines.append(row)
557
+
558
+ # Pad list area
559
+ while len(lines) < list_height:
560
+ lines.append("")
561
+
562
+ # Show hidden count
563
+ if len(active) > list_height - 1:
564
+ hidden = len(active) - (list_height - 1)
565
+ lines[list_height - 1] = f"{DIM} ... and {hidden} more{RESET}"
566
+
567
+ # Separator with focus indicator and selected sample name
568
+ selected_name = ""
569
+ if model.tailing_sample_id:
570
+ active = _get_active_samples(model.events)
571
+ if 0 <= model.selected_idx < len(active):
572
+ selected_name = f" [{active[model.selected_idx].name[:20]}]"
573
+
574
+ if model.focus == "list":
575
+ label = "samples"
576
+ else:
577
+ label = "viewport"
578
+ # Build separator to fill width
579
+ prefix = f"─── {label}{selected_name} "
580
+ fill = "─" * max(0, width - len(prefix))
581
+ lines.append(f"{DIM}{prefix}{fill}{RESET}")
582
+
583
+ # Viewport
584
+ if model.tailing_sample_id:
585
+ # Include any streaming text as a "virtual" line at the end
586
+ display_lines = list(model.viewport_lines)
587
+ if model.streaming_text:
588
+ # Show streaming text (last 200 chars to keep it manageable)
589
+ stream_preview = model.streaming_text[-200:]
590
+ if len(model.streaming_text) > 200:
591
+ stream_preview = "..." + stream_preview
592
+ display_lines.append(f"{CYAN}{stream_preview}{RESET}")
593
+
594
+ if display_lines:
595
+ # Clamp scroll
596
+ max_scroll = max(0, len(display_lines) - viewport_height)
597
+ scroll = min(model.viewport_scroll, max_scroll)
598
+
599
+ # Render visible lines
600
+ visible = display_lines[scroll : scroll + viewport_height]
601
+ for line in visible:
602
+ # Truncate long lines (but preserve ANSI - crude truncation for now)
603
+ # TODO: proper ANSI-aware truncation
604
+ if len(line) > width + 20: # Allow some slack for ANSI codes
605
+ lines.append(line[: width + 10] + "...")
606
+ else:
607
+ lines.append(line)
608
+ else:
609
+ lines.append(f"{DIM} (waiting for events...){RESET}")
610
+
611
+ # Pad viewport
612
+ while len(lines) < list_height + 1 + viewport_height:
613
+ lines.append("")
614
+ else:
615
+ # No sample selected
616
+ lines.append(f"{DIM} (select a sample to view its output){RESET}")
617
+ while len(lines) < list_height + 1 + viewport_height:
618
+ lines.append("")
619
+
620
+ # Footer
621
+ if model.focus == "list":
622
+ footer = f"{DIM}j/k: select Tab: focus viewport q: quit{RESET}"
623
+ else:
624
+ follow_status = "follow" if model.viewport_auto_follow else "manual"
625
+ footer = f"{DIM}j/k: scroll G: bottom Tab: focus list [{follow_status}] q: quit{RESET}"
626
+
627
+ # Pad to fill screen
302
628
  while len(lines) < height - 1:
303
629
  lines.append("")
304
630
  lines.append(footer)
631
+
305
632
  return lines[:height]
306
633
 
307
634
 
@@ -375,16 +702,22 @@ if __name__ == "__main__":
375
702
  from ._pytui import App, Cmd, Sub
376
703
 
377
704
  events_file = sys.argv[1]
705
+ output_dir = Path(events_file).parent
706
+ samples_dir = output_dir / "samples"
378
707
 
379
- # Model needs the path for subscriptions. We use a non-frozen wrapper
380
- # to thread the path through — subscriptions() needs it but it's static.
381
- _events_path = events_file
382
-
383
- # Patch subscriptions to close over the path
384
708
  def _subscriptions(model: Model) -> Sub:
385
709
  if model.done:
386
710
  return Sub.none()
387
- return Sub.file_tail(_events_path, lambda line: NewEvent(line=line))
711
+
712
+ subs = [Sub.file_tail(events_file, lambda line: NewEvent(line=line))]
713
+
714
+ # Tail the selected sample's file if we have one
715
+ if model.tailing_sample_id:
716
+ sample_file = samples_dir / f"{model.tailing_sample_id}.jsonl"
717
+ if sample_file.exists():
718
+ subs.append(Sub.file_tail(str(sample_file), lambda line: SampleEvent(line=line)))
719
+
720
+ return Sub.batch(*subs)
388
721
 
389
722
  App(
390
723
  init=(Model(), Cmd.none()),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-core
3
- Version: 0.1.43
3
+ Version: 0.1.44
4
4
  Summary: Core utilities and environments for Wafer GPU kernel optimization
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: aiohttp>=3.9.0
@@ -362,7 +362,7 @@ wafer_core/rollouts/models.py,sha256=BrMnUpTA9_HnOghpedzdLUm9Di3FyJouzOkK4PPBg_k
362
362
  wafer_core/rollouts/paths.py,sha256=9XtrA9ylhb5LttMFe2DE7X0IHeUMjuGUerII9OscYec,3436
363
363
  wafer_core/rollouts/pipeline.py,sha256=vlJTYE3ZX2XScpF9pmtv91K8Q0g8uLmcbI5jn6b5Hzg,15319
364
364
  wafer_core/rollouts/progress.py,sha256=szA9cvWT2xUxGVhF9BaAqJMmKDqMAUlxImxcOpcnqbY,29228
365
- wafer_core/rollouts/progress_app.py,sha256=ir3fk8OZd09Zs7e7EC2WjQs_H-6W-zCbZAjSCDIcKEo,13517
365
+ wafer_core/rollouts/progress_app.py,sha256=JSx4FKn_H8Thi2IK19mr2RvpH15BZuUi3nC0h6qMcD8,26218
366
366
  wafer_core/rollouts/prompt.py,sha256=EDmGb0rhWwke7tokIcO8dukc3q5c8x0n5Omi5CpAQmA,11022
367
367
  wafer_core/rollouts/providers.py,sha256=dcGJh1p30hstVbCDDtJ902lyafkg81DKjcOzb0uuKS0,1400
368
368
  wafer_core/rollouts/remote.py,sha256=cAYpRCONlsTeRxzLiegAUfjZWGtqBNwZTHehMhk5ldA,8816
@@ -723,6 +723,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdh
723
723
  wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
724
724
  wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
725
725
  wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
726
- wafer_core-0.1.43.dist-info/METADATA,sha256=mvNNv5AucqwdnYMP2OO2NaUmPe1xAk5re_YSwjIu4s8,1477
727
- wafer_core-0.1.43.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
728
- wafer_core-0.1.43.dist-info/RECORD,,
726
+ wafer_core-0.1.44.dist-info/METADATA,sha256=onunSBewljqBQ2pzzBa7iv3KYK5sAVUHDLByQORNqLc,1477
727
+ wafer_core-0.1.44.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
728
+ wafer_core-0.1.44.dist-info/RECORD,,