synkro 0.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. synkro/__init__.py +179 -0
  2. synkro/advanced.py +186 -0
  3. synkro/cli.py +128 -0
  4. synkro/core/__init__.py +7 -0
  5. synkro/core/checkpoint.py +250 -0
  6. synkro/core/dataset.py +402 -0
  7. synkro/core/policy.py +337 -0
  8. synkro/errors.py +178 -0
  9. synkro/examples/__init__.py +148 -0
  10. synkro/factory.py +276 -0
  11. synkro/formatters/__init__.py +12 -0
  12. synkro/formatters/qa.py +98 -0
  13. synkro/formatters/sft.py +90 -0
  14. synkro/formatters/tool_call.py +127 -0
  15. synkro/generation/__init__.py +9 -0
  16. synkro/generation/follow_ups.py +134 -0
  17. synkro/generation/generator.py +220 -0
  18. synkro/generation/golden_responses.py +244 -0
  19. synkro/generation/golden_scenarios.py +276 -0
  20. synkro/generation/golden_tool_responses.py +416 -0
  21. synkro/generation/logic_extractor.py +126 -0
  22. synkro/generation/multiturn_responses.py +177 -0
  23. synkro/generation/planner.py +131 -0
  24. synkro/generation/responses.py +189 -0
  25. synkro/generation/scenarios.py +90 -0
  26. synkro/generation/tool_responses.py +376 -0
  27. synkro/generation/tool_simulator.py +114 -0
  28. synkro/interactive/__init__.py +12 -0
  29. synkro/interactive/hitl_session.py +77 -0
  30. synkro/interactive/logic_map_editor.py +173 -0
  31. synkro/interactive/rich_ui.py +205 -0
  32. synkro/llm/__init__.py +7 -0
  33. synkro/llm/client.py +235 -0
  34. synkro/llm/rate_limits.py +95 -0
  35. synkro/models/__init__.py +43 -0
  36. synkro/models/anthropic.py +26 -0
  37. synkro/models/google.py +19 -0
  38. synkro/models/openai.py +31 -0
  39. synkro/modes/__init__.py +15 -0
  40. synkro/modes/config.py +66 -0
  41. synkro/modes/qa.py +18 -0
  42. synkro/modes/sft.py +18 -0
  43. synkro/modes/tool_call.py +18 -0
  44. synkro/parsers.py +442 -0
  45. synkro/pipeline/__init__.py +20 -0
  46. synkro/pipeline/phases.py +592 -0
  47. synkro/pipeline/runner.py +424 -0
  48. synkro/pipelines.py +123 -0
  49. synkro/prompts/__init__.py +57 -0
  50. synkro/prompts/base.py +167 -0
  51. synkro/prompts/golden_templates.py +474 -0
  52. synkro/prompts/interactive_templates.py +65 -0
  53. synkro/prompts/multiturn_templates.py +156 -0
  54. synkro/prompts/qa_templates.py +97 -0
  55. synkro/prompts/templates.py +281 -0
  56. synkro/prompts/tool_templates.py +201 -0
  57. synkro/quality/__init__.py +14 -0
  58. synkro/quality/golden_refiner.py +163 -0
  59. synkro/quality/grader.py +153 -0
  60. synkro/quality/multiturn_grader.py +150 -0
  61. synkro/quality/refiner.py +137 -0
  62. synkro/quality/tool_grader.py +126 -0
  63. synkro/quality/tool_refiner.py +128 -0
  64. synkro/quality/verifier.py +228 -0
  65. synkro/reporting.py +537 -0
  66. synkro/schemas.py +472 -0
  67. synkro/types/__init__.py +41 -0
  68. synkro/types/core.py +126 -0
  69. synkro/types/dataset_type.py +30 -0
  70. synkro/types/logic_map.py +345 -0
  71. synkro/types/tool.py +94 -0
  72. synkro-0.4.12.data/data/examples/__init__.py +148 -0
  73. synkro-0.4.12.dist-info/METADATA +258 -0
  74. synkro-0.4.12.dist-info/RECORD +77 -0
  75. synkro-0.4.12.dist-info/WHEEL +4 -0
  76. synkro-0.4.12.dist-info/entry_points.txt +2 -0
  77. synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
synkro/reporting.py ADDED
@@ -0,0 +1,537 @@
1
+ """Progress reporting abstraction for generation pipeline.
2
+
3
+ This module provides a Protocol for progress reporting and implementations
4
+ for different use cases (rich console, silent for testing, etc.).
5
+
6
+ Enhanced for Golden Trace pipeline with:
7
+ - Logic Map logging
8
+ - Scenario type distribution
9
+ - Per-trace category/type logging
10
+ """
11
+
12
+ from typing import Protocol, TYPE_CHECKING, Callable
13
+
14
+ from synkro.types.core import Plan, Scenario, Trace, GradeResult
15
+
16
+ if TYPE_CHECKING:
17
+ from synkro.types.logic_map import LogicMap, GoldenScenario
18
+
19
+
20
+ class ProgressReporter(Protocol):
21
+ """
22
+ Protocol for reporting generation progress.
23
+
24
+ Implement this to customize how progress is displayed or logged.
25
+
26
+ Examples:
27
+ >>> # Use silent reporter for testing
28
+ >>> generator = Generator(reporter=SilentReporter())
29
+
30
+ >>> # Use rich reporter for CLI (default)
31
+ >>> generator = Generator(reporter=RichReporter())
32
+ """
33
+
34
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
35
+ """Called when generation starts."""
36
+ ...
37
+
38
+ def on_plan_complete(self, plan: Plan) -> None:
39
+ """Called when planning phase completes."""
40
+ ...
41
+
42
+ def on_scenario_progress(self, completed: int, total: int) -> None:
43
+ """Called during scenario generation."""
44
+ ...
45
+
46
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
47
+ """Called when all scenarios are generated."""
48
+ ...
49
+
50
+ def on_response_progress(self, completed: int, total: int) -> None:
51
+ """Called during response generation."""
52
+ ...
53
+
54
+ def on_responses_complete(self, traces: list[Trace]) -> None:
55
+ """Called when all responses are generated."""
56
+ ...
57
+
58
+ def on_grading_progress(self, completed: int, total: int) -> None:
59
+ """Called during grading."""
60
+ ...
61
+
62
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
63
+ """Called when grading completes."""
64
+ ...
65
+
66
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
67
+ """Called when a refinement iteration starts."""
68
+ ...
69
+
70
+ def on_grading_skipped(self) -> None:
71
+ """Called when grading is skipped."""
72
+ ...
73
+
74
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
75
+ """Called when generation is complete."""
76
+ ...
77
+
78
+ def on_logic_map_complete(self, logic_map: "LogicMap") -> None:
79
+ """Called when logic extraction completes (Stage 1)."""
80
+ ...
81
+
82
+ def on_golden_scenarios_complete(
83
+ self, scenarios: list["GoldenScenario"], distribution: dict[str, int]
84
+ ) -> None:
85
+ """Called when golden scenarios are generated (Stage 2)."""
86
+ ...
87
+
88
+ def on_hitl_start(self, rules_count: int) -> None:
89
+ """Called when HITL session starts."""
90
+ ...
91
+
92
+ def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
93
+ """Called after each HITL refinement."""
94
+ ...
95
+
96
+ def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
97
+ """Called when HITL session completes."""
98
+ ...
99
+
100
+
101
+ class SilentReporter:
102
+ """
103
+ No-op reporter for testing and embedding.
104
+
105
+ Use this when you don't want any console output.
106
+
107
+ Examples:
108
+ >>> generator = Generator(reporter=SilentReporter())
109
+ >>> dataset = generator.generate(policy) # No console output
110
+ """
111
+
112
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
113
+ pass
114
+
115
+ def on_plan_complete(self, plan: Plan) -> None:
116
+ pass
117
+
118
+ def on_scenario_progress(self, completed: int, total: int) -> None:
119
+ pass
120
+
121
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
122
+ pass
123
+
124
+ def on_response_progress(self, completed: int, total: int) -> None:
125
+ pass
126
+
127
+ def on_responses_complete(self, traces: list[Trace]) -> None:
128
+ pass
129
+
130
+ def on_grading_progress(self, completed: int, total: int) -> None:
131
+ pass
132
+
133
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
134
+ pass
135
+
136
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
137
+ pass
138
+
139
+ def on_grading_skipped(self) -> None:
140
+ pass
141
+
142
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
143
+ pass
144
+
145
+ def on_logic_map_complete(self, logic_map) -> None:
146
+ pass
147
+
148
+ def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
149
+ pass
150
+
151
+ def on_hitl_start(self, rules_count: int) -> None:
152
+ pass
153
+
154
+ def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
155
+ pass
156
+
157
+ def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
158
+ pass
159
+
160
+
161
+ class RichReporter:
162
+ """
163
+ Rich console reporter with progress bars and formatted output.
164
+
165
+ This is the default reporter that provides the familiar synkro CLI experience.
166
+ """
167
+
168
+ def __init__(self):
169
+ from rich.console import Console
170
+ self.console = Console()
171
+ self._progress = None
172
+ self._current_task = None
173
+
174
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
175
+ from rich.panel import Panel
176
+
177
+ self.console.print()
178
+ self.console.print(Panel.fit(
179
+ f"[bold]Generating {traces} traces[/bold]\n"
180
+ f"[dim]Type: {dataset_type.upper()} | Model: {model}[/dim]",
181
+ title="[cyan]synkro[/cyan]",
182
+ border_style="cyan"
183
+ ))
184
+ self.console.print()
185
+
186
+ def on_plan_complete(self, plan: Plan) -> None:
187
+ from rich.table import Table
188
+
189
+ self.console.print(f"[green]📋 Planning[/green] [dim]{len(plan.categories)} categories[/dim]")
190
+
191
+ cat_table = Table(title="Categories", show_header=True, header_style="bold cyan")
192
+ cat_table.add_column("Name")
193
+ cat_table.add_column("Description")
194
+ cat_table.add_column("Count", justify="right")
195
+ for cat in plan.categories:
196
+ cat_table.add_row(cat.name, cat.description, str(cat.count))
197
+ self.console.print(cat_table)
198
+ self.console.print()
199
+
200
+ def on_scenario_progress(self, completed: int, total: int) -> None:
201
+ pass # Progress shown in on_scenarios_complete
202
+
203
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
204
+ self.console.print(f"[green]💡 Scenarios[/green] [dim]{len(scenarios)} created[/dim]")
205
+ for idx, s in enumerate(scenarios, 1):
206
+ desc = s.description[:80] + "..." if len(s.description) > 80 else s.description
207
+ self.console.print(f" [dim]#{idx}[/dim] [yellow]{desc}[/yellow]")
208
+
209
+ def on_response_progress(self, completed: int, total: int) -> None:
210
+ pass # Progress shown in on_responses_complete
211
+
212
+ def on_responses_complete(self, traces: list[Trace]) -> None:
213
+ self.console.print(f"[green]✍️ Responses[/green] [dim]{len(traces)} generated[/dim]")
214
+ for idx, trace in enumerate(traces, 1):
215
+ user_preview = trace.user_message[:60] + "..." if len(trace.user_message) > 60 else trace.user_message
216
+ asst_preview = trace.assistant_message[:60] + "..." if len(trace.assistant_message) > 60 else trace.assistant_message
217
+ self.console.print(f" [dim]#{idx}[/dim] [blue]User:[/blue] {user_preview}")
218
+ self.console.print(f" [green]Assistant:[/green] {asst_preview}")
219
+
220
+ def on_grading_progress(self, completed: int, total: int) -> None:
221
+ pass # Progress shown in on_grading_complete
222
+
223
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
224
+ self.console.print(f"[green]⚖️ Grading[/green] [dim]{pass_rate:.0f}% passed[/dim]")
225
+ for idx, trace in enumerate(traces, 1):
226
+ scenario_preview = trace.scenario.description[:40] + "..." if len(trace.scenario.description) > 40 else trace.scenario.description
227
+ if trace.grade and trace.grade.passed:
228
+ self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [green]✓ Passed[/green]")
229
+ else:
230
+ issues = ", ".join(trace.grade.issues[:2]) if trace.grade and trace.grade.issues else "No specific issues"
231
+ issues_preview = issues[:40] + "..." if len(issues) > 40 else issues
232
+ self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [red]✗ Failed[/red] [dim]{issues_preview}[/dim]")
233
+
234
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
235
+ self.console.print(f" [yellow]↻ Refining {failed_count} failed traces (iteration {iteration})...[/yellow]")
236
+
237
+ def on_grading_skipped(self) -> None:
238
+ self.console.print(f" [dim]⚖️ Grading skipped[/dim]")
239
+
240
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
241
+ from rich.panel import Panel
242
+ from rich.table import Table
243
+
244
+ elapsed_str = f"{int(elapsed_seconds) // 60}m {int(elapsed_seconds) % 60}s" if elapsed_seconds >= 60 else f"{int(elapsed_seconds)}s"
245
+
246
+ self.console.print()
247
+ summary = Table.grid(padding=(0, 2))
248
+ summary.add_column(style="green")
249
+ summary.add_column()
250
+ summary.add_row("✅ Done!", f"Generated {dataset_size} traces in {elapsed_str}")
251
+ if pass_rate is not None:
252
+ summary.add_row("📊 Quality:", f"{pass_rate:.0f}% passed verification")
253
+ self.console.print(Panel(summary, border_style="green", title="[green]Complete[/green]"))
254
+ self.console.print()
255
+
256
+ def on_logic_map_complete(self, logic_map) -> None:
257
+ """Display the extracted Logic Map (Stage 1)."""
258
+ from rich.panel import Panel
259
+ from rich.tree import Tree
260
+
261
+ self.console.print(f"\n[green]📜 Logic Map[/green] [dim]{len(logic_map.rules)} rules extracted[/dim]")
262
+
263
+ # Show rules as a tree
264
+ tree = Tree("[bold cyan]Rules[/bold cyan]")
265
+
266
+ # Group by category
267
+ by_category = {}
268
+ for rule in logic_map.rules:
269
+ cat = rule.category.value
270
+ by_category.setdefault(cat, []).append(rule)
271
+
272
+ for category, rules in by_category.items():
273
+ cat_branch = tree.add(f"[yellow]{category.upper()}[/yellow] ({len(rules)} rules)")
274
+ for rule in rules[:3]: # Show first 3 per category
275
+ deps = f" → {', '.join(rule.dependencies)}" if rule.dependencies else ""
276
+ rule_text = rule.text[:50] + "..." if len(rule.text) > 50 else rule.text
277
+ cat_branch.add(f"[dim]{rule.rule_id}[/dim]: {rule_text}{deps}")
278
+ if len(rules) > 3:
279
+ cat_branch.add(f"[dim]... and {len(rules) - 3} more[/dim]")
280
+
281
+ self.console.print(tree)
282
+
283
+ # Show dependency chains
284
+ if logic_map.root_rules:
285
+ self.console.print(f" [dim]Root rules: {', '.join(logic_map.root_rules)}[/dim]")
286
+
287
+ def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
288
+ """Display golden scenarios as a category × type matrix (Stage 2)."""
289
+ from rich.table import Table
290
+
291
+ self.console.print(f"\n[green]💡 Golden Scenarios[/green] [dim]{len(scenarios)} created[/dim]")
292
+
293
+ # Build category × type matrix
294
+ matrix: dict[str, dict[str, int]] = {}
295
+ for s in scenarios:
296
+ cat = s.category or "uncategorized"
297
+ stype = s.scenario_type.value if hasattr(s.scenario_type, 'value') else s.scenario_type
298
+ matrix.setdefault(cat, {"positive": 0, "negative": 0, "edge_case": 0, "irrelevant": 0})
299
+ matrix[cat][stype] += 1
300
+
301
+ # Create the combined table
302
+ table = Table(title="Scenario Distribution", show_header=True, header_style="bold cyan")
303
+ table.add_column("Category", style="cyan")
304
+ table.add_column("[green]✓ Positive[/green]", justify="right")
305
+ table.add_column("[red]✗ Negative[/red]", justify="right")
306
+ table.add_column("[yellow]⚡ Edge[/yellow]", justify="right")
307
+ table.add_column("[dim]○ Irrelevant[/dim]", justify="right")
308
+ table.add_column("Total", justify="right", style="bold")
309
+
310
+ # Track column totals
311
+ totals = {"positive": 0, "negative": 0, "edge_case": 0, "irrelevant": 0}
312
+
313
+ for cat_name, counts in matrix.items():
314
+ row_total = sum(counts.values())
315
+ table.add_row(
316
+ cat_name,
317
+ str(counts["positive"]),
318
+ str(counts["negative"]),
319
+ str(counts["edge_case"]),
320
+ str(counts["irrelevant"]),
321
+ str(row_total),
322
+ )
323
+ for stype, count in counts.items():
324
+ totals[stype] += count
325
+
326
+ # Add totals row
327
+ grand_total = sum(totals.values())
328
+ table.add_section()
329
+ table.add_row(
330
+ "[bold]Total[/bold]",
331
+ f"[bold]{totals['positive']}[/bold]",
332
+ f"[bold]{totals['negative']}[/bold]",
333
+ f"[bold]{totals['edge_case']}[/bold]",
334
+ f"[bold]{totals['irrelevant']}[/bold]",
335
+ f"[bold]{grand_total}[/bold]",
336
+ )
337
+
338
+ self.console.print(table)
339
+
340
+ def on_responses_complete(self, traces: list[Trace]) -> None:
341
+ """Enhanced to show category and type for each trace."""
342
+ self.console.print(f"\n[green]✍️ Traces[/green] [dim]{len(traces)} generated[/dim]")
343
+
344
+ # Group by category
345
+ by_category = {}
346
+ for trace in traces:
347
+ cat = trace.scenario.category or "uncategorized"
348
+ by_category.setdefault(cat, []).append(trace)
349
+
350
+ for cat_name, cat_traces in by_category.items():
351
+ self.console.print(f"\n [cyan]📁 {cat_name}[/cyan] ({len(cat_traces)} traces)")
352
+
353
+ for trace in cat_traces[:3]: # Show first 3 per category
354
+ # Try to get scenario type if available
355
+ scenario_type = getattr(trace.scenario, 'scenario_type', None)
356
+ if scenario_type:
357
+ type_indicator = {
358
+ "positive": "[green]✓[/green]",
359
+ "negative": "[red]✗[/red]",
360
+ "edge_case": "[yellow]⚡[/yellow]",
361
+ "irrelevant": "[dim]○[/dim]"
362
+ }.get(scenario_type if isinstance(scenario_type, str) else scenario_type.value, "[white]?[/white]")
363
+ else:
364
+ type_indicator = "[white]•[/white]"
365
+
366
+ user_preview = trace.user_message[:50] + "..." if len(trace.user_message) > 50 else trace.user_message
367
+ self.console.print(f" {type_indicator} [blue]{user_preview}[/blue]")
368
+
369
+ if len(cat_traces) > 3:
370
+ self.console.print(f" [dim]... and {len(cat_traces) - 3} more[/dim]")
371
+
372
+ def on_hitl_start(self, rules_count: int) -> None:
373
+ """Display HITL session start."""
374
+ from rich.panel import Panel
375
+
376
+ self.console.print()
377
+ self.console.print(Panel.fit(
378
+ f"[bold]Interactive Logic Map Editor[/bold]\n"
379
+ f"[dim]Review and refine {rules_count} extracted rules[/dim]",
380
+ title="[cyan]HITL Mode[/cyan]",
381
+ border_style="cyan"
382
+ ))
383
+
384
+ def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
385
+ """Display refinement result."""
386
+ feedback_preview = feedback[:60] + "..." if len(feedback) > 60 else feedback
387
+ self.console.print(f" [green]✓[/green] [dim]{feedback_preview}[/dim]")
388
+ self.console.print(f" [cyan]{changes_summary}[/cyan]")
389
+
390
+ def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
391
+ """Display HITL session completion."""
392
+ if change_count > 0:
393
+ self.console.print(
394
+ f"\n[green]✅ HITL Complete[/green] - "
395
+ f"Made {change_count} change(s), proceeding with {final_rules_count} rules"
396
+ )
397
+ else:
398
+ self.console.print(
399
+ f"\n[green]✅ HITL Complete[/green] - "
400
+ f"No changes made, proceeding with {final_rules_count} rules"
401
+ )
402
+
403
+
404
+ class CallbackReporter:
405
+ """
406
+ Reporter that invokes user-provided callbacks for progress events.
407
+
408
+ Use this when you need programmatic access to progress events
409
+ (e.g., updating a progress bar, logging to a file, etc.)
410
+
411
+ Examples:
412
+ >>> def on_progress(event: str, data: dict):
413
+ ... print(f"{event}: {data}")
414
+ ...
415
+ >>> reporter = CallbackReporter(on_progress=on_progress)
416
+ >>> generator = Generator(reporter=reporter)
417
+
418
+ >>> # With specific event handlers
419
+ >>> reporter = CallbackReporter(
420
+ ... on_start=lambda traces, model, dtype: print(f"Starting {traces} traces"),
421
+ ... on_complete=lambda size, elapsed, rate: print(f"Done! {size} traces"),
422
+ ... )
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ on_progress: "Callable[[str, dict], None] | None" = None,
428
+ on_start: "Callable[[int, str, str], None] | None" = None,
429
+ on_plan_complete: "Callable[[Plan], None] | None" = None,
430
+ on_scenario_progress: "Callable[[int, int], None] | None" = None,
431
+ on_scenarios_complete: "Callable[[list[Scenario]], None] | None" = None,
432
+ on_response_progress: "Callable[[int, int], None] | None" = None,
433
+ on_responses_complete: "Callable[[list[Trace]], None] | None" = None,
434
+ on_grading_progress: "Callable[[int, int], None] | None" = None,
435
+ on_grading_complete: "Callable[[list[Trace], float], None] | None" = None,
436
+ on_complete: "Callable[[int, float, float | None], None] | None" = None,
437
+ ):
438
+ """
439
+ Initialize the callback reporter.
440
+
441
+ Args:
442
+ on_progress: Generic callback for all events. Receives (event_name, data_dict).
443
+ on_start: Called when generation starts (traces, model, dataset_type)
444
+ on_plan_complete: Called when planning completes (plan)
445
+ on_scenario_progress: Called during scenario generation (completed, total)
446
+ on_scenarios_complete: Called when scenarios are done (scenarios list)
447
+ on_response_progress: Called during response generation (completed, total)
448
+ on_responses_complete: Called when responses are done (traces list)
449
+ on_grading_progress: Called during grading (completed, total)
450
+ on_grading_complete: Called when grading is done (traces, pass_rate)
451
+ on_complete: Called when generation completes (dataset_size, elapsed, pass_rate)
452
+ """
453
+ self._on_progress = on_progress
454
+ self._on_start = on_start
455
+ self._on_plan_complete = on_plan_complete
456
+ self._on_scenario_progress = on_scenario_progress
457
+ self._on_scenarios_complete = on_scenarios_complete
458
+ self._on_response_progress = on_response_progress
459
+ self._on_responses_complete = on_responses_complete
460
+ self._on_grading_progress = on_grading_progress
461
+ self._on_grading_complete = on_grading_complete
462
+ self._on_complete_cb = on_complete
463
+
464
+ def _emit(self, event: str, data: dict) -> None:
465
+ """Emit an event to the generic callback."""
466
+ if self._on_progress:
467
+ self._on_progress(event, data)
468
+
469
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
470
+ self._emit("start", {"traces": traces, "model": model, "dataset_type": dataset_type})
471
+ if self._on_start:
472
+ self._on_start(traces, model, dataset_type)
473
+
474
+ def on_plan_complete(self, plan: Plan) -> None:
475
+ self._emit("plan_complete", {"categories": len(plan.categories)})
476
+ if self._on_plan_complete:
477
+ self._on_plan_complete(plan)
478
+
479
+ def on_scenario_progress(self, completed: int, total: int) -> None:
480
+ self._emit("scenario_progress", {"completed": completed, "total": total})
481
+ if self._on_scenario_progress:
482
+ self._on_scenario_progress(completed, total)
483
+
484
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
485
+ self._emit("scenarios_complete", {"count": len(scenarios)})
486
+ if self._on_scenarios_complete:
487
+ self._on_scenarios_complete(scenarios)
488
+
489
+ def on_response_progress(self, completed: int, total: int) -> None:
490
+ self._emit("response_progress", {"completed": completed, "total": total})
491
+ if self._on_response_progress:
492
+ self._on_response_progress(completed, total)
493
+
494
+ def on_responses_complete(self, traces: list[Trace]) -> None:
495
+ self._emit("responses_complete", {"count": len(traces)})
496
+ if self._on_responses_complete:
497
+ self._on_responses_complete(traces)
498
+
499
+ def on_grading_progress(self, completed: int, total: int) -> None:
500
+ self._emit("grading_progress", {"completed": completed, "total": total})
501
+ if self._on_grading_progress:
502
+ self._on_grading_progress(completed, total)
503
+
504
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
505
+ self._emit("grading_complete", {"count": len(traces), "pass_rate": pass_rate})
506
+ if self._on_grading_complete:
507
+ self._on_grading_complete(traces, pass_rate)
508
+
509
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
510
+ self._emit("refinement_start", {"iteration": iteration, "failed_count": failed_count})
511
+
512
+ def on_grading_skipped(self) -> None:
513
+ self._emit("grading_skipped", {})
514
+
515
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
516
+ self._emit("complete", {"dataset_size": dataset_size, "elapsed_seconds": elapsed_seconds, "pass_rate": pass_rate})
517
+ if self._on_complete_cb:
518
+ self._on_complete_cb(dataset_size, elapsed_seconds, pass_rate)
519
+
520
+ def on_logic_map_complete(self, logic_map) -> None:
521
+ self._emit("logic_map_complete", {"rules_count": len(logic_map.rules)})
522
+
523
+ def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
524
+ self._emit("golden_scenarios_complete", {"count": len(scenarios), "distribution": distribution})
525
+
526
+ def on_hitl_start(self, rules_count: int) -> None:
527
+ self._emit("hitl_start", {"rules_count": rules_count})
528
+
529
+ def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
530
+ self._emit("hitl_refinement", {"feedback": feedback, "changes_summary": changes_summary})
531
+
532
+ def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
533
+ self._emit("hitl_complete", {"change_count": change_count, "final_rules_count": final_rules_count})
534
+
535
+
536
+ __all__ = ["ProgressReporter", "SilentReporter", "RichReporter", "CallbackReporter"]
537
+