synkro 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkro/__init__.py +179 -0
- synkro/advanced.py +186 -0
- synkro/cli.py +128 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +402 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +276 -0
- synkro/formatters/__init__.py +12 -0
- synkro/formatters/qa.py +98 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +220 -0
- synkro/generation/golden_responses.py +244 -0
- synkro/generation/golden_scenarios.py +276 -0
- synkro/generation/golden_tool_responses.py +416 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +376 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +12 -0
- synkro/interactive/hitl_session.py +77 -0
- synkro/interactive/logic_map_editor.py +173 -0
- synkro/interactive/rich_ui.py +205 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +235 -0
- synkro/llm/rate_limits.py +95 -0
- synkro/models/__init__.py +43 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +15 -0
- synkro/modes/config.py +66 -0
- synkro/modes/qa.py +18 -0
- synkro/modes/sft.py +18 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +424 -0
- synkro/pipelines.py +123 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +474 -0
- synkro/prompts/interactive_templates.py +65 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/qa_templates.py +97 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +201 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +537 -0
- synkro/schemas.py +472 -0
- synkro/types/__init__.py +41 -0
- synkro/types/core.py +126 -0
- synkro/types/dataset_type.py +30 -0
- synkro/types/logic_map.py +345 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.12.data/data/examples/__init__.py +148 -0
- synkro-0.4.12.dist-info/METADATA +258 -0
- synkro-0.4.12.dist-info/RECORD +77 -0
- synkro-0.4.12.dist-info/WHEEL +4 -0
- synkro-0.4.12.dist-info/entry_points.txt +2 -0
- synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
synkro/reporting.py
ADDED
|
@@ -0,0 +1,537 @@
|
|
|
1
|
+
"""Progress reporting abstraction for generation pipeline.
|
|
2
|
+
|
|
3
|
+
This module provides a Protocol for progress reporting and implementations
|
|
4
|
+
for different use cases (rich console, silent for testing, etc.).
|
|
5
|
+
|
|
6
|
+
Enhanced for Golden Trace pipeline with:
|
|
7
|
+
- Logic Map logging
|
|
8
|
+
- Scenario type distribution
|
|
9
|
+
- Per-trace category/type logging
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Protocol, TYPE_CHECKING, Callable
|
|
13
|
+
|
|
14
|
+
from synkro.types.core import Plan, Scenario, Trace, GradeResult
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from synkro.types.logic_map import LogicMap, GoldenScenario
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProgressReporter(Protocol):
|
|
21
|
+
"""
|
|
22
|
+
Protocol for reporting generation progress.
|
|
23
|
+
|
|
24
|
+
Implement this to customize how progress is displayed or logged.
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
>>> # Use silent reporter for testing
|
|
28
|
+
>>> generator = Generator(reporter=SilentReporter())
|
|
29
|
+
|
|
30
|
+
>>> # Use rich reporter for CLI (default)
|
|
31
|
+
>>> generator = Generator(reporter=RichReporter())
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def on_start(self, traces: int, model: str, dataset_type: str) -> None:
|
|
35
|
+
"""Called when generation starts."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
def on_plan_complete(self, plan: Plan) -> None:
|
|
39
|
+
"""Called when planning phase completes."""
|
|
40
|
+
...
|
|
41
|
+
|
|
42
|
+
def on_scenario_progress(self, completed: int, total: int) -> None:
|
|
43
|
+
"""Called during scenario generation."""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
|
|
47
|
+
"""Called when all scenarios are generated."""
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
def on_response_progress(self, completed: int, total: int) -> None:
|
|
51
|
+
"""Called during response generation."""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
def on_responses_complete(self, traces: list[Trace]) -> None:
|
|
55
|
+
"""Called when all responses are generated."""
|
|
56
|
+
...
|
|
57
|
+
|
|
58
|
+
def on_grading_progress(self, completed: int, total: int) -> None:
|
|
59
|
+
"""Called during grading."""
|
|
60
|
+
...
|
|
61
|
+
|
|
62
|
+
def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
|
|
63
|
+
"""Called when grading completes."""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
def on_refinement_start(self, iteration: int, failed_count: int) -> None:
|
|
67
|
+
"""Called when a refinement iteration starts."""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
def on_grading_skipped(self) -> None:
|
|
71
|
+
"""Called when grading is skipped."""
|
|
72
|
+
...
|
|
73
|
+
|
|
74
|
+
def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
|
|
75
|
+
"""Called when generation is complete."""
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
def on_logic_map_complete(self, logic_map: "LogicMap") -> None:
|
|
79
|
+
"""Called when logic extraction completes (Stage 1)."""
|
|
80
|
+
...
|
|
81
|
+
|
|
82
|
+
def on_golden_scenarios_complete(
|
|
83
|
+
self, scenarios: list["GoldenScenario"], distribution: dict[str, int]
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Called when golden scenarios are generated (Stage 2)."""
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
def on_hitl_start(self, rules_count: int) -> None:
|
|
89
|
+
"""Called when HITL session starts."""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
|
|
93
|
+
"""Called after each HITL refinement."""
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
|
|
97
|
+
"""Called when HITL session completes."""
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SilentReporter:
|
|
102
|
+
"""
|
|
103
|
+
No-op reporter for testing and embedding.
|
|
104
|
+
|
|
105
|
+
Use this when you don't want any console output.
|
|
106
|
+
|
|
107
|
+
Examples:
|
|
108
|
+
>>> generator = Generator(reporter=SilentReporter())
|
|
109
|
+
>>> dataset = generator.generate(policy) # No console output
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def on_start(self, traces: int, model: str, dataset_type: str) -> None:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
def on_plan_complete(self, plan: Plan) -> None:
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
def on_scenario_progress(self, completed: int, total: int) -> None:
|
|
119
|
+
pass
|
|
120
|
+
|
|
121
|
+
def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
|
|
122
|
+
pass
|
|
123
|
+
|
|
124
|
+
def on_response_progress(self, completed: int, total: int) -> None:
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
def on_responses_complete(self, traces: list[Trace]) -> None:
|
|
128
|
+
pass
|
|
129
|
+
|
|
130
|
+
def on_grading_progress(self, completed: int, total: int) -> None:
|
|
131
|
+
pass
|
|
132
|
+
|
|
133
|
+
def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
def on_refinement_start(self, iteration: int, failed_count: int) -> None:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def on_grading_skipped(self) -> None:
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
def on_logic_map_complete(self, logic_map) -> None:
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
|
|
149
|
+
pass
|
|
150
|
+
|
|
151
|
+
def on_hitl_start(self, rules_count: int) -> None:
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
|
|
155
|
+
pass
|
|
156
|
+
|
|
157
|
+
def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class RichReporter:
|
|
162
|
+
"""
|
|
163
|
+
Rich console reporter with progress bars and formatted output.
|
|
164
|
+
|
|
165
|
+
This is the default reporter that provides the familiar synkro CLI experience.
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
def __init__(self):
|
|
169
|
+
from rich.console import Console
|
|
170
|
+
self.console = Console()
|
|
171
|
+
self._progress = None
|
|
172
|
+
self._current_task = None
|
|
173
|
+
|
|
174
|
+
def on_start(self, traces: int, model: str, dataset_type: str) -> None:
|
|
175
|
+
from rich.panel import Panel
|
|
176
|
+
|
|
177
|
+
self.console.print()
|
|
178
|
+
self.console.print(Panel.fit(
|
|
179
|
+
f"[bold]Generating {traces} traces[/bold]\n"
|
|
180
|
+
f"[dim]Type: {dataset_type.upper()} | Model: {model}[/dim]",
|
|
181
|
+
title="[cyan]synkro[/cyan]",
|
|
182
|
+
border_style="cyan"
|
|
183
|
+
))
|
|
184
|
+
self.console.print()
|
|
185
|
+
|
|
186
|
+
def on_plan_complete(self, plan: Plan) -> None:
|
|
187
|
+
from rich.table import Table
|
|
188
|
+
|
|
189
|
+
self.console.print(f"[green]📋 Planning[/green] [dim]{len(plan.categories)} categories[/dim]")
|
|
190
|
+
|
|
191
|
+
cat_table = Table(title="Categories", show_header=True, header_style="bold cyan")
|
|
192
|
+
cat_table.add_column("Name")
|
|
193
|
+
cat_table.add_column("Description")
|
|
194
|
+
cat_table.add_column("Count", justify="right")
|
|
195
|
+
for cat in plan.categories:
|
|
196
|
+
cat_table.add_row(cat.name, cat.description, str(cat.count))
|
|
197
|
+
self.console.print(cat_table)
|
|
198
|
+
self.console.print()
|
|
199
|
+
|
|
200
|
+
def on_scenario_progress(self, completed: int, total: int) -> None:
|
|
201
|
+
pass # Progress shown in on_scenarios_complete
|
|
202
|
+
|
|
203
|
+
def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
|
|
204
|
+
self.console.print(f"[green]💡 Scenarios[/green] [dim]{len(scenarios)} created[/dim]")
|
|
205
|
+
for idx, s in enumerate(scenarios, 1):
|
|
206
|
+
desc = s.description[:80] + "..." if len(s.description) > 80 else s.description
|
|
207
|
+
self.console.print(f" [dim]#{idx}[/dim] [yellow]{desc}[/yellow]")
|
|
208
|
+
|
|
209
|
+
def on_response_progress(self, completed: int, total: int) -> None:
|
|
210
|
+
pass # Progress shown in on_responses_complete
|
|
211
|
+
|
|
212
|
+
def on_responses_complete(self, traces: list[Trace]) -> None:
|
|
213
|
+
self.console.print(f"[green]✍️ Responses[/green] [dim]{len(traces)} generated[/dim]")
|
|
214
|
+
for idx, trace in enumerate(traces, 1):
|
|
215
|
+
user_preview = trace.user_message[:60] + "..." if len(trace.user_message) > 60 else trace.user_message
|
|
216
|
+
asst_preview = trace.assistant_message[:60] + "..." if len(trace.assistant_message) > 60 else trace.assistant_message
|
|
217
|
+
self.console.print(f" [dim]#{idx}[/dim] [blue]User:[/blue] {user_preview}")
|
|
218
|
+
self.console.print(f" [green]Assistant:[/green] {asst_preview}")
|
|
219
|
+
|
|
220
|
+
def on_grading_progress(self, completed: int, total: int) -> None:
|
|
221
|
+
pass # Progress shown in on_grading_complete
|
|
222
|
+
|
|
223
|
+
def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
|
|
224
|
+
self.console.print(f"[green]⚖️ Grading[/green] [dim]{pass_rate:.0f}% passed[/dim]")
|
|
225
|
+
for idx, trace in enumerate(traces, 1):
|
|
226
|
+
scenario_preview = trace.scenario.description[:40] + "..." if len(trace.scenario.description) > 40 else trace.scenario.description
|
|
227
|
+
if trace.grade and trace.grade.passed:
|
|
228
|
+
self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [green]✓ Passed[/green]")
|
|
229
|
+
else:
|
|
230
|
+
issues = ", ".join(trace.grade.issues[:2]) if trace.grade and trace.grade.issues else "No specific issues"
|
|
231
|
+
issues_preview = issues[:40] + "..." if len(issues) > 40 else issues
|
|
232
|
+
self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [red]✗ Failed[/red] [dim]{issues_preview}[/dim]")
|
|
233
|
+
|
|
234
|
+
def on_refinement_start(self, iteration: int, failed_count: int) -> None:
|
|
235
|
+
self.console.print(f" [yellow]↻ Refining {failed_count} failed traces (iteration {iteration})...[/yellow]")
|
|
236
|
+
|
|
237
|
+
def on_grading_skipped(self) -> None:
|
|
238
|
+
self.console.print(f" [dim]⚖️ Grading skipped[/dim]")
|
|
239
|
+
|
|
240
|
+
def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
|
|
241
|
+
from rich.panel import Panel
|
|
242
|
+
from rich.table import Table
|
|
243
|
+
|
|
244
|
+
elapsed_str = f"{int(elapsed_seconds) // 60}m {int(elapsed_seconds) % 60}s" if elapsed_seconds >= 60 else f"{int(elapsed_seconds)}s"
|
|
245
|
+
|
|
246
|
+
self.console.print()
|
|
247
|
+
summary = Table.grid(padding=(0, 2))
|
|
248
|
+
summary.add_column(style="green")
|
|
249
|
+
summary.add_column()
|
|
250
|
+
summary.add_row("✅ Done!", f"Generated {dataset_size} traces in {elapsed_str}")
|
|
251
|
+
if pass_rate is not None:
|
|
252
|
+
summary.add_row("📊 Quality:", f"{pass_rate:.0f}% passed verification")
|
|
253
|
+
self.console.print(Panel(summary, border_style="green", title="[green]Complete[/green]"))
|
|
254
|
+
self.console.print()
|
|
255
|
+
|
|
256
|
+
def on_logic_map_complete(self, logic_map) -> None:
|
|
257
|
+
"""Display the extracted Logic Map (Stage 1)."""
|
|
258
|
+
from rich.panel import Panel
|
|
259
|
+
from rich.tree import Tree
|
|
260
|
+
|
|
261
|
+
self.console.print(f"\n[green]📜 Logic Map[/green] [dim]{len(logic_map.rules)} rules extracted[/dim]")
|
|
262
|
+
|
|
263
|
+
# Show rules as a tree
|
|
264
|
+
tree = Tree("[bold cyan]Rules[/bold cyan]")
|
|
265
|
+
|
|
266
|
+
# Group by category
|
|
267
|
+
by_category = {}
|
|
268
|
+
for rule in logic_map.rules:
|
|
269
|
+
cat = rule.category.value
|
|
270
|
+
by_category.setdefault(cat, []).append(rule)
|
|
271
|
+
|
|
272
|
+
for category, rules in by_category.items():
|
|
273
|
+
cat_branch = tree.add(f"[yellow]{category.upper()}[/yellow] ({len(rules)} rules)")
|
|
274
|
+
for rule in rules[:3]: # Show first 3 per category
|
|
275
|
+
deps = f" → {', '.join(rule.dependencies)}" if rule.dependencies else ""
|
|
276
|
+
rule_text = rule.text[:50] + "..." if len(rule.text) > 50 else rule.text
|
|
277
|
+
cat_branch.add(f"[dim]{rule.rule_id}[/dim]: {rule_text}{deps}")
|
|
278
|
+
if len(rules) > 3:
|
|
279
|
+
cat_branch.add(f"[dim]... and {len(rules) - 3} more[/dim]")
|
|
280
|
+
|
|
281
|
+
self.console.print(tree)
|
|
282
|
+
|
|
283
|
+
# Show dependency chains
|
|
284
|
+
if logic_map.root_rules:
|
|
285
|
+
self.console.print(f" [dim]Root rules: {', '.join(logic_map.root_rules)}[/dim]")
|
|
286
|
+
|
|
287
|
+
def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
|
|
288
|
+
"""Display golden scenarios as a category × type matrix (Stage 2)."""
|
|
289
|
+
from rich.table import Table
|
|
290
|
+
|
|
291
|
+
self.console.print(f"\n[green]💡 Golden Scenarios[/green] [dim]{len(scenarios)} created[/dim]")
|
|
292
|
+
|
|
293
|
+
# Build category × type matrix
|
|
294
|
+
matrix: dict[str, dict[str, int]] = {}
|
|
295
|
+
for s in scenarios:
|
|
296
|
+
cat = s.category or "uncategorized"
|
|
297
|
+
stype = s.scenario_type.value if hasattr(s.scenario_type, 'value') else s.scenario_type
|
|
298
|
+
matrix.setdefault(cat, {"positive": 0, "negative": 0, "edge_case": 0, "irrelevant": 0})
|
|
299
|
+
matrix[cat][stype] += 1
|
|
300
|
+
|
|
301
|
+
# Create the combined table
|
|
302
|
+
table = Table(title="Scenario Distribution", show_header=True, header_style="bold cyan")
|
|
303
|
+
table.add_column("Category", style="cyan")
|
|
304
|
+
table.add_column("[green]✓ Positive[/green]", justify="right")
|
|
305
|
+
table.add_column("[red]✗ Negative[/red]", justify="right")
|
|
306
|
+
table.add_column("[yellow]⚡ Edge[/yellow]", justify="right")
|
|
307
|
+
table.add_column("[dim]○ Irrelevant[/dim]", justify="right")
|
|
308
|
+
table.add_column("Total", justify="right", style="bold")
|
|
309
|
+
|
|
310
|
+
# Track column totals
|
|
311
|
+
totals = {"positive": 0, "negative": 0, "edge_case": 0, "irrelevant": 0}
|
|
312
|
+
|
|
313
|
+
for cat_name, counts in matrix.items():
|
|
314
|
+
row_total = sum(counts.values())
|
|
315
|
+
table.add_row(
|
|
316
|
+
cat_name,
|
|
317
|
+
str(counts["positive"]),
|
|
318
|
+
str(counts["negative"]),
|
|
319
|
+
str(counts["edge_case"]),
|
|
320
|
+
str(counts["irrelevant"]),
|
|
321
|
+
str(row_total),
|
|
322
|
+
)
|
|
323
|
+
for stype, count in counts.items():
|
|
324
|
+
totals[stype] += count
|
|
325
|
+
|
|
326
|
+
# Add totals row
|
|
327
|
+
grand_total = sum(totals.values())
|
|
328
|
+
table.add_section()
|
|
329
|
+
table.add_row(
|
|
330
|
+
"[bold]Total[/bold]",
|
|
331
|
+
f"[bold]{totals['positive']}[/bold]",
|
|
332
|
+
f"[bold]{totals['negative']}[/bold]",
|
|
333
|
+
f"[bold]{totals['edge_case']}[/bold]",
|
|
334
|
+
f"[bold]{totals['irrelevant']}[/bold]",
|
|
335
|
+
f"[bold]{grand_total}[/bold]",
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
self.console.print(table)
|
|
339
|
+
|
|
340
|
+
def on_responses_complete(self, traces: list[Trace]) -> None:
|
|
341
|
+
"""Enhanced to show category and type for each trace."""
|
|
342
|
+
self.console.print(f"\n[green]✍️ Traces[/green] [dim]{len(traces)} generated[/dim]")
|
|
343
|
+
|
|
344
|
+
# Group by category
|
|
345
|
+
by_category = {}
|
|
346
|
+
for trace in traces:
|
|
347
|
+
cat = trace.scenario.category or "uncategorized"
|
|
348
|
+
by_category.setdefault(cat, []).append(trace)
|
|
349
|
+
|
|
350
|
+
for cat_name, cat_traces in by_category.items():
|
|
351
|
+
self.console.print(f"\n [cyan]📁 {cat_name}[/cyan] ({len(cat_traces)} traces)")
|
|
352
|
+
|
|
353
|
+
for trace in cat_traces[:3]: # Show first 3 per category
|
|
354
|
+
# Try to get scenario type if available
|
|
355
|
+
scenario_type = getattr(trace.scenario, 'scenario_type', None)
|
|
356
|
+
if scenario_type:
|
|
357
|
+
type_indicator = {
|
|
358
|
+
"positive": "[green]✓[/green]",
|
|
359
|
+
"negative": "[red]✗[/red]",
|
|
360
|
+
"edge_case": "[yellow]⚡[/yellow]",
|
|
361
|
+
"irrelevant": "[dim]○[/dim]"
|
|
362
|
+
}.get(scenario_type if isinstance(scenario_type, str) else scenario_type.value, "[white]?[/white]")
|
|
363
|
+
else:
|
|
364
|
+
type_indicator = "[white]•[/white]"
|
|
365
|
+
|
|
366
|
+
user_preview = trace.user_message[:50] + "..." if len(trace.user_message) > 50 else trace.user_message
|
|
367
|
+
self.console.print(f" {type_indicator} [blue]{user_preview}[/blue]")
|
|
368
|
+
|
|
369
|
+
if len(cat_traces) > 3:
|
|
370
|
+
self.console.print(f" [dim]... and {len(cat_traces) - 3} more[/dim]")
|
|
371
|
+
|
|
372
|
+
def on_hitl_start(self, rules_count: int) -> None:
|
|
373
|
+
"""Display HITL session start."""
|
|
374
|
+
from rich.panel import Panel
|
|
375
|
+
|
|
376
|
+
self.console.print()
|
|
377
|
+
self.console.print(Panel.fit(
|
|
378
|
+
f"[bold]Interactive Logic Map Editor[/bold]\n"
|
|
379
|
+
f"[dim]Review and refine {rules_count} extracted rules[/dim]",
|
|
380
|
+
title="[cyan]HITL Mode[/cyan]",
|
|
381
|
+
border_style="cyan"
|
|
382
|
+
))
|
|
383
|
+
|
|
384
|
+
def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
|
|
385
|
+
"""Display refinement result."""
|
|
386
|
+
feedback_preview = feedback[:60] + "..." if len(feedback) > 60 else feedback
|
|
387
|
+
self.console.print(f" [green]✓[/green] [dim]{feedback_preview}[/dim]")
|
|
388
|
+
self.console.print(f" [cyan]{changes_summary}[/cyan]")
|
|
389
|
+
|
|
390
|
+
def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
|
|
391
|
+
"""Display HITL session completion."""
|
|
392
|
+
if change_count > 0:
|
|
393
|
+
self.console.print(
|
|
394
|
+
f"\n[green]✅ HITL Complete[/green] - "
|
|
395
|
+
f"Made {change_count} change(s), proceeding with {final_rules_count} rules"
|
|
396
|
+
)
|
|
397
|
+
else:
|
|
398
|
+
self.console.print(
|
|
399
|
+
f"\n[green]✅ HITL Complete[/green] - "
|
|
400
|
+
f"No changes made, proceeding with {final_rules_count} rules"
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class CallbackReporter:
|
|
405
|
+
"""
|
|
406
|
+
Reporter that invokes user-provided callbacks for progress events.
|
|
407
|
+
|
|
408
|
+
Use this when you need programmatic access to progress events
|
|
409
|
+
(e.g., updating a progress bar, logging to a file, etc.)
|
|
410
|
+
|
|
411
|
+
Examples:
|
|
412
|
+
>>> def on_progress(event: str, data: dict):
|
|
413
|
+
... print(f"{event}: {data}")
|
|
414
|
+
...
|
|
415
|
+
>>> reporter = CallbackReporter(on_progress=on_progress)
|
|
416
|
+
>>> generator = Generator(reporter=reporter)
|
|
417
|
+
|
|
418
|
+
>>> # With specific event handlers
|
|
419
|
+
>>> reporter = CallbackReporter(
|
|
420
|
+
... on_start=lambda traces, model, dtype: print(f"Starting {traces} traces"),
|
|
421
|
+
... on_complete=lambda size, elapsed, rate: print(f"Done! {size} traces"),
|
|
422
|
+
... )
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
on_progress: "Callable[[str, dict], None] | None" = None,
|
|
428
|
+
on_start: "Callable[[int, str, str], None] | None" = None,
|
|
429
|
+
on_plan_complete: "Callable[[Plan], None] | None" = None,
|
|
430
|
+
on_scenario_progress: "Callable[[int, int], None] | None" = None,
|
|
431
|
+
on_scenarios_complete: "Callable[[list[Scenario]], None] | None" = None,
|
|
432
|
+
on_response_progress: "Callable[[int, int], None] | None" = None,
|
|
433
|
+
on_responses_complete: "Callable[[list[Trace]], None] | None" = None,
|
|
434
|
+
on_grading_progress: "Callable[[int, int], None] | None" = None,
|
|
435
|
+
on_grading_complete: "Callable[[list[Trace], float], None] | None" = None,
|
|
436
|
+
on_complete: "Callable[[int, float, float | None], None] | None" = None,
|
|
437
|
+
):
|
|
438
|
+
"""
|
|
439
|
+
Initialize the callback reporter.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
on_progress: Generic callback for all events. Receives (event_name, data_dict).
|
|
443
|
+
on_start: Called when generation starts (traces, model, dataset_type)
|
|
444
|
+
on_plan_complete: Called when planning completes (plan)
|
|
445
|
+
on_scenario_progress: Called during scenario generation (completed, total)
|
|
446
|
+
on_scenarios_complete: Called when scenarios are done (scenarios list)
|
|
447
|
+
on_response_progress: Called during response generation (completed, total)
|
|
448
|
+
on_responses_complete: Called when responses are done (traces list)
|
|
449
|
+
on_grading_progress: Called during grading (completed, total)
|
|
450
|
+
on_grading_complete: Called when grading is done (traces, pass_rate)
|
|
451
|
+
on_complete: Called when generation completes (dataset_size, elapsed, pass_rate)
|
|
452
|
+
"""
|
|
453
|
+
self._on_progress = on_progress
|
|
454
|
+
self._on_start = on_start
|
|
455
|
+
self._on_plan_complete = on_plan_complete
|
|
456
|
+
self._on_scenario_progress = on_scenario_progress
|
|
457
|
+
self._on_scenarios_complete = on_scenarios_complete
|
|
458
|
+
self._on_response_progress = on_response_progress
|
|
459
|
+
self._on_responses_complete = on_responses_complete
|
|
460
|
+
self._on_grading_progress = on_grading_progress
|
|
461
|
+
self._on_grading_complete = on_grading_complete
|
|
462
|
+
self._on_complete_cb = on_complete
|
|
463
|
+
|
|
464
|
+
def _emit(self, event: str, data: dict) -> None:
|
|
465
|
+
"""Emit an event to the generic callback."""
|
|
466
|
+
if self._on_progress:
|
|
467
|
+
self._on_progress(event, data)
|
|
468
|
+
|
|
469
|
+
def on_start(self, traces: int, model: str, dataset_type: str) -> None:
|
|
470
|
+
self._emit("start", {"traces": traces, "model": model, "dataset_type": dataset_type})
|
|
471
|
+
if self._on_start:
|
|
472
|
+
self._on_start(traces, model, dataset_type)
|
|
473
|
+
|
|
474
|
+
def on_plan_complete(self, plan: Plan) -> None:
|
|
475
|
+
self._emit("plan_complete", {"categories": len(plan.categories)})
|
|
476
|
+
if self._on_plan_complete:
|
|
477
|
+
self._on_plan_complete(plan)
|
|
478
|
+
|
|
479
|
+
def on_scenario_progress(self, completed: int, total: int) -> None:
|
|
480
|
+
self._emit("scenario_progress", {"completed": completed, "total": total})
|
|
481
|
+
if self._on_scenario_progress:
|
|
482
|
+
self._on_scenario_progress(completed, total)
|
|
483
|
+
|
|
484
|
+
def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
|
|
485
|
+
self._emit("scenarios_complete", {"count": len(scenarios)})
|
|
486
|
+
if self._on_scenarios_complete:
|
|
487
|
+
self._on_scenarios_complete(scenarios)
|
|
488
|
+
|
|
489
|
+
def on_response_progress(self, completed: int, total: int) -> None:
|
|
490
|
+
self._emit("response_progress", {"completed": completed, "total": total})
|
|
491
|
+
if self._on_response_progress:
|
|
492
|
+
self._on_response_progress(completed, total)
|
|
493
|
+
|
|
494
|
+
def on_responses_complete(self, traces: list[Trace]) -> None:
|
|
495
|
+
self._emit("responses_complete", {"count": len(traces)})
|
|
496
|
+
if self._on_responses_complete:
|
|
497
|
+
self._on_responses_complete(traces)
|
|
498
|
+
|
|
499
|
+
def on_grading_progress(self, completed: int, total: int) -> None:
|
|
500
|
+
self._emit("grading_progress", {"completed": completed, "total": total})
|
|
501
|
+
if self._on_grading_progress:
|
|
502
|
+
self._on_grading_progress(completed, total)
|
|
503
|
+
|
|
504
|
+
def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
|
|
505
|
+
self._emit("grading_complete", {"count": len(traces), "pass_rate": pass_rate})
|
|
506
|
+
if self._on_grading_complete:
|
|
507
|
+
self._on_grading_complete(traces, pass_rate)
|
|
508
|
+
|
|
509
|
+
def on_refinement_start(self, iteration: int, failed_count: int) -> None:
|
|
510
|
+
self._emit("refinement_start", {"iteration": iteration, "failed_count": failed_count})
|
|
511
|
+
|
|
512
|
+
def on_grading_skipped(self) -> None:
|
|
513
|
+
self._emit("grading_skipped", {})
|
|
514
|
+
|
|
515
|
+
def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
|
|
516
|
+
self._emit("complete", {"dataset_size": dataset_size, "elapsed_seconds": elapsed_seconds, "pass_rate": pass_rate})
|
|
517
|
+
if self._on_complete_cb:
|
|
518
|
+
self._on_complete_cb(dataset_size, elapsed_seconds, pass_rate)
|
|
519
|
+
|
|
520
|
+
def on_logic_map_complete(self, logic_map) -> None:
|
|
521
|
+
self._emit("logic_map_complete", {"rules_count": len(logic_map.rules)})
|
|
522
|
+
|
|
523
|
+
def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
|
|
524
|
+
self._emit("golden_scenarios_complete", {"count": len(scenarios), "distribution": distribution})
|
|
525
|
+
|
|
526
|
+
def on_hitl_start(self, rules_count: int) -> None:
|
|
527
|
+
self._emit("hitl_start", {"rules_count": rules_count})
|
|
528
|
+
|
|
529
|
+
def on_hitl_refinement(self, feedback: str, changes_summary: str) -> None:
|
|
530
|
+
self._emit("hitl_refinement", {"feedback": feedback, "changes_summary": changes_summary})
|
|
531
|
+
|
|
532
|
+
def on_hitl_complete(self, change_count: int, final_rules_count: int) -> None:
|
|
533
|
+
self._emit("hitl_complete", {"change_count": change_count, "final_rules_count": final_rules_count})
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
__all__ = ["ProgressReporter", "SilentReporter", "RichReporter", "CallbackReporter"]
|
|
537
|
+
|