synkro 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkro/__init__.py +179 -0
- synkro/advanced.py +186 -0
- synkro/cli.py +128 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +402 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +276 -0
- synkro/formatters/__init__.py +12 -0
- synkro/formatters/qa.py +98 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +220 -0
- synkro/generation/golden_responses.py +244 -0
- synkro/generation/golden_scenarios.py +276 -0
- synkro/generation/golden_tool_responses.py +416 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +376 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +12 -0
- synkro/interactive/hitl_session.py +77 -0
- synkro/interactive/logic_map_editor.py +173 -0
- synkro/interactive/rich_ui.py +205 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +235 -0
- synkro/llm/rate_limits.py +95 -0
- synkro/models/__init__.py +43 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +15 -0
- synkro/modes/config.py +66 -0
- synkro/modes/qa.py +18 -0
- synkro/modes/sft.py +18 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +424 -0
- synkro/pipelines.py +123 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +474 -0
- synkro/prompts/interactive_templates.py +65 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/qa_templates.py +97 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +201 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +537 -0
- synkro/schemas.py +472 -0
- synkro/types/__init__.py +41 -0
- synkro/types/core.py +126 -0
- synkro/types/dataset_type.py +30 -0
- synkro/types/logic_map.py +345 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.12.data/data/examples/__init__.py +148 -0
- synkro-0.4.12.dist-info/METADATA +258 -0
- synkro-0.4.12.dist-info/RECORD +77 -0
- synkro-0.4.12.dist-info/WHEEL +4 -0
- synkro-0.4.12.dist-info/entry_points.txt +2 -0
- synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Follow-up question generation for multi-turn conversations."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from synkro.llm.client import LLM
|
|
6
|
+
from synkro.models import Model, OpenAI
|
|
7
|
+
from synkro.types.core import Message
|
|
8
|
+
from synkro.prompts.multiturn_templates import FOLLOW_UP_GENERATION_PROMPT
|
|
9
|
+
from synkro.schemas import FollowUpQuestion
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
QuestionType = Literal["clarification", "edge_case", "what_if", "specificity", "challenge"]
|
|
13
|
+
|
|
14
|
+
# Question type progression for multi-turn conversations
|
|
15
|
+
# Earlier turns focus on clarification, later turns probe deeper
|
|
16
|
+
QUESTION_TYPE_BY_TURN = {
|
|
17
|
+
1: "clarification",
|
|
18
|
+
2: "specificity",
|
|
19
|
+
3: "edge_case",
|
|
20
|
+
4: "what_if",
|
|
21
|
+
5: "challenge",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FollowUpGenerator:
|
|
26
|
+
"""
|
|
27
|
+
Generates follow-up questions for multi-turn conversations.
|
|
28
|
+
|
|
29
|
+
Uses different question types based on turn index:
|
|
30
|
+
- Turn 1: clarification - Ask for more details
|
|
31
|
+
- Turn 2: specificity - Drill into specifics
|
|
32
|
+
- Turn 3: edge_case - Probe boundary conditions
|
|
33
|
+
- Turn 4: what_if - Explore hypotheticals
|
|
34
|
+
- Turn 5+: challenge - Question reasoning
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> gen = FollowUpGenerator()
|
|
38
|
+
>>> follow_up = await gen.generate(policy_text, messages, turn_index=2)
|
|
39
|
+
>>> print(follow_up.question)
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, llm: LLM | None = None, model: Model = OpenAI.GPT_4O_MINI):
|
|
43
|
+
"""
|
|
44
|
+
Initialize the follow-up generator.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
llm: LLM client to use (creates one if not provided)
|
|
48
|
+
model: Model to use if creating LLM
|
|
49
|
+
"""
|
|
50
|
+
self.llm = llm or LLM(model=model)
|
|
51
|
+
|
|
52
|
+
def _select_question_type(self, turn_index: int) -> QuestionType:
|
|
53
|
+
"""
|
|
54
|
+
Select question type based on turn index.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
turn_index: Which turn this is (1-based, counting user-assistant exchanges)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Appropriate question type for this turn
|
|
61
|
+
"""
|
|
62
|
+
if turn_index in QUESTION_TYPE_BY_TURN:
|
|
63
|
+
return QUESTION_TYPE_BY_TURN[turn_index]
|
|
64
|
+
# For turns beyond 5, cycle through challenging questions
|
|
65
|
+
return "challenge"
|
|
66
|
+
|
|
67
|
+
def _format_conversation(self, messages: list[Message]) -> str:
|
|
68
|
+
"""Format conversation messages for prompt inclusion."""
|
|
69
|
+
formatted = []
|
|
70
|
+
for msg in messages:
|
|
71
|
+
role = msg.role.upper()
|
|
72
|
+
content = msg.content or "[No content]"
|
|
73
|
+
formatted.append(f"{role}: {content}")
|
|
74
|
+
return "\n\n".join(formatted)
|
|
75
|
+
|
|
76
|
+
async def generate(
|
|
77
|
+
self,
|
|
78
|
+
policy_text: str,
|
|
79
|
+
messages: list[Message],
|
|
80
|
+
turn_index: int,
|
|
81
|
+
question_type: QuestionType | None = None,
|
|
82
|
+
scenario_index: int = 0,
|
|
83
|
+
) -> FollowUpQuestion:
|
|
84
|
+
"""
|
|
85
|
+
Generate a follow-up question for the conversation.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
policy_text: The policy text for context
|
|
89
|
+
messages: Conversation messages so far
|
|
90
|
+
turn_index: Which turn this is (1-based)
|
|
91
|
+
question_type: Override auto-selected question type
|
|
92
|
+
scenario_index: Index for the scenario (default 0)
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
FollowUpQuestion with the generated question
|
|
96
|
+
"""
|
|
97
|
+
# Select question type if not specified
|
|
98
|
+
if question_type is None:
|
|
99
|
+
question_type = self._select_question_type(turn_index)
|
|
100
|
+
|
|
101
|
+
# Format conversation for prompt
|
|
102
|
+
conversation = self._format_conversation(messages)
|
|
103
|
+
|
|
104
|
+
# Build prompt
|
|
105
|
+
prompt = FOLLOW_UP_GENERATION_PROMPT.format(
|
|
106
|
+
question_type=question_type,
|
|
107
|
+
conversation=conversation,
|
|
108
|
+
policy=policy_text,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
# Generate the follow-up question
|
|
113
|
+
response = await self.llm.generate(prompt)
|
|
114
|
+
question_text = response.strip()
|
|
115
|
+
|
|
116
|
+
return FollowUpQuestion(
|
|
117
|
+
index=scenario_index,
|
|
118
|
+
question=question_text,
|
|
119
|
+
question_type=question_type,
|
|
120
|
+
)
|
|
121
|
+
except Exception:
|
|
122
|
+
# Fallback generic follow-up
|
|
123
|
+
fallback_questions = {
|
|
124
|
+
"clarification": "Can you clarify that further?",
|
|
125
|
+
"edge_case": "What about edge cases?",
|
|
126
|
+
"what_if": "What if the situation changes?",
|
|
127
|
+
"specificity": "Can you be more specific?",
|
|
128
|
+
"challenge": "Why is that the best approach?",
|
|
129
|
+
}
|
|
130
|
+
return FollowUpQuestion(
|
|
131
|
+
index=scenario_index,
|
|
132
|
+
question=fallback_questions.get(question_type, "Can you elaborate?"),
|
|
133
|
+
question_type=question_type,
|
|
134
|
+
)
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
"""Main Generator class orchestrating the full trace generation pipeline."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from synkro.llm.client import LLM
|
|
9
|
+
from synkro.llm.rate_limits import auto_workers
|
|
10
|
+
from synkro.models import Model, OpenAI
|
|
11
|
+
from synkro.types.dataset_type import DatasetType
|
|
12
|
+
from synkro.core.policy import Policy
|
|
13
|
+
from synkro.core.dataset import Dataset
|
|
14
|
+
from synkro.core.checkpoint import CheckpointManager, hash_policy
|
|
15
|
+
from synkro.modes.config import get_mode_config
|
|
16
|
+
from synkro.errors import handle_error
|
|
17
|
+
from synkro.factory import ComponentFactory
|
|
18
|
+
from synkro.reporting import ProgressReporter, RichReporter
|
|
19
|
+
from synkro.pipeline.runner import GenerationPipeline, GenerationResult
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from synkro.types.tool import ToolDefinition
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Generator:
|
|
26
|
+
"""
|
|
27
|
+
Main orchestrator for generating training datasets.
|
|
28
|
+
|
|
29
|
+
The Generator handles the full pipeline:
|
|
30
|
+
1. Plan: Analyze policy and create category distribution
|
|
31
|
+
2. Generate: Create scenarios and responses
|
|
32
|
+
3. Grade: Evaluate response quality
|
|
33
|
+
4. Refine: Fix failed responses
|
|
34
|
+
5. Return: Dataset of passing traces
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> generator = Generator()
|
|
38
|
+
>>> dataset = generator.generate(policy, traces=20)
|
|
39
|
+
|
|
40
|
+
>>> # QA dataset
|
|
41
|
+
>>> generator = Generator(dataset_type=DatasetType.QA)
|
|
42
|
+
>>> dataset = generator.generate(policy)
|
|
43
|
+
|
|
44
|
+
>>> # Silent mode (no console output)
|
|
45
|
+
>>> from synkro.reporting import SilentReporter
|
|
46
|
+
>>> generator = Generator(reporter=SilentReporter())
|
|
47
|
+
>>> dataset = generator.generate(policy)
|
|
48
|
+
|
|
49
|
+
>>> # Tool call dataset
|
|
50
|
+
>>> from synkro import ToolDefinition
|
|
51
|
+
>>> tools = [ToolDefinition(name="search", description="...", parameters={})]
|
|
52
|
+
>>> generator = Generator(dataset_type=DatasetType.TOOL_CALL, tools=tools)
|
|
53
|
+
>>> dataset = generator.generate("Usage guidelines", traces=20)
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
dataset_type: DatasetType = DatasetType.SFT,
|
|
59
|
+
generation_model: Model = OpenAI.GPT_4O_MINI,
|
|
60
|
+
grading_model: Model = OpenAI.GPT_4O,
|
|
61
|
+
max_iterations: int = 1,
|
|
62
|
+
skip_grading: bool = False,
|
|
63
|
+
reporter: ProgressReporter | None = None,
|
|
64
|
+
tools: list["ToolDefinition"] | None = None,
|
|
65
|
+
turns: int | str = "auto",
|
|
66
|
+
checkpoint_dir: str | Path | None = None,
|
|
67
|
+
enable_hitl: bool = True,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Initialize the Generator.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
dataset_type: Type of dataset to generate (QA, SFT, or TOOL_CALL)
|
|
74
|
+
generation_model: Model for scenarios/responses (default: gpt-4o-mini)
|
|
75
|
+
grading_model: Model for grading (default: gpt-4o, recommend stronger)
|
|
76
|
+
max_iterations: Max refinement iterations per trace (default: 1, no retries)
|
|
77
|
+
skip_grading: Skip grading phase for faster generation (default: False)
|
|
78
|
+
reporter: Progress reporter (default: RichReporter for console output)
|
|
79
|
+
tools: List of ToolDefinition for TOOL_CALL dataset type
|
|
80
|
+
turns: Conversation turns per trace. Use int for fixed turns, or "auto"
|
|
81
|
+
for policy complexity-driven turns (Simple=1-2, Conditional=3, Complex=5+)
|
|
82
|
+
checkpoint_dir: Directory for checkpoints. If provided, enables resumable
|
|
83
|
+
generation. Progress is saved after each stage.
|
|
84
|
+
enable_hitl: Enable Human-in-the-Loop Logic Map editing. When enabled,
|
|
85
|
+
pauses after Logic Map extraction to allow interactive refinement.
|
|
86
|
+
"""
|
|
87
|
+
self.dataset_type = dataset_type
|
|
88
|
+
self.mode_config = get_mode_config(dataset_type)
|
|
89
|
+
self.max_iterations = max_iterations
|
|
90
|
+
self.skip_grading = skip_grading
|
|
91
|
+
self.tools = tools
|
|
92
|
+
self.turns = turns
|
|
93
|
+
self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None
|
|
94
|
+
|
|
95
|
+
# Create checkpoint manager if checkpointing enabled
|
|
96
|
+
self.checkpoint_manager = (
|
|
97
|
+
CheckpointManager(self.checkpoint_dir) if self.checkpoint_dir else None
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# HITL configuration
|
|
101
|
+
self.enable_hitl = enable_hitl
|
|
102
|
+
|
|
103
|
+
# Validate tools for TOOL_CALL dataset type
|
|
104
|
+
if dataset_type == DatasetType.TOOL_CALL and not tools:
|
|
105
|
+
raise ValueError("TOOL_CALL dataset type requires tools parameter")
|
|
106
|
+
|
|
107
|
+
# Store model info for reporting
|
|
108
|
+
self.generation_model = generation_model
|
|
109
|
+
self.grading_model = grading_model
|
|
110
|
+
|
|
111
|
+
# Create LLM clients
|
|
112
|
+
self.generation_llm = LLM(model=generation_model)
|
|
113
|
+
self.grading_llm = LLM(model=grading_model)
|
|
114
|
+
|
|
115
|
+
# Create factory for component creation
|
|
116
|
+
self.factory = ComponentFactory(
|
|
117
|
+
generation_llm=self.generation_llm,
|
|
118
|
+
grading_llm=self.grading_llm,
|
|
119
|
+
mode_config=self.mode_config,
|
|
120
|
+
tools=tools,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Reporter for progress output
|
|
124
|
+
self.reporter = reporter or RichReporter()
|
|
125
|
+
|
|
126
|
+
# Auto-scale workers based on provider
|
|
127
|
+
model_str = generation_model.value if isinstance(generation_model, Enum) else str(generation_model)
|
|
128
|
+
self.workers = auto_workers(model_str)
|
|
129
|
+
|
|
130
|
+
# Create HITL editor if enabled
|
|
131
|
+
hitl_editor = self.factory.create_logic_map_editor() if enable_hitl else None
|
|
132
|
+
|
|
133
|
+
# Create pipeline
|
|
134
|
+
self.pipeline = GenerationPipeline(
|
|
135
|
+
factory=self.factory,
|
|
136
|
+
reporter=self.reporter,
|
|
137
|
+
workers=self.workers,
|
|
138
|
+
max_iterations=max_iterations,
|
|
139
|
+
skip_grading=skip_grading,
|
|
140
|
+
checkpoint_manager=self.checkpoint_manager,
|
|
141
|
+
enable_hitl=enable_hitl,
|
|
142
|
+
hitl_editor=hitl_editor,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
@handle_error
|
|
146
|
+
def generate(
|
|
147
|
+
self,
|
|
148
|
+
policy: Policy | str,
|
|
149
|
+
traces: int = 20,
|
|
150
|
+
return_logic_map: bool = False,
|
|
151
|
+
) -> Dataset | GenerationResult:
|
|
152
|
+
"""
|
|
153
|
+
Generate a training dataset from a policy.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
policy: Policy object or text string
|
|
157
|
+
traces: Target number of traces to generate (default: 20)
|
|
158
|
+
return_logic_map: If True, return GenerationResult with access to
|
|
159
|
+
the Logic Map, scenarios, and distribution (default: False)
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Dataset (default) or GenerationResult if return_logic_map=True
|
|
163
|
+
|
|
164
|
+
Examples:
|
|
165
|
+
>>> # Standard usage
|
|
166
|
+
>>> dataset = generator.generate(policy, traces=50)
|
|
167
|
+
|
|
168
|
+
>>> # Access Logic Map for inspection
|
|
169
|
+
>>> result = generator.generate(policy, return_logic_map=True)
|
|
170
|
+
>>> print(result.logic_map.rules) # See extracted rules
|
|
171
|
+
>>> print(result.distribution) # See scenario type counts
|
|
172
|
+
>>> dataset = result.dataset # Get the dataset
|
|
173
|
+
"""
|
|
174
|
+
if isinstance(policy, str):
|
|
175
|
+
policy = Policy(text=policy)
|
|
176
|
+
|
|
177
|
+
# Validate policy has enough content
|
|
178
|
+
policy.validate_length()
|
|
179
|
+
|
|
180
|
+
return asyncio.run(self._generate_async(policy, traces, return_logic_map))
|
|
181
|
+
|
|
182
|
+
async def _generate_async(
|
|
183
|
+
self,
|
|
184
|
+
policy: Policy,
|
|
185
|
+
traces: int,
|
|
186
|
+
return_logic_map: bool = False,
|
|
187
|
+
) -> Dataset | GenerationResult:
|
|
188
|
+
"""Async implementation of generation pipeline."""
|
|
189
|
+
model_str = self.generation_model.value if isinstance(self.generation_model, Enum) else str(self.generation_model)
|
|
190
|
+
|
|
191
|
+
return await self.pipeline.run(
|
|
192
|
+
policy=policy,
|
|
193
|
+
traces=traces,
|
|
194
|
+
model=model_str,
|
|
195
|
+
dataset_type=self.dataset_type.value,
|
|
196
|
+
turns=self.turns,
|
|
197
|
+
return_result=return_logic_map,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
async def generate_async(
|
|
201
|
+
self,
|
|
202
|
+
policy: Policy | str,
|
|
203
|
+
traces: int = 20,
|
|
204
|
+
return_logic_map: bool = False,
|
|
205
|
+
) -> Dataset | GenerationResult:
|
|
206
|
+
"""
|
|
207
|
+
Async version of generate for use in async contexts.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
policy: Policy object or text string
|
|
211
|
+
traces: Target number of traces to generate (default: 20)
|
|
212
|
+
return_logic_map: If True, return GenerationResult with Logic Map access
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Dataset (default) or GenerationResult if return_logic_map=True
|
|
216
|
+
"""
|
|
217
|
+
if isinstance(policy, str):
|
|
218
|
+
policy = Policy(text=policy)
|
|
219
|
+
|
|
220
|
+
return await self._generate_async(policy, traces, return_logic_map)
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Golden Response Generator - The Thinker.
|
|
2
|
+
|
|
3
|
+
Generates traces with grounded Chain-of-Thought reasoning and rule citations.
|
|
4
|
+
This is Stage 3 of the Golden Trace pipeline for SFT/QA datasets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from synkro.llm.client import LLM
|
|
11
|
+
from synkro.models import Model, OpenAI
|
|
12
|
+
from synkro.schemas import GoldenTraceOutput
|
|
13
|
+
from synkro.types.core import Trace, Message, Scenario
|
|
14
|
+
from synkro.types.logic_map import (
|
|
15
|
+
LogicMap,
|
|
16
|
+
GoldenScenario,
|
|
17
|
+
ReasoningStep,
|
|
18
|
+
)
|
|
19
|
+
from synkro.prompts.golden_templates import (
|
|
20
|
+
GOLDEN_TRACE_PROMPT,
|
|
21
|
+
GOLDEN_TRACE_MULTI_TURN_PROMPT,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GoldenResponseGenerator:
|
|
26
|
+
"""
|
|
27
|
+
The Thinker - Generates traces with grounded reasoning.
|
|
28
|
+
|
|
29
|
+
Produces traces with:
|
|
30
|
+
- Explicit Chain-of-Thought reasoning
|
|
31
|
+
- Rule citations (Rule IDs) for each reasoning step
|
|
32
|
+
- Exclusionary reasoning (why rules DON'T apply)
|
|
33
|
+
- DAG-compliant dependency order
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
>>> generator = GoldenResponseGenerator(llm=LLM(model=OpenAI.GPT_4O_MINI))
|
|
37
|
+
>>> trace = await generator.generate_single(
|
|
38
|
+
... policy_text="...",
|
|
39
|
+
... logic_map=logic_map,
|
|
40
|
+
... scenario=scenario,
|
|
41
|
+
... target_turns=1,
|
|
42
|
+
... )
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
llm: LLM | None = None,
|
|
48
|
+
model: Model = OpenAI.GPT_4O_MINI,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Initialize the Golden Response Generator.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
llm: LLM client to use (creates one if not provided)
|
|
55
|
+
model: Model to use if creating LLM
|
|
56
|
+
"""
|
|
57
|
+
self.llm = llm or LLM(model=model, temperature=0.7)
|
|
58
|
+
|
|
59
|
+
async def generate_single(
|
|
60
|
+
self,
|
|
61
|
+
policy_text: str,
|
|
62
|
+
logic_map: LogicMap,
|
|
63
|
+
scenario: GoldenScenario,
|
|
64
|
+
target_turns: int = 1,
|
|
65
|
+
) -> Trace:
|
|
66
|
+
"""
|
|
67
|
+
Generate a single trace with grounded reasoning.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
policy_text: The policy document text
|
|
71
|
+
logic_map: The extracted Logic Map (DAG of rules)
|
|
72
|
+
scenario: The golden scenario to respond to
|
|
73
|
+
target_turns: Number of conversation turns
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Trace with messages and reasoning metadata
|
|
77
|
+
"""
|
|
78
|
+
if target_turns > 1:
|
|
79
|
+
return await self._generate_multi_turn(
|
|
80
|
+
policy_text, logic_map, scenario, target_turns
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
return await self._generate_single_turn(policy_text, logic_map, scenario)
|
|
84
|
+
|
|
85
|
+
async def _generate_single_turn(
|
|
86
|
+
self,
|
|
87
|
+
policy_text: str,
|
|
88
|
+
logic_map: LogicMap,
|
|
89
|
+
scenario: GoldenScenario,
|
|
90
|
+
) -> Trace:
|
|
91
|
+
"""Generate a single-turn trace."""
|
|
92
|
+
# Format Logic Map for prompt
|
|
93
|
+
logic_map_str = self._format_logic_map(logic_map)
|
|
94
|
+
|
|
95
|
+
# Build prompt
|
|
96
|
+
prompt = GOLDEN_TRACE_PROMPT.format(
|
|
97
|
+
policy_text=policy_text,
|
|
98
|
+
logic_map=logic_map_str,
|
|
99
|
+
scenario_description=scenario.description,
|
|
100
|
+
scenario_context=scenario.context,
|
|
101
|
+
target_rule_ids=", ".join(scenario.target_rule_ids),
|
|
102
|
+
scenario_type=scenario.scenario_type.value.upper(),
|
|
103
|
+
expected_outcome=scenario.expected_outcome,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Generate structured output
|
|
107
|
+
result = await self.llm.generate_structured(prompt, GoldenTraceOutput)
|
|
108
|
+
|
|
109
|
+
# Convert to Trace
|
|
110
|
+
messages = [
|
|
111
|
+
Message(role=m.role, content=m.content)
|
|
112
|
+
for m in result.messages
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
# Convert GoldenScenario to base Scenario for Trace
|
|
116
|
+
base_scenario = scenario.to_base_scenario()
|
|
117
|
+
|
|
118
|
+
# Convert reasoning chain to serializable format
|
|
119
|
+
reasoning_chain = None
|
|
120
|
+
if result.reasoning_chain:
|
|
121
|
+
reasoning_chain = [
|
|
122
|
+
{
|
|
123
|
+
"rule_id": step.rule_id,
|
|
124
|
+
"rule_text": step.rule_text,
|
|
125
|
+
"applies": step.applies,
|
|
126
|
+
"reasoning": step.reasoning,
|
|
127
|
+
"exclusions": step.exclusions,
|
|
128
|
+
}
|
|
129
|
+
for step in result.reasoning_chain
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
return Trace(
|
|
133
|
+
messages=messages,
|
|
134
|
+
scenario=base_scenario,
|
|
135
|
+
reasoning_chain=reasoning_chain,
|
|
136
|
+
rules_applied=result.rules_applied,
|
|
137
|
+
rules_excluded=result.rules_excluded,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
async def _generate_multi_turn(
|
|
141
|
+
self,
|
|
142
|
+
policy_text: str,
|
|
143
|
+
logic_map: LogicMap,
|
|
144
|
+
scenario: GoldenScenario,
|
|
145
|
+
target_turns: int,
|
|
146
|
+
) -> Trace:
|
|
147
|
+
"""Generate a multi-turn trace."""
|
|
148
|
+
# Format Logic Map for prompt
|
|
149
|
+
logic_map_str = self._format_logic_map(logic_map)
|
|
150
|
+
|
|
151
|
+
# Build prompt
|
|
152
|
+
prompt = GOLDEN_TRACE_MULTI_TURN_PROMPT.format(
|
|
153
|
+
policy_text=policy_text,
|
|
154
|
+
logic_map=logic_map_str,
|
|
155
|
+
scenario_description=scenario.description,
|
|
156
|
+
scenario_context=scenario.context,
|
|
157
|
+
target_rule_ids=", ".join(scenario.target_rule_ids),
|
|
158
|
+
scenario_type=scenario.scenario_type.value.upper(),
|
|
159
|
+
target_turns=target_turns,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Generate structured output
|
|
163
|
+
result = await self.llm.generate_structured(prompt, GoldenTraceOutput)
|
|
164
|
+
|
|
165
|
+
# Convert to Trace
|
|
166
|
+
messages = [
|
|
167
|
+
Message(role=m.role, content=m.content)
|
|
168
|
+
for m in result.messages
|
|
169
|
+
]
|
|
170
|
+
|
|
171
|
+
# Convert GoldenScenario to base Scenario for Trace
|
|
172
|
+
base_scenario = scenario.to_base_scenario()
|
|
173
|
+
|
|
174
|
+
# Convert reasoning chain to serializable format
|
|
175
|
+
reasoning_chain = None
|
|
176
|
+
if result.reasoning_chain:
|
|
177
|
+
reasoning_chain = [
|
|
178
|
+
{
|
|
179
|
+
"rule_id": step.rule_id,
|
|
180
|
+
"rule_text": step.rule_text,
|
|
181
|
+
"applies": step.applies,
|
|
182
|
+
"reasoning": step.reasoning,
|
|
183
|
+
"exclusions": step.exclusions,
|
|
184
|
+
}
|
|
185
|
+
for step in result.reasoning_chain
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
return Trace(
|
|
189
|
+
messages=messages,
|
|
190
|
+
scenario=base_scenario,
|
|
191
|
+
reasoning_chain=reasoning_chain,
|
|
192
|
+
rules_applied=result.rules_applied,
|
|
193
|
+
rules_excluded=result.rules_excluded,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def _format_logic_map(self, logic_map: LogicMap) -> str:
|
|
197
|
+
"""Format Logic Map for prompt inclusion."""
|
|
198
|
+
lines = []
|
|
199
|
+
lines.append("RULES:")
|
|
200
|
+
for rule in logic_map.rules:
|
|
201
|
+
deps = f" [depends on: {', '.join(rule.dependencies)}]" if rule.dependencies else ""
|
|
202
|
+
lines.append(
|
|
203
|
+
f" {rule.rule_id} ({rule.category.value}): {rule.text}{deps}"
|
|
204
|
+
)
|
|
205
|
+
lines.append(f" IF: {rule.condition}")
|
|
206
|
+
lines.append(f" THEN: {rule.action}")
|
|
207
|
+
|
|
208
|
+
lines.append("\nDEPENDENCY ORDER (evaluate in this order):")
|
|
209
|
+
# Show topological order for root rules and their chains
|
|
210
|
+
for root_id in logic_map.root_rules:
|
|
211
|
+
chain = logic_map.get_chain(root_id)
|
|
212
|
+
if chain:
|
|
213
|
+
chain_str = " -> ".join(r.rule_id for r in chain)
|
|
214
|
+
lines.append(f" {chain_str}")
|
|
215
|
+
|
|
216
|
+
return "\n".join(lines)
|
|
217
|
+
|
|
218
|
+
async def generate(
|
|
219
|
+
self,
|
|
220
|
+
policy_text: str,
|
|
221
|
+
logic_map: LogicMap,
|
|
222
|
+
scenarios: list[GoldenScenario],
|
|
223
|
+
target_turns: int = 1,
|
|
224
|
+
) -> list[Trace]:
|
|
225
|
+
"""
|
|
226
|
+
Generate traces for multiple scenarios.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
policy_text: The policy document text
|
|
230
|
+
logic_map: The extracted Logic Map
|
|
231
|
+
scenarios: List of golden scenarios
|
|
232
|
+
target_turns: Number of conversation turns
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
List of traces with grounded reasoning
|
|
236
|
+
"""
|
|
237
|
+
tasks = [
|
|
238
|
+
self.generate_single(policy_text, logic_map, s, target_turns)
|
|
239
|
+
for s in scenarios
|
|
240
|
+
]
|
|
241
|
+
return await asyncio.gather(*tasks)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
__all__ = ["GoldenResponseGenerator"]
|