synkro 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkro/__init__.py +179 -0
- synkro/advanced.py +186 -0
- synkro/cli.py +128 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +402 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +276 -0
- synkro/formatters/__init__.py +12 -0
- synkro/formatters/qa.py +98 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +220 -0
- synkro/generation/golden_responses.py +244 -0
- synkro/generation/golden_scenarios.py +276 -0
- synkro/generation/golden_tool_responses.py +416 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +376 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +12 -0
- synkro/interactive/hitl_session.py +77 -0
- synkro/interactive/logic_map_editor.py +173 -0
- synkro/interactive/rich_ui.py +205 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +235 -0
- synkro/llm/rate_limits.py +95 -0
- synkro/models/__init__.py +43 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +15 -0
- synkro/modes/config.py +66 -0
- synkro/modes/qa.py +18 -0
- synkro/modes/sft.py +18 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +424 -0
- synkro/pipelines.py +123 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +474 -0
- synkro/prompts/interactive_templates.py +65 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/qa_templates.py +97 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +201 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +537 -0
- synkro/schemas.py +472 -0
- synkro/types/__init__.py +41 -0
- synkro/types/core.py +126 -0
- synkro/types/dataset_type.py +30 -0
- synkro/types/logic_map.py +345 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.12.data/data/examples/__init__.py +148 -0
- synkro-0.4.12.dist-info/METADATA +258 -0
- synkro-0.4.12.dist-info/RECORD +77 -0
- synkro-0.4.12.dist-info/WHEEL +4 -0
- synkro-0.4.12.dist-info/entry_points.txt +2 -0
- synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""Logic Map Editor - LLM-powered interactive refinement of Logic Maps."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from synkro.llm.client import LLM
|
|
8
|
+
from synkro.models import Model, OpenAI
|
|
9
|
+
from synkro.schemas import RefinedLogicMapOutput
|
|
10
|
+
from synkro.types.logic_map import LogicMap, Rule, RuleCategory
|
|
11
|
+
from synkro.prompts.interactive_templates import LOGIC_MAP_REFINEMENT_PROMPT
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LogicMapEditor:
|
|
18
|
+
"""
|
|
19
|
+
LLM-powered Logic Map editor that interprets natural language feedback.
|
|
20
|
+
|
|
21
|
+
The editor takes user feedback in natural language (e.g., "add a rule for...",
|
|
22
|
+
"remove R005", "merge R002 and R003") and uses an LLM to interpret and apply
|
|
23
|
+
the changes to the Logic Map.
|
|
24
|
+
|
|
25
|
+
Examples:
|
|
26
|
+
>>> editor = LogicMapEditor(llm=LLM(model=OpenAI.GPT_4O))
|
|
27
|
+
>>> new_logic_map = await editor.refine(
|
|
28
|
+
... logic_map=current_map,
|
|
29
|
+
... user_feedback="Add a rule for overtime approval",
|
|
30
|
+
... policy_text=policy.text
|
|
31
|
+
... )
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
llm: LLM | None = None,
|
|
37
|
+
model: Model = OpenAI.GPT_4O,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initialize the Logic Map Editor.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
llm: LLM client to use (creates one if not provided)
|
|
44
|
+
model: Model to use if creating LLM (default: GPT-4O for accuracy)
|
|
45
|
+
"""
|
|
46
|
+
self.llm = llm or LLM(model=model, temperature=0.3)
|
|
47
|
+
|
|
48
|
+
async def refine(
|
|
49
|
+
self,
|
|
50
|
+
logic_map: LogicMap,
|
|
51
|
+
user_feedback: str,
|
|
52
|
+
policy_text: str,
|
|
53
|
+
) -> tuple[LogicMap, str]:
|
|
54
|
+
"""
|
|
55
|
+
Refine the Logic Map based on natural language feedback.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
logic_map: Current Logic Map to refine
|
|
59
|
+
user_feedback: Natural language instruction from user
|
|
60
|
+
policy_text: Original policy text for context
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Tuple of (refined LogicMap, changes summary string)
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If refinement produces invalid DAG
|
|
67
|
+
"""
|
|
68
|
+
# Format current Logic Map as string for prompt
|
|
69
|
+
current_map_str = self._format_logic_map_for_prompt(logic_map)
|
|
70
|
+
|
|
71
|
+
# Format the prompt
|
|
72
|
+
prompt = LOGIC_MAP_REFINEMENT_PROMPT.format(
|
|
73
|
+
current_logic_map=current_map_str,
|
|
74
|
+
policy_text=policy_text,
|
|
75
|
+
user_feedback=user_feedback,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Generate structured output
|
|
79
|
+
result = await self.llm.generate_structured(prompt, RefinedLogicMapOutput)
|
|
80
|
+
|
|
81
|
+
# Convert to domain model
|
|
82
|
+
refined_map = self._convert_to_logic_map(result)
|
|
83
|
+
|
|
84
|
+
# Validate DAG properties
|
|
85
|
+
if not refined_map.validate_dag():
|
|
86
|
+
raise ValueError(
|
|
87
|
+
"Refined Logic Map contains circular dependencies. "
|
|
88
|
+
"Please try a different modification."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return refined_map, result.changes_summary
|
|
92
|
+
|
|
93
|
+
def _format_logic_map_for_prompt(self, logic_map: LogicMap) -> str:
|
|
94
|
+
"""Format a Logic Map as a string for the LLM prompt."""
|
|
95
|
+
lines = []
|
|
96
|
+
lines.append(f"Total Rules: {len(logic_map.rules)}")
|
|
97
|
+
lines.append(f"Root Rules: {', '.join(logic_map.root_rules)}")
|
|
98
|
+
lines.append("")
|
|
99
|
+
lines.append("Rules:")
|
|
100
|
+
|
|
101
|
+
for rule in logic_map.rules:
|
|
102
|
+
deps = f" -> {', '.join(rule.dependencies)}" if rule.dependencies else ""
|
|
103
|
+
lines.append(f" {rule.rule_id}: {rule.text}")
|
|
104
|
+
lines.append(f" Category: {rule.category.value}")
|
|
105
|
+
lines.append(f" Condition: {rule.condition}")
|
|
106
|
+
lines.append(f" Action: {rule.action}")
|
|
107
|
+
if deps:
|
|
108
|
+
lines.append(f" Dependencies: {', '.join(rule.dependencies)}")
|
|
109
|
+
lines.append("")
|
|
110
|
+
|
|
111
|
+
return "\n".join(lines)
|
|
112
|
+
|
|
113
|
+
def _convert_to_logic_map(self, output: RefinedLogicMapOutput) -> LogicMap:
|
|
114
|
+
"""Convert schema output to domain model."""
|
|
115
|
+
rules = []
|
|
116
|
+
for rule_out in output.rules:
|
|
117
|
+
# Convert category string to enum
|
|
118
|
+
category = RuleCategory(rule_out.category)
|
|
119
|
+
|
|
120
|
+
rule = Rule(
|
|
121
|
+
rule_id=rule_out.rule_id,
|
|
122
|
+
text=rule_out.text,
|
|
123
|
+
condition=rule_out.condition,
|
|
124
|
+
action=rule_out.action,
|
|
125
|
+
dependencies=rule_out.dependencies,
|
|
126
|
+
category=category,
|
|
127
|
+
)
|
|
128
|
+
rules.append(rule)
|
|
129
|
+
|
|
130
|
+
return LogicMap(
|
|
131
|
+
rules=rules,
|
|
132
|
+
root_rules=output.root_rules,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def validate_refinement(
|
|
136
|
+
self,
|
|
137
|
+
original: LogicMap,
|
|
138
|
+
refined: LogicMap,
|
|
139
|
+
) -> tuple[bool, list[str]]:
|
|
140
|
+
"""
|
|
141
|
+
Validate that refinement maintains DAG properties and is sensible.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
original: Original Logic Map
|
|
145
|
+
refined: Refined Logic Map
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Tuple of (is_valid, list of issue descriptions)
|
|
149
|
+
"""
|
|
150
|
+
issues = []
|
|
151
|
+
|
|
152
|
+
# Check DAG validity
|
|
153
|
+
if not refined.validate_dag():
|
|
154
|
+
issues.append("Refined Logic Map has circular dependencies")
|
|
155
|
+
|
|
156
|
+
# Check that all dependencies reference existing rules
|
|
157
|
+
rule_ids = {r.rule_id for r in refined.rules}
|
|
158
|
+
for rule in refined.rules:
|
|
159
|
+
for dep in rule.dependencies:
|
|
160
|
+
if dep not in rule_ids:
|
|
161
|
+
issues.append(f"Rule {rule.rule_id} depends on non-existent rule {dep}")
|
|
162
|
+
|
|
163
|
+
# Check root_rules consistency
|
|
164
|
+
for root_id in refined.root_rules:
|
|
165
|
+
if root_id not in rule_ids:
|
|
166
|
+
issues.append(f"Root rule {root_id} does not exist")
|
|
167
|
+
|
|
168
|
+
# Check that rules with no dependencies are in root_rules
|
|
169
|
+
for rule in refined.rules:
|
|
170
|
+
if not rule.dependencies and rule.rule_id not in refined.root_rules:
|
|
171
|
+
issues.append(f"Rule {rule.rule_id} has no dependencies but is not in root_rules")
|
|
172
|
+
|
|
173
|
+
return len(issues) == 0, issues
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
"""Rich UI components for Human-in-the-Loop interaction."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from synkro.types.logic_map import LogicMap
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LogicMapDisplay:
|
|
12
|
+
"""Rich-based display for Logic Maps."""
|
|
13
|
+
|
|
14
|
+
def __init__(self) -> None:
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
|
|
17
|
+
self.console = Console()
|
|
18
|
+
|
|
19
|
+
def display_full(self, logic_map: "LogicMap") -> None:
|
|
20
|
+
"""Display the complete Logic Map with all details."""
|
|
21
|
+
from rich.panel import Panel
|
|
22
|
+
from rich.tree import Tree
|
|
23
|
+
|
|
24
|
+
# Build tree view of rules by category
|
|
25
|
+
tree = Tree("[bold cyan]Logic Map[/bold cyan]")
|
|
26
|
+
|
|
27
|
+
# Group rules by category
|
|
28
|
+
categories: dict[str, list] = {}
|
|
29
|
+
for rule in logic_map.rules:
|
|
30
|
+
cat = rule.category.value if hasattr(rule.category, "value") else str(rule.category)
|
|
31
|
+
if cat not in categories:
|
|
32
|
+
categories[cat] = []
|
|
33
|
+
categories[cat].append(rule)
|
|
34
|
+
|
|
35
|
+
# Add each category as a branch
|
|
36
|
+
for category, rules in sorted(categories.items()):
|
|
37
|
+
branch = tree.add(f"[bold]{category}[/bold] ({len(rules)} rules)")
|
|
38
|
+
for rule in rules:
|
|
39
|
+
rule_text = f"[cyan]{rule.rule_id}[/cyan]: {rule.text[:60]}..."
|
|
40
|
+
if rule.dependencies:
|
|
41
|
+
rule_text += f" [dim]→ {', '.join(rule.dependencies)}[/dim]"
|
|
42
|
+
branch.add(rule_text)
|
|
43
|
+
|
|
44
|
+
# Show root rules
|
|
45
|
+
root_info = f"[dim]Root rules: {', '.join(logic_map.root_rules)}[/dim]"
|
|
46
|
+
|
|
47
|
+
self.console.print()
|
|
48
|
+
self.console.print(
|
|
49
|
+
Panel(
|
|
50
|
+
tree,
|
|
51
|
+
title=f"[bold]📜 Logic Map ({len(logic_map.rules)} rules)[/bold]",
|
|
52
|
+
subtitle=root_info,
|
|
53
|
+
border_style="cyan",
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def display_diff(self, before: "LogicMap", after: "LogicMap") -> None:
|
|
58
|
+
"""Display all rules with changes highlighted in different colors."""
|
|
59
|
+
from rich.panel import Panel
|
|
60
|
+
from rich.tree import Tree
|
|
61
|
+
|
|
62
|
+
before_ids = {r.rule_id for r in before.rules}
|
|
63
|
+
after_ids = {r.rule_id for r in after.rules}
|
|
64
|
+
|
|
65
|
+
added = after_ids - before_ids
|
|
66
|
+
removed = before_ids - after_ids
|
|
67
|
+
common = before_ids & after_ids
|
|
68
|
+
|
|
69
|
+
# Check for modifications in common rules
|
|
70
|
+
modified: set[str] = set()
|
|
71
|
+
before_map = {r.rule_id: r for r in before.rules}
|
|
72
|
+
after_map = {r.rule_id: r for r in after.rules}
|
|
73
|
+
|
|
74
|
+
for rule_id in common:
|
|
75
|
+
if before_map[rule_id].text != after_map[rule_id].text:
|
|
76
|
+
modified.add(rule_id)
|
|
77
|
+
elif before_map[rule_id].dependencies != after_map[rule_id].dependencies:
|
|
78
|
+
modified.add(rule_id)
|
|
79
|
+
|
|
80
|
+
if not added and not removed and not modified:
|
|
81
|
+
self.console.print("[dim]No changes detected[/dim]")
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
# Build tree view of rules by category (like display_full but with colors)
|
|
85
|
+
tree = Tree("[bold cyan]Logic Map[/bold cyan]")
|
|
86
|
+
|
|
87
|
+
# Group rules by category
|
|
88
|
+
categories: dict[str, list] = {}
|
|
89
|
+
for rule in after.rules:
|
|
90
|
+
cat = rule.category.value if hasattr(rule.category, "value") else str(rule.category)
|
|
91
|
+
if cat not in categories:
|
|
92
|
+
categories[cat] = []
|
|
93
|
+
categories[cat].append(rule)
|
|
94
|
+
|
|
95
|
+
# Add each category as a branch
|
|
96
|
+
for category, rules in sorted(categories.items()):
|
|
97
|
+
branch = tree.add(f"[bold]{category}[/bold] ({len(rules)} rules)")
|
|
98
|
+
for rule in rules:
|
|
99
|
+
# Determine style based on change type
|
|
100
|
+
if rule.rule_id in added:
|
|
101
|
+
prefix = "[green]+ "
|
|
102
|
+
style_close = "[/green]"
|
|
103
|
+
id_style = "[green]"
|
|
104
|
+
elif rule.rule_id in modified:
|
|
105
|
+
prefix = "[yellow]~ "
|
|
106
|
+
style_close = "[/yellow]"
|
|
107
|
+
id_style = "[yellow]"
|
|
108
|
+
else:
|
|
109
|
+
prefix = ""
|
|
110
|
+
style_close = ""
|
|
111
|
+
id_style = "[cyan]"
|
|
112
|
+
|
|
113
|
+
rule_text = f"{prefix}{id_style}{rule.rule_id}[/]: {rule.text[:60]}...{style_close}"
|
|
114
|
+
if rule.dependencies:
|
|
115
|
+
rule_text += f" [dim]→ {', '.join(rule.dependencies)}[/dim]"
|
|
116
|
+
branch.add(rule_text)
|
|
117
|
+
|
|
118
|
+
# Add removed rules section at the bottom
|
|
119
|
+
if removed:
|
|
120
|
+
removed_branch = tree.add("[red][bold]REMOVED[/bold][/red]")
|
|
121
|
+
for rule_id in sorted(removed):
|
|
122
|
+
rule = before_map[rule_id]
|
|
123
|
+
removed_branch.add(f"[red][strike]- {rule_id}: {rule.text[:60]}...[/strike][/red]")
|
|
124
|
+
|
|
125
|
+
# Build legend
|
|
126
|
+
legend = "[dim]Legend: [green]+ Added[/green] | [yellow]~ Modified[/yellow] | [red][strike]- Removed[/strike][/red][/dim]"
|
|
127
|
+
|
|
128
|
+
self.console.print()
|
|
129
|
+
self.console.print(
|
|
130
|
+
Panel(
|
|
131
|
+
tree,
|
|
132
|
+
title=f"[bold]📜 Logic Map ({len(after.rules)} rules)[/bold]",
|
|
133
|
+
subtitle=legend,
|
|
134
|
+
border_style="cyan",
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def display_rule(self, rule_id: str, logic_map: "LogicMap") -> None:
|
|
139
|
+
"""Display details of a specific rule."""
|
|
140
|
+
from rich.panel import Panel
|
|
141
|
+
|
|
142
|
+
rule = logic_map.get_rule(rule_id)
|
|
143
|
+
if not rule:
|
|
144
|
+
self.console.print(f"[red]Rule {rule_id} not found[/red]")
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
content = f"""[bold]ID:[/bold] {rule.rule_id}
|
|
148
|
+
[bold]Text:[/bold] {rule.text}
|
|
149
|
+
[bold]Category:[/bold] {rule.category}
|
|
150
|
+
[bold]Condition:[/bold] {rule.condition or 'N/A'}
|
|
151
|
+
[bold]Action:[/bold] {rule.action or 'N/A'}
|
|
152
|
+
[bold]Dependencies:[/bold] {', '.join(rule.dependencies) if rule.dependencies else 'None (root rule)'}"""
|
|
153
|
+
|
|
154
|
+
self.console.print(Panel(content, title=f"Rule {rule_id}", border_style="cyan"))
|
|
155
|
+
|
|
156
|
+
def show_error(self, message: str) -> None:
|
|
157
|
+
"""Display an error message."""
|
|
158
|
+
self.console.print(f"[red]Error:[/red] {message}")
|
|
159
|
+
|
|
160
|
+
def show_success(self, message: str) -> None:
|
|
161
|
+
"""Display a success message."""
|
|
162
|
+
self.console.print(f"[green]✓[/green] {message}")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class InteractivePrompt:
|
|
166
|
+
"""Handles user input for HITL sessions."""
|
|
167
|
+
|
|
168
|
+
def __init__(self) -> None:
|
|
169
|
+
from rich.console import Console
|
|
170
|
+
|
|
171
|
+
self.console = Console()
|
|
172
|
+
|
|
173
|
+
def show_instructions(self) -> None:
|
|
174
|
+
"""Display instructions for the HITL session."""
|
|
175
|
+
from rich.panel import Panel
|
|
176
|
+
|
|
177
|
+
instructions = """[bold]Commands:[/bold]
|
|
178
|
+
• Type feedback to modify the Logic Map (e.g., "add a rule for...", "remove R005")
|
|
179
|
+
• [cyan]done[/cyan] - Continue with current Logic Map
|
|
180
|
+
• [cyan]undo[/cyan] - Revert last change
|
|
181
|
+
• [cyan]reset[/cyan] - Restore original Logic Map
|
|
182
|
+
• [cyan]show R001[/cyan] - Show details of a specific rule
|
|
183
|
+
• [cyan]help[/cyan] - Show this message"""
|
|
184
|
+
|
|
185
|
+
self.console.print()
|
|
186
|
+
self.console.print(
|
|
187
|
+
Panel(
|
|
188
|
+
instructions,
|
|
189
|
+
title="[bold cyan]Interactive Logic Map Editor[/bold cyan]",
|
|
190
|
+
border_style="cyan",
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def get_feedback(self) -> str:
|
|
195
|
+
"""Prompt user for feedback on the Logic Map."""
|
|
196
|
+
from rich.prompt import Prompt
|
|
197
|
+
|
|
198
|
+
self.console.print()
|
|
199
|
+
return Prompt.ask("[cyan]Enter feedback[/cyan]")
|
|
200
|
+
|
|
201
|
+
def confirm_continue(self) -> bool:
|
|
202
|
+
"""Ask user if they want to continue with current Logic Map."""
|
|
203
|
+
from rich.prompt import Confirm
|
|
204
|
+
|
|
205
|
+
return Confirm.ask("Continue with this Logic Map?", default=True)
|
synkro/llm/__init__.py
ADDED
synkro/llm/client.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""Type-safe LLM wrapper using LiteLLM."""
|
|
2
|
+
|
|
3
|
+
from typing import TypeVar, Type, overload
|
|
4
|
+
|
|
5
|
+
import litellm
|
|
6
|
+
from litellm import acompletion, supports_response_schema
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
# Configure litellm
|
|
10
|
+
litellm.suppress_debug_info = True
|
|
11
|
+
litellm.enable_json_schema_validation=True
|
|
12
|
+
|
|
13
|
+
from synkro.models import OpenAI, Model, get_model_string
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T", bound=BaseModel)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LLM:
|
|
20
|
+
"""
|
|
21
|
+
Type-safe LLM wrapper using LiteLLM for universal provider support.
|
|
22
|
+
|
|
23
|
+
Supports structured outputs via native JSON mode for reliable responses.
|
|
24
|
+
|
|
25
|
+
Supported providers: OpenAI, Anthropic, Google (Gemini)
|
|
26
|
+
|
|
27
|
+
Examples:
|
|
28
|
+
>>> from synkro import LLM, OpenAI, Anthropic, Google
|
|
29
|
+
|
|
30
|
+
# Use OpenAI
|
|
31
|
+
>>> llm = LLM(model=OpenAI.GPT_4O_MINI)
|
|
32
|
+
>>> response = await llm.generate("Hello!")
|
|
33
|
+
|
|
34
|
+
# Use Anthropic
|
|
35
|
+
>>> llm = LLM(model=Anthropic.CLAUDE_35_SONNET)
|
|
36
|
+
|
|
37
|
+
# Use Google Gemini
|
|
38
|
+
>>> llm = LLM(model=Google.GEMINI_25_FLASH)
|
|
39
|
+
|
|
40
|
+
# Structured output
|
|
41
|
+
>>> class Output(BaseModel):
|
|
42
|
+
... answer: str
|
|
43
|
+
... confidence: float
|
|
44
|
+
>>> result = await llm.generate_structured("What is 2+2?", Output)
|
|
45
|
+
>>> result.answer
|
|
46
|
+
'4'
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
model: Model = OpenAI.GPT_4O_MINI,
|
|
52
|
+
temperature: float = 0.7,
|
|
53
|
+
max_tokens: int | None = None,
|
|
54
|
+
api_key: str | None = None,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Initialize the LLM client.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
model: Model to use (enum or string)
|
|
61
|
+
temperature: Sampling temperature (0.0-2.0)
|
|
62
|
+
max_tokens: Maximum tokens to generate (default: None = model's max)
|
|
63
|
+
api_key: Optional API key override
|
|
64
|
+
"""
|
|
65
|
+
self.model = get_model_string(model)
|
|
66
|
+
self.temperature = temperature
|
|
67
|
+
self.max_tokens = max_tokens
|
|
68
|
+
self._api_key = api_key
|
|
69
|
+
|
|
70
|
+
async def generate(self, prompt: str, system: str | None = None) -> str:
|
|
71
|
+
"""
|
|
72
|
+
Generate a text response.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
prompt: The user prompt
|
|
76
|
+
system: Optional system prompt
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Generated text response
|
|
80
|
+
"""
|
|
81
|
+
messages = []
|
|
82
|
+
if system:
|
|
83
|
+
messages.append({"role": "system", "content": system})
|
|
84
|
+
messages.append({"role": "user", "content": prompt})
|
|
85
|
+
|
|
86
|
+
kwargs = {
|
|
87
|
+
"model": self.model,
|
|
88
|
+
"messages": messages,
|
|
89
|
+
"temperature": self.temperature,
|
|
90
|
+
"api_key": self._api_key,
|
|
91
|
+
}
|
|
92
|
+
if self.max_tokens is not None:
|
|
93
|
+
kwargs["max_tokens"] = self.max_tokens
|
|
94
|
+
|
|
95
|
+
response = await acompletion(**kwargs)
|
|
96
|
+
return response.choices[0].message.content
|
|
97
|
+
|
|
98
|
+
async def generate_batch(
|
|
99
|
+
self, prompts: list[str], system: str | None = None
|
|
100
|
+
) -> list[str]:
|
|
101
|
+
"""
|
|
102
|
+
Generate responses for multiple prompts in parallel.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
prompts: List of user prompts
|
|
106
|
+
system: Optional system prompt for all
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
List of generated responses
|
|
110
|
+
"""
|
|
111
|
+
import asyncio
|
|
112
|
+
|
|
113
|
+
tasks = [self.generate(p, system) for p in prompts]
|
|
114
|
+
return await asyncio.gather(*tasks)
|
|
115
|
+
|
|
116
|
+
@overload
|
|
117
|
+
async def generate_structured(
|
|
118
|
+
self,
|
|
119
|
+
prompt: str,
|
|
120
|
+
response_model: Type[T],
|
|
121
|
+
system: str | None = None,
|
|
122
|
+
) -> T: ...
|
|
123
|
+
|
|
124
|
+
@overload
|
|
125
|
+
async def generate_structured(
|
|
126
|
+
self,
|
|
127
|
+
prompt: str,
|
|
128
|
+
response_model: Type[list[T]],
|
|
129
|
+
system: str | None = None,
|
|
130
|
+
) -> list[T]: ...
|
|
131
|
+
|
|
132
|
+
async def generate_structured(
|
|
133
|
+
self,
|
|
134
|
+
prompt: str,
|
|
135
|
+
response_model: Type[T] | Type[list[T]],
|
|
136
|
+
system: str | None = None,
|
|
137
|
+
) -> T | list[T]:
|
|
138
|
+
"""
|
|
139
|
+
Generate a structured response matching a Pydantic model.
|
|
140
|
+
|
|
141
|
+
Uses LiteLLM's native JSON mode with response_format for
|
|
142
|
+
reliable structured outputs.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
prompt: The user prompt
|
|
146
|
+
response_model: Pydantic model class for the response
|
|
147
|
+
system: Optional system prompt
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Parsed response matching the model
|
|
151
|
+
|
|
152
|
+
Example:
|
|
153
|
+
>>> class Analysis(BaseModel):
|
|
154
|
+
... sentiment: str
|
|
155
|
+
... score: float
|
|
156
|
+
>>> result = await llm.generate_structured(
|
|
157
|
+
... "Analyze: I love this product!",
|
|
158
|
+
... Analysis
|
|
159
|
+
... )
|
|
160
|
+
>>> result.sentiment
|
|
161
|
+
'positive'
|
|
162
|
+
"""
|
|
163
|
+
# Check if model supports structured outputs
|
|
164
|
+
if not supports_response_schema(model=self.model, custom_llm_provider=None):
|
|
165
|
+
raise ValueError(
|
|
166
|
+
f"Model '{self.model}' does not support structured outputs (response_format). "
|
|
167
|
+
f"Use a model that supports JSON schema like GPT-4o, Gemini 1.5+, or Claude 3.5+."
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
messages = []
|
|
171
|
+
if system:
|
|
172
|
+
messages.append({"role": "system", "content": system})
|
|
173
|
+
messages.append({"role": "user", "content": prompt})
|
|
174
|
+
|
|
175
|
+
# Use LiteLLM's native response_format with Pydantic model
|
|
176
|
+
kwargs = {
|
|
177
|
+
"model": self.model,
|
|
178
|
+
"messages": messages,
|
|
179
|
+
"response_format": response_model,
|
|
180
|
+
"temperature": self.temperature,
|
|
181
|
+
"api_key": self._api_key,
|
|
182
|
+
}
|
|
183
|
+
if self.max_tokens is not None:
|
|
184
|
+
kwargs["max_tokens"] = self.max_tokens
|
|
185
|
+
|
|
186
|
+
response = await acompletion(**kwargs)
|
|
187
|
+
return response_model.model_validate_json(response.choices[0].message.content)
|
|
188
|
+
|
|
189
|
+
async def generate_chat(
|
|
190
|
+
self, messages: list[dict], response_model: Type[T] | None = None
|
|
191
|
+
) -> str | T:
|
|
192
|
+
"""
|
|
193
|
+
Generate a response for a full conversation.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
messages: List of message dicts with 'role' and 'content'
|
|
197
|
+
response_model: Optional Pydantic model for structured output
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Generated response (string or structured)
|
|
201
|
+
"""
|
|
202
|
+
if response_model:
|
|
203
|
+
# Check if model supports structured outputs
|
|
204
|
+
if not supports_response_schema(model=self.model, custom_llm_provider=None):
|
|
205
|
+
raise ValueError(
|
|
206
|
+
f"Model '{self.model}' does not support structured outputs (response_format). "
|
|
207
|
+
f"Use a model that supports JSON schema like GPT-4o, Gemini 1.5+, or Claude 3.5+."
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Use LiteLLM's native response_format with Pydantic model
|
|
211
|
+
kwargs = {
|
|
212
|
+
"model": self.model,
|
|
213
|
+
"messages": messages,
|
|
214
|
+
"response_format": response_model,
|
|
215
|
+
"temperature": self.temperature,
|
|
216
|
+
"api_key": self._api_key,
|
|
217
|
+
}
|
|
218
|
+
if self.max_tokens is not None:
|
|
219
|
+
kwargs["max_tokens"] = self.max_tokens
|
|
220
|
+
|
|
221
|
+
response = await acompletion(**kwargs)
|
|
222
|
+
return response_model.model_validate_json(response.choices[0].message.content)
|
|
223
|
+
|
|
224
|
+
kwargs = {
|
|
225
|
+
"model": self.model,
|
|
226
|
+
"messages": messages,
|
|
227
|
+
"temperature": self.temperature,
|
|
228
|
+
"api_key": self._api_key,
|
|
229
|
+
}
|
|
230
|
+
if self.max_tokens is not None:
|
|
231
|
+
kwargs["max_tokens"] = self.max_tokens
|
|
232
|
+
|
|
233
|
+
response = await acompletion(**kwargs)
|
|
234
|
+
return response.choices[0].message.content
|
|
235
|
+
|