synkro 0.4.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synkro might be problematic. Click here for more details.
- synkro/__init__.py +331 -0
- synkro/advanced.py +184 -0
- synkro/cli.py +156 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +432 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +291 -0
- synkro/formatters/__init__.py +18 -0
- synkro/formatters/chatml.py +121 -0
- synkro/formatters/langfuse.py +98 -0
- synkro/formatters/langsmith.py +98 -0
- synkro/formatters/qa.py +112 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +314 -0
- synkro/generation/golden_responses.py +269 -0
- synkro/generation/golden_scenarios.py +333 -0
- synkro/generation/golden_tool_responses.py +791 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +625 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +16 -0
- synkro/interactive/hitl_session.py +205 -0
- synkro/interactive/intent_classifier.py +94 -0
- synkro/interactive/logic_map_editor.py +176 -0
- synkro/interactive/rich_ui.py +459 -0
- synkro/interactive/scenario_editor.py +198 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +309 -0
- synkro/llm/rate_limits.py +99 -0
- synkro/models/__init__.py +50 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/local.py +104 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +13 -0
- synkro/modes/config.py +66 -0
- synkro/modes/conversation.py +35 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +769 -0
- synkro/pipelines.py +136 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +533 -0
- synkro/prompts/interactive_templates.py +198 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +318 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +464 -0
- synkro/schemas.py +521 -0
- synkro/types/__init__.py +43 -0
- synkro/types/core.py +153 -0
- synkro/types/dataset_type.py +33 -0
- synkro/types/logic_map.py +348 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.36.data/data/examples/__init__.py +148 -0
- synkro-0.4.36.dist-info/METADATA +507 -0
- synkro-0.4.36.dist-info/RECORD +81 -0
- synkro-0.4.36.dist-info/WHEEL +4 -0
- synkro-0.4.36.dist-info/entry_points.txt +2 -0
- synkro-0.4.36.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
"""Main Generator class orchestrating the full trace generation pipeline."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from synkro.llm.client import LLM
|
|
9
|
+
from synkro.llm.rate_limits import auto_workers
|
|
10
|
+
from synkro.models import Model, OpenAI
|
|
11
|
+
from synkro.types.dataset_type import DatasetType
|
|
12
|
+
from synkro.core.policy import Policy
|
|
13
|
+
from synkro.core.dataset import Dataset
|
|
14
|
+
from synkro.core.checkpoint import CheckpointManager, hash_policy
|
|
15
|
+
from synkro.modes.config import get_mode_config
|
|
16
|
+
from synkro.errors import handle_error
|
|
17
|
+
from synkro.factory import ComponentFactory
|
|
18
|
+
from synkro.reporting import ProgressReporter, RichReporter
|
|
19
|
+
from synkro.pipeline.runner import GenerationPipeline, GenerationResult, ScenariosResult
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from synkro.types.tool import ToolDefinition
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Generator:
|
|
26
|
+
"""
|
|
27
|
+
Main orchestrator for generating training datasets.
|
|
28
|
+
|
|
29
|
+
The Generator handles the full pipeline:
|
|
30
|
+
1. Plan: Analyze policy and create category distribution
|
|
31
|
+
2. Generate: Create scenarios and responses
|
|
32
|
+
3. Grade: Evaluate response quality
|
|
33
|
+
4. Refine: Fix failed responses
|
|
34
|
+
5. Return: Dataset of passing traces
|
|
35
|
+
|
|
36
|
+
Examples:
|
|
37
|
+
>>> generator = Generator()
|
|
38
|
+
>>> dataset = generator.generate(policy, traces=20)
|
|
39
|
+
|
|
40
|
+
>>> # Conversation dataset (default, multi-turn)
|
|
41
|
+
>>> generator = Generator(dataset_type=DatasetType.CONVERSATION)
|
|
42
|
+
>>> dataset = generator.generate(policy)
|
|
43
|
+
|
|
44
|
+
>>> # Instruction dataset (single-turn)
|
|
45
|
+
>>> generator = Generator(dataset_type=DatasetType.INSTRUCTION)
|
|
46
|
+
>>> dataset = generator.generate(policy)
|
|
47
|
+
|
|
48
|
+
>>> # Silent mode (no console output)
|
|
49
|
+
>>> from synkro.reporting import SilentReporter
|
|
50
|
+
>>> generator = Generator(reporter=SilentReporter())
|
|
51
|
+
>>> dataset = generator.generate(policy)
|
|
52
|
+
|
|
53
|
+
>>> # Tool call dataset
|
|
54
|
+
>>> from synkro import ToolDefinition
|
|
55
|
+
>>> tools = [ToolDefinition(name="search", description="...", parameters={})]
|
|
56
|
+
>>> generator = Generator(dataset_type=DatasetType.TOOL_CALL, tools=tools)
|
|
57
|
+
>>> dataset = generator.generate("Usage guidelines", traces=20)
|
|
58
|
+
|
|
59
|
+
>>> # Eval dataset with low temperature for deterministic outputs
|
|
60
|
+
>>> generator = Generator(dataset_type=DatasetType.EVALUATION, temperature=0.2)
|
|
61
|
+
>>> dataset = generator.generate(policy, traces=50)
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
dataset_type: DatasetType = DatasetType.CONVERSATION,
|
|
67
|
+
generation_model: Model = OpenAI.GPT_4O_MINI,
|
|
68
|
+
grading_model: Model = OpenAI.GPT_4O,
|
|
69
|
+
max_iterations: int = 1,
|
|
70
|
+
skip_grading: bool = False,
|
|
71
|
+
reporter: ProgressReporter | None = None,
|
|
72
|
+
tools: list["ToolDefinition"] | None = None,
|
|
73
|
+
turns: int | str = "auto",
|
|
74
|
+
checkpoint_dir: str | Path | None = None,
|
|
75
|
+
enable_hitl: bool = True,
|
|
76
|
+
base_url: str | None = None,
|
|
77
|
+
thinking: bool = False,
|
|
78
|
+
temperature: float = 0.7,
|
|
79
|
+
):
|
|
80
|
+
"""
|
|
81
|
+
Initialize the Generator.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
dataset_type: Type of dataset to generate (CONVERSATION, INSTRUCTION, or TOOL_CALL)
|
|
85
|
+
generation_model: Model for scenarios/responses (default: gpt-4o-mini)
|
|
86
|
+
grading_model: Model for grading (default: gpt-4o, recommend stronger)
|
|
87
|
+
max_iterations: Max refinement iterations per trace (default: 1, no retries)
|
|
88
|
+
skip_grading: Skip grading phase for faster generation (default: False)
|
|
89
|
+
reporter: Progress reporter (default: RichReporter for console output)
|
|
90
|
+
tools: List of ToolDefinition for TOOL_CALL dataset type
|
|
91
|
+
turns: Conversation turns per trace. Use int for fixed turns, or "auto"
|
|
92
|
+
for policy complexity-driven turns (Simple=1-2, Conditional=3, Complex=5+)
|
|
93
|
+
checkpoint_dir: Directory for checkpoints. If provided, enables resumable
|
|
94
|
+
generation. Progress is saved after each stage.
|
|
95
|
+
enable_hitl: Enable Human-in-the-Loop Logic Map editing. When enabled,
|
|
96
|
+
pauses after Logic Map extraction to allow interactive refinement.
|
|
97
|
+
base_url: Optional API base URL for local LLM providers (Ollama, vLLM, etc.)
|
|
98
|
+
thinking: Enable thinking mode with <think> tags in responses (default: False).
|
|
99
|
+
When enabled, assistant responses will include reasoning wrapped in
|
|
100
|
+
<think>...</think> tags, compatible with Qwen3 and DeepSeek-R1 formats.
|
|
101
|
+
temperature: Sampling temperature for generation (0.0-2.0, default: 0.7).
|
|
102
|
+
Lower values (0.1-0.3) produce more deterministic outputs for eval datasets.
|
|
103
|
+
Higher values (0.7-1.0) produce more diverse outputs for training data.
|
|
104
|
+
"""
|
|
105
|
+
self.dataset_type = dataset_type
|
|
106
|
+
self.mode_config = get_mode_config(dataset_type)
|
|
107
|
+
self.max_iterations = max_iterations
|
|
108
|
+
self.skip_grading = skip_grading
|
|
109
|
+
self.tools = tools
|
|
110
|
+
self.turns = turns
|
|
111
|
+
self.thinking = thinking
|
|
112
|
+
self.checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir else None
|
|
113
|
+
|
|
114
|
+
# Create checkpoint manager if checkpointing enabled
|
|
115
|
+
self.checkpoint_manager = (
|
|
116
|
+
CheckpointManager(self.checkpoint_dir) if self.checkpoint_dir else None
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# HITL configuration
|
|
120
|
+
self.enable_hitl = enable_hitl
|
|
121
|
+
|
|
122
|
+
# Validate tools for TOOL_CALL dataset type
|
|
123
|
+
if dataset_type == DatasetType.TOOL_CALL and not tools:
|
|
124
|
+
raise ValueError("TOOL_CALL dataset type requires tools parameter")
|
|
125
|
+
|
|
126
|
+
# Force turns=1 for INSTRUCTION and EVALUATION types
|
|
127
|
+
if dataset_type in (DatasetType.INSTRUCTION, DatasetType.EVALUATION):
|
|
128
|
+
self.turns = 1
|
|
129
|
+
|
|
130
|
+
# Store model info for reporting
|
|
131
|
+
self.generation_model = generation_model
|
|
132
|
+
self.grading_model = grading_model
|
|
133
|
+
|
|
134
|
+
# Create LLM clients
|
|
135
|
+
self.generation_llm = LLM(model=generation_model, base_url=base_url, temperature=temperature)
|
|
136
|
+
self.grading_llm = LLM(model=grading_model, base_url=base_url)
|
|
137
|
+
|
|
138
|
+
# Create factory for component creation
|
|
139
|
+
self.factory = ComponentFactory(
|
|
140
|
+
generation_llm=self.generation_llm,
|
|
141
|
+
grading_llm=self.grading_llm,
|
|
142
|
+
mode_config=self.mode_config,
|
|
143
|
+
tools=tools,
|
|
144
|
+
thinking=thinking,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Reporter for progress output
|
|
148
|
+
self.reporter = reporter or RichReporter()
|
|
149
|
+
|
|
150
|
+
# Auto-scale workers based on provider
|
|
151
|
+
model_str = generation_model.value if isinstance(generation_model, Enum) else str(generation_model)
|
|
152
|
+
self.workers = auto_workers(model_str)
|
|
153
|
+
|
|
154
|
+
# Create HITL editors if enabled
|
|
155
|
+
hitl_editor = self.factory.create_logic_map_editor() if enable_hitl else None
|
|
156
|
+
scenario_editor = self.factory.create_scenario_editor() if enable_hitl else None
|
|
157
|
+
|
|
158
|
+
# Create pipeline
|
|
159
|
+
self.pipeline = GenerationPipeline(
|
|
160
|
+
factory=self.factory,
|
|
161
|
+
reporter=self.reporter,
|
|
162
|
+
workers=self.workers,
|
|
163
|
+
max_iterations=max_iterations,
|
|
164
|
+
skip_grading=skip_grading,
|
|
165
|
+
checkpoint_manager=self.checkpoint_manager,
|
|
166
|
+
enable_hitl=enable_hitl,
|
|
167
|
+
hitl_editor=hitl_editor,
|
|
168
|
+
scenario_editor=scenario_editor,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
@handle_error
|
|
172
|
+
def generate(
|
|
173
|
+
self,
|
|
174
|
+
policy: Policy | str,
|
|
175
|
+
traces: int = 20,
|
|
176
|
+
return_logic_map: bool = False,
|
|
177
|
+
) -> Dataset | GenerationResult:
|
|
178
|
+
"""
|
|
179
|
+
Generate a training dataset from a policy.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
policy: Policy object or text string
|
|
183
|
+
traces: Target number of traces to generate (default: 20)
|
|
184
|
+
return_logic_map: If True, return GenerationResult with access to
|
|
185
|
+
the Logic Map, scenarios, and distribution (default: False)
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
Dataset (default) or GenerationResult if return_logic_map=True
|
|
189
|
+
|
|
190
|
+
Examples:
|
|
191
|
+
>>> # Standard usage
|
|
192
|
+
>>> dataset = generator.generate(policy, traces=50)
|
|
193
|
+
|
|
194
|
+
>>> # Access Logic Map for inspection
|
|
195
|
+
>>> result = generator.generate(policy, return_logic_map=True)
|
|
196
|
+
>>> print(result.logic_map.rules) # See extracted rules
|
|
197
|
+
>>> print(result.distribution) # See scenario type counts
|
|
198
|
+
>>> dataset = result.dataset # Get the dataset
|
|
199
|
+
"""
|
|
200
|
+
if isinstance(policy, str):
|
|
201
|
+
policy = Policy(text=policy)
|
|
202
|
+
|
|
203
|
+
# Validate policy has enough content
|
|
204
|
+
policy.validate_length()
|
|
205
|
+
|
|
206
|
+
return asyncio.run(self._generate_async(policy, traces, return_logic_map))
|
|
207
|
+
|
|
208
|
+
async def _generate_async(
|
|
209
|
+
self,
|
|
210
|
+
policy: Policy,
|
|
211
|
+
traces: int,
|
|
212
|
+
return_logic_map: bool = False,
|
|
213
|
+
) -> Dataset | GenerationResult:
|
|
214
|
+
"""Async implementation of generation pipeline."""
|
|
215
|
+
model_str = self.generation_model.value if isinstance(self.generation_model, Enum) else str(self.generation_model)
|
|
216
|
+
|
|
217
|
+
return await self.pipeline.run(
|
|
218
|
+
policy=policy,
|
|
219
|
+
traces=traces,
|
|
220
|
+
model=model_str,
|
|
221
|
+
dataset_type=self.dataset_type.value,
|
|
222
|
+
turns=self.turns,
|
|
223
|
+
return_result=return_logic_map,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
async def generate_async(
|
|
227
|
+
self,
|
|
228
|
+
policy: Policy | str,
|
|
229
|
+
traces: int = 20,
|
|
230
|
+
return_logic_map: bool = False,
|
|
231
|
+
) -> Dataset | GenerationResult:
|
|
232
|
+
"""
|
|
233
|
+
Async version of generate for use in async contexts.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
policy: Policy object or text string
|
|
237
|
+
traces: Target number of traces to generate (default: 20)
|
|
238
|
+
return_logic_map: If True, return GenerationResult with Logic Map access
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
Dataset (default) or GenerationResult if return_logic_map=True
|
|
242
|
+
"""
|
|
243
|
+
if isinstance(policy, str):
|
|
244
|
+
policy = Policy(text=policy)
|
|
245
|
+
|
|
246
|
+
return await self._generate_async(policy, traces, return_logic_map)
|
|
247
|
+
|
|
248
|
+
@handle_error
|
|
249
|
+
def generate_scenarios(
|
|
250
|
+
self,
|
|
251
|
+
policy: Policy | str,
|
|
252
|
+
count: int = 20,
|
|
253
|
+
) -> ScenariosResult:
|
|
254
|
+
"""
|
|
255
|
+
Generate eval scenarios without synthetic responses.
|
|
256
|
+
|
|
257
|
+
This runs stages 0-2 of the pipeline (planning, logic extraction,
|
|
258
|
+
scenario synthesis) but skips response generation. Use this for
|
|
259
|
+
creating eval datasets where you want to test your own model.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
policy: Policy object or text string
|
|
263
|
+
count: Target number of scenarios to generate (default: 20)
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
ScenariosResult with scenarios, logic_map, and distribution
|
|
267
|
+
|
|
268
|
+
Examples:
|
|
269
|
+
>>> result = generator.generate_scenarios(policy, count=100)
|
|
270
|
+
>>> for scenario in result.scenarios:
|
|
271
|
+
... response = my_model(scenario.user_message)
|
|
272
|
+
... grade = synkro.grade(response, scenario, policy)
|
|
273
|
+
"""
|
|
274
|
+
if isinstance(policy, str):
|
|
275
|
+
policy = Policy(text=policy)
|
|
276
|
+
|
|
277
|
+
# Validate policy has enough content
|
|
278
|
+
policy.validate_length()
|
|
279
|
+
|
|
280
|
+
return asyncio.run(self._generate_scenarios_async(policy, count))
|
|
281
|
+
|
|
282
|
+
async def _generate_scenarios_async(
|
|
283
|
+
self,
|
|
284
|
+
policy: Policy,
|
|
285
|
+
count: int,
|
|
286
|
+
) -> ScenariosResult:
|
|
287
|
+
"""Async implementation of scenario-only generation."""
|
|
288
|
+
model_str = self.generation_model.value if isinstance(self.generation_model, Enum) else str(self.generation_model)
|
|
289
|
+
|
|
290
|
+
return await self.pipeline.run_scenarios_only(
|
|
291
|
+
policy=policy,
|
|
292
|
+
count=count,
|
|
293
|
+
model=model_str,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
async def generate_scenarios_async(
|
|
297
|
+
self,
|
|
298
|
+
policy: Policy | str,
|
|
299
|
+
count: int = 20,
|
|
300
|
+
) -> ScenariosResult:
|
|
301
|
+
"""
|
|
302
|
+
Async version of generate_scenarios for use in async contexts.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
policy: Policy object or text string
|
|
306
|
+
count: Target number of scenarios to generate (default: 20)
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
ScenariosResult with scenarios, logic_map, and distribution
|
|
310
|
+
"""
|
|
311
|
+
if isinstance(policy, str):
|
|
312
|
+
policy = Policy(text=policy)
|
|
313
|
+
|
|
314
|
+
return await self._generate_scenarios_async(policy, count)
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
"""Golden Response Generator - The Thinker.
|
|
2
|
+
|
|
3
|
+
Generates traces with grounded Chain-of-Thought reasoning and rule citations.
|
|
4
|
+
This is Stage 3 of the Golden Trace pipeline for CONVERSATION/INSTRUCTION datasets.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
from synkro.llm.client import LLM
|
|
11
|
+
from synkro.models import Model, OpenAI
|
|
12
|
+
from synkro.schemas import GoldenTraceOutput
|
|
13
|
+
from synkro.types.core import Trace, Message, Scenario
|
|
14
|
+
from synkro.types.logic_map import (
|
|
15
|
+
LogicMap,
|
|
16
|
+
GoldenScenario,
|
|
17
|
+
ReasoningStep,
|
|
18
|
+
)
|
|
19
|
+
from synkro.prompts.golden_templates import (
|
|
20
|
+
GOLDEN_TRACE_PROMPT,
|
|
21
|
+
GOLDEN_TRACE_MULTI_TURN_PROMPT,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class GoldenResponseGenerator:
|
|
26
|
+
"""
|
|
27
|
+
The Thinker - Generates traces with grounded reasoning.
|
|
28
|
+
|
|
29
|
+
Produces traces with:
|
|
30
|
+
- Explicit Chain-of-Thought reasoning
|
|
31
|
+
- Rule citations (Rule IDs) for each reasoning step
|
|
32
|
+
- Exclusionary reasoning (why rules DON'T apply)
|
|
33
|
+
- DAG-compliant dependency order
|
|
34
|
+
|
|
35
|
+
Examples:
|
|
36
|
+
>>> generator = GoldenResponseGenerator(llm=LLM(model=OpenAI.GPT_4O_MINI))
|
|
37
|
+
>>> trace = await generator.generate_single(
|
|
38
|
+
... policy_text="...",
|
|
39
|
+
... logic_map=logic_map,
|
|
40
|
+
... scenario=scenario,
|
|
41
|
+
... target_turns=1,
|
|
42
|
+
... )
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Instruction to inject when thinking mode is enabled
|
|
46
|
+
THINKING_INSTRUCTION = """
|
|
47
|
+
THINKING MODE:
|
|
48
|
+
Your assistant response MUST include reasoning wrapped in <think> and </think> tags.
|
|
49
|
+
Place your step-by-step reasoning inside the think tags BEFORE your actual response.
|
|
50
|
+
|
|
51
|
+
Format:
|
|
52
|
+
<think>
|
|
53
|
+
[Your reasoning about which rules apply, why they apply/don't apply, etc.]
|
|
54
|
+
</think>
|
|
55
|
+
|
|
56
|
+
[Your actual response to the user]
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
llm: LLM | None = None,
|
|
62
|
+
model: Model = OpenAI.GPT_4O_MINI,
|
|
63
|
+
thinking: bool = False,
|
|
64
|
+
):
|
|
65
|
+
"""
|
|
66
|
+
Initialize the Golden Response Generator.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
llm: LLM client to use (creates one if not provided)
|
|
70
|
+
model: Model to use if creating LLM
|
|
71
|
+
thinking: Enable thinking mode with <think> tags in responses
|
|
72
|
+
"""
|
|
73
|
+
self.llm = llm or LLM(model=model, temperature=0.7)
|
|
74
|
+
self.thinking = thinking
|
|
75
|
+
|
|
76
|
+
async def generate_single(
|
|
77
|
+
self,
|
|
78
|
+
policy_text: str,
|
|
79
|
+
logic_map: LogicMap,
|
|
80
|
+
scenario: GoldenScenario,
|
|
81
|
+
target_turns: int = 1,
|
|
82
|
+
) -> Trace:
|
|
83
|
+
"""
|
|
84
|
+
Generate a single trace with grounded reasoning.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
policy_text: The policy document text
|
|
88
|
+
logic_map: The extracted Logic Map (DAG of rules)
|
|
89
|
+
scenario: The golden scenario to respond to
|
|
90
|
+
target_turns: Number of conversation turns
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Trace with messages and reasoning metadata
|
|
94
|
+
"""
|
|
95
|
+
if target_turns > 1:
|
|
96
|
+
return await self._generate_multi_turn(
|
|
97
|
+
policy_text, logic_map, scenario, target_turns
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return await self._generate_single_turn(policy_text, logic_map, scenario)
|
|
101
|
+
|
|
102
|
+
async def _generate_single_turn(
|
|
103
|
+
self,
|
|
104
|
+
policy_text: str,
|
|
105
|
+
logic_map: LogicMap,
|
|
106
|
+
scenario: GoldenScenario,
|
|
107
|
+
) -> Trace:
|
|
108
|
+
"""Generate a single-turn trace."""
|
|
109
|
+
# Format Logic Map for prompt
|
|
110
|
+
logic_map_str = self._format_logic_map(logic_map)
|
|
111
|
+
|
|
112
|
+
# Build prompt
|
|
113
|
+
prompt = GOLDEN_TRACE_PROMPT.format(
|
|
114
|
+
policy_text=policy_text,
|
|
115
|
+
logic_map=logic_map_str,
|
|
116
|
+
scenario_description=scenario.description,
|
|
117
|
+
scenario_context=scenario.context,
|
|
118
|
+
target_rule_ids=", ".join(scenario.target_rule_ids),
|
|
119
|
+
scenario_type=scenario.scenario_type.value.upper(),
|
|
120
|
+
expected_outcome=scenario.expected_outcome,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Inject thinking instruction if enabled
|
|
124
|
+
if self.thinking:
|
|
125
|
+
prompt = prompt + self.THINKING_INSTRUCTION
|
|
126
|
+
|
|
127
|
+
# Generate structured output
|
|
128
|
+
result = await self.llm.generate_structured(prompt, GoldenTraceOutput)
|
|
129
|
+
|
|
130
|
+
# Convert to Trace
|
|
131
|
+
messages = [
|
|
132
|
+
Message(role=m.role, content=m.content)
|
|
133
|
+
for m in result.messages
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
# Convert GoldenScenario to base Scenario for Trace
|
|
137
|
+
base_scenario = scenario.to_base_scenario()
|
|
138
|
+
|
|
139
|
+
# Convert reasoning chain to serializable format
|
|
140
|
+
reasoning_chain = None
|
|
141
|
+
if result.reasoning_chain:
|
|
142
|
+
reasoning_chain = [
|
|
143
|
+
{
|
|
144
|
+
"rule_id": step.rule_id,
|
|
145
|
+
"rule_text": step.rule_text,
|
|
146
|
+
"applies": step.applies,
|
|
147
|
+
"reasoning": step.reasoning,
|
|
148
|
+
"exclusions": step.exclusions,
|
|
149
|
+
}
|
|
150
|
+
for step in result.reasoning_chain
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
return Trace(
|
|
154
|
+
messages=messages,
|
|
155
|
+
scenario=base_scenario,
|
|
156
|
+
reasoning_chain=reasoning_chain,
|
|
157
|
+
rules_applied=result.rules_applied,
|
|
158
|
+
rules_excluded=result.rules_excluded,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
async def _generate_multi_turn(
|
|
162
|
+
self,
|
|
163
|
+
policy_text: str,
|
|
164
|
+
logic_map: LogicMap,
|
|
165
|
+
scenario: GoldenScenario,
|
|
166
|
+
target_turns: int,
|
|
167
|
+
) -> Trace:
|
|
168
|
+
"""Generate a multi-turn trace."""
|
|
169
|
+
# Format Logic Map for prompt
|
|
170
|
+
logic_map_str = self._format_logic_map(logic_map)
|
|
171
|
+
|
|
172
|
+
# Build prompt
|
|
173
|
+
prompt = GOLDEN_TRACE_MULTI_TURN_PROMPT.format(
|
|
174
|
+
policy_text=policy_text,
|
|
175
|
+
logic_map=logic_map_str,
|
|
176
|
+
scenario_description=scenario.description,
|
|
177
|
+
scenario_context=scenario.context,
|
|
178
|
+
target_rule_ids=", ".join(scenario.target_rule_ids),
|
|
179
|
+
scenario_type=scenario.scenario_type.value.upper(),
|
|
180
|
+
target_turns=target_turns,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Inject thinking instruction if enabled
|
|
184
|
+
if self.thinking:
|
|
185
|
+
prompt = prompt + self.THINKING_INSTRUCTION
|
|
186
|
+
|
|
187
|
+
# Generate structured output
|
|
188
|
+
result = await self.llm.generate_structured(prompt, GoldenTraceOutput)
|
|
189
|
+
|
|
190
|
+
# Convert to Trace
|
|
191
|
+
messages = [
|
|
192
|
+
Message(role=m.role, content=m.content)
|
|
193
|
+
for m in result.messages
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
# Convert GoldenScenario to base Scenario for Trace
|
|
197
|
+
base_scenario = scenario.to_base_scenario()
|
|
198
|
+
|
|
199
|
+
# Convert reasoning chain to serializable format
|
|
200
|
+
reasoning_chain = None
|
|
201
|
+
if result.reasoning_chain:
|
|
202
|
+
reasoning_chain = [
|
|
203
|
+
{
|
|
204
|
+
"rule_id": step.rule_id,
|
|
205
|
+
"rule_text": step.rule_text,
|
|
206
|
+
"applies": step.applies,
|
|
207
|
+
"reasoning": step.reasoning,
|
|
208
|
+
"exclusions": step.exclusions,
|
|
209
|
+
}
|
|
210
|
+
for step in result.reasoning_chain
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
return Trace(
|
|
214
|
+
messages=messages,
|
|
215
|
+
scenario=base_scenario,
|
|
216
|
+
reasoning_chain=reasoning_chain,
|
|
217
|
+
rules_applied=result.rules_applied,
|
|
218
|
+
rules_excluded=result.rules_excluded,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
def _format_logic_map(self, logic_map: LogicMap) -> str:
|
|
222
|
+
"""Format Logic Map for prompt inclusion."""
|
|
223
|
+
lines = []
|
|
224
|
+
lines.append("RULES:")
|
|
225
|
+
for rule in logic_map.rules:
|
|
226
|
+
deps = f" [depends on: {', '.join(rule.dependencies)}]" if rule.dependencies else ""
|
|
227
|
+
lines.append(
|
|
228
|
+
f" {rule.rule_id} ({rule.category.value}): {rule.text}{deps}"
|
|
229
|
+
)
|
|
230
|
+
lines.append(f" IF: {rule.condition}")
|
|
231
|
+
lines.append(f" THEN: {rule.action}")
|
|
232
|
+
|
|
233
|
+
lines.append("\nDEPENDENCY ORDER (evaluate in this order):")
|
|
234
|
+
# Show topological order for root rules and their chains
|
|
235
|
+
for root_id in logic_map.root_rules:
|
|
236
|
+
chain = logic_map.get_chain(root_id)
|
|
237
|
+
if chain:
|
|
238
|
+
chain_str = " -> ".join(r.rule_id for r in chain)
|
|
239
|
+
lines.append(f" {chain_str}")
|
|
240
|
+
|
|
241
|
+
return "\n".join(lines)
|
|
242
|
+
|
|
243
|
+
async def generate(
|
|
244
|
+
self,
|
|
245
|
+
policy_text: str,
|
|
246
|
+
logic_map: LogicMap,
|
|
247
|
+
scenarios: list[GoldenScenario],
|
|
248
|
+
target_turns: int = 1,
|
|
249
|
+
) -> list[Trace]:
|
|
250
|
+
"""
|
|
251
|
+
Generate traces for multiple scenarios.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
policy_text: The policy document text
|
|
255
|
+
logic_map: The extracted Logic Map
|
|
256
|
+
scenarios: List of golden scenarios
|
|
257
|
+
target_turns: Number of conversation turns
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
List of traces with grounded reasoning
|
|
261
|
+
"""
|
|
262
|
+
tasks = [
|
|
263
|
+
self.generate_single(policy_text, logic_map, s, target_turns)
|
|
264
|
+
for s in scenarios
|
|
265
|
+
]
|
|
266
|
+
return await asyncio.gather(*tasks)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
__all__ = ["GoldenResponseGenerator"]
|