synkro 0.4.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synkro might be problematic. Click here for more details.

Files changed (81) hide show
  1. synkro/__init__.py +331 -0
  2. synkro/advanced.py +184 -0
  3. synkro/cli.py +156 -0
  4. synkro/core/__init__.py +7 -0
  5. synkro/core/checkpoint.py +250 -0
  6. synkro/core/dataset.py +432 -0
  7. synkro/core/policy.py +337 -0
  8. synkro/errors.py +178 -0
  9. synkro/examples/__init__.py +148 -0
  10. synkro/factory.py +291 -0
  11. synkro/formatters/__init__.py +18 -0
  12. synkro/formatters/chatml.py +121 -0
  13. synkro/formatters/langfuse.py +98 -0
  14. synkro/formatters/langsmith.py +98 -0
  15. synkro/formatters/qa.py +112 -0
  16. synkro/formatters/sft.py +90 -0
  17. synkro/formatters/tool_call.py +127 -0
  18. synkro/generation/__init__.py +9 -0
  19. synkro/generation/follow_ups.py +134 -0
  20. synkro/generation/generator.py +314 -0
  21. synkro/generation/golden_responses.py +269 -0
  22. synkro/generation/golden_scenarios.py +333 -0
  23. synkro/generation/golden_tool_responses.py +791 -0
  24. synkro/generation/logic_extractor.py +126 -0
  25. synkro/generation/multiturn_responses.py +177 -0
  26. synkro/generation/planner.py +131 -0
  27. synkro/generation/responses.py +189 -0
  28. synkro/generation/scenarios.py +90 -0
  29. synkro/generation/tool_responses.py +625 -0
  30. synkro/generation/tool_simulator.py +114 -0
  31. synkro/interactive/__init__.py +16 -0
  32. synkro/interactive/hitl_session.py +205 -0
  33. synkro/interactive/intent_classifier.py +94 -0
  34. synkro/interactive/logic_map_editor.py +176 -0
  35. synkro/interactive/rich_ui.py +459 -0
  36. synkro/interactive/scenario_editor.py +198 -0
  37. synkro/llm/__init__.py +7 -0
  38. synkro/llm/client.py +309 -0
  39. synkro/llm/rate_limits.py +99 -0
  40. synkro/models/__init__.py +50 -0
  41. synkro/models/anthropic.py +26 -0
  42. synkro/models/google.py +19 -0
  43. synkro/models/local.py +104 -0
  44. synkro/models/openai.py +31 -0
  45. synkro/modes/__init__.py +13 -0
  46. synkro/modes/config.py +66 -0
  47. synkro/modes/conversation.py +35 -0
  48. synkro/modes/tool_call.py +18 -0
  49. synkro/parsers.py +442 -0
  50. synkro/pipeline/__init__.py +20 -0
  51. synkro/pipeline/phases.py +592 -0
  52. synkro/pipeline/runner.py +769 -0
  53. synkro/pipelines.py +136 -0
  54. synkro/prompts/__init__.py +57 -0
  55. synkro/prompts/base.py +167 -0
  56. synkro/prompts/golden_templates.py +533 -0
  57. synkro/prompts/interactive_templates.py +198 -0
  58. synkro/prompts/multiturn_templates.py +156 -0
  59. synkro/prompts/templates.py +281 -0
  60. synkro/prompts/tool_templates.py +318 -0
  61. synkro/quality/__init__.py +14 -0
  62. synkro/quality/golden_refiner.py +163 -0
  63. synkro/quality/grader.py +153 -0
  64. synkro/quality/multiturn_grader.py +150 -0
  65. synkro/quality/refiner.py +137 -0
  66. synkro/quality/tool_grader.py +126 -0
  67. synkro/quality/tool_refiner.py +128 -0
  68. synkro/quality/verifier.py +228 -0
  69. synkro/reporting.py +464 -0
  70. synkro/schemas.py +521 -0
  71. synkro/types/__init__.py +43 -0
  72. synkro/types/core.py +153 -0
  73. synkro/types/dataset_type.py +33 -0
  74. synkro/types/logic_map.py +348 -0
  75. synkro/types/tool.py +94 -0
  76. synkro-0.4.36.data/data/examples/__init__.py +148 -0
  77. synkro-0.4.36.dist-info/METADATA +507 -0
  78. synkro-0.4.36.dist-info/RECORD +81 -0
  79. synkro-0.4.36.dist-info/WHEEL +4 -0
  80. synkro-0.4.36.dist-info/entry_points.txt +2 -0
  81. synkro-0.4.36.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,769 @@
1
+ """Pipeline runner that orchestrates all phases.
2
+
3
+ Uses the Golden Trace 4-stage pipeline for all dataset types:
4
+ 1. Logic Extraction (The Cartographer)
5
+ 2. Scenario Synthesis (The Adversary)
6
+ 3. Trace Synthesis (The Thinker)
7
+ 4. Verification (The Auditor)
8
+ """
9
+
10
+ import asyncio
11
+ from datetime import datetime
12
+
13
+ from synkro.core.policy import Policy
14
+ from synkro.core.dataset import Dataset
15
+ from synkro.core.checkpoint import CheckpointManager, hash_policy
16
+ from synkro.factory import ComponentFactory
17
+ from synkro.reporting import ProgressReporter
18
+ from synkro.pipeline.phases import (
19
+ PlanPhase,
20
+ LogicExtractionPhase,
21
+ GoldenScenarioPhase,
22
+ GoldenTracePhase,
23
+ GoldenToolCallPhase,
24
+ VerificationPhase,
25
+ )
26
+ from synkro.types.logic_map import LogicMap
27
+
28
+ # Type hints for HITL components (imported dynamically to avoid circular imports)
29
+ from typing import TYPE_CHECKING
30
+ if TYPE_CHECKING:
31
+ from synkro.interactive.logic_map_editor import LogicMapEditor
32
+ from synkro.interactive.scenario_editor import ScenarioEditor
33
+ from synkro.types.core import Plan
34
+ from synkro.types.logic_map import GoldenScenario
35
+
36
+
37
+ class GenerationResult:
38
+ """
39
+ Result of the generation pipeline.
40
+
41
+ Provides access to both the dataset and internal artifacts like the Logic Map.
42
+
43
+ Examples:
44
+ >>> result = await pipeline.run(policy, traces=50, ...)
45
+ >>> dataset = result.dataset
46
+ >>> logic_map = result.logic_map # Inspect extracted rules
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ dataset: "Dataset",
52
+ logic_map: LogicMap | None = None,
53
+ scenarios: list | None = None,
54
+ distribution: dict[str, int] | None = None,
55
+ ):
56
+ self.dataset = dataset
57
+ self.logic_map = logic_map
58
+ self.scenarios = scenarios or []
59
+ self.distribution = distribution or {}
60
+
61
+ # Allow unpacking: dataset, logic_map = result
62
+ def __iter__(self):
63
+ return iter((self.dataset, self.logic_map))
64
+
65
+ # Allow direct Dataset access for backwards compatibility
66
+ def __getattr__(self, name):
67
+ # Delegate to dataset for backwards compatibility
68
+ return getattr(self.dataset, name)
69
+
70
+
71
+ class ScenariosResult:
72
+ """
73
+ Result of scenario-only generation for eval datasets.
74
+
75
+ Contains scenarios with ground truth labels but no synthetic responses.
76
+ Use with synkro.grade() to evaluate your own model's outputs.
77
+
78
+ Examples:
79
+ >>> result = synkro.generate_scenarios(policy, count=100)
80
+ >>> for scenario in result.scenarios:
81
+ ... response = my_model(scenario.user_message)
82
+ ... grade = synkro.grade(response, scenario, policy)
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ scenarios: list,
88
+ logic_map: LogicMap,
89
+ distribution: dict[str, int],
90
+ ):
91
+ from synkro.types.core import EvalScenario
92
+
93
+ # Convert GoldenScenarios to EvalScenarios for public API
94
+ self.scenarios: list[EvalScenario] = [
95
+ EvalScenario(
96
+ user_message=s.description,
97
+ expected_outcome=s.expected_outcome,
98
+ target_rule_ids=s.target_rule_ids,
99
+ scenario_type=s.scenario_type.value if hasattr(s.scenario_type, 'value') else s.scenario_type,
100
+ category=s.category,
101
+ context=s.context,
102
+ )
103
+ for s in scenarios
104
+ ]
105
+ self.logic_map = logic_map
106
+ self.distribution = distribution
107
+
108
+ def __len__(self) -> int:
109
+ return len(self.scenarios)
110
+
111
+ def __iter__(self):
112
+ return iter(self.scenarios)
113
+
114
+
115
+ class GenerationPipeline:
116
+ """
117
+ Orchestrates the Golden Trace generation pipeline.
118
+
119
+ All dataset types (CONVERSATION, INSTRUCTION, TOOL_CALL) use the unified 4-stage pipeline:
120
+ - Stage 1: Logic Extraction - Extract rules as DAG
121
+ - Stage 2: Scenario Synthesis - Generate typed scenarios (positive, negative, edge_case, irrelevant)
122
+ - Stage 3: Trace Synthesis - Produce grounded reasoning with rule citations
123
+ - Stage 4: Verification - Cross-reference against Logic Map
124
+
125
+ Examples:
126
+ >>> pipeline = GenerationPipeline(factory, reporter, workers=10)
127
+ >>> dataset = await pipeline.run(policy, traces=50)
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ factory: ComponentFactory,
133
+ reporter: ProgressReporter,
134
+ workers: int,
135
+ max_iterations: int = 1,
136
+ skip_grading: bool = False,
137
+ checkpoint_manager: CheckpointManager | None = None,
138
+ enable_hitl: bool = False,
139
+ hitl_editor: "LogicMapEditor | None" = None,
140
+ scenario_editor: "ScenarioEditor | None" = None,
141
+ ):
142
+ """
143
+ Initialize the pipeline.
144
+
145
+ Args:
146
+ factory: ComponentFactory for creating pipeline components
147
+ reporter: ProgressReporter for reporting progress
148
+ workers: Number of concurrent workers (API calls)
149
+ max_iterations: Maximum refinement iterations
150
+ skip_grading: Whether to skip the verification phase
151
+ checkpoint_manager: Optional checkpoint manager for resumable generation
152
+ enable_hitl: Whether to enable Human-in-the-Loop editing (rules + scenarios)
153
+ hitl_editor: Optional LogicMapEditor for HITL sessions
154
+ scenario_editor: Optional ScenarioEditor for scenario editing
155
+ """
156
+ self.factory = factory
157
+ self.reporter = reporter
158
+ self.workers = workers
159
+ self.max_iterations = max_iterations
160
+ self.skip_grading = skip_grading
161
+ self.checkpoint_manager = checkpoint_manager
162
+ self.enable_hitl = enable_hitl
163
+ self.hitl_editor = hitl_editor
164
+ self.scenario_editor = scenario_editor
165
+
166
+ # Golden Trace phases
167
+ self.plan_phase = PlanPhase()
168
+ self.logic_extraction_phase = LogicExtractionPhase()
169
+ self.golden_scenario_phase = GoldenScenarioPhase()
170
+ self.golden_trace_phase = GoldenTracePhase()
171
+ self.golden_tool_call_phase = GoldenToolCallPhase()
172
+ self.verification_phase = VerificationPhase()
173
+
174
+ async def run(
175
+ self,
176
+ policy: Policy,
177
+ traces: int,
178
+ model: str,
179
+ dataset_type: str,
180
+ turns: int | str = "auto",
181
+ return_result: bool = False,
182
+ ) -> Dataset | GenerationResult:
183
+ """
184
+ Run the Golden Trace generation pipeline.
185
+
186
+ All dataset types use the same 4-stage pipeline, with Stage 3
187
+ branching based on whether TOOL_CALL is needed.
188
+
189
+ Args:
190
+ policy: The policy to generate from
191
+ traces: Target number of traces
192
+ model: Model name (for reporting)
193
+ dataset_type: Dataset type (sft, qa, tool_call)
194
+ turns: Conversation turns per trace. Use int for fixed turns, or "auto"
195
+ for policy complexity-driven turns
196
+ return_result: If True, return GenerationResult with logic_map access
197
+
198
+ Returns:
199
+ Dataset (default) or GenerationResult if return_result=True
200
+ """
201
+ start_time = datetime.now()
202
+ semaphore = asyncio.Semaphore(self.workers)
203
+
204
+ # Check if this is a tool_call dataset
205
+ is_tool_call = dataset_type == "tool_call"
206
+
207
+ # Checkpointing setup
208
+ cm = self.checkpoint_manager
209
+ policy_hash = hash_policy(policy.text) if cm else ""
210
+ resuming = False
211
+
212
+ # Check for existing checkpoint
213
+ if cm and cm.has_checkpoint():
214
+ if cm.matches_config(policy_hash, traces, dataset_type):
215
+ resuming = True
216
+ from rich.console import Console
217
+ Console().print(f"[cyan]🔄 Resuming from checkpoint (stage: {cm.stage})[/cyan]")
218
+ else:
219
+ cm.clear() # Config mismatch, start fresh
220
+
221
+ # Report start
222
+ self.reporter.on_start(traces, model, dataset_type)
223
+
224
+ # Create components via factory
225
+ planner = self.factory.create_planner()
226
+ logic_extractor = self.factory.create_logic_extractor()
227
+ golden_scenario_gen = self.factory.create_golden_scenario_generator()
228
+ verifier = self.factory.create_verifier()
229
+ golden_refiner = self.factory.create_golden_refiner()
230
+
231
+ # Create appropriate trace generator based on dataset type
232
+ if is_tool_call and self.factory.has_tools:
233
+ golden_tool_call_gen = self.factory.create_golden_tool_call_generator()
234
+ else:
235
+ golden_response_gen = self.factory.create_golden_response_generator()
236
+
237
+ # Phase 0: Planning (for category distribution)
238
+ analyze_turns = turns == "auto"
239
+ plan = await self.plan_phase.execute(policy, traces, planner, analyze_turns=analyze_turns)
240
+ self.reporter.on_plan_complete(plan)
241
+
242
+ # Determine target turns
243
+ if isinstance(turns, int):
244
+ target_turns = turns
245
+ else:
246
+ target_turns = plan.recommended_turns
247
+
248
+ # =====================================================================
249
+ # STAGE 1: Logic Extraction (The Cartographer)
250
+ # =====================================================================
251
+ if resuming and cm and cm.stage in ("logic_map", "scenarios", "traces", "complete"):
252
+ logic_map = cm.get_logic_map()
253
+ from rich.console import Console
254
+ Console().print("[dim]📂 Loaded Logic Map from checkpoint[/dim]")
255
+ else:
256
+ with self.reporter.spinner("Extracting rules..."):
257
+ logic_map = await self.logic_extraction_phase.execute(policy, logic_extractor)
258
+ if cm:
259
+ cm.save_logic_map(logic_map, policy_hash, traces, dataset_type)
260
+
261
+ self.reporter.on_logic_map_complete(logic_map)
262
+
263
+ # Reset grading LLM call counter after setup phases
264
+ # (planner and logic extractor use grading_llm but aren't "grading" calls)
265
+ self.factory.grading_llm.reset_tracking()
266
+
267
+ # =====================================================================
268
+ # STAGE 2: Scenario Synthesis (The Adversary)
269
+ # =====================================================================
270
+ # Track scenario generation calls
271
+ scenario_calls_start = self.factory.generation_llm.call_count
272
+
273
+ if resuming and cm and cm.stage in ("scenarios", "traces", "complete"):
274
+ golden_scenarios = cm.get_scenarios()
275
+ distribution = cm.load().scenario_distribution
276
+ from rich.console import Console
277
+ Console().print(f"[dim]📂 Loaded {len(golden_scenarios)} scenarios from checkpoint[/dim]")
278
+ else:
279
+ with self.reporter.spinner("Generating scenarios..."):
280
+ golden_scenarios, distribution = await self.golden_scenario_phase.execute(
281
+ policy, logic_map, plan, golden_scenario_gen, semaphore
282
+ )
283
+ if cm:
284
+ cm.save_scenarios(golden_scenarios, distribution)
285
+
286
+ scenario_calls = self.factory.generation_llm.call_count - scenario_calls_start
287
+ self.reporter.on_golden_scenarios_complete(golden_scenarios, distribution)
288
+
289
+ # =====================================================================
290
+ # HUMAN-IN-THE-LOOP: Unified Session (Turns + Rules + Scenarios)
291
+ # =====================================================================
292
+ if self.enable_hitl and self.hitl_editor:
293
+ logic_map, golden_scenarios, distribution, target_turns = await self._run_hitl_session(
294
+ logic_map, golden_scenarios, distribution, policy, plan, target_turns
295
+ )
296
+
297
+ # =====================================================================
298
+ # STAGE 3: Trace Synthesis (The Thinker)
299
+ # =====================================================================
300
+ # Track response generation calls
301
+ response_calls_start = self.factory.generation_llm.call_count
302
+
303
+ if resuming and cm and cm.stage in ("traces", "complete"):
304
+ # Resume from checkpoint - get already completed traces
305
+ existing_traces = cm.get_traces()
306
+ pending_indices = cm.get_pending_scenario_indices(len(golden_scenarios))
307
+
308
+ if pending_indices:
309
+ from rich.console import Console
310
+ Console().print(f"[dim]📂 Resuming: {len(existing_traces)} done, {len(pending_indices)} pending[/dim]")
311
+
312
+ # Generate only pending scenarios
313
+ pending_scenarios = [golden_scenarios[i] for i in pending_indices]
314
+
315
+ with self.reporter.spinner("Generating responses..."):
316
+ if is_tool_call and self.factory.has_tools:
317
+ new_traces = await self.golden_tool_call_phase.execute(
318
+ policy, logic_map, pending_scenarios, golden_tool_call_gen, semaphore, target_turns
319
+ )
320
+ else:
321
+ new_traces = await self.golden_trace_phase.execute(
322
+ policy, logic_map, pending_scenarios, golden_response_gen, semaphore, target_turns
323
+ )
324
+
325
+ # Save new traces to checkpoint
326
+ if cm:
327
+ cm.save_traces_batch(list(new_traces), pending_indices)
328
+
329
+ all_traces = existing_traces + list(new_traces)
330
+ else:
331
+ all_traces = existing_traces
332
+ else:
333
+ with self.reporter.spinner("Generating responses..."):
334
+ if is_tool_call and self.factory.has_tools:
335
+ all_traces = await self.golden_tool_call_phase.execute(
336
+ policy, logic_map, golden_scenarios, golden_tool_call_gen, semaphore, target_turns
337
+ )
338
+ else:
339
+ all_traces = await self.golden_trace_phase.execute(
340
+ policy, logic_map, golden_scenarios, golden_response_gen, semaphore, target_turns
341
+ )
342
+
343
+ # Save all traces to checkpoint
344
+ if cm:
345
+ cm.save_traces_batch(list(all_traces), list(range(len(all_traces))))
346
+
347
+ response_calls = self.factory.generation_llm.call_count - response_calls_start
348
+ self.reporter.on_responses_complete(list(all_traces))
349
+
350
+ # =====================================================================
351
+ # STAGE 4: Verification (The Auditor)
352
+ # =====================================================================
353
+ pass_rate: float | None = None
354
+
355
+ if resuming and cm and cm.stage == "complete":
356
+ final_traces = cm.get_verified_traces()
357
+ passed_count = sum(1 for t in final_traces if t.grade and t.grade.passed)
358
+ pass_rate = (passed_count / len(final_traces) * 100) if final_traces else 0
359
+ from rich.console import Console
360
+ Console().print(f"[dim]📂 Loaded {len(final_traces)} verified traces from checkpoint[/dim]")
361
+ elif self.skip_grading:
362
+ final_traces = list(all_traces)
363
+ self.reporter.on_grading_skipped()
364
+ else:
365
+ with self.reporter.spinner("Verifying responses..."):
366
+ final_traces, pass_rate = await self.verification_phase.execute(
367
+ policy,
368
+ logic_map,
369
+ golden_scenarios,
370
+ list(all_traces),
371
+ verifier,
372
+ golden_refiner,
373
+ self.max_iterations,
374
+ semaphore,
375
+ )
376
+ if cm:
377
+ cm.save_verified_traces(final_traces)
378
+
379
+ self.reporter.on_grading_complete(final_traces, pass_rate)
380
+
381
+ # Report completion with cost tracking
382
+ elapsed = (datetime.now() - start_time).total_seconds()
383
+ total_cost = (
384
+ self.factory.generation_llm.total_cost +
385
+ self.factory.grading_llm.total_cost
386
+ )
387
+ self.reporter.on_complete(
388
+ len(final_traces),
389
+ elapsed,
390
+ pass_rate,
391
+ total_cost=total_cost,
392
+ generation_calls=self.factory.generation_llm.call_count,
393
+ grading_calls=self.factory.grading_llm.call_count,
394
+ scenario_calls=scenario_calls,
395
+ response_calls=response_calls,
396
+ )
397
+
398
+ dataset = Dataset(traces=final_traces)
399
+
400
+ if return_result:
401
+ return GenerationResult(
402
+ dataset=dataset,
403
+ logic_map=logic_map,
404
+ scenarios=golden_scenarios,
405
+ distribution=distribution,
406
+ )
407
+
408
+ return dataset
409
+
410
+ async def run_scenarios_only(
411
+ self,
412
+ policy: Policy,
413
+ count: int,
414
+ model: str,
415
+ ) -> ScenariosResult:
416
+ """
417
+ Run stages 0-2 only, returning scenarios without generating responses.
418
+
419
+ This is the eval-focused pipeline that produces test scenarios with
420
+ ground truth labels but no synthetic responses.
421
+
422
+ Args:
423
+ policy: The policy to generate scenarios from
424
+ count: Target number of scenarios
425
+ model: Model name (for reporting)
426
+
427
+ Returns:
428
+ ScenariosResult with scenarios, logic_map, and distribution
429
+ """
430
+ from datetime import datetime
431
+
432
+ start_time = datetime.now()
433
+ semaphore = asyncio.Semaphore(self.workers)
434
+
435
+ # Report start (using a simplified message)
436
+ self.reporter.on_start(count, model, "scenarios")
437
+
438
+ # Create components via factory
439
+ planner = self.factory.create_planner()
440
+ logic_extractor = self.factory.create_logic_extractor()
441
+ golden_scenario_gen = self.factory.create_golden_scenario_generator()
442
+
443
+ # Phase 0: Planning (for category distribution)
444
+ plan = await self.plan_phase.execute(policy, count, planner, analyze_turns=False)
445
+ self.reporter.on_plan_complete(plan)
446
+
447
+ # =====================================================================
448
+ # STAGE 1: Logic Extraction (The Cartographer)
449
+ # =====================================================================
450
+ with self.reporter.spinner("Extracting rules..."):
451
+ logic_map = await self.logic_extraction_phase.execute(policy, logic_extractor)
452
+
453
+ self.reporter.on_logic_map_complete(logic_map)
454
+
455
+ # =====================================================================
456
+ # STAGE 2: Scenario Synthesis (The Adversary)
457
+ # =====================================================================
458
+ with self.reporter.spinner("Generating scenarios..."):
459
+ golden_scenarios, distribution = await self.golden_scenario_phase.execute(
460
+ policy, logic_map, plan, golden_scenario_gen, semaphore
461
+ )
462
+
463
+ self.reporter.on_golden_scenarios_complete(golden_scenarios, distribution)
464
+
465
+ # =====================================================================
466
+ # HUMAN-IN-THE-LOOP (optional)
467
+ # =====================================================================
468
+ if self.enable_hitl and self.hitl_editor:
469
+ logic_map, golden_scenarios, distribution, _ = await self._run_hitl_session(
470
+ logic_map, golden_scenarios, distribution, policy, plan, 1
471
+ )
472
+
473
+ # Report completion
474
+ elapsed = (datetime.now() - start_time).total_seconds()
475
+ total_cost = self.factory.generation_llm.total_cost
476
+
477
+ self.reporter.on_complete(
478
+ len(golden_scenarios),
479
+ elapsed,
480
+ pass_rate=None,
481
+ total_cost=total_cost,
482
+ generation_calls=self.factory.generation_llm.call_count,
483
+ grading_calls=0,
484
+ scenario_calls=self.factory.generation_llm.call_count,
485
+ response_calls=0,
486
+ )
487
+
488
+ return ScenariosResult(
489
+ scenarios=golden_scenarios,
490
+ logic_map=logic_map,
491
+ distribution=distribution,
492
+ )
493
+
494
+ async def _run_hitl_session(
495
+ self,
496
+ logic_map: LogicMap,
497
+ scenarios: list["GoldenScenario"],
498
+ distribution: dict[str, int],
499
+ policy: Policy,
500
+ plan: "Plan",
501
+ initial_turns: int,
502
+ ) -> tuple[LogicMap, list["GoldenScenario"], dict[str, int], int]:
503
+ """
504
+ Run unified HITL session for turns, Logic Map, and scenario editing.
505
+
506
+ Args:
507
+ logic_map: The extracted Logic Map to edit
508
+ scenarios: The generated scenarios to edit
509
+ distribution: The scenario type distribution
510
+ policy: The policy document (for context in refinements)
511
+ plan: The generation plan (for complexity info)
512
+ initial_turns: Initial target turns setting
513
+
514
+ Returns:
515
+ Tuple of (modified LogicMap, modified scenarios, modified distribution, confirmed target_turns)
516
+ """
517
+ from synkro.interactive.hitl_session import HITLSession
518
+ from synkro.interactive.rich_ui import LogicMapDisplay, InteractivePrompt
519
+ from synkro.interactive.intent_classifier import HITLIntentClassifier
520
+
521
+ session = HITLSession(original_logic_map=logic_map)
522
+ session.set_scenarios(scenarios, distribution)
523
+
524
+ display = LogicMapDisplay()
525
+ prompt = InteractivePrompt()
526
+ classifier = HITLIntentClassifier(llm=self.factory.generation_llm)
527
+
528
+ current_turns = initial_turns
529
+ turns_history: list[int] = [] # For undo support
530
+
531
+ # Show initial state (includes session details with instructions)
532
+ display.display_full_session_state(
533
+ plan,
534
+ session.current_logic_map,
535
+ current_turns,
536
+ session.current_scenarios,
537
+ session.current_distribution,
538
+ )
539
+
540
+ while True:
541
+ feedback = prompt.get_feedback().strip()
542
+
543
+ # Handle explicit commands first (no LLM needed)
544
+ if feedback.lower() == "done":
545
+ break
546
+
547
+ if feedback.lower() == "undo":
548
+ # Undo turns, rules, or scenarios
549
+ if session.can_undo or turns_history:
550
+ restored_map, restored_scenarios, restored_dist = session.undo()
551
+ if turns_history:
552
+ current_turns = turns_history.pop()
553
+ display.show_success("Reverted to previous state")
554
+ display.display_full_session_state(
555
+ plan,
556
+ session.current_logic_map,
557
+ current_turns,
558
+ session.current_scenarios,
559
+ session.current_distribution,
560
+ )
561
+ else:
562
+ display.show_error("Nothing to undo")
563
+ continue
564
+
565
+ if feedback.lower() == "reset":
566
+ session.reset()
567
+ current_turns = initial_turns
568
+ turns_history.clear()
569
+ display.show_success("Reset to original state")
570
+ display.display_full_session_state(
571
+ plan,
572
+ session.current_logic_map,
573
+ current_turns,
574
+ session.current_scenarios,
575
+ session.current_distribution,
576
+ )
577
+ continue
578
+
579
+ if feedback.lower() == "help":
580
+ prompt.show_unified_instructions()
581
+ continue
582
+
583
+ if feedback.lower().startswith("show "):
584
+ target = feedback[5:].strip().upper()
585
+ if target.startswith("S") and target[1:].isdigit():
586
+ # Show scenario
587
+ if session.current_scenarios:
588
+ display.display_scenario(target, session.current_scenarios)
589
+ else:
590
+ # Show rule
591
+ display.display_rule(target, session.current_logic_map)
592
+ continue
593
+
594
+ # Empty input
595
+ if not feedback:
596
+ continue
597
+
598
+ # Classify intent via LLM
599
+ scenario_count = len(session.current_scenarios) if session.current_scenarios else 0
600
+ history = session.get_history_for_prompt()
601
+ with display.spinner("Processing..."):
602
+ intent = await classifier.classify(
603
+ feedback,
604
+ current_turns,
605
+ plan.complexity_level,
606
+ len(session.current_logic_map.rules),
607
+ scenario_count=scenario_count,
608
+ conversation_history=history,
609
+ )
610
+
611
+ if intent.intent_type == "turns" and intent.target_turns is not None:
612
+ # Handle turns change
613
+ turns_history.append(current_turns)
614
+ current_turns = intent.target_turns
615
+ reasoning = intent.turns_reasoning or "User preference"
616
+ summary = f"Set to {current_turns} turns ({reasoning})"
617
+ display.show_success(summary)
618
+ session.record_feedback(feedback, "turns", summary)
619
+ display.display_full_session_state(
620
+ plan,
621
+ session.current_logic_map,
622
+ current_turns,
623
+ session.current_scenarios,
624
+ session.current_distribution,
625
+ )
626
+
627
+ elif intent.intent_type == "rules" and intent.rule_feedback:
628
+ # Handle rule change
629
+ try:
630
+ with display.spinner("Updating rules..."):
631
+ new_map, changes_summary = await self.hitl_editor.refine(
632
+ session.current_logic_map,
633
+ intent.rule_feedback,
634
+ policy.text,
635
+ conversation_history=history,
636
+ )
637
+
638
+ # Validate the refinement
639
+ is_valid, issues = self.hitl_editor.validate_refinement(
640
+ session.current_logic_map,
641
+ new_map,
642
+ )
643
+
644
+ if is_valid:
645
+ display.display_diff(session.current_logic_map, new_map)
646
+ session.apply_change(intent.rule_feedback, new_map)
647
+ display.show_success(changes_summary)
648
+ session.record_feedback(feedback, "rules", changes_summary)
649
+ else:
650
+ display.show_error(f"Invalid refinement: {', '.join(issues)}")
651
+
652
+ except Exception as e:
653
+ display.show_error(f"Failed to apply refinement: {e}")
654
+
655
+ elif intent.intent_type == "scenarios" and self.scenario_editor:
656
+ # Handle scenario change
657
+ try:
658
+ scenario_feedback = intent.scenario_feedback or feedback
659
+ with display.spinner("Updating scenarios..."):
660
+ new_scenarios, new_dist, changes_summary = await self.scenario_editor.refine(
661
+ session.current_scenarios or [],
662
+ session.current_distribution or {},
663
+ scenario_feedback,
664
+ policy.text,
665
+ session.current_logic_map,
666
+ conversation_history=history,
667
+ )
668
+
669
+ # Validate the scenarios
670
+ is_valid, issues = self.scenario_editor.validate_scenarios(
671
+ new_scenarios,
672
+ session.current_logic_map,
673
+ )
674
+
675
+ if is_valid:
676
+ if session.current_scenarios:
677
+ display.display_scenario_diff(session.current_scenarios, new_scenarios)
678
+ session.apply_scenario_change(scenario_feedback, new_scenarios, new_dist)
679
+ display.show_success(changes_summary)
680
+ session.record_feedback(feedback, "scenarios", changes_summary)
681
+ else:
682
+ display.show_error(f"Invalid scenario edit: {', '.join(issues)}")
683
+
684
+ except Exception as e:
685
+ display.show_error(f"Failed to apply scenario edit: {e}")
686
+
687
+ elif intent.intent_type == "scenarios" and not self.scenario_editor:
688
+ display.show_error("Scenario editor not available")
689
+
690
+ elif intent.intent_type == "compound" and intent.rule_feedback and intent.scenario_feedback:
691
+ # Handle compound intent: rules first, then scenarios
692
+ try:
693
+ # Step 1: Apply rule changes
694
+ with display.spinner("Updating rules..."):
695
+ new_map, rule_summary = await self.hitl_editor.refine(
696
+ session.current_logic_map,
697
+ intent.rule_feedback,
698
+ policy.text,
699
+ conversation_history=history,
700
+ )
701
+
702
+ is_valid, issues = self.hitl_editor.validate_refinement(
703
+ session.current_logic_map,
704
+ new_map,
705
+ )
706
+
707
+ if not is_valid:
708
+ display.show_error(f"Invalid rule change: {', '.join(issues)}")
709
+ continue
710
+
711
+ # Show rule diff and apply
712
+ display.display_diff(session.current_logic_map, new_map)
713
+ session.apply_change(intent.rule_feedback, new_map)
714
+ display.show_success(rule_summary)
715
+ session.record_feedback(feedback, "rules", rule_summary)
716
+
717
+ # Step 2: Apply scenario changes (using updated logic map)
718
+ # Get updated history after rule change
719
+ updated_history = session.get_history_for_prompt()
720
+ if self.scenario_editor:
721
+ with display.spinner("Updating scenarios..."):
722
+ new_scenarios, new_dist, scenario_summary = await self.scenario_editor.refine(
723
+ session.current_scenarios or [],
724
+ session.current_distribution or {},
725
+ intent.scenario_feedback,
726
+ policy.text,
727
+ session.current_logic_map, # Now has the new rules
728
+ conversation_history=updated_history,
729
+ )
730
+
731
+ is_valid, issues = self.scenario_editor.validate_scenarios(
732
+ new_scenarios,
733
+ session.current_logic_map,
734
+ )
735
+
736
+ if is_valid:
737
+ if session.current_scenarios:
738
+ display.display_scenario_diff(session.current_scenarios, new_scenarios)
739
+ session.apply_scenario_change(intent.scenario_feedback, new_scenarios, new_dist)
740
+ display.show_success(scenario_summary)
741
+ session.record_feedback(feedback, "scenarios", scenario_summary)
742
+ else:
743
+ display.show_error(f"Invalid scenario edit: {', '.join(issues)}")
744
+ else:
745
+ display.show_error("Scenario editor not available for compound operation")
746
+
747
+ except Exception as e:
748
+ display.show_error(f"Failed to apply compound change: {e}")
749
+
750
+ elif intent.intent_type == "unclear":
751
+ display.show_error("Could not understand feedback. Try 'help' for examples.")
752
+
753
+ # Final summary
754
+ display.console.print(
755
+ f"\n[green]✅ Session complete[/green] - "
756
+ f"{session.rule_change_count} rule change(s), "
757
+ f"{session.scenario_change_count} scenario change(s), "
758
+ f"{current_turns} turns"
759
+ )
760
+
761
+ return (
762
+ session.current_logic_map,
763
+ session.current_scenarios or scenarios,
764
+ session.current_distribution or distribution,
765
+ current_turns,
766
+ )
767
+
768
+
769
+ __all__ = ["GenerationPipeline", "GenerationResult", "ScenariosResult"]