synkro 0.4.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synkro might be problematic. Click here for more details.
- synkro/__init__.py +331 -0
- synkro/advanced.py +184 -0
- synkro/cli.py +156 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +432 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +291 -0
- synkro/formatters/__init__.py +18 -0
- synkro/formatters/chatml.py +121 -0
- synkro/formatters/langfuse.py +98 -0
- synkro/formatters/langsmith.py +98 -0
- synkro/formatters/qa.py +112 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +314 -0
- synkro/generation/golden_responses.py +269 -0
- synkro/generation/golden_scenarios.py +333 -0
- synkro/generation/golden_tool_responses.py +791 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +625 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +16 -0
- synkro/interactive/hitl_session.py +205 -0
- synkro/interactive/intent_classifier.py +94 -0
- synkro/interactive/logic_map_editor.py +176 -0
- synkro/interactive/rich_ui.py +459 -0
- synkro/interactive/scenario_editor.py +198 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +309 -0
- synkro/llm/rate_limits.py +99 -0
- synkro/models/__init__.py +50 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/local.py +104 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +13 -0
- synkro/modes/config.py +66 -0
- synkro/modes/conversation.py +35 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +769 -0
- synkro/pipelines.py +136 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +533 -0
- synkro/prompts/interactive_templates.py +198 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +318 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +464 -0
- synkro/schemas.py +521 -0
- synkro/types/__init__.py +43 -0
- synkro/types/core.py +153 -0
- synkro/types/dataset_type.py +33 -0
- synkro/types/logic_map.py +348 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.36.data/data/examples/__init__.py +148 -0
- synkro-0.4.36.dist-info/METADATA +507 -0
- synkro-0.4.36.dist-info/RECORD +81 -0
- synkro-0.4.36.dist-info/WHEEL +4 -0
- synkro-0.4.36.dist-info/entry_points.txt +2 -0
- synkro-0.4.36.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Multi-turn conversation grading with per-turn and overall evaluation."""
|
|
2
|
+
|
|
3
|
+
from synkro.llm.client import LLM
|
|
4
|
+
from synkro.models import Model, OpenAI
|
|
5
|
+
from synkro.types.core import Trace, Message, GradeResult
|
|
6
|
+
from synkro.prompts.multiturn_templates import MULTI_TURN_GRADE_PROMPT
|
|
7
|
+
from synkro.schemas import ConversationGrade, TurnGrade
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MultiTurnGrader:
|
|
11
|
+
"""
|
|
12
|
+
Grades multi-turn conversations using per-turn and overall criteria.
|
|
13
|
+
|
|
14
|
+
Uses existing schemas:
|
|
15
|
+
- TurnGrade: Per-turn policy violations, citations, reasoning
|
|
16
|
+
- ConversationGrade: Overall pass, coherence, progressive depth
|
|
17
|
+
|
|
18
|
+
Examples:
|
|
19
|
+
>>> grader = MultiTurnGrader()
|
|
20
|
+
>>> result = await grader.grade(trace, policy_text)
|
|
21
|
+
>>> print(result.passed, result.feedback)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, llm: LLM | None = None, model: Model = OpenAI.GPT_4O):
|
|
25
|
+
"""
|
|
26
|
+
Initialize the multi-turn grader.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
llm: LLM client to use (creates one if not provided)
|
|
30
|
+
model: Model to use if creating LLM (recommend stronger model)
|
|
31
|
+
"""
|
|
32
|
+
self.llm = llm or LLM(model=model)
|
|
33
|
+
|
|
34
|
+
def _count_assistant_turns(self, trace: Trace) -> int:
|
|
35
|
+
"""Count the number of assistant messages (turns) in a trace."""
|
|
36
|
+
return sum(1 for m in trace.messages if m.role == "assistant")
|
|
37
|
+
|
|
38
|
+
def _format_conversation(self, messages: list[Message]) -> str:
|
|
39
|
+
"""Format conversation messages for prompt inclusion."""
|
|
40
|
+
formatted = []
|
|
41
|
+
for msg in messages:
|
|
42
|
+
role = msg.role.upper()
|
|
43
|
+
content = msg.content or "[No content]"
|
|
44
|
+
formatted.append(f"{role}: {content}")
|
|
45
|
+
return "\n\n".join(formatted)
|
|
46
|
+
|
|
47
|
+
def _extract_all_issues(self, conversation_grade: ConversationGrade) -> list[str]:
|
|
48
|
+
"""Extract all issues from conversation grade into flat list."""
|
|
49
|
+
issues = []
|
|
50
|
+
|
|
51
|
+
# Add coherence issues
|
|
52
|
+
issues.extend(conversation_grade.coherence_issues)
|
|
53
|
+
|
|
54
|
+
# Add per-turn issues
|
|
55
|
+
for turn_grade in conversation_grade.turn_grades:
|
|
56
|
+
issues.extend(turn_grade.policy_violations)
|
|
57
|
+
issues.extend(turn_grade.missing_citations)
|
|
58
|
+
issues.extend(turn_grade.incomplete_reasoning)
|
|
59
|
+
issues.extend(turn_grade.vague_recommendations)
|
|
60
|
+
|
|
61
|
+
return issues
|
|
62
|
+
|
|
63
|
+
async def _grade_conversation(
|
|
64
|
+
self,
|
|
65
|
+
trace: Trace,
|
|
66
|
+
policy_text: str,
|
|
67
|
+
) -> ConversationGrade:
|
|
68
|
+
"""
|
|
69
|
+
Grade the full conversation using ConversationGrade schema.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
trace: The trace to grade
|
|
73
|
+
policy_text: The policy for evaluation
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
ConversationGrade with per-turn and overall assessment
|
|
77
|
+
"""
|
|
78
|
+
conversation = self._format_conversation(trace.messages)
|
|
79
|
+
|
|
80
|
+
prompt = f"""{MULTI_TURN_GRADE_PROMPT.format(
|
|
81
|
+
conversation=conversation,
|
|
82
|
+
policy=policy_text,
|
|
83
|
+
)}"""
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
return await self.llm.generate_structured(prompt, ConversationGrade)
|
|
87
|
+
except Exception:
|
|
88
|
+
# Fallback - create a failing grade
|
|
89
|
+
num_turns = self._count_assistant_turns(trace)
|
|
90
|
+
turn_grades = [
|
|
91
|
+
TurnGrade(
|
|
92
|
+
turn_index=i,
|
|
93
|
+
passed=False,
|
|
94
|
+
policy_violations=[],
|
|
95
|
+
missing_citations=[],
|
|
96
|
+
incomplete_reasoning=[],
|
|
97
|
+
vague_recommendations=[],
|
|
98
|
+
feedback="Unable to grade - parsing error",
|
|
99
|
+
)
|
|
100
|
+
for i in range(num_turns)
|
|
101
|
+
]
|
|
102
|
+
return ConversationGrade(
|
|
103
|
+
index=0,
|
|
104
|
+
overall_pass=False,
|
|
105
|
+
turn_grades=turn_grades,
|
|
106
|
+
coherence_pass=False,
|
|
107
|
+
coherence_issues=["Unable to evaluate - grading error"],
|
|
108
|
+
progressive_depth=False,
|
|
109
|
+
overall_feedback="Grading failed - please retry",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
async def grade(self, trace: Trace, policy_text: str) -> GradeResult:
|
|
113
|
+
"""
|
|
114
|
+
Grade a multi-turn conversation.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
trace: The trace to grade
|
|
118
|
+
policy_text: The policy for evaluation
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
GradeResult with pass/fail, issues, and feedback
|
|
122
|
+
"""
|
|
123
|
+
# Get full conversation grade
|
|
124
|
+
conversation_grade = await self._grade_conversation(trace, policy_text)
|
|
125
|
+
|
|
126
|
+
# Convert to standard GradeResult
|
|
127
|
+
return GradeResult(
|
|
128
|
+
passed=conversation_grade.overall_pass,
|
|
129
|
+
issues=self._extract_all_issues(conversation_grade),
|
|
130
|
+
feedback=conversation_grade.overall_feedback,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
async def grade_detailed(
|
|
134
|
+
self,
|
|
135
|
+
trace: Trace,
|
|
136
|
+
policy_text: str,
|
|
137
|
+
) -> ConversationGrade:
|
|
138
|
+
"""
|
|
139
|
+
Get detailed per-turn grading for a conversation.
|
|
140
|
+
|
|
141
|
+
Use this when you need access to individual turn grades.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
trace: The trace to grade
|
|
145
|
+
policy_text: The policy for evaluation
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
ConversationGrade with full per-turn breakdown
|
|
149
|
+
"""
|
|
150
|
+
return await self._grade_conversation(trace, policy_text)
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Refinement of failed traces based on grader feedback."""
|
|
2
|
+
|
|
3
|
+
from synkro.llm.client import LLM
|
|
4
|
+
from synkro.models import Model, OpenAI
|
|
5
|
+
from synkro.types.core import Trace, GradeResult, Message
|
|
6
|
+
from synkro.prompts.templates import BATCHED_REFINER_PROMPT, SYSTEM_PROMPT
|
|
7
|
+
from synkro.parsers import parse_single_response, extract_content
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Refiner:
|
|
11
|
+
"""
|
|
12
|
+
Refines traces that failed grading.
|
|
13
|
+
|
|
14
|
+
Takes failed traces and their grader feedback and generates
|
|
15
|
+
improved versions that address the issues.
|
|
16
|
+
|
|
17
|
+
Examples:
|
|
18
|
+
>>> refiner = Refiner()
|
|
19
|
+
>>> improved = await refiner.refine(failed_trace, grade_result, policy.text)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, llm: LLM | None = None, model: Model = OpenAI.GPT_4O_MINI):
|
|
23
|
+
"""
|
|
24
|
+
Initialize the refiner.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
llm: LLM client to use (creates one if not provided)
|
|
28
|
+
model: Model to use if creating LLM
|
|
29
|
+
"""
|
|
30
|
+
self.llm = llm or LLM(model=model)
|
|
31
|
+
self.prompt_template = BATCHED_REFINER_PROMPT
|
|
32
|
+
|
|
33
|
+
async def refine(
|
|
34
|
+
self, trace: Trace, grade: GradeResult, policy_text: str
|
|
35
|
+
) -> Trace:
|
|
36
|
+
"""
|
|
37
|
+
Refine a failed trace based on grader feedback.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
trace: The trace that failed grading
|
|
41
|
+
grade: The grade result with feedback
|
|
42
|
+
policy_text: The policy text
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
New trace with improved response
|
|
46
|
+
"""
|
|
47
|
+
prompt = self._build_prompt(trace, grade, policy_text)
|
|
48
|
+
|
|
49
|
+
response = await self.llm.generate(prompt)
|
|
50
|
+
parsed = parse_single_response(response)
|
|
51
|
+
|
|
52
|
+
if parsed and len(parsed.messages) >= 3:
|
|
53
|
+
messages = [
|
|
54
|
+
Message(role=m.role, content=m.content) for m in parsed.messages
|
|
55
|
+
]
|
|
56
|
+
else:
|
|
57
|
+
# Fallback: construct from response
|
|
58
|
+
content = extract_content(response)
|
|
59
|
+
messages = [
|
|
60
|
+
Message(role="system", content=SYSTEM_PROMPT),
|
|
61
|
+
Message(
|
|
62
|
+
role="user",
|
|
63
|
+
content=f"Scenario: {trace.scenario.description}\n\nContext: {trace.scenario.context}",
|
|
64
|
+
),
|
|
65
|
+
Message(role="assistant", content=content),
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
return Trace(messages=messages, scenario=trace.scenario)
|
|
69
|
+
|
|
70
|
+
def _build_prompt(
|
|
71
|
+
self, trace: Trace, grade: GradeResult, policy_text: str
|
|
72
|
+
) -> str:
|
|
73
|
+
"""Build the refinement prompt."""
|
|
74
|
+
return f"""You are improving a response that failed quality checks.
|
|
75
|
+
|
|
76
|
+
SCENARIO:
|
|
77
|
+
{trace.scenario.description}
|
|
78
|
+
|
|
79
|
+
CONTEXT:
|
|
80
|
+
{trace.scenario.context}
|
|
81
|
+
|
|
82
|
+
ORIGINAL RESPONSE:
|
|
83
|
+
{trace.assistant_message}
|
|
84
|
+
|
|
85
|
+
GRADER FEEDBACK:
|
|
86
|
+
Issues: {', '.join(grade.issues) if grade.issues else 'None listed'}
|
|
87
|
+
Summary: {grade.feedback}
|
|
88
|
+
|
|
89
|
+
POLICY:
|
|
90
|
+
{policy_text}
|
|
91
|
+
|
|
92
|
+
Generate an IMPROVED response that fixes all the issues. Output a JSON object:
|
|
93
|
+
{{
|
|
94
|
+
"messages": [
|
|
95
|
+
{{"role": "system", "content": "<system prompt>"}},
|
|
96
|
+
{{"role": "user", "content": "<the scenario>"}},
|
|
97
|
+
{{"role": "assistant", "content": "<your IMPROVED response>"}}
|
|
98
|
+
]
|
|
99
|
+
}}
|
|
100
|
+
|
|
101
|
+
The improved response must:
|
|
102
|
+
- Fix all policy violations
|
|
103
|
+
- Add missing citations
|
|
104
|
+
- Complete reasoning with no gaps
|
|
105
|
+
- Make recommendations specific and actionable
|
|
106
|
+
- Keep what was correct from the original
|
|
107
|
+
|
|
108
|
+
Respond with ONLY the JSON object."""
|
|
109
|
+
|
|
110
|
+
async def refine_batch(
|
|
111
|
+
self,
|
|
112
|
+
traces: list[Trace],
|
|
113
|
+
grades: list[GradeResult],
|
|
114
|
+
policy_text: str,
|
|
115
|
+
) -> list[Trace]:
|
|
116
|
+
"""
|
|
117
|
+
Refine multiple failed traces.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
traces: List of traces that failed grading
|
|
121
|
+
grades: Corresponding grade results
|
|
122
|
+
policy_text: The policy text
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
List of refined traces
|
|
126
|
+
"""
|
|
127
|
+
refined = []
|
|
128
|
+
|
|
129
|
+
for trace, grade in zip(traces, grades):
|
|
130
|
+
if not grade.passed:
|
|
131
|
+
improved = await self.refine(trace, grade, policy_text)
|
|
132
|
+
refined.append(improved)
|
|
133
|
+
else:
|
|
134
|
+
refined.append(trace)
|
|
135
|
+
|
|
136
|
+
return refined
|
|
137
|
+
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Specialized grading for tool call traces."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from synkro.quality.grader import Grader
|
|
7
|
+
from synkro.llm.client import LLM
|
|
8
|
+
from synkro.models import Model, OpenAI
|
|
9
|
+
from synkro.types.core import Trace, GradeResult
|
|
10
|
+
from synkro.schemas import ToolCallGrade
|
|
11
|
+
from synkro.prompts.tool_templates import TOOL_GRADE_PROMPT
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from synkro.types.tool import ToolDefinition
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ToolCallGrader(Grader):
|
|
18
|
+
"""
|
|
19
|
+
Specialized grader for tool call traces.
|
|
20
|
+
|
|
21
|
+
Evaluates tool usage on four criteria:
|
|
22
|
+
- Tool Selection: Did they use the right tool?
|
|
23
|
+
- Parameter Accuracy: Were the parameters correct?
|
|
24
|
+
- Response Synthesis: Did they use tool results correctly?
|
|
25
|
+
- Timing: Did they call tools at the right time?
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
>>> grader = ToolCallGrader(tools=[web_search, db_lookup])
|
|
29
|
+
>>> result = await grader.grade(trace, policy_text)
|
|
30
|
+
>>> if not result.passed:
|
|
31
|
+
... print(f"Issues: {result.issues}")
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
tools: list["ToolDefinition"],
|
|
37
|
+
llm: LLM | None = None,
|
|
38
|
+
model: Model = OpenAI.GPT_52,
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Initialize the tool call grader.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
tools: List of available tool definitions (for context)
|
|
45
|
+
llm: LLM client to use (creates one if not provided)
|
|
46
|
+
model: Model to use if creating LLM (recommend stronger model)
|
|
47
|
+
"""
|
|
48
|
+
super().__init__(llm=llm, model=model)
|
|
49
|
+
self.tools = tools
|
|
50
|
+
|
|
51
|
+
def _get_tools_description(self) -> str:
|
|
52
|
+
"""Get formatted description of all tools for grading context."""
|
|
53
|
+
descriptions = []
|
|
54
|
+
for tool in self.tools:
|
|
55
|
+
descriptions.append(tool.to_system_prompt())
|
|
56
|
+
return "\n\n".join(descriptions)
|
|
57
|
+
|
|
58
|
+
def _format_conversation(self, trace: Trace) -> str:
|
|
59
|
+
"""Format the trace messages for the grading prompt, including tool_calls."""
|
|
60
|
+
lines = []
|
|
61
|
+
for msg in trace.messages:
|
|
62
|
+
if msg.role == "system":
|
|
63
|
+
lines.append(f"[SYSTEM]\n{msg.content}")
|
|
64
|
+
elif msg.role == "user":
|
|
65
|
+
lines.append(f"[USER]\n{msg.content}")
|
|
66
|
+
elif msg.role == "assistant":
|
|
67
|
+
if msg.tool_calls:
|
|
68
|
+
# Format assistant message with tool calls
|
|
69
|
+
tool_calls_str = []
|
|
70
|
+
for tc in msg.tool_calls:
|
|
71
|
+
tool_calls_str.append(
|
|
72
|
+
f" - {tc.function.name}({tc.function.arguments})"
|
|
73
|
+
)
|
|
74
|
+
lines.append(
|
|
75
|
+
f"[ASSISTANT - TOOL CALLS]\n" + "\n".join(tool_calls_str)
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
lines.append(f"[ASSISTANT]\n{msg.content}")
|
|
79
|
+
elif msg.role == "tool":
|
|
80
|
+
lines.append(
|
|
81
|
+
f"[TOOL RESULT - {msg.tool_call_id}]\n{msg.content}"
|
|
82
|
+
)
|
|
83
|
+
return "\n\n".join(lines)
|
|
84
|
+
|
|
85
|
+
async def grade(self, trace: Trace, policy_text: str) -> GradeResult:
|
|
86
|
+
"""
|
|
87
|
+
Grade a tool call trace using tool-specific criteria.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
trace: The trace to grade
|
|
91
|
+
policy_text: The policy/guidelines text
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
GradeResult with pass/fail and detailed feedback
|
|
95
|
+
"""
|
|
96
|
+
tools_desc = self._get_tools_description()
|
|
97
|
+
conversation = self._format_conversation(trace)
|
|
98
|
+
|
|
99
|
+
prompt = TOOL_GRADE_PROMPT.format(
|
|
100
|
+
TOOLS_DESCRIPTION=tools_desc,
|
|
101
|
+
GUIDELINES=policy_text,
|
|
102
|
+
SCENARIO=trace.scenario.description,
|
|
103
|
+
CONVERSATION=conversation,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# Use structured output for consistent grading
|
|
108
|
+
parsed = await self.llm.generate_structured(prompt, ToolCallGrade)
|
|
109
|
+
|
|
110
|
+
# Convert to standard GradeResult format
|
|
111
|
+
return GradeResult(
|
|
112
|
+
passed=parsed.passed,
|
|
113
|
+
issues=parsed.get_all_issues(),
|
|
114
|
+
feedback=parsed.feedback,
|
|
115
|
+
)
|
|
116
|
+
except Exception:
|
|
117
|
+
# Fallback: assume fail if we can't parse
|
|
118
|
+
return GradeResult(
|
|
119
|
+
passed=False,
|
|
120
|
+
issues=["Unable to parse grade response"],
|
|
121
|
+
feedback="Grading failed - unable to parse response",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
__all__ = ["ToolCallGrader"]
|
|
126
|
+
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Specialized refinement for tool call traces that preserves format."""
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from synkro.quality.refiner import Refiner
|
|
6
|
+
from synkro.llm.client import LLM
|
|
7
|
+
from synkro.models import Model, OpenAI
|
|
8
|
+
from synkro.types.core import Trace, GradeResult, Scenario
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from synkro.types.tool import ToolDefinition
|
|
12
|
+
from synkro.generation.tool_simulator import ToolSimulator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ToolCallRefiner(Refiner):
|
|
16
|
+
"""
|
|
17
|
+
Specialized refiner for tool call traces.
|
|
18
|
+
|
|
19
|
+
Unlike the base Refiner which generates plain text responses, this refiner
|
|
20
|
+
uses the ToolCallResponseGenerator to regenerate traces, ensuring the
|
|
21
|
+
tool_calls format is preserved in the output.
|
|
22
|
+
|
|
23
|
+
The grading feedback is incorporated into the scenario context so the
|
|
24
|
+
LLM knows what to fix during regeneration.
|
|
25
|
+
|
|
26
|
+
Examples:
|
|
27
|
+
>>> refiner = ToolCallRefiner(
|
|
28
|
+
... tools=[web_search, db_lookup],
|
|
29
|
+
... simulator=tool_simulator,
|
|
30
|
+
... )
|
|
31
|
+
>>> improved = await refiner.refine(failed_trace, grade, policy_text)
|
|
32
|
+
>>> # improved trace has proper tool_calls format
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
tools: list["ToolDefinition"],
|
|
38
|
+
simulator: "ToolSimulator",
|
|
39
|
+
llm: LLM | None = None,
|
|
40
|
+
model: Model = OpenAI.GPT_4O_MINI,
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Initialize the tool call refiner.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
tools: List of available tool definitions
|
|
47
|
+
simulator: Tool simulator for generating tool responses
|
|
48
|
+
llm: LLM client to use (creates one if not provided)
|
|
49
|
+
model: Model to use if creating LLM
|
|
50
|
+
"""
|
|
51
|
+
super().__init__(llm=llm, model=model)
|
|
52
|
+
self.tools = tools
|
|
53
|
+
self.simulator = simulator
|
|
54
|
+
self._response_generator = None
|
|
55
|
+
|
|
56
|
+
def _get_response_generator(self):
|
|
57
|
+
"""Lazily create the ToolCallResponseGenerator."""
|
|
58
|
+
if self._response_generator is None:
|
|
59
|
+
from synkro.generation.tool_responses import ToolCallResponseGenerator
|
|
60
|
+
self._response_generator = ToolCallResponseGenerator(
|
|
61
|
+
tools=self.tools,
|
|
62
|
+
llm=self.llm,
|
|
63
|
+
simulator=self.simulator,
|
|
64
|
+
)
|
|
65
|
+
return self._response_generator
|
|
66
|
+
|
|
67
|
+
def _build_enhanced_scenario(
|
|
68
|
+
self, trace: Trace, grade: GradeResult
|
|
69
|
+
) -> Scenario:
|
|
70
|
+
"""
|
|
71
|
+
Build an enhanced scenario that includes grading feedback.
|
|
72
|
+
|
|
73
|
+
The feedback helps the LLM understand what went wrong and how to fix it.
|
|
74
|
+
"""
|
|
75
|
+
# Build feedback context
|
|
76
|
+
feedback_parts = []
|
|
77
|
+
if grade.issues:
|
|
78
|
+
feedback_parts.append("PREVIOUS ISSUES TO FIX:")
|
|
79
|
+
for issue in grade.issues:
|
|
80
|
+
feedback_parts.append(f" - {issue}")
|
|
81
|
+
if grade.feedback:
|
|
82
|
+
feedback_parts.append(f"\nGRADER FEEDBACK: {grade.feedback}")
|
|
83
|
+
|
|
84
|
+
feedback_context = "\n".join(feedback_parts) if feedback_parts else ""
|
|
85
|
+
|
|
86
|
+
# Enhance the context with feedback
|
|
87
|
+
enhanced_context = trace.scenario.context
|
|
88
|
+
if feedback_context:
|
|
89
|
+
enhanced_context = f"{trace.scenario.context}\n\n--- REFINEMENT GUIDANCE ---\n{feedback_context}"
|
|
90
|
+
|
|
91
|
+
return Scenario(
|
|
92
|
+
description=trace.scenario.description,
|
|
93
|
+
context=enhanced_context,
|
|
94
|
+
category=trace.scenario.category,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
async def refine(
|
|
98
|
+
self, trace: Trace, grade: GradeResult, policy_text: str
|
|
99
|
+
) -> Trace:
|
|
100
|
+
"""
|
|
101
|
+
Refine a failed tool call trace by regenerating with feedback.
|
|
102
|
+
|
|
103
|
+
Uses the ToolCallResponseGenerator to ensure the regenerated trace
|
|
104
|
+
maintains proper tool_calls format.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
trace: The trace that failed grading
|
|
108
|
+
grade: The grade result with feedback
|
|
109
|
+
policy_text: The policy/guidelines text
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
New trace with improved response and preserved tool_calls format
|
|
113
|
+
"""
|
|
114
|
+
# Create enhanced scenario with grading feedback
|
|
115
|
+
enhanced_scenario = self._build_enhanced_scenario(trace, grade)
|
|
116
|
+
|
|
117
|
+
# Regenerate using ToolCallResponseGenerator (preserves format)
|
|
118
|
+
generator = self._get_response_generator()
|
|
119
|
+
refined_trace = await generator.generate_single(policy_text, enhanced_scenario)
|
|
120
|
+
|
|
121
|
+
# Preserve the original scenario reference (without the feedback context)
|
|
122
|
+
refined_trace.scenario = trace.scenario
|
|
123
|
+
|
|
124
|
+
return refined_trace
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
__all__ = ["ToolCallRefiner"]
|
|
128
|
+
|