synkro 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synkro might be problematic. Click here for more details.

Files changed (58) hide show
  1. synkro/__init__.py +165 -0
  2. synkro/cli.py +120 -0
  3. synkro/core/__init__.py +7 -0
  4. synkro/core/dataset.py +233 -0
  5. synkro/core/policy.py +337 -0
  6. synkro/errors.py +178 -0
  7. synkro/examples/__init__.py +148 -0
  8. synkro/factory.py +160 -0
  9. synkro/formatters/__init__.py +12 -0
  10. synkro/formatters/qa.py +85 -0
  11. synkro/formatters/sft.py +90 -0
  12. synkro/formatters/tool_call.py +127 -0
  13. synkro/generation/__init__.py +9 -0
  14. synkro/generation/generator.py +163 -0
  15. synkro/generation/planner.py +87 -0
  16. synkro/generation/responses.py +160 -0
  17. synkro/generation/scenarios.py +90 -0
  18. synkro/generation/tool_responses.py +370 -0
  19. synkro/generation/tool_simulator.py +114 -0
  20. synkro/llm/__init__.py +7 -0
  21. synkro/llm/client.py +235 -0
  22. synkro/llm/rate_limits.py +95 -0
  23. synkro/models/__init__.py +43 -0
  24. synkro/models/anthropic.py +26 -0
  25. synkro/models/google.py +19 -0
  26. synkro/models/openai.py +31 -0
  27. synkro/modes/__init__.py +15 -0
  28. synkro/modes/config.py +66 -0
  29. synkro/modes/qa.py +18 -0
  30. synkro/modes/sft.py +18 -0
  31. synkro/modes/tool_call.py +18 -0
  32. synkro/parsers.py +442 -0
  33. synkro/pipeline/__init__.py +20 -0
  34. synkro/pipeline/phases.py +237 -0
  35. synkro/pipeline/runner.py +198 -0
  36. synkro/pipelines.py +105 -0
  37. synkro/prompts/__init__.py +44 -0
  38. synkro/prompts/base.py +167 -0
  39. synkro/prompts/qa_templates.py +97 -0
  40. synkro/prompts/templates.py +281 -0
  41. synkro/prompts/tool_templates.py +201 -0
  42. synkro/quality/__init__.py +14 -0
  43. synkro/quality/grader.py +130 -0
  44. synkro/quality/refiner.py +137 -0
  45. synkro/quality/tool_grader.py +126 -0
  46. synkro/quality/tool_refiner.py +128 -0
  47. synkro/reporting.py +213 -0
  48. synkro/schemas.py +325 -0
  49. synkro/types/__init__.py +41 -0
  50. synkro/types/core.py +113 -0
  51. synkro/types/dataset_type.py +30 -0
  52. synkro/types/tool.py +94 -0
  53. synkro-0.4.5.data/data/examples/__init__.py +148 -0
  54. synkro-0.4.5.dist-info/METADATA +221 -0
  55. synkro-0.4.5.dist-info/RECORD +58 -0
  56. synkro-0.4.5.dist-info/WHEEL +4 -0
  57. synkro-0.4.5.dist-info/entry_points.txt +2 -0
  58. synkro-0.4.5.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,128 @@
1
+ """Specialized refinement for tool call traces that preserves format."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from synkro.quality.refiner import Refiner
6
+ from synkro.llm.client import LLM
7
+ from synkro.models import Model, OpenAI
8
+ from synkro.types.core import Trace, GradeResult, Scenario
9
+
10
+ if TYPE_CHECKING:
11
+ from synkro.types.tool import ToolDefinition
12
+ from synkro.generation.tool_simulator import ToolSimulator
13
+
14
+
15
+ class ToolCallRefiner(Refiner):
16
+ """
17
+ Specialized refiner for tool call traces.
18
+
19
+ Unlike the base Refiner which generates plain text responses, this refiner
20
+ uses the ToolCallResponseGenerator to regenerate traces, ensuring the
21
+ tool_calls format is preserved in the output.
22
+
23
+ The grading feedback is incorporated into the scenario context so the
24
+ LLM knows what to fix during regeneration.
25
+
26
+ Examples:
27
+ >>> refiner = ToolCallRefiner(
28
+ ... tools=[web_search, db_lookup],
29
+ ... simulator=tool_simulator,
30
+ ... )
31
+ >>> improved = await refiner.refine(failed_trace, grade, policy_text)
32
+ >>> # improved trace has proper tool_calls format
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ tools: list["ToolDefinition"],
38
+ simulator: "ToolSimulator",
39
+ llm: LLM | None = None,
40
+ model: Model = OpenAI.GPT_4O_MINI,
41
+ ):
42
+ """
43
+ Initialize the tool call refiner.
44
+
45
+ Args:
46
+ tools: List of available tool definitions
47
+ simulator: Tool simulator for generating tool responses
48
+ llm: LLM client to use (creates one if not provided)
49
+ model: Model to use if creating LLM
50
+ """
51
+ super().__init__(llm=llm, model=model)
52
+ self.tools = tools
53
+ self.simulator = simulator
54
+ self._response_generator = None
55
+
56
+ def _get_response_generator(self):
57
+ """Lazily create the ToolCallResponseGenerator."""
58
+ if self._response_generator is None:
59
+ from synkro.generation.tool_responses import ToolCallResponseGenerator
60
+ self._response_generator = ToolCallResponseGenerator(
61
+ tools=self.tools,
62
+ llm=self.llm,
63
+ simulator=self.simulator,
64
+ )
65
+ return self._response_generator
66
+
67
+ def _build_enhanced_scenario(
68
+ self, trace: Trace, grade: GradeResult
69
+ ) -> Scenario:
70
+ """
71
+ Build an enhanced scenario that includes grading feedback.
72
+
73
+ The feedback helps the LLM understand what went wrong and how to fix it.
74
+ """
75
+ # Build feedback context
76
+ feedback_parts = []
77
+ if grade.issues:
78
+ feedback_parts.append("PREVIOUS ISSUES TO FIX:")
79
+ for issue in grade.issues:
80
+ feedback_parts.append(f" - {issue}")
81
+ if grade.feedback:
82
+ feedback_parts.append(f"\nGRADER FEEDBACK: {grade.feedback}")
83
+
84
+ feedback_context = "\n".join(feedback_parts) if feedback_parts else ""
85
+
86
+ # Enhance the context with feedback
87
+ enhanced_context = trace.scenario.context
88
+ if feedback_context:
89
+ enhanced_context = f"{trace.scenario.context}\n\n--- REFINEMENT GUIDANCE ---\n{feedback_context}"
90
+
91
+ return Scenario(
92
+ description=trace.scenario.description,
93
+ context=enhanced_context,
94
+ category=trace.scenario.category,
95
+ )
96
+
97
+ async def refine(
98
+ self, trace: Trace, grade: GradeResult, policy_text: str
99
+ ) -> Trace:
100
+ """
101
+ Refine a failed tool call trace by regenerating with feedback.
102
+
103
+ Uses the ToolCallResponseGenerator to ensure the regenerated trace
104
+ maintains proper tool_calls format.
105
+
106
+ Args:
107
+ trace: The trace that failed grading
108
+ grade: The grade result with feedback
109
+ policy_text: The policy/guidelines text
110
+
111
+ Returns:
112
+ New trace with improved response and preserved tool_calls format
113
+ """
114
+ # Create enhanced scenario with grading feedback
115
+ enhanced_scenario = self._build_enhanced_scenario(trace, grade)
116
+
117
+ # Regenerate using ToolCallResponseGenerator (preserves format)
118
+ generator = self._get_response_generator()
119
+ refined_trace = await generator.generate_single(policy_text, enhanced_scenario)
120
+
121
+ # Preserve the original scenario reference (without the feedback context)
122
+ refined_trace.scenario = trace.scenario
123
+
124
+ return refined_trace
125
+
126
+
127
+ __all__ = ["ToolCallRefiner"]
128
+
synkro/reporting.py ADDED
@@ -0,0 +1,213 @@
1
+ """Progress reporting abstraction for generation pipeline.
2
+
3
+ This module provides a Protocol for progress reporting and implementations
4
+ for different use cases (rich console, silent for testing, etc.).
5
+ """
6
+
7
+ from typing import Protocol
8
+
9
+ from synkro.types.core import Plan, Scenario, Trace, GradeResult
10
+
11
+
12
+ class ProgressReporter(Protocol):
13
+ """
14
+ Protocol for reporting generation progress.
15
+
16
+ Implement this to customize how progress is displayed or logged.
17
+
18
+ Examples:
19
+ >>> # Use silent reporter for testing
20
+ >>> generator = Generator(reporter=SilentReporter())
21
+
22
+ >>> # Use rich reporter for CLI (default)
23
+ >>> generator = Generator(reporter=RichReporter())
24
+ """
25
+
26
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
27
+ """Called when generation starts."""
28
+ ...
29
+
30
+ def on_plan_complete(self, plan: Plan) -> None:
31
+ """Called when planning phase completes."""
32
+ ...
33
+
34
+ def on_scenario_progress(self, completed: int, total: int) -> None:
35
+ """Called during scenario generation."""
36
+ ...
37
+
38
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
39
+ """Called when all scenarios are generated."""
40
+ ...
41
+
42
+ def on_response_progress(self, completed: int, total: int) -> None:
43
+ """Called during response generation."""
44
+ ...
45
+
46
+ def on_responses_complete(self, traces: list[Trace]) -> None:
47
+ """Called when all responses are generated."""
48
+ ...
49
+
50
+ def on_grading_progress(self, completed: int, total: int) -> None:
51
+ """Called during grading."""
52
+ ...
53
+
54
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
55
+ """Called when grading completes."""
56
+ ...
57
+
58
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
59
+ """Called when a refinement iteration starts."""
60
+ ...
61
+
62
+ def on_grading_skipped(self) -> None:
63
+ """Called when grading is skipped."""
64
+ ...
65
+
66
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
67
+ """Called when generation is complete."""
68
+ ...
69
+
70
+
71
+ class SilentReporter:
72
+ """
73
+ No-op reporter for testing and embedding.
74
+
75
+ Use this when you don't want any console output.
76
+
77
+ Examples:
78
+ >>> generator = Generator(reporter=SilentReporter())
79
+ >>> dataset = generator.generate(policy) # No console output
80
+ """
81
+
82
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
83
+ pass
84
+
85
+ def on_plan_complete(self, plan: Plan) -> None:
86
+ pass
87
+
88
+ def on_scenario_progress(self, completed: int, total: int) -> None:
89
+ pass
90
+
91
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
92
+ pass
93
+
94
+ def on_response_progress(self, completed: int, total: int) -> None:
95
+ pass
96
+
97
+ def on_responses_complete(self, traces: list[Trace]) -> None:
98
+ pass
99
+
100
+ def on_grading_progress(self, completed: int, total: int) -> None:
101
+ pass
102
+
103
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
104
+ pass
105
+
106
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
107
+ pass
108
+
109
+ def on_grading_skipped(self) -> None:
110
+ pass
111
+
112
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
113
+ pass
114
+
115
+
116
+ class RichReporter:
117
+ """
118
+ Rich console reporter with progress bars and formatted output.
119
+
120
+ This is the default reporter that provides the familiar synkro CLI experience.
121
+ """
122
+
123
+ def __init__(self):
124
+ from rich.console import Console
125
+ self.console = Console()
126
+ self._progress = None
127
+ self._current_task = None
128
+
129
+ def on_start(self, traces: int, model: str, dataset_type: str) -> None:
130
+ from rich.panel import Panel
131
+
132
+ self.console.print()
133
+ self.console.print(Panel.fit(
134
+ f"[bold]Generating {traces} traces[/bold]\n"
135
+ f"[dim]Type: {dataset_type.upper()} | Model: {model}[/dim]",
136
+ title="[cyan]synkro[/cyan]",
137
+ border_style="cyan"
138
+ ))
139
+ self.console.print()
140
+
141
+ def on_plan_complete(self, plan: Plan) -> None:
142
+ from rich.table import Table
143
+
144
+ self.console.print(f"[green]📋 Planning[/green] [dim]{len(plan.categories)} categories[/dim]")
145
+
146
+ cat_table = Table(title="Categories", show_header=True, header_style="bold cyan")
147
+ cat_table.add_column("Name")
148
+ cat_table.add_column("Description")
149
+ cat_table.add_column("Count", justify="right")
150
+ for cat in plan.categories:
151
+ cat_table.add_row(cat.name, cat.description, str(cat.count))
152
+ self.console.print(cat_table)
153
+ self.console.print()
154
+
155
+ def on_scenario_progress(self, completed: int, total: int) -> None:
156
+ pass # Progress shown in on_scenarios_complete
157
+
158
+ def on_scenarios_complete(self, scenarios: list[Scenario]) -> None:
159
+ self.console.print(f"[green]💡 Scenarios[/green] [dim]{len(scenarios)} created[/dim]")
160
+ for idx, s in enumerate(scenarios, 1):
161
+ desc = s.description[:80] + "..." if len(s.description) > 80 else s.description
162
+ self.console.print(f" [dim]#{idx}[/dim] [yellow]{desc}[/yellow]")
163
+
164
+ def on_response_progress(self, completed: int, total: int) -> None:
165
+ pass # Progress shown in on_responses_complete
166
+
167
+ def on_responses_complete(self, traces: list[Trace]) -> None:
168
+ self.console.print(f"[green]✍️ Responses[/green] [dim]{len(traces)} generated[/dim]")
169
+ for idx, trace in enumerate(traces, 1):
170
+ user_preview = trace.user_message[:60] + "..." if len(trace.user_message) > 60 else trace.user_message
171
+ asst_preview = trace.assistant_message[:60] + "..." if len(trace.assistant_message) > 60 else trace.assistant_message
172
+ self.console.print(f" [dim]#{idx}[/dim] [blue]User:[/blue] {user_preview}")
173
+ self.console.print(f" [green]Assistant:[/green] {asst_preview}")
174
+
175
+ def on_grading_progress(self, completed: int, total: int) -> None:
176
+ pass # Progress shown in on_grading_complete
177
+
178
+ def on_grading_complete(self, traces: list[Trace], pass_rate: float) -> None:
179
+ self.console.print(f"[green]⚖️ Grading[/green] [dim]{pass_rate:.0f}% passed[/dim]")
180
+ for idx, trace in enumerate(traces, 1):
181
+ scenario_preview = trace.scenario.description[:40] + "..." if len(trace.scenario.description) > 40 else trace.scenario.description
182
+ if trace.grade and trace.grade.passed:
183
+ self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [green]✓ Passed[/green]")
184
+ else:
185
+ issues = ", ".join(trace.grade.issues[:2]) if trace.grade and trace.grade.issues else "No specific issues"
186
+ issues_preview = issues[:40] + "..." if len(issues) > 40 else issues
187
+ self.console.print(f" [dim]#{idx}[/dim] [cyan]{scenario_preview}[/cyan] [red]✗ Failed[/red] [dim]{issues_preview}[/dim]")
188
+
189
+ def on_refinement_start(self, iteration: int, failed_count: int) -> None:
190
+ self.console.print(f" [yellow]↻ Refining {failed_count} failed traces (iteration {iteration})...[/yellow]")
191
+
192
+ def on_grading_skipped(self) -> None:
193
+ self.console.print(f" [dim]⚖️ Grading skipped[/dim]")
194
+
195
+ def on_complete(self, dataset_size: int, elapsed_seconds: float, pass_rate: float | None) -> None:
196
+ from rich.panel import Panel
197
+ from rich.table import Table
198
+
199
+ elapsed_str = f"{int(elapsed_seconds) // 60}m {int(elapsed_seconds) % 60}s" if elapsed_seconds >= 60 else f"{int(elapsed_seconds)}s"
200
+
201
+ self.console.print()
202
+ summary = Table.grid(padding=(0, 2))
203
+ summary.add_column(style="green")
204
+ summary.add_column()
205
+ summary.add_row("✅ Done!", f"Generated {dataset_size} traces in {elapsed_str}")
206
+ if pass_rate is not None:
207
+ summary.add_row("📊 Quality:", f"{pass_rate:.0f}% passed grading")
208
+ self.console.print(Panel(summary, border_style="green", title="[green]Complete[/green]"))
209
+ self.console.print()
210
+
211
+
212
+ __all__ = ["ProgressReporter", "SilentReporter", "RichReporter"]
213
+
synkro/schemas.py ADDED
@@ -0,0 +1,325 @@
1
+ """Pydantic schemas for structured LLM outputs and validation."""
2
+
3
+ from typing import Literal
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ # =============================================================================
8
+ # SCENARIO SCHEMAS
9
+ # =============================================================================
10
+
11
+
12
+ class ScenarioOutput(BaseModel):
13
+ """Output schema for scenario generation."""
14
+
15
+ scenario: str = Field(description="Detailed scenario description")
16
+ context: str = Field(description="Relevant background information")
17
+
18
+
19
+ class ScenariosArray(BaseModel):
20
+ """Array of generated scenarios."""
21
+
22
+ scenarios: list[ScenarioOutput]
23
+
24
+
25
+ # =============================================================================
26
+ # POLICY ANALYSIS SCHEMAS
27
+ # =============================================================================
28
+
29
+
30
+ class PolicyComplexity(BaseModel):
31
+ """Policy complexity analysis for auto-detecting optimal turns."""
32
+
33
+ variable_count: int = Field(
34
+ description="Number of variables/conditions in the policy (rules, exceptions, conditions)"
35
+ )
36
+ complexity_level: Literal["simple", "conditional", "complex"] = Field(
37
+ description="Overall complexity: simple (1 var), conditional (2-3 vars), complex (4+ vars)"
38
+ )
39
+ recommended_turns: int = Field(
40
+ ge=1, le=6, description="Recommended conversation turns based on complexity"
41
+ )
42
+ reasoning: str = Field(description="Brief explanation of the complexity assessment")
43
+
44
+
45
+ class PlanCategory(BaseModel):
46
+ """A category in the generation plan."""
47
+
48
+ name: str = Field(description='Short category name (e.g., "Consent Violations", "Edge Cases")')
49
+ description: str = Field(description="What this category tests")
50
+ traces: int = Field(ge=1, description="Number of traces to generate for this category")
51
+
52
+
53
+ class PolicyPlan(BaseModel):
54
+ """LLM-generated plan for dataset generation."""
55
+
56
+ categories: list[PlanCategory] = Field(
57
+ min_length=2, max_length=10, description="Scenario categories with trace allocations"
58
+ )
59
+ reasoning: str = Field(
60
+ description="Explanation of why these categories were chosen based on policy content"
61
+ )
62
+
63
+
64
+ # =============================================================================
65
+ # CHAT MESSAGE SCHEMAS
66
+ # =============================================================================
67
+
68
+
69
+ class ChatMessage(BaseModel):
70
+ """A single chat message in OpenAI format."""
71
+
72
+ role: Literal["system", "user", "assistant"] = Field(description="Message role")
73
+ content: str = Field(description="Message content")
74
+
75
+
76
+ class ConversationOutput(BaseModel):
77
+ """Output from response generation - a complete conversation."""
78
+
79
+ index: int = Field(description="Scenario index (0-based)")
80
+ messages: list[ChatMessage] = Field(
81
+ description="Full conversation with system, user, and assistant messages"
82
+ )
83
+
84
+
85
+ class BatchedConversations(BaseModel):
86
+ """Batch of generated conversations."""
87
+
88
+ conversations: list[ConversationOutput]
89
+
90
+
91
+ # =============================================================================
92
+ # GRADING SCHEMAS
93
+ # =============================================================================
94
+
95
+
96
+ class GradeOutput(BaseModel):
97
+ """Grading result for a single response."""
98
+
99
+ index: int = Field(description="Scenario index (0-based)")
100
+ passed: bool = Field(
101
+ alias="pass", description="Is the response FULLY correct, policy-compliant, and format-valid?"
102
+ )
103
+ policy_violations: list[str] = Field(
104
+ default_factory=list,
105
+ description="Specific policy rules that were violated or misinterpreted",
106
+ )
107
+ missing_citations: list[str] = Field(
108
+ default_factory=list,
109
+ description="Policy sections that should have been cited but were not",
110
+ )
111
+ incomplete_reasoning: list[str] = Field(
112
+ default_factory=list, description="Logical gaps or missing steps in the chain of thought"
113
+ )
114
+ vague_recommendations: list[str] = Field(
115
+ default_factory=list,
116
+ description="Recommendations that need to be more specific or actionable",
117
+ )
118
+ feedback: str = Field(description="Summary of how to fix the issues")
119
+
120
+ class Config:
121
+ populate_by_name = True
122
+
123
+
124
+ class BatchedGrades(BaseModel):
125
+ """Batch of grading results."""
126
+
127
+ grades: list[GradeOutput]
128
+
129
+
130
+ # =============================================================================
131
+ # SINGLE-ITEM SCHEMAS (for parallel generation)
132
+ # =============================================================================
133
+
134
+
135
+ class SingleResponse(BaseModel):
136
+ """Single response output for parallel generation."""
137
+
138
+ messages: list[ChatMessage] = Field(
139
+ min_length=3, max_length=3, description="Exactly 3 messages: system, user, assistant"
140
+ )
141
+
142
+
143
+ class SingleGrade(BaseModel):
144
+ """Single grade output for parallel generation."""
145
+
146
+ passed: bool = Field(
147
+ alias="pass", description="Is the response FULLY correct, policy-compliant, and format-valid?"
148
+ )
149
+ policy_violations: list[str] = Field(
150
+ default_factory=list, description="Specific policy rules that were violated"
151
+ )
152
+ missing_citations: list[str] = Field(
153
+ default_factory=list, description="Policy sections that should have been cited"
154
+ )
155
+ incomplete_reasoning: list[str] = Field(
156
+ default_factory=list, description="Logical gaps or missing reasoning steps"
157
+ )
158
+ vague_recommendations: list[str] = Field(
159
+ default_factory=list, description="Recommendations that need to be more specific"
160
+ )
161
+ feedback: str = Field(description='Summary of issues or "Correct" if passing')
162
+
163
+ class Config:
164
+ populate_by_name = True
165
+
166
+
167
+ # =============================================================================
168
+ # MULTI-TURN SCHEMAS
169
+ # =============================================================================
170
+
171
+
172
+ class FollowUpQuestion(BaseModel):
173
+ """A follow-up question for multi-turn conversations."""
174
+
175
+ index: int = Field(description="Scenario index")
176
+ question: str = Field(description="Follow-up question from the user")
177
+ question_type: Literal["clarification", "edge_case", "what_if", "specificity", "challenge"] = (
178
+ Field(description="Type of follow-up")
179
+ )
180
+
181
+
182
+ class TurnGrade(BaseModel):
183
+ """Grade for a single turn in a multi-turn conversation."""
184
+
185
+ turn_index: int = Field(description="Which turn (0-based, only assistant turns)")
186
+ passed: bool = Field(alias="pass", description="Does this turn pass all criteria?")
187
+ policy_violations: list[str] = Field(
188
+ default_factory=list, description="Policy violations in this turn"
189
+ )
190
+ missing_citations: list[str] = Field(
191
+ default_factory=list, description="Missing citations in this turn"
192
+ )
193
+ incomplete_reasoning: list[str] = Field(
194
+ default_factory=list, description="Reasoning gaps in this turn"
195
+ )
196
+ vague_recommendations: list[str] = Field(
197
+ default_factory=list, description="Vague recommendations in this turn"
198
+ )
199
+ feedback: str = Field(description="Specific feedback for this turn")
200
+
201
+ class Config:
202
+ populate_by_name = True
203
+
204
+
205
+ class ConversationGrade(BaseModel):
206
+ """Full grading for a multi-turn conversation."""
207
+
208
+ index: int = Field(description="Scenario index")
209
+ overall_pass: bool = Field(description="Does the ENTIRE conversation pass?")
210
+ turn_grades: list[TurnGrade] = Field(description="Grade for each assistant turn")
211
+ coherence_pass: bool = Field(
212
+ description="Is the conversation coherent with no contradictions?"
213
+ )
214
+ coherence_issues: list[str] = Field(
215
+ default_factory=list, description="Any contradictions or incoherence across turns"
216
+ )
217
+ progressive_depth: bool = Field(
218
+ description="Does each turn build on previous context appropriately?"
219
+ )
220
+ overall_feedback: str = Field(
221
+ description="Summary of what needs to be fixed across the conversation"
222
+ )
223
+
224
+
225
+ # =============================================================================
226
+ # AGENTIC SCHEMAS
227
+ # =============================================================================
228
+
229
+
230
+ class ToolCall(BaseModel):
231
+ """A tool call in an agentic trace."""
232
+
233
+ tool_name: str = Field(description="Name of the tool to call")
234
+ arguments: dict[str, str] = Field(description="Arguments to pass to the tool")
235
+
236
+
237
+ class AgenticStep(BaseModel):
238
+ """A single step in an agentic trace."""
239
+
240
+ reasoning: str = Field(description="Reasoning before tool call")
241
+ tool_name: str = Field(description="Tool to call")
242
+ tool_args: dict = Field(description="Tool arguments")
243
+
244
+
245
+ class AgenticTrace(BaseModel):
246
+ """Complete agentic trace with tool usage."""
247
+
248
+ index: int = Field(description="Scenario index")
249
+ steps: list[AgenticStep] = Field(description="Steps of tool usage")
250
+ final_answer: str = Field(description="Final comprehensive answer")
251
+
252
+
253
+ # =============================================================================
254
+ # TOOL CALL GRADING SCHEMAS
255
+ # =============================================================================
256
+
257
+
258
+ class ToolCallGrade(BaseModel):
259
+ """Grading result for a tool call trace.
260
+
261
+ Evaluates tool usage on four criteria:
262
+ - Tool Selection: Did they use the right tool?
263
+ - Parameter Accuracy: Were the parameters correct?
264
+ - Response Synthesis: Did they use tool results correctly?
265
+ - Timing: Did they call tools at the right time?
266
+ """
267
+
268
+ passed: bool = Field(
269
+ alias="pass",
270
+ description="Does the trace pass ALL criteria?"
271
+ )
272
+
273
+ # Criterion 1: Tool Selection
274
+ tool_selection_correct: bool = Field(
275
+ description="Did the assistant choose the appropriate tool for the task?"
276
+ )
277
+ tool_selection_issues: list[str] = Field(
278
+ default_factory=list,
279
+ description="Specific issues with tool selection (wrong tool, missing tool, unnecessary tool)"
280
+ )
281
+
282
+ # Criterion 2: Parameter Accuracy
283
+ parameters_valid: bool = Field(
284
+ description="Were the tool parameters correct (types, values, required fields)?"
285
+ )
286
+ parameter_issues: list[str] = Field(
287
+ default_factory=list,
288
+ description="Specific issues with parameters (wrong type, invalid value, missing required)"
289
+ )
290
+
291
+ # Criterion 3: Response Synthesis
292
+ synthesis_accurate: bool = Field(
293
+ description="Did the assistant correctly use tool results without hallucination?"
294
+ )
295
+ synthesis_issues: list[str] = Field(
296
+ default_factory=list,
297
+ description="Specific issues with synthesis (hallucinated data, ignored results, misinterpreted)"
298
+ )
299
+
300
+ # Criterion 4: Timing
301
+ timing_appropriate: bool = Field(
302
+ description="Did the assistant call tools at the right moment?"
303
+ )
304
+ timing_issues: list[str] = Field(
305
+ default_factory=list,
306
+ description="Specific issues with timing (premature call, delayed call, should have called earlier)"
307
+ )
308
+
309
+ # Overall feedback
310
+ feedback: str = Field(
311
+ description="Summary of issues or 'Correct' if passing"
312
+ )
313
+
314
+ class Config:
315
+ populate_by_name = True
316
+
317
+ def get_all_issues(self) -> list[str]:
318
+ """Get all issues combined."""
319
+ return (
320
+ self.tool_selection_issues
321
+ + self.parameter_issues
322
+ + self.synthesis_issues
323
+ + self.timing_issues
324
+ )
325
+