synkro 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkro/__init__.py +179 -0
- synkro/advanced.py +186 -0
- synkro/cli.py +128 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +402 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +276 -0
- synkro/formatters/__init__.py +12 -0
- synkro/formatters/qa.py +98 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +220 -0
- synkro/generation/golden_responses.py +244 -0
- synkro/generation/golden_scenarios.py +276 -0
- synkro/generation/golden_tool_responses.py +416 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +376 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +12 -0
- synkro/interactive/hitl_session.py +77 -0
- synkro/interactive/logic_map_editor.py +173 -0
- synkro/interactive/rich_ui.py +205 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +235 -0
- synkro/llm/rate_limits.py +95 -0
- synkro/models/__init__.py +43 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +15 -0
- synkro/modes/config.py +66 -0
- synkro/modes/qa.py +18 -0
- synkro/modes/sft.py +18 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +424 -0
- synkro/pipelines.py +123 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +474 -0
- synkro/prompts/interactive_templates.py +65 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/qa_templates.py +97 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +201 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +537 -0
- synkro/schemas.py +472 -0
- synkro/types/__init__.py +41 -0
- synkro/types/core.py +126 -0
- synkro/types/dataset_type.py +30 -0
- synkro/types/logic_map.py +345 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.12.data/data/examples/__init__.py +148 -0
- synkro-0.4.12.dist-info/METADATA +258 -0
- synkro-0.4.12.dist-info/RECORD +77 -0
- synkro-0.4.12.dist-info/WHEEL +4 -0
- synkro-0.4.12.dist-info/entry_points.txt +2 -0
- synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
"""Pipeline runner that orchestrates all phases.
|
|
2
|
+
|
|
3
|
+
Uses the Golden Trace 4-stage pipeline for all dataset types:
|
|
4
|
+
1. Logic Extraction (The Cartographer)
|
|
5
|
+
2. Scenario Synthesis (The Adversary)
|
|
6
|
+
3. Trace Synthesis (The Thinker)
|
|
7
|
+
4. Verification (The Auditor)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
from datetime import datetime
|
|
12
|
+
|
|
13
|
+
from synkro.core.policy import Policy
|
|
14
|
+
from synkro.core.dataset import Dataset
|
|
15
|
+
from synkro.core.checkpoint import CheckpointManager, hash_policy
|
|
16
|
+
from synkro.factory import ComponentFactory
|
|
17
|
+
from synkro.reporting import ProgressReporter
|
|
18
|
+
from synkro.pipeline.phases import (
|
|
19
|
+
PlanPhase,
|
|
20
|
+
LogicExtractionPhase,
|
|
21
|
+
GoldenScenarioPhase,
|
|
22
|
+
GoldenTracePhase,
|
|
23
|
+
GoldenToolCallPhase,
|
|
24
|
+
VerificationPhase,
|
|
25
|
+
)
|
|
26
|
+
from synkro.types.logic_map import LogicMap
|
|
27
|
+
|
|
28
|
+
# Type hints for HITL components (imported dynamically to avoid circular imports)
|
|
29
|
+
from typing import TYPE_CHECKING
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from synkro.interactive.logic_map_editor import LogicMapEditor
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GenerationResult:
|
|
35
|
+
"""
|
|
36
|
+
Result of the generation pipeline.
|
|
37
|
+
|
|
38
|
+
Provides access to both the dataset and internal artifacts like the Logic Map.
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
>>> result = await pipeline.run(policy, traces=50, ...)
|
|
42
|
+
>>> dataset = result.dataset
|
|
43
|
+
>>> logic_map = result.logic_map # Inspect extracted rules
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
dataset: "Dataset",
|
|
49
|
+
logic_map: LogicMap | None = None,
|
|
50
|
+
scenarios: list | None = None,
|
|
51
|
+
distribution: dict[str, int] | None = None,
|
|
52
|
+
):
|
|
53
|
+
self.dataset = dataset
|
|
54
|
+
self.logic_map = logic_map
|
|
55
|
+
self.scenarios = scenarios or []
|
|
56
|
+
self.distribution = distribution or {}
|
|
57
|
+
|
|
58
|
+
# Allow unpacking: dataset, logic_map = result
|
|
59
|
+
def __iter__(self):
|
|
60
|
+
return iter((self.dataset, self.logic_map))
|
|
61
|
+
|
|
62
|
+
# Allow direct Dataset access for backwards compatibility
|
|
63
|
+
def __getattr__(self, name):
|
|
64
|
+
# Delegate to dataset for backwards compatibility
|
|
65
|
+
return getattr(self.dataset, name)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class GenerationPipeline:
|
|
69
|
+
"""
|
|
70
|
+
Orchestrates the Golden Trace generation pipeline.
|
|
71
|
+
|
|
72
|
+
All dataset types (SFT, QA, TOOL_CALL) use the unified 4-stage pipeline:
|
|
73
|
+
- Stage 1: Logic Extraction - Extract rules as DAG
|
|
74
|
+
- Stage 2: Scenario Synthesis - Generate typed scenarios (positive, negative, edge_case, irrelevant)
|
|
75
|
+
- Stage 3: Trace Synthesis - Produce grounded reasoning with rule citations
|
|
76
|
+
- Stage 4: Verification - Cross-reference against Logic Map
|
|
77
|
+
|
|
78
|
+
Examples:
|
|
79
|
+
>>> pipeline = GenerationPipeline(factory, reporter, workers=10)
|
|
80
|
+
>>> dataset = await pipeline.run(policy, traces=50)
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
factory: ComponentFactory,
|
|
86
|
+
reporter: ProgressReporter,
|
|
87
|
+
workers: int,
|
|
88
|
+
max_iterations: int = 1,
|
|
89
|
+
skip_grading: bool = False,
|
|
90
|
+
checkpoint_manager: CheckpointManager | None = None,
|
|
91
|
+
enable_hitl: bool = False,
|
|
92
|
+
hitl_editor: "LogicMapEditor | None" = None,
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Initialize the pipeline.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
factory: ComponentFactory for creating pipeline components
|
|
99
|
+
reporter: ProgressReporter for reporting progress
|
|
100
|
+
workers: Number of concurrent workers (API calls)
|
|
101
|
+
max_iterations: Maximum refinement iterations
|
|
102
|
+
skip_grading: Whether to skip the verification phase
|
|
103
|
+
checkpoint_manager: Optional checkpoint manager for resumable generation
|
|
104
|
+
enable_hitl: Whether to enable Human-in-the-Loop Logic Map editing
|
|
105
|
+
hitl_editor: Optional LogicMapEditor for HITL sessions
|
|
106
|
+
"""
|
|
107
|
+
self.factory = factory
|
|
108
|
+
self.reporter = reporter
|
|
109
|
+
self.workers = workers
|
|
110
|
+
self.max_iterations = max_iterations
|
|
111
|
+
self.skip_grading = skip_grading
|
|
112
|
+
self.checkpoint_manager = checkpoint_manager
|
|
113
|
+
self.enable_hitl = enable_hitl
|
|
114
|
+
self.hitl_editor = hitl_editor
|
|
115
|
+
|
|
116
|
+
# Golden Trace phases
|
|
117
|
+
self.plan_phase = PlanPhase()
|
|
118
|
+
self.logic_extraction_phase = LogicExtractionPhase()
|
|
119
|
+
self.golden_scenario_phase = GoldenScenarioPhase()
|
|
120
|
+
self.golden_trace_phase = GoldenTracePhase()
|
|
121
|
+
self.golden_tool_call_phase = GoldenToolCallPhase()
|
|
122
|
+
self.verification_phase = VerificationPhase()
|
|
123
|
+
|
|
124
|
+
async def run(
|
|
125
|
+
self,
|
|
126
|
+
policy: Policy,
|
|
127
|
+
traces: int,
|
|
128
|
+
model: str,
|
|
129
|
+
dataset_type: str,
|
|
130
|
+
turns: int | str = "auto",
|
|
131
|
+
return_result: bool = False,
|
|
132
|
+
) -> Dataset | GenerationResult:
|
|
133
|
+
"""
|
|
134
|
+
Run the Golden Trace generation pipeline.
|
|
135
|
+
|
|
136
|
+
All dataset types use the same 4-stage pipeline, with Stage 3
|
|
137
|
+
branching based on whether TOOL_CALL is needed.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
policy: The policy to generate from
|
|
141
|
+
traces: Target number of traces
|
|
142
|
+
model: Model name (for reporting)
|
|
143
|
+
dataset_type: Dataset type (sft, qa, tool_call)
|
|
144
|
+
turns: Conversation turns per trace. Use int for fixed turns, or "auto"
|
|
145
|
+
for policy complexity-driven turns
|
|
146
|
+
return_result: If True, return GenerationResult with logic_map access
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Dataset (default) or GenerationResult if return_result=True
|
|
150
|
+
"""
|
|
151
|
+
start_time = datetime.now()
|
|
152
|
+
semaphore = asyncio.Semaphore(self.workers)
|
|
153
|
+
|
|
154
|
+
# Check if this is a tool_call dataset
|
|
155
|
+
is_tool_call = dataset_type == "tool_call"
|
|
156
|
+
|
|
157
|
+
# Checkpointing setup
|
|
158
|
+
cm = self.checkpoint_manager
|
|
159
|
+
policy_hash = hash_policy(policy.text) if cm else ""
|
|
160
|
+
resuming = False
|
|
161
|
+
|
|
162
|
+
# Check for existing checkpoint
|
|
163
|
+
if cm and cm.has_checkpoint():
|
|
164
|
+
if cm.matches_config(policy_hash, traces, dataset_type):
|
|
165
|
+
resuming = True
|
|
166
|
+
from rich.console import Console
|
|
167
|
+
Console().print(f"[cyan]🔄 Resuming from checkpoint (stage: {cm.stage})[/cyan]")
|
|
168
|
+
else:
|
|
169
|
+
cm.clear() # Config mismatch, start fresh
|
|
170
|
+
|
|
171
|
+
# Report start
|
|
172
|
+
self.reporter.on_start(traces, model, dataset_type)
|
|
173
|
+
|
|
174
|
+
# Create components via factory
|
|
175
|
+
planner = self.factory.create_planner()
|
|
176
|
+
logic_extractor = self.factory.create_logic_extractor()
|
|
177
|
+
golden_scenario_gen = self.factory.create_golden_scenario_generator()
|
|
178
|
+
verifier = self.factory.create_verifier()
|
|
179
|
+
golden_refiner = self.factory.create_golden_refiner()
|
|
180
|
+
|
|
181
|
+
# Create appropriate trace generator based on dataset type
|
|
182
|
+
if is_tool_call and self.factory.has_tools:
|
|
183
|
+
golden_tool_call_gen = self.factory.create_golden_tool_call_generator()
|
|
184
|
+
else:
|
|
185
|
+
golden_response_gen = self.factory.create_golden_response_generator()
|
|
186
|
+
|
|
187
|
+
# Phase 0: Planning (for category distribution)
|
|
188
|
+
analyze_turns = turns == "auto"
|
|
189
|
+
plan = await self.plan_phase.execute(policy, traces, planner, analyze_turns=analyze_turns)
|
|
190
|
+
self.reporter.on_plan_complete(plan)
|
|
191
|
+
|
|
192
|
+
# Determine target turns
|
|
193
|
+
if isinstance(turns, int):
|
|
194
|
+
target_turns = turns
|
|
195
|
+
else:
|
|
196
|
+
target_turns = plan.recommended_turns
|
|
197
|
+
|
|
198
|
+
# =====================================================================
|
|
199
|
+
# STAGE 1: Logic Extraction (The Cartographer)
|
|
200
|
+
# =====================================================================
|
|
201
|
+
if resuming and cm and cm.stage in ("logic_map", "scenarios", "traces", "complete"):
|
|
202
|
+
logic_map = cm.get_logic_map()
|
|
203
|
+
from rich.console import Console
|
|
204
|
+
Console().print("[dim]📂 Loaded Logic Map from checkpoint[/dim]")
|
|
205
|
+
else:
|
|
206
|
+
logic_map = await self.logic_extraction_phase.execute(policy, logic_extractor)
|
|
207
|
+
if cm:
|
|
208
|
+
cm.save_logic_map(logic_map, policy_hash, traces, dataset_type)
|
|
209
|
+
|
|
210
|
+
self.reporter.on_logic_map_complete(logic_map)
|
|
211
|
+
|
|
212
|
+
# =====================================================================
|
|
213
|
+
# HUMAN-IN-THE-LOOP: Logic Map Editing (Optional)
|
|
214
|
+
# =====================================================================
|
|
215
|
+
if self.enable_hitl and self.hitl_editor:
|
|
216
|
+
logic_map = await self._run_hitl_session(logic_map, policy)
|
|
217
|
+
|
|
218
|
+
# =====================================================================
|
|
219
|
+
# STAGE 2: Scenario Synthesis (The Adversary)
|
|
220
|
+
# =====================================================================
|
|
221
|
+
if resuming and cm and cm.stage in ("scenarios", "traces", "complete"):
|
|
222
|
+
golden_scenarios = cm.get_scenarios()
|
|
223
|
+
distribution = cm.load().scenario_distribution
|
|
224
|
+
from rich.console import Console
|
|
225
|
+
Console().print(f"[dim]📂 Loaded {len(golden_scenarios)} scenarios from checkpoint[/dim]")
|
|
226
|
+
else:
|
|
227
|
+
golden_scenarios, distribution = await self.golden_scenario_phase.execute(
|
|
228
|
+
policy, logic_map, plan, golden_scenario_gen, semaphore
|
|
229
|
+
)
|
|
230
|
+
if cm:
|
|
231
|
+
cm.save_scenarios(golden_scenarios, distribution)
|
|
232
|
+
|
|
233
|
+
self.reporter.on_golden_scenarios_complete(golden_scenarios, distribution)
|
|
234
|
+
|
|
235
|
+
# =====================================================================
|
|
236
|
+
# STAGE 3: Trace Synthesis (The Thinker)
|
|
237
|
+
# =====================================================================
|
|
238
|
+
if resuming and cm and cm.stage in ("traces", "complete"):
|
|
239
|
+
# Resume from checkpoint - get already completed traces
|
|
240
|
+
existing_traces = cm.get_traces()
|
|
241
|
+
pending_indices = cm.get_pending_scenario_indices(len(golden_scenarios))
|
|
242
|
+
|
|
243
|
+
if pending_indices:
|
|
244
|
+
from rich.console import Console
|
|
245
|
+
Console().print(f"[dim]📂 Resuming: {len(existing_traces)} done, {len(pending_indices)} pending[/dim]")
|
|
246
|
+
|
|
247
|
+
# Generate only pending scenarios
|
|
248
|
+
pending_scenarios = [golden_scenarios[i] for i in pending_indices]
|
|
249
|
+
|
|
250
|
+
if is_tool_call and self.factory.has_tools:
|
|
251
|
+
new_traces = await self.golden_tool_call_phase.execute(
|
|
252
|
+
policy, logic_map, pending_scenarios, golden_tool_call_gen, semaphore, target_turns
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
new_traces = await self.golden_trace_phase.execute(
|
|
256
|
+
policy, logic_map, pending_scenarios, golden_response_gen, semaphore, target_turns
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Save new traces to checkpoint
|
|
260
|
+
if cm:
|
|
261
|
+
cm.save_traces_batch(list(new_traces), pending_indices)
|
|
262
|
+
|
|
263
|
+
all_traces = existing_traces + list(new_traces)
|
|
264
|
+
else:
|
|
265
|
+
all_traces = existing_traces
|
|
266
|
+
else:
|
|
267
|
+
if is_tool_call and self.factory.has_tools:
|
|
268
|
+
all_traces = await self.golden_tool_call_phase.execute(
|
|
269
|
+
policy, logic_map, golden_scenarios, golden_tool_call_gen, semaphore, target_turns
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
all_traces = await self.golden_trace_phase.execute(
|
|
273
|
+
policy, logic_map, golden_scenarios, golden_response_gen, semaphore, target_turns
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Save all traces to checkpoint
|
|
277
|
+
if cm:
|
|
278
|
+
cm.save_traces_batch(list(all_traces), list(range(len(all_traces))))
|
|
279
|
+
|
|
280
|
+
self.reporter.on_responses_complete(list(all_traces))
|
|
281
|
+
|
|
282
|
+
# =====================================================================
|
|
283
|
+
# STAGE 4: Verification (The Auditor)
|
|
284
|
+
# =====================================================================
|
|
285
|
+
pass_rate: float | None = None
|
|
286
|
+
|
|
287
|
+
if resuming and cm and cm.stage == "complete":
|
|
288
|
+
final_traces = cm.get_verified_traces()
|
|
289
|
+
passed_count = sum(1 for t in final_traces if t.grade and t.grade.passed)
|
|
290
|
+
pass_rate = (passed_count / len(final_traces) * 100) if final_traces else 0
|
|
291
|
+
from rich.console import Console
|
|
292
|
+
Console().print(f"[dim]📂 Loaded {len(final_traces)} verified traces from checkpoint[/dim]")
|
|
293
|
+
elif self.skip_grading:
|
|
294
|
+
final_traces = list(all_traces)
|
|
295
|
+
self.reporter.on_grading_skipped()
|
|
296
|
+
else:
|
|
297
|
+
final_traces, pass_rate = await self.verification_phase.execute(
|
|
298
|
+
policy,
|
|
299
|
+
logic_map,
|
|
300
|
+
golden_scenarios,
|
|
301
|
+
list(all_traces),
|
|
302
|
+
verifier,
|
|
303
|
+
golden_refiner,
|
|
304
|
+
self.max_iterations,
|
|
305
|
+
semaphore,
|
|
306
|
+
)
|
|
307
|
+
if cm:
|
|
308
|
+
cm.save_verified_traces(final_traces)
|
|
309
|
+
|
|
310
|
+
self.reporter.on_grading_complete(final_traces, pass_rate)
|
|
311
|
+
|
|
312
|
+
# Report completion
|
|
313
|
+
elapsed = (datetime.now() - start_time).total_seconds()
|
|
314
|
+
self.reporter.on_complete(len(final_traces), elapsed, pass_rate)
|
|
315
|
+
|
|
316
|
+
dataset = Dataset(traces=final_traces)
|
|
317
|
+
|
|
318
|
+
if return_result:
|
|
319
|
+
return GenerationResult(
|
|
320
|
+
dataset=dataset,
|
|
321
|
+
logic_map=logic_map,
|
|
322
|
+
scenarios=golden_scenarios,
|
|
323
|
+
distribution=distribution,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
return dataset
|
|
327
|
+
|
|
328
|
+
async def _run_hitl_session(self, logic_map: LogicMap, policy: Policy) -> LogicMap:
|
|
329
|
+
"""
|
|
330
|
+
Run an interactive Human-in-the-Loop session for Logic Map editing.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
logic_map: The extracted Logic Map to edit
|
|
334
|
+
policy: The policy document (for context in refinements)
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
The (potentially modified) Logic Map
|
|
338
|
+
"""
|
|
339
|
+
from synkro.interactive.hitl_session import HITLSession
|
|
340
|
+
from synkro.interactive.rich_ui import LogicMapDisplay, InteractivePrompt
|
|
341
|
+
|
|
342
|
+
session = HITLSession(original_logic_map=logic_map)
|
|
343
|
+
display = LogicMapDisplay()
|
|
344
|
+
prompt = InteractivePrompt()
|
|
345
|
+
|
|
346
|
+
# Show instructions and initial Logic Map
|
|
347
|
+
prompt.show_instructions()
|
|
348
|
+
display.display_full(session.current_logic_map)
|
|
349
|
+
|
|
350
|
+
while True:
|
|
351
|
+
feedback = prompt.get_feedback().strip()
|
|
352
|
+
|
|
353
|
+
# Handle commands
|
|
354
|
+
if feedback.lower() == "done":
|
|
355
|
+
break
|
|
356
|
+
|
|
357
|
+
if feedback.lower() == "undo":
|
|
358
|
+
if session.can_undo:
|
|
359
|
+
session.undo()
|
|
360
|
+
display.show_success("Reverted to previous state")
|
|
361
|
+
display.display_full(session.current_logic_map)
|
|
362
|
+
else:
|
|
363
|
+
display.show_error("Nothing to undo")
|
|
364
|
+
continue
|
|
365
|
+
|
|
366
|
+
if feedback.lower() == "reset":
|
|
367
|
+
session.reset()
|
|
368
|
+
display.show_success("Reset to original Logic Map")
|
|
369
|
+
display.display_full(session.current_logic_map)
|
|
370
|
+
continue
|
|
371
|
+
|
|
372
|
+
if feedback.lower() == "help":
|
|
373
|
+
prompt.show_instructions()
|
|
374
|
+
continue
|
|
375
|
+
|
|
376
|
+
if feedback.lower().startswith("show "):
|
|
377
|
+
rule_id = feedback[5:].strip().upper()
|
|
378
|
+
display.display_rule(rule_id, session.current_logic_map)
|
|
379
|
+
continue
|
|
380
|
+
|
|
381
|
+
# Empty input
|
|
382
|
+
if not feedback:
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
# Apply LLM-based refinement
|
|
386
|
+
try:
|
|
387
|
+
new_map, changes_summary = await self.hitl_editor.refine(
|
|
388
|
+
session.current_logic_map,
|
|
389
|
+
feedback,
|
|
390
|
+
policy.text,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Validate the refinement
|
|
394
|
+
is_valid, issues = self.hitl_editor.validate_refinement(
|
|
395
|
+
session.current_logic_map,
|
|
396
|
+
new_map,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
if is_valid:
|
|
400
|
+
display.display_diff(session.current_logic_map, new_map)
|
|
401
|
+
session.apply_change(feedback, new_map)
|
|
402
|
+
display.show_success(changes_summary)
|
|
403
|
+
else:
|
|
404
|
+
display.show_error(f"Invalid refinement: {', '.join(issues)}")
|
|
405
|
+
|
|
406
|
+
except Exception as e:
|
|
407
|
+
display.show_error(f"Failed to apply refinement: {e}")
|
|
408
|
+
|
|
409
|
+
# Final summary
|
|
410
|
+
if session.change_count > 0:
|
|
411
|
+
display.console.print(
|
|
412
|
+
f"\n[green]✅ HITL Complete[/green] - "
|
|
413
|
+
f"Made {session.change_count} change(s), proceeding with {len(session.current_logic_map.rules)} rules"
|
|
414
|
+
)
|
|
415
|
+
else:
|
|
416
|
+
display.console.print(
|
|
417
|
+
f"\n[green]✅ HITL Complete[/green] - "
|
|
418
|
+
f"No changes made, proceeding with {len(session.current_logic_map.rules)} rules"
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
return session.current_logic_map
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
__all__ = ["GenerationPipeline", "GenerationResult"]
|
synkro/pipelines.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Pipeline creation utilities.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
from synkro.pipelines import create_pipeline
|
|
5
|
+
from synkro.models.openai import OpenAI
|
|
6
|
+
from synkro.types import DatasetType
|
|
7
|
+
|
|
8
|
+
pipeline = create_pipeline(
|
|
9
|
+
model=OpenAI.GPT_5_MINI,
|
|
10
|
+
dataset_type=DatasetType.SFT,
|
|
11
|
+
)
|
|
12
|
+
dataset = pipeline.generate("policy text", traces=50)
|
|
13
|
+
|
|
14
|
+
# Tool calling pipeline
|
|
15
|
+
from synkro import ToolDefinition
|
|
16
|
+
|
|
17
|
+
web_search = ToolDefinition(
|
|
18
|
+
name="web_search",
|
|
19
|
+
description="Search the web",
|
|
20
|
+
parameters={"type": "object", "properties": {"query": {"type": "string"}}}
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
pipeline = create_pipeline(
|
|
24
|
+
dataset_type=DatasetType.TOOL_CALL,
|
|
25
|
+
tools=[web_search],
|
|
26
|
+
)
|
|
27
|
+
dataset = pipeline.generate("Search guidelines", traces=50)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from typing import TYPE_CHECKING
|
|
31
|
+
|
|
32
|
+
from synkro.generation.generator import Generator
|
|
33
|
+
from synkro.types import DatasetType
|
|
34
|
+
from synkro.models import Model, OpenAI
|
|
35
|
+
from synkro.reporting import ProgressReporter
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
from synkro.types.tool import ToolDefinition
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def create_pipeline(
|
|
42
|
+
model: Model = OpenAI.GPT_5_MINI,
|
|
43
|
+
dataset_type: DatasetType = DatasetType.SFT,
|
|
44
|
+
grading_model: Model = OpenAI.GPT_52,
|
|
45
|
+
max_iterations: int = 3,
|
|
46
|
+
skip_grading: bool = False,
|
|
47
|
+
reporter: ProgressReporter | None = None,
|
|
48
|
+
tools: list["ToolDefinition"] | None = None,
|
|
49
|
+
turns: int | str = "auto",
|
|
50
|
+
checkpoint_dir: str | None = None,
|
|
51
|
+
enable_hitl: bool = True,
|
|
52
|
+
) -> Generator:
|
|
53
|
+
"""
|
|
54
|
+
Create a pipeline for generating training datasets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
model: Model enum for generation (default: OpenAI.GPT_5_MINI)
|
|
58
|
+
dataset_type: Type of dataset - QA, SFT, or TOOL_CALL (default: SFT)
|
|
59
|
+
grading_model: Model enum for grading (default: OpenAI.GPT_52)
|
|
60
|
+
max_iterations: Max refinement iterations per trace (default: 3)
|
|
61
|
+
skip_grading: Skip grading phase for faster generation (default: False)
|
|
62
|
+
reporter: Progress reporter (default: RichReporter for console output)
|
|
63
|
+
tools: List of ToolDefinition for TOOL_CALL dataset type
|
|
64
|
+
turns: Conversation turns per trace. Use int for fixed turns, or "auto"
|
|
65
|
+
for policy complexity-driven turns (Simple=1-2, Conditional=3, Complex=5+)
|
|
66
|
+
checkpoint_dir: Directory for checkpoints. Enables resumable generation.
|
|
67
|
+
enable_hitl: Enable Human-in-the-Loop Logic Map editing (default: False)
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Generator instance ready to use
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
>>> from synkro.pipelines import create_pipeline
|
|
74
|
+
>>> from synkro.models.openai import OpenAI
|
|
75
|
+
>>> from synkro.types import DatasetType
|
|
76
|
+
>>>
|
|
77
|
+
>>> pipeline = create_pipeline(
|
|
78
|
+
... model=OpenAI.GPT_5_MINI,
|
|
79
|
+
... dataset_type=DatasetType.SFT,
|
|
80
|
+
... )
|
|
81
|
+
>>> dataset = pipeline.generate("policy text", traces=50)
|
|
82
|
+
>>> dataset.save("training.jsonl")
|
|
83
|
+
|
|
84
|
+
>>> # Multi-turn with fixed 3 turns
|
|
85
|
+
>>> pipeline = create_pipeline(turns=3)
|
|
86
|
+
>>> dataset = pipeline.generate("policy text", traces=50)
|
|
87
|
+
|
|
88
|
+
>>> # Silent mode for embedding
|
|
89
|
+
>>> from synkro.reporting import SilentReporter
|
|
90
|
+
>>> pipeline = create_pipeline(reporter=SilentReporter())
|
|
91
|
+
|
|
92
|
+
>>> # Interactive Logic Map editing
|
|
93
|
+
>>> pipeline = create_pipeline(enable_hitl=True)
|
|
94
|
+
>>> dataset = pipeline.generate("policy text", traces=50)
|
|
95
|
+
|
|
96
|
+
>>> # Tool calling dataset
|
|
97
|
+
>>> from synkro import ToolDefinition
|
|
98
|
+
>>> search_tool = ToolDefinition(
|
|
99
|
+
... name="web_search",
|
|
100
|
+
... description="Search the web for information",
|
|
101
|
+
... parameters={"type": "object", "properties": {"query": {"type": "string"}}}
|
|
102
|
+
... )
|
|
103
|
+
>>> pipeline = create_pipeline(
|
|
104
|
+
... dataset_type=DatasetType.TOOL_CALL,
|
|
105
|
+
... tools=[search_tool],
|
|
106
|
+
... )
|
|
107
|
+
>>> dataset = pipeline.generate("Search guidelines", traces=50)
|
|
108
|
+
"""
|
|
109
|
+
return Generator(
|
|
110
|
+
dataset_type=dataset_type,
|
|
111
|
+
generation_model=model,
|
|
112
|
+
grading_model=grading_model,
|
|
113
|
+
max_iterations=max_iterations,
|
|
114
|
+
skip_grading=skip_grading,
|
|
115
|
+
reporter=reporter,
|
|
116
|
+
tools=tools,
|
|
117
|
+
turns=turns,
|
|
118
|
+
checkpoint_dir=checkpoint_dir,
|
|
119
|
+
enable_hitl=enable_hitl,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
__all__ = ["create_pipeline"]
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Prompt templates and customizable prompt classes for Synkro."""
|
|
2
|
+
|
|
3
|
+
from synkro.prompts.base import (
|
|
4
|
+
SystemPrompt,
|
|
5
|
+
ScenarioPrompt,
|
|
6
|
+
ResponsePrompt,
|
|
7
|
+
GradePrompt,
|
|
8
|
+
RefinePrompt,
|
|
9
|
+
PlanPrompt,
|
|
10
|
+
)
|
|
11
|
+
from synkro.prompts.templates import (
|
|
12
|
+
SYSTEM_PROMPT,
|
|
13
|
+
SCENARIO_GENERATOR_PROMPT,
|
|
14
|
+
CATEGORY_SCENARIO_PROMPT,
|
|
15
|
+
POLICY_PLANNING_PROMPT,
|
|
16
|
+
POLICY_COMPLEXITY_PROMPT,
|
|
17
|
+
BATCHED_RESPONSE_PROMPT,
|
|
18
|
+
BATCHED_GRADER_PROMPT,
|
|
19
|
+
BATCHED_REFINER_PROMPT,
|
|
20
|
+
SINGLE_RESPONSE_PROMPT,
|
|
21
|
+
SINGLE_GRADE_PROMPT,
|
|
22
|
+
)
|
|
23
|
+
from synkro.prompts.multiturn_templates import (
|
|
24
|
+
FOLLOW_UP_GENERATION_PROMPT,
|
|
25
|
+
MULTI_TURN_RESPONSE_PROMPT,
|
|
26
|
+
MULTI_TURN_INITIAL_PROMPT,
|
|
27
|
+
MULTI_TURN_GRADE_PROMPT,
|
|
28
|
+
MULTI_TURN_REFINE_PROMPT,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
__all__ = [
|
|
32
|
+
# Prompt classes
|
|
33
|
+
"SystemPrompt",
|
|
34
|
+
"ScenarioPrompt",
|
|
35
|
+
"ResponsePrompt",
|
|
36
|
+
"GradePrompt",
|
|
37
|
+
"RefinePrompt",
|
|
38
|
+
"PlanPrompt",
|
|
39
|
+
# Raw templates
|
|
40
|
+
"SYSTEM_PROMPT",
|
|
41
|
+
"SCENARIO_GENERATOR_PROMPT",
|
|
42
|
+
"CATEGORY_SCENARIO_PROMPT",
|
|
43
|
+
"POLICY_PLANNING_PROMPT",
|
|
44
|
+
"POLICY_COMPLEXITY_PROMPT",
|
|
45
|
+
"BATCHED_RESPONSE_PROMPT",
|
|
46
|
+
"BATCHED_GRADER_PROMPT",
|
|
47
|
+
"BATCHED_REFINER_PROMPT",
|
|
48
|
+
"SINGLE_RESPONSE_PROMPT",
|
|
49
|
+
"SINGLE_GRADE_PROMPT",
|
|
50
|
+
# Multi-turn templates
|
|
51
|
+
"FOLLOW_UP_GENERATION_PROMPT",
|
|
52
|
+
"MULTI_TURN_RESPONSE_PROMPT",
|
|
53
|
+
"MULTI_TURN_INITIAL_PROMPT",
|
|
54
|
+
"MULTI_TURN_GRADE_PROMPT",
|
|
55
|
+
"MULTI_TURN_REFINE_PROMPT",
|
|
56
|
+
]
|
|
57
|
+
|