synkro 0.4.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synkro might be problematic. Click here for more details.

Files changed (81) hide show
  1. synkro/__init__.py +331 -0
  2. synkro/advanced.py +184 -0
  3. synkro/cli.py +156 -0
  4. synkro/core/__init__.py +7 -0
  5. synkro/core/checkpoint.py +250 -0
  6. synkro/core/dataset.py +432 -0
  7. synkro/core/policy.py +337 -0
  8. synkro/errors.py +178 -0
  9. synkro/examples/__init__.py +148 -0
  10. synkro/factory.py +291 -0
  11. synkro/formatters/__init__.py +18 -0
  12. synkro/formatters/chatml.py +121 -0
  13. synkro/formatters/langfuse.py +98 -0
  14. synkro/formatters/langsmith.py +98 -0
  15. synkro/formatters/qa.py +112 -0
  16. synkro/formatters/sft.py +90 -0
  17. synkro/formatters/tool_call.py +127 -0
  18. synkro/generation/__init__.py +9 -0
  19. synkro/generation/follow_ups.py +134 -0
  20. synkro/generation/generator.py +314 -0
  21. synkro/generation/golden_responses.py +269 -0
  22. synkro/generation/golden_scenarios.py +333 -0
  23. synkro/generation/golden_tool_responses.py +791 -0
  24. synkro/generation/logic_extractor.py +126 -0
  25. synkro/generation/multiturn_responses.py +177 -0
  26. synkro/generation/planner.py +131 -0
  27. synkro/generation/responses.py +189 -0
  28. synkro/generation/scenarios.py +90 -0
  29. synkro/generation/tool_responses.py +625 -0
  30. synkro/generation/tool_simulator.py +114 -0
  31. synkro/interactive/__init__.py +16 -0
  32. synkro/interactive/hitl_session.py +205 -0
  33. synkro/interactive/intent_classifier.py +94 -0
  34. synkro/interactive/logic_map_editor.py +176 -0
  35. synkro/interactive/rich_ui.py +459 -0
  36. synkro/interactive/scenario_editor.py +198 -0
  37. synkro/llm/__init__.py +7 -0
  38. synkro/llm/client.py +309 -0
  39. synkro/llm/rate_limits.py +99 -0
  40. synkro/models/__init__.py +50 -0
  41. synkro/models/anthropic.py +26 -0
  42. synkro/models/google.py +19 -0
  43. synkro/models/local.py +104 -0
  44. synkro/models/openai.py +31 -0
  45. synkro/modes/__init__.py +13 -0
  46. synkro/modes/config.py +66 -0
  47. synkro/modes/conversation.py +35 -0
  48. synkro/modes/tool_call.py +18 -0
  49. synkro/parsers.py +442 -0
  50. synkro/pipeline/__init__.py +20 -0
  51. synkro/pipeline/phases.py +592 -0
  52. synkro/pipeline/runner.py +769 -0
  53. synkro/pipelines.py +136 -0
  54. synkro/prompts/__init__.py +57 -0
  55. synkro/prompts/base.py +167 -0
  56. synkro/prompts/golden_templates.py +533 -0
  57. synkro/prompts/interactive_templates.py +198 -0
  58. synkro/prompts/multiturn_templates.py +156 -0
  59. synkro/prompts/templates.py +281 -0
  60. synkro/prompts/tool_templates.py +318 -0
  61. synkro/quality/__init__.py +14 -0
  62. synkro/quality/golden_refiner.py +163 -0
  63. synkro/quality/grader.py +153 -0
  64. synkro/quality/multiturn_grader.py +150 -0
  65. synkro/quality/refiner.py +137 -0
  66. synkro/quality/tool_grader.py +126 -0
  67. synkro/quality/tool_refiner.py +128 -0
  68. synkro/quality/verifier.py +228 -0
  69. synkro/reporting.py +464 -0
  70. synkro/schemas.py +521 -0
  71. synkro/types/__init__.py +43 -0
  72. synkro/types/core.py +153 -0
  73. synkro/types/dataset_type.py +33 -0
  74. synkro/types/logic_map.py +348 -0
  75. synkro/types/tool.py +94 -0
  76. synkro-0.4.36.data/data/examples/__init__.py +148 -0
  77. synkro-0.4.36.dist-info/METADATA +507 -0
  78. synkro-0.4.36.dist-info/RECORD +81 -0
  79. synkro-0.4.36.dist-info/WHEEL +4 -0
  80. synkro-0.4.36.dist-info/entry_points.txt +2 -0
  81. synkro-0.4.36.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,228 @@
1
+ """Trace Verifier - The Auditor.
2
+
3
+ Verifies generated traces against the Logic Map to ensure:
4
+ - No skipped rules
5
+ - No hallucinated rules
6
+ - No contradictions
7
+ - DAG compliance
8
+
9
+ This is Stage 4 of the Golden Trace pipeline.
10
+ """
11
+
12
+ from synkro.llm.client import LLM
13
+ from synkro.models import Model, OpenAI
14
+ from synkro.schemas import VerificationOutput
15
+ from synkro.types.core import Trace, GradeResult
16
+ from synkro.types.logic_map import LogicMap, GoldenScenario, VerificationResult
17
+ from synkro.prompts.golden_templates import VERIFICATION_PROMPT
18
+
19
+
20
+ class TraceVerifier:
21
+ """
22
+ The Auditor - Verifies traces against the Logic Map.
23
+
24
+ Performs strict verification to ensure:
25
+ 1. No Skipped Rules: All target rules were evaluated
26
+ 2. No Hallucinated Rules: Only valid rules were cited
27
+ 3. No Contradictions: Reasoning is internally consistent
28
+ 4. DAG Compliance: Dependency order was followed
29
+ 5. Outcome Alignment: Response matches expected outcome
30
+
31
+ Examples:
32
+ >>> verifier = TraceVerifier(llm=LLM(model=OpenAI.GPT_4O))
33
+ >>> result = await verifier.verify(trace, logic_map, scenario)
34
+ >>> if result.passed:
35
+ ... print("Trace verified successfully")
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ llm: LLM | None = None,
41
+ model: Model = OpenAI.GPT_4O,
42
+ ):
43
+ """
44
+ Initialize the Trace Verifier.
45
+
46
+ Args:
47
+ llm: LLM client to use (creates one if not provided)
48
+ model: Model to use if creating LLM (default: GPT-4O for accuracy)
49
+ """
50
+ self.llm = llm or LLM(model=model, temperature=0.1)
51
+
52
+ async def verify(
53
+ self,
54
+ trace: Trace,
55
+ logic_map: LogicMap,
56
+ scenario: GoldenScenario,
57
+ reasoning_chain: list | None = None,
58
+ rules_applied: list[str] | None = None,
59
+ rules_excluded: list[str] | None = None,
60
+ ) -> VerificationResult:
61
+ """
62
+ Verify a trace against the Logic Map.
63
+
64
+ Args:
65
+ trace: The trace to verify
66
+ logic_map: The Logic Map (ground truth)
67
+ scenario: The golden scenario
68
+ reasoning_chain: Optional reasoning chain from generation
69
+ rules_applied: Optional list of rules claimed applied
70
+ rules_excluded: Optional list of rules claimed excluded
71
+
72
+ Returns:
73
+ VerificationResult with pass/fail and detailed issues
74
+ """
75
+ # Format inputs for prompt
76
+ logic_map_str = self._format_logic_map(logic_map)
77
+ trace_messages_str = self._format_trace_messages(trace)
78
+ reasoning_str = self._format_reasoning_chain(reasoning_chain) if reasoning_chain else "Not provided"
79
+
80
+ # Build prompt
81
+ prompt = VERIFICATION_PROMPT.format(
82
+ logic_map=logic_map_str,
83
+ scenario_type=scenario.scenario_type.value.upper(),
84
+ scenario_description=scenario.description,
85
+ target_rule_ids=", ".join(scenario.target_rule_ids),
86
+ expected_outcome=scenario.expected_outcome,
87
+ trace_messages=trace_messages_str,
88
+ reasoning_chain=reasoning_str,
89
+ rules_applied=", ".join(rules_applied) if rules_applied else "Not specified",
90
+ rules_excluded=", ".join(rules_excluded) if rules_excluded else "Not specified",
91
+ )
92
+
93
+ # Generate structured output
94
+ result = await self.llm.generate_structured(prompt, VerificationOutput)
95
+
96
+ # Convert to domain model
97
+ return VerificationResult(
98
+ passed=result.passed,
99
+ issues=result.issues,
100
+ skipped_rules=result.skipped_rules,
101
+ hallucinated_rules=result.hallucinated_rules,
102
+ contradictions=result.contradictions,
103
+ rules_verified=result.rules_verified,
104
+ )
105
+
106
+ def _format_logic_map(self, logic_map: LogicMap) -> str:
107
+ """Format Logic Map for verification prompt."""
108
+ lines = []
109
+ lines.append("RULES:")
110
+ for rule in logic_map.rules:
111
+ deps = f" [depends on: {', '.join(rule.dependencies)}]" if rule.dependencies else ""
112
+ lines.append(
113
+ f" {rule.rule_id} ({rule.category.value}): {rule.text}{deps}"
114
+ )
115
+ lines.append(f" IF: {rule.condition}")
116
+ lines.append(f" THEN: {rule.action}")
117
+
118
+ lines.append("\nROOT RULES (Entry Points):")
119
+ lines.append(f" {', '.join(logic_map.root_rules)}")
120
+
121
+ return "\n".join(lines)
122
+
123
+ def _format_trace_messages(self, trace: Trace) -> str:
124
+ """Format trace messages for verification prompt."""
125
+ lines = []
126
+ for i, msg in enumerate(trace.messages):
127
+ role = msg.role.upper()
128
+ content = msg.content or "(no content)"
129
+
130
+ # Handle tool calls
131
+ if msg.tool_calls:
132
+ tool_info = []
133
+ for tc in msg.tool_calls:
134
+ if hasattr(tc, 'function'):
135
+ tool_info.append(f" - {tc.function.name}({tc.function.arguments})")
136
+ elif isinstance(tc, dict):
137
+ func = tc.get('function', {})
138
+ tool_info.append(f" - {func.get('name', 'unknown')}({func.get('arguments', '{}')})")
139
+ content = "Tool calls:\n" + "\n".join(tool_info)
140
+
141
+ # Handle tool responses
142
+ if msg.tool_call_id:
143
+ role = f"TOOL (call_id: {msg.tool_call_id})"
144
+
145
+ lines.append(f"[{role}] {content}")
146
+
147
+ return "\n\n".join(lines)
148
+
149
+ def _format_reasoning_chain(self, reasoning_chain: list) -> str:
150
+ """Format reasoning chain for verification prompt."""
151
+ lines = []
152
+ for i, step in enumerate(reasoning_chain, 1):
153
+ if hasattr(step, 'rule_id'):
154
+ applies = "APPLIES" if step.applies else "DOES NOT APPLY"
155
+ lines.append(f"Step {i}: {step.rule_id} - {applies}")
156
+ lines.append(f" Rule: {step.rule_text}")
157
+ lines.append(f" Reasoning: {step.reasoning}")
158
+ if step.exclusions:
159
+ lines.append(f" Excludes: {', '.join(step.exclusions)}")
160
+ else:
161
+ # Handle dict format
162
+ applies = "APPLIES" if step.get('applies', False) else "DOES NOT APPLY"
163
+ lines.append(f"Step {i}: {step.get('rule_id', 'unknown')} - {applies}")
164
+ lines.append(f" Reasoning: {step.get('reasoning', 'N/A')}")
165
+
166
+ return "\n".join(lines)
167
+
168
+ async def verify_and_grade(
169
+ self,
170
+ trace: Trace,
171
+ logic_map: LogicMap,
172
+ scenario: GoldenScenario,
173
+ ) -> tuple[VerificationResult, GradeResult]:
174
+ """
175
+ Verify a trace and convert to GradeResult for pipeline compatibility.
176
+
177
+ Args:
178
+ trace: The trace to verify
179
+ logic_map: The Logic Map
180
+ scenario: The golden scenario
181
+
182
+ Returns:
183
+ Tuple of (VerificationResult, GradeResult)
184
+ """
185
+ # Extract reasoning chain metadata from trace (if present)
186
+ reasoning_chain = getattr(trace, 'reasoning_chain', None)
187
+ rules_applied = getattr(trace, 'rules_applied', None)
188
+ rules_excluded = getattr(trace, 'rules_excluded', None)
189
+
190
+ verification = await self.verify(
191
+ trace, logic_map, scenario,
192
+ reasoning_chain=reasoning_chain,
193
+ rules_applied=rules_applied,
194
+ rules_excluded=rules_excluded,
195
+ )
196
+
197
+ # Convert to GradeResult for pipeline compatibility
198
+ grade = GradeResult(
199
+ passed=verification.passed,
200
+ issues=verification.issues,
201
+ feedback=self._create_feedback(verification),
202
+ )
203
+
204
+ return verification, grade
205
+
206
+ def _create_feedback(self, verification: VerificationResult) -> str:
207
+ """Create feedback string from verification result."""
208
+ if verification.passed:
209
+ return f"Verified. Rules correctly applied: {', '.join(verification.rules_verified)}"
210
+
211
+ feedback_parts = []
212
+
213
+ if verification.skipped_rules:
214
+ feedback_parts.append(f"Skipped rules: {', '.join(verification.skipped_rules)}")
215
+
216
+ if verification.hallucinated_rules:
217
+ feedback_parts.append(f"Hallucinated rules: {', '.join(verification.hallucinated_rules)}")
218
+
219
+ if verification.contradictions:
220
+ feedback_parts.append(f"Contradictions: {'; '.join(verification.contradictions)}")
221
+
222
+ if verification.issues:
223
+ feedback_parts.append(f"Other issues: {'; '.join(verification.issues)}")
224
+
225
+ return " | ".join(feedback_parts) if feedback_parts else "Verification failed"
226
+
227
+
228
+ __all__ = ["TraceVerifier"]
synkro/reporting.py ADDED
@@ -0,0 +1,464 @@
1
+ """Progress reporting abstraction for generation pipeline.
2
+
3
+ This module provides a Protocol for progress reporting and implementations
4
+ for different use cases (rich console, silent for testing, etc.).
5
+
6
+ Enhanced for Golden Trace pipeline with:
7
+ - Logic Map logging
8
+ - Scenario type distribution
9
+ - Per-trace category/type logging
10
+ """
11
+
12
+ from typing import Protocol, TYPE_CHECKING, Callable
13
+
14
+ from synkro.types.core import Plan, Scenario, Trace, GradeResult
15
+
16
+ if TYPE_CHECKING:
17
+ from synkro.types.logic_map import LogicMap, GoldenScenario
18
+
19
+
20
+ class ProgressReporter(Protocol):
21
+ """
22
+ Protocol for reporting generation progress.
23
+
24
+ Implement this to customize how progress is displayed or logged.
25
+
26
+ Examples:
27
+ >>> # Use silent reporter for testing
28
+ >>> generator = Generator(reporter=SilentReporter())
29
+
30
+ >>> # Use rich reporter for CLI (default)
31
+ >>> generator = Generator(reporter=RichReporter())
32
+ """
33
+
34
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
35
+ """Called when generation starts."""
36
+ ...
37
+
38
+ def on_plan_complete(self, plan: Plan) -> None:
39
+ """Called when planning phase completes."""
40
+ ...
41
+
42
+ def on_scenario_progress(self, completed: int, total: int) -> None:
43
+ """Called during scenario generation."""
44
+ ...
45
+
46
+ def on_response_progress(self, completed: int, total: int) -> None:
47
+ """Called during response generation."""
48
+ ...
49
+
50
+ def on_responses_complete(self, traces: list[Trace]) -> None:
51
+ """Called when all responses are generated."""
52
+ ...
53
+
54
+ def on_grading_progress(self, completed: int, total: int) -> None:
55
+ """Called during grading."""
56
+ ...
57
+
58
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
59
+ """Called when grading completes."""
60
+ ...
61
+
62
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
63
+ """Called when a refinement iteration starts."""
64
+ ...
65
+
66
+ def on_grading_skipped(self) -> None:
67
+ """Called when grading is skipped."""
68
+ ...
69
+
70
+ def on_complete(
71
+ self,
72
+ dataset_size: int,
73
+ elapsed_seconds: float,
74
+ pass_rate: float | None,
75
+ total_cost: float | None = None,
76
+ generation_calls: int | None = None,
77
+ grading_calls: int | None = None,
78
+ scenario_calls: int | None = None,
79
+ response_calls: int | None = None,
80
+ ) -> None:
81
+ """Called when generation is complete."""
82
+ ...
83
+
84
+ def on_logic_map_complete(self, logic_map: "LogicMap") -> None:
85
+ """Called when logic extraction completes (Stage 1)."""
86
+ ...
87
+
88
+ def on_golden_scenarios_complete(
89
+ self, scenarios: list["GoldenScenario"], distribution: dict[str, int]
90
+ ) -> None:
91
+ """Called when golden scenarios are generated (Stage 2)."""
92
+ ...
93
+
94
+
95
+
96
+ class _NoOpContextManager:
97
+ """No-op context manager for SilentReporter spinner."""
98
+ def __enter__(self):
99
+ return self
100
+ def __exit__(self, *args):
101
+ pass
102
+
103
+
104
+ class SilentReporter:
105
+ """
106
+ No-op reporter for testing and embedding.
107
+
108
+ Use this when you don't want any console output.
109
+
110
+ Examples:
111
+ >>> generator = Generator(reporter=SilentReporter())
112
+ >>> dataset = generator.generate(policy) # No console output
113
+ """
114
+
115
+ def spinner(self, message: str = "Thinking..."):
116
+ """No-op spinner for silent mode."""
117
+ return _NoOpContextManager()
118
+
119
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
120
+ pass
121
+
122
+ def on_plan_complete(self, plan: Plan) -> None:
123
+ pass
124
+
125
+ def on_scenario_progress(self, completed: int, total: int) -> None:
126
+ pass
127
+
128
+ def on_response_progress(self, completed: int, total: int) -> None:
129
+ pass
130
+
131
+ def on_responses_complete(self, traces: list[Trace]) -> None:
132
+ pass
133
+
134
+ def on_grading_progress(self, completed: int, total: int) -> None:
135
+ pass
136
+
137
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
138
+ pass
139
+
140
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
141
+ pass
142
+
143
+ def on_grading_skipped(self) -> None:
144
+ pass
145
+
146
+ def on_complete(
147
+ self,
148
+ dataset_size: int,
149
+ elapsed_seconds: float,
150
+ pass_rate: float | None,
151
+ total_cost: float | None = None,
152
+ generation_calls: int | None = None,
153
+ grading_calls: int | None = None,
154
+ scenario_calls: int | None = None,
155
+ response_calls: int | None = None,
156
+ ) -> None:
157
+ pass
158
+
159
+ def on_logic_map_complete(self, logic_map) -> None:
160
+ pass
161
+
162
+ def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
163
+ pass
164
+
165
+
166
+
167
+ class RichReporter:
168
+ """
169
+ Rich console reporter with progress bars and formatted output.
170
+
171
+ This is the default reporter that provides the familiar synkro CLI experience.
172
+ """
173
+
174
+ def __init__(self):
175
+ from rich.console import Console
176
+ self.console = Console()
177
+ self._progress = None
178
+ self._current_task = None
179
+
180
+ def spinner(self, message: str = "Thinking..."):
181
+ """Context manager that shows a loading spinner.
182
+
183
+ Usage:
184
+ with reporter.spinner("Thinking..."):
185
+ await some_llm_call()
186
+ """
187
+ from rich.status import Status
188
+ self.console.print() # Add space above spinner
189
+ return Status(f"[cyan]{message}[/cyan]", spinner="dots", console=self.console)
190
+
191
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
192
+ self.console.print()
193
+ self.console.print(
194
+ f"[cyan]⚡ Generating {traces} traces[/cyan] "
195
+ f"[dim]| {dataset_type.upper()} | {model}[/dim]"
196
+ )
197
+
198
+ def on_plan_complete(self, plan: Plan) -> None:
199
+ from rich.table import Table
200
+
201
+ self.console.print(f"[green]📋 Planning[/green] [dim]{len(plan.categories)} categories[/dim]")
202
+
203
+ cat_table = Table(title="Categories", show_header=True, header_style="bold cyan")
204
+ cat_table.add_column("Name")
205
+ cat_table.add_column("Description")
206
+ cat_table.add_column("Count", justify="right")
207
+ for cat in plan.categories:
208
+ cat_table.add_row(cat.name, cat.description, str(cat.count))
209
+ self.console.print(cat_table)
210
+ self.console.print()
211
+
212
+ def on_scenario_progress(self, completed: int, total: int) -> None:
213
+ pass
214
+
215
+ def on_response_progress(self, completed: int, total: int) -> None:
216
+ pass
217
+
218
+ def on_grading_progress(self, completed: int, total: int) -> None:
219
+ pass # Progress shown in on_grading_complete
220
+
221
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
222
+ self.console.print(f"[green]⚖️ Grading[/green] [dim]{pass_rate:.0f}% passed[/dim]")
223
+ for idx, trace in enumerate(traces, 1):
224
+ scenario_preview = trace.scenario.description[:40] + "..." if len(trace.scenario.description) > 40 else trace.scenario.description
225
+ if trace.grade and trace.grade.passed:
226
+ self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [green]✓ Passed[/green]")
227
+ else:
228
+ issues = ", ".join(trace.grade.issues[:2]) if trace.grade and trace.grade.issues else "No specific issues"
229
+ issues_preview = issues[:40] + "..." if len(issues) > 40 else issues
230
+ self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [red]✗ Failed[/red] [dim]{issues_preview}[/dim]")
231
+
232
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
233
+ self.console.print(f" [yellow]↻ Refining {failed_count} failed traces (iteration {iteration})...[/yellow]")
234
+
235
+ def on_grading_skipped(self) -> None:
236
+ self.console.print(f" [dim]⚖️ Grading skipped[/dim]")
237
+
238
+ def on_complete(
239
+ self,
240
+ dataset_size: int,
241
+ elapsed_seconds: float,
242
+ pass_rate: float | None,
243
+ total_cost: float | None = None,
244
+ generation_calls: int | None = None,
245
+ grading_calls: int | None = None,
246
+ scenario_calls: int | None = None,
247
+ response_calls: int | None = None,
248
+ ) -> None:
249
+ from rich.panel import Panel
250
+ from rich.table import Table
251
+
252
+ elapsed_str = f"{int(elapsed_seconds) // 60}m {int(elapsed_seconds) % 60}s" if elapsed_seconds >= 60 else f"{int(elapsed_seconds)}s"
253
+
254
+ self.console.print()
255
+ summary = Table.grid(padding=(0, 2))
256
+ summary.add_column(style="green")
257
+ summary.add_column()
258
+ summary.add_row("✅ Done!", f"Generated {dataset_size} traces in {elapsed_str}")
259
+ if pass_rate is not None:
260
+ summary.add_row("📊 Quality:", f"{pass_rate:.0f}% passed verification")
261
+ if total_cost is not None and total_cost > 0:
262
+ summary.add_row("💰 Cost:", f"${total_cost:.4f}")
263
+ if scenario_calls is not None and response_calls is not None:
264
+ calls_str = f"{scenario_calls} scenario + {response_calls} response"
265
+ if grading_calls is not None and grading_calls > 0:
266
+ calls_str += f" + {grading_calls} grading"
267
+ summary.add_row("🔄 LLM Calls:", calls_str)
268
+ elif generation_calls is not None:
269
+ calls_str = f"{generation_calls} generation"
270
+ if grading_calls is not None and grading_calls > 0:
271
+ calls_str += f" + {grading_calls} grading"
272
+ summary.add_row("🔄 LLM Calls:", calls_str)
273
+ self.console.print(Panel(summary, border_style="green", title="[green]Complete[/green]"))
274
+ self.console.print()
275
+
276
+ def on_logic_map_complete(self, logic_map) -> None:
277
+ """Logic Map details shown in HITL session."""
278
+ pass
279
+
280
+ def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
281
+ """Scenario details shown in HITL session."""
282
+ pass
283
+
284
+ def on_responses_complete(self, traces: list[Trace]) -> None:
285
+ """Enhanced to show category and type for each trace."""
286
+ self.console.print(f"\n[green]✍️ Traces[/green] [dim]{len(traces)} generated[/dim]")
287
+
288
+ # Group by category
289
+ by_category = {}
290
+ for trace in traces:
291
+ cat = trace.scenario.category or "uncategorized"
292
+ by_category.setdefault(cat, []).append(trace)
293
+
294
+ for cat_name, cat_traces in by_category.items():
295
+ self.console.print(f"\n [cyan]📁 {cat_name}[/cyan] ({len(cat_traces)} traces)")
296
+
297
+ for trace in cat_traces[:3]: # Show first 3 per category
298
+ # Try to get scenario type if available
299
+ scenario_type = getattr(trace.scenario, 'scenario_type', None)
300
+ if scenario_type:
301
+ type_indicator = {
302
+ "positive": "[green]✓[/green]",
303
+ "negative": "[red]✗[/red]",
304
+ "edge_case": "[yellow]⚡[/yellow]",
305
+ "irrelevant": "[dim]○[/dim]"
306
+ }.get(scenario_type if isinstance(scenario_type, str) else scenario_type.value, "[white]?[/white]")
307
+ else:
308
+ type_indicator = "[white]•[/white]"
309
+
310
+ user_preview = trace.user_message[:50] + "..." if len(trace.user_message) > 50 else trace.user_message
311
+ self.console.print(f" {type_indicator} [blue]{user_preview}[/blue]")
312
+
313
+ if len(cat_traces) > 3:
314
+ self.console.print(f" [dim]... and {len(cat_traces) - 3} more[/dim]")
315
+
316
+
317
+ class CallbackReporter:
318
+ """
319
+ Reporter that invokes user-provided callbacks for progress events.
320
+
321
+ Use this when you need programmatic access to progress events
322
+ (e.g., updating a progress bar, logging to a file, etc.)
323
+
324
+ Examples:
325
+ >>> def on_progress(event: str, data: dict):
326
+ ... print(f"{event}: {data}")
327
+ ...
328
+ >>> reporter = CallbackReporter(on_progress=on_progress)
329
+ >>> generator = Generator(reporter=reporter)
330
+
331
+ >>> # With specific event handlers
332
+ >>> reporter = CallbackReporter(
333
+ ... on_start=lambda traces, model, dtype: print(f"Starting {traces} traces"),
334
+ ... on_complete=lambda size, elapsed, rate: print(f"Done! {size} traces"),
335
+ ... )
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ on_progress: "Callable[[str, dict], None] | None" = None,
341
+ on_start: "Callable[[int, str, str], None] | None" = None,
342
+ on_plan_complete: "Callable[[Plan], None] | None" = None,
343
+ on_scenario_progress: "Callable[[int, int], None] | None" = None,
344
+ on_scenarios_complete: "Callable[[list[Scenario]], None] | None" = None,
345
+ on_response_progress: "Callable[[int, int], None] | None" = None,
346
+ on_responses_complete: "Callable[[list[Trace]], None] | None" = None,
347
+ on_grading_progress: "Callable[[int, int], None] | None" = None,
348
+ on_grading_complete: "Callable[[list[Trace], float], None] | None" = None,
349
+ on_complete: "Callable[[int, float, float | None], None] | None" = None,
350
+ ):
351
+ """
352
+ Initialize the callback reporter.
353
+
354
+ Args:
355
+ on_progress: Generic callback for all events. Receives (event_name, data_dict).
356
+ on_start: Called when generation starts (traces, model, dataset_type)
357
+ on_plan_complete: Called when planning completes (plan)
358
+ on_scenario_progress: Called during scenario generation (completed, total)
359
+ on_scenarios_complete: Called when scenarios are done (scenarios list)
360
+ on_response_progress: Called during response generation (completed, total)
361
+ on_responses_complete: Called when responses are done (traces list)
362
+ on_grading_progress: Called during grading (completed, total)
363
+ on_grading_complete: Called when grading is done (traces, pass_rate)
364
+ on_complete: Called when generation completes (dataset_size, elapsed, pass_rate)
365
+ """
366
+ self._on_progress = on_progress
367
+ self._on_start = on_start
368
+ self._on_plan_complete = on_plan_complete
369
+ self._on_scenario_progress = on_scenario_progress
370
+ self._on_scenarios_complete = on_scenarios_complete
371
+ self._on_response_progress = on_response_progress
372
+ self._on_responses_complete = on_responses_complete
373
+ self._on_grading_progress = on_grading_progress
374
+ self._on_grading_complete = on_grading_complete
375
+ self._on_complete_cb = on_complete
376
+
377
+ def _emit(self, event: str, data: dict) -> None:
378
+ """Emit an event to the generic callback."""
379
+ if self._on_progress:
380
+ self._on_progress(event, data)
381
+
382
+ def spinner(self, message: str = "Thinking..."):
383
+ """No-op spinner for callback mode."""
384
+ return _NoOpContextManager()
385
+
386
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
387
+ self._emit("start", {"traces": traces, "model": model, "dataset_type": dataset_type})
388
+ if self._on_start:
389
+ self._on_start(traces, model, dataset_type)
390
+
391
+ def on_plan_complete(self, plan: Plan) -> None:
392
+ self._emit("plan_complete", {"categories": len(plan.categories)})
393
+ if self._on_plan_complete:
394
+ self._on_plan_complete(plan)
395
+
396
+ def on_scenario_progress(self, completed: int, total: int) -> None:
397
+ self._emit("scenario_progress", {"completed": completed, "total": total})
398
+ if self._on_scenario_progress:
399
+ self._on_scenario_progress(completed, total)
400
+
401
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
402
+ self._emit("scenarios_complete", {"count": len(scenarios)})
403
+ if self._on_scenarios_complete:
404
+ self._on_scenarios_complete(scenarios)
405
+
406
+ def on_response_progress(self, completed: int, total: int) -> None:
407
+ self._emit("response_progress", {"completed": completed, "total": total})
408
+ if self._on_response_progress:
409
+ self._on_response_progress(completed, total)
410
+
411
+ def on_responses_complete(self, traces: list[Trace]) -> None:
412
+ self._emit("responses_complete", {"count": len(traces)})
413
+ if self._on_responses_complete:
414
+ self._on_responses_complete(traces)
415
+
416
+ def on_grading_progress(self, completed: int, total: int) -> None:
417
+ self._emit("grading_progress", {"completed": completed, "total": total})
418
+ if self._on_grading_progress:
419
+ self._on_grading_progress(completed, total)
420
+
421
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
422
+ self._emit("grading_complete", {"count": len(traces), "pass_rate": pass_rate})
423
+ if self._on_grading_complete:
424
+ self._on_grading_complete(traces, pass_rate)
425
+
426
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
427
+ self._emit("refinement_start", {"iteration": iteration, "failed_count": failed_count})
428
+
429
+ def on_grading_skipped(self) -> None:
430
+ self._emit("grading_skipped", {})
431
+
432
+ def on_complete(
433
+ self,
434
+ dataset_size: int,
435
+ elapsed_seconds: float,
436
+ pass_rate: float | None,
437
+ total_cost: float | None = None,
438
+ generation_calls: int | None = None,
439
+ grading_calls: int | None = None,
440
+ scenario_calls: int | None = None,
441
+ response_calls: int | None = None,
442
+ ) -> None:
443
+ self._emit("complete", {
444
+ "dataset_size": dataset_size,
445
+ "elapsed_seconds": elapsed_seconds,
446
+ "pass_rate": pass_rate,
447
+ "total_cost": total_cost,
448
+ "generation_calls": generation_calls,
449
+ "grading_calls": grading_calls,
450
+ "scenario_calls": scenario_calls,
451
+ "response_calls": response_calls,
452
+ })
453
+ if self._on_complete_cb:
454
+ self._on_complete_cb(dataset_size, elapsed_seconds, pass_rate)
455
+
456
+ def on_logic_map_complete(self, logic_map) -> None:
457
+ self._emit("logic_map_complete", {"rules_count": len(logic_map.rules)})
458
+
459
+ def on_golden_scenarios_complete(self, scenarios, distribution) -> None:
460
+ self._emit("golden_scenarios_complete", {"count": len(scenarios), "distribution": distribution})
461
+
462
+
463
+ __all__ = ["ProgressReporter", "SilentReporter", "RichReporter", "CallbackReporter"]
464
+