synkro 0.4.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synkro/__init__.py +179 -0
- synkro/advanced.py +186 -0
- synkro/cli.py +128 -0
- synkro/core/__init__.py +7 -0
- synkro/core/checkpoint.py +250 -0
- synkro/core/dataset.py +402 -0
- synkro/core/policy.py +337 -0
- synkro/errors.py +178 -0
- synkro/examples/__init__.py +148 -0
- synkro/factory.py +276 -0
- synkro/formatters/__init__.py +12 -0
- synkro/formatters/qa.py +98 -0
- synkro/formatters/sft.py +90 -0
- synkro/formatters/tool_call.py +127 -0
- synkro/generation/__init__.py +9 -0
- synkro/generation/follow_ups.py +134 -0
- synkro/generation/generator.py +220 -0
- synkro/generation/golden_responses.py +244 -0
- synkro/generation/golden_scenarios.py +276 -0
- synkro/generation/golden_tool_responses.py +416 -0
- synkro/generation/logic_extractor.py +126 -0
- synkro/generation/multiturn_responses.py +177 -0
- synkro/generation/planner.py +131 -0
- synkro/generation/responses.py +189 -0
- synkro/generation/scenarios.py +90 -0
- synkro/generation/tool_responses.py +376 -0
- synkro/generation/tool_simulator.py +114 -0
- synkro/interactive/__init__.py +12 -0
- synkro/interactive/hitl_session.py +77 -0
- synkro/interactive/logic_map_editor.py +173 -0
- synkro/interactive/rich_ui.py +205 -0
- synkro/llm/__init__.py +7 -0
- synkro/llm/client.py +235 -0
- synkro/llm/rate_limits.py +95 -0
- synkro/models/__init__.py +43 -0
- synkro/models/anthropic.py +26 -0
- synkro/models/google.py +19 -0
- synkro/models/openai.py +31 -0
- synkro/modes/__init__.py +15 -0
- synkro/modes/config.py +66 -0
- synkro/modes/qa.py +18 -0
- synkro/modes/sft.py +18 -0
- synkro/modes/tool_call.py +18 -0
- synkro/parsers.py +442 -0
- synkro/pipeline/__init__.py +20 -0
- synkro/pipeline/phases.py +592 -0
- synkro/pipeline/runner.py +424 -0
- synkro/pipelines.py +123 -0
- synkro/prompts/__init__.py +57 -0
- synkro/prompts/base.py +167 -0
- synkro/prompts/golden_templates.py +474 -0
- synkro/prompts/interactive_templates.py +65 -0
- synkro/prompts/multiturn_templates.py +156 -0
- synkro/prompts/qa_templates.py +97 -0
- synkro/prompts/templates.py +281 -0
- synkro/prompts/tool_templates.py +201 -0
- synkro/quality/__init__.py +14 -0
- synkro/quality/golden_refiner.py +163 -0
- synkro/quality/grader.py +153 -0
- synkro/quality/multiturn_grader.py +150 -0
- synkro/quality/refiner.py +137 -0
- synkro/quality/tool_grader.py +126 -0
- synkro/quality/tool_refiner.py +128 -0
- synkro/quality/verifier.py +228 -0
- synkro/reporting.py +537 -0
- synkro/schemas.py +472 -0
- synkro/types/__init__.py +41 -0
- synkro/types/core.py +126 -0
- synkro/types/dataset_type.py +30 -0
- synkro/types/logic_map.py +345 -0
- synkro/types/tool.py +94 -0
- synkro-0.4.12.data/data/examples/__init__.py +148 -0
- synkro-0.4.12.dist-info/METADATA +258 -0
- synkro-0.4.12.dist-info/RECORD +77 -0
- synkro-0.4.12.dist-info/WHEEL +4 -0
- synkro-0.4.12.dist-info/entry_points.txt +2 -0
- synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
"""Checkpoint manager for resumable generation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from synkro.types.core import Trace
|
|
13
|
+
from synkro.types.logic_map import LogicMap, GoldenScenario
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CheckpointData(BaseModel):
|
|
19
|
+
"""Data stored in a checkpoint file."""
|
|
20
|
+
|
|
21
|
+
# Metadata
|
|
22
|
+
created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
|
|
23
|
+
policy_hash: str = ""
|
|
24
|
+
target_traces: int = 0
|
|
25
|
+
dataset_type: str = ""
|
|
26
|
+
|
|
27
|
+
# Stage 1: Logic Map
|
|
28
|
+
logic_map_data: dict | None = None
|
|
29
|
+
|
|
30
|
+
# Stage 2: Scenarios
|
|
31
|
+
scenarios_data: list[dict] = Field(default_factory=list)
|
|
32
|
+
scenario_distribution: dict[str, int] = Field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
# Stage 3: Generated traces
|
|
35
|
+
traces_data: list[dict] = Field(default_factory=list)
|
|
36
|
+
completed_scenario_indices: list[int] = Field(default_factory=list)
|
|
37
|
+
|
|
38
|
+
# Stage 4: Verified traces
|
|
39
|
+
verified_traces_data: list[dict] = Field(default_factory=list)
|
|
40
|
+
verification_complete: bool = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CheckpointManager:
|
|
44
|
+
"""
|
|
45
|
+
Manages checkpoints for resumable generation.
|
|
46
|
+
|
|
47
|
+
Saves progress after each stage and allows resuming from the last
|
|
48
|
+
successful checkpoint.
|
|
49
|
+
|
|
50
|
+
Examples:
|
|
51
|
+
>>> manager = CheckpointManager("./checkpoints")
|
|
52
|
+
>>> manager.save_logic_map(logic_map, policy_hash, 100, "sft")
|
|
53
|
+
>>> manager.save_scenarios(scenarios, distribution)
|
|
54
|
+
>>> manager.save_trace(trace, scenario_index)
|
|
55
|
+
|
|
56
|
+
>>> # Resume from checkpoint
|
|
57
|
+
>>> checkpoint = manager.load()
|
|
58
|
+
>>> if checkpoint:
|
|
59
|
+
... logic_map = checkpoint.get_logic_map()
|
|
60
|
+
... completed = checkpoint.completed_scenario_indices
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, checkpoint_dir: str | Path):
|
|
64
|
+
"""
|
|
65
|
+
Initialize the checkpoint manager.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
checkpoint_dir: Directory to store checkpoint files
|
|
69
|
+
"""
|
|
70
|
+
self.checkpoint_dir = Path(checkpoint_dir)
|
|
71
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
72
|
+
self.checkpoint_file = self.checkpoint_dir / "checkpoint.json"
|
|
73
|
+
self._data: CheckpointData | None = None
|
|
74
|
+
|
|
75
|
+
def _load_or_create(self) -> CheckpointData:
|
|
76
|
+
"""Load existing checkpoint or create new one."""
|
|
77
|
+
if self._data is not None:
|
|
78
|
+
return self._data
|
|
79
|
+
|
|
80
|
+
if self.checkpoint_file.exists():
|
|
81
|
+
with open(self.checkpoint_file) as f:
|
|
82
|
+
data = json.load(f)
|
|
83
|
+
self._data = CheckpointData.model_validate(data)
|
|
84
|
+
else:
|
|
85
|
+
self._data = CheckpointData()
|
|
86
|
+
|
|
87
|
+
return self._data
|
|
88
|
+
|
|
89
|
+
def _save(self) -> None:
|
|
90
|
+
"""Save checkpoint to disk."""
|
|
91
|
+
if self._data is None:
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
with open(self.checkpoint_file, "w") as f:
|
|
95
|
+
json.dump(self._data.model_dump(), f, indent=2)
|
|
96
|
+
|
|
97
|
+
def has_checkpoint(self) -> bool:
|
|
98
|
+
"""Check if a checkpoint exists."""
|
|
99
|
+
return self.checkpoint_file.exists()
|
|
100
|
+
|
|
101
|
+
def load(self) -> CheckpointData | None:
|
|
102
|
+
"""Load checkpoint if it exists."""
|
|
103
|
+
if not self.has_checkpoint():
|
|
104
|
+
return None
|
|
105
|
+
return self._load_or_create()
|
|
106
|
+
|
|
107
|
+
def matches_config(self, policy_hash: str, target_traces: int, dataset_type: str) -> bool:
|
|
108
|
+
"""Check if checkpoint matches the current generation config."""
|
|
109
|
+
data = self._load_or_create()
|
|
110
|
+
return (
|
|
111
|
+
data.policy_hash == policy_hash
|
|
112
|
+
and data.target_traces == target_traces
|
|
113
|
+
and data.dataset_type == dataset_type
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def save_logic_map(
|
|
117
|
+
self,
|
|
118
|
+
logic_map: "LogicMap",
|
|
119
|
+
policy_hash: str,
|
|
120
|
+
target_traces: int,
|
|
121
|
+
dataset_type: str,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Save Logic Map (Stage 1 complete)."""
|
|
124
|
+
data = self._load_or_create()
|
|
125
|
+
data.policy_hash = policy_hash
|
|
126
|
+
data.target_traces = target_traces
|
|
127
|
+
data.dataset_type = dataset_type
|
|
128
|
+
data.logic_map_data = logic_map.model_dump()
|
|
129
|
+
self._save()
|
|
130
|
+
console.print("[dim]💾 Checkpoint: Logic Map saved[/dim]")
|
|
131
|
+
|
|
132
|
+
def save_scenarios(
|
|
133
|
+
self,
|
|
134
|
+
scenarios: list["GoldenScenario"],
|
|
135
|
+
distribution: dict[str, int],
|
|
136
|
+
) -> None:
|
|
137
|
+
"""Save scenarios (Stage 2 complete)."""
|
|
138
|
+
data = self._load_or_create()
|
|
139
|
+
data.scenarios_data = [s.model_dump() for s in scenarios]
|
|
140
|
+
data.scenario_distribution = distribution
|
|
141
|
+
self._save()
|
|
142
|
+
console.print("[dim]💾 Checkpoint: Scenarios saved[/dim]")
|
|
143
|
+
|
|
144
|
+
def save_trace(self, trace: "Trace", scenario_index: int) -> None:
|
|
145
|
+
"""Save a generated trace (incremental Stage 3)."""
|
|
146
|
+
data = self._load_or_create()
|
|
147
|
+
data.traces_data.append(trace.model_dump())
|
|
148
|
+
data.completed_scenario_indices.append(scenario_index)
|
|
149
|
+
self._save()
|
|
150
|
+
|
|
151
|
+
def save_traces_batch(self, traces: list["Trace"], indices: list[int]) -> None:
|
|
152
|
+
"""Save a batch of traces at once."""
|
|
153
|
+
data = self._load_or_create()
|
|
154
|
+
for trace, idx in zip(traces, indices):
|
|
155
|
+
data.traces_data.append(trace.model_dump())
|
|
156
|
+
data.completed_scenario_indices.append(idx)
|
|
157
|
+
self._save()
|
|
158
|
+
console.print(f"[dim]💾 Checkpoint: {len(traces)} traces saved[/dim]")
|
|
159
|
+
|
|
160
|
+
def save_verified_traces(self, traces: list["Trace"]) -> None:
|
|
161
|
+
"""Save verified traces (Stage 4 complete)."""
|
|
162
|
+
data = self._load_or_create()
|
|
163
|
+
data.verified_traces_data = [t.model_dump() for t in traces]
|
|
164
|
+
data.verification_complete = True
|
|
165
|
+
self._save()
|
|
166
|
+
console.print("[dim]💾 Checkpoint: Verification complete[/dim]")
|
|
167
|
+
|
|
168
|
+
def get_logic_map(self) -> "LogicMap | None":
|
|
169
|
+
"""Retrieve Logic Map from checkpoint."""
|
|
170
|
+
from synkro.types.logic_map import LogicMap
|
|
171
|
+
|
|
172
|
+
data = self._load_or_create()
|
|
173
|
+
if data.logic_map_data:
|
|
174
|
+
return LogicMap.model_validate(data.logic_map_data)
|
|
175
|
+
return None
|
|
176
|
+
|
|
177
|
+
def get_scenarios(self) -> list["GoldenScenario"]:
|
|
178
|
+
"""Retrieve scenarios from checkpoint."""
|
|
179
|
+
from synkro.types.logic_map import GoldenScenario
|
|
180
|
+
|
|
181
|
+
data = self._load_or_create()
|
|
182
|
+
return [GoldenScenario.model_validate(s) for s in data.scenarios_data]
|
|
183
|
+
|
|
184
|
+
def get_traces(self) -> list["Trace"]:
|
|
185
|
+
"""Retrieve traces from checkpoint."""
|
|
186
|
+
from synkro.types.core import Trace
|
|
187
|
+
|
|
188
|
+
data = self._load_or_create()
|
|
189
|
+
return [Trace.model_validate(t) for t in data.traces_data]
|
|
190
|
+
|
|
191
|
+
def get_verified_traces(self) -> list["Trace"]:
|
|
192
|
+
"""Retrieve verified traces from checkpoint."""
|
|
193
|
+
from synkro.types.core import Trace
|
|
194
|
+
|
|
195
|
+
data = self._load_or_create()
|
|
196
|
+
return [Trace.model_validate(t) for t in data.verified_traces_data]
|
|
197
|
+
|
|
198
|
+
def get_pending_scenario_indices(self, total: int) -> list[int]:
|
|
199
|
+
"""Get indices of scenarios that haven't been processed yet."""
|
|
200
|
+
data = self._load_or_create()
|
|
201
|
+
completed = set(data.completed_scenario_indices)
|
|
202
|
+
return [i for i in range(total) if i not in completed]
|
|
203
|
+
|
|
204
|
+
def clear(self) -> None:
|
|
205
|
+
"""Clear the checkpoint."""
|
|
206
|
+
if self.checkpoint_file.exists():
|
|
207
|
+
self.checkpoint_file.unlink()
|
|
208
|
+
self._data = None
|
|
209
|
+
console.print("[dim]🗑️ Checkpoint cleared[/dim]")
|
|
210
|
+
|
|
211
|
+
@property
|
|
212
|
+
def stage(self) -> str:
|
|
213
|
+
"""Get the current stage based on checkpoint data."""
|
|
214
|
+
data = self._load_or_create()
|
|
215
|
+
|
|
216
|
+
if data.verification_complete:
|
|
217
|
+
return "complete"
|
|
218
|
+
if data.traces_data:
|
|
219
|
+
return "traces" # In progress or done
|
|
220
|
+
if data.scenarios_data:
|
|
221
|
+
return "scenarios"
|
|
222
|
+
if data.logic_map_data:
|
|
223
|
+
return "logic_map"
|
|
224
|
+
return "start"
|
|
225
|
+
|
|
226
|
+
def summary(self) -> str:
|
|
227
|
+
"""Get a summary of the checkpoint status."""
|
|
228
|
+
data = self._load_or_create()
|
|
229
|
+
|
|
230
|
+
lines = [
|
|
231
|
+
f"Checkpoint Status",
|
|
232
|
+
f"=================",
|
|
233
|
+
f"Stage: {self.stage}",
|
|
234
|
+
f"Target traces: {data.target_traces}",
|
|
235
|
+
f"Logic Map: {'✓' if data.logic_map_data else '✗'}",
|
|
236
|
+
f"Scenarios: {len(data.scenarios_data)}",
|
|
237
|
+
f"Traces: {len(data.traces_data)}/{data.target_traces}",
|
|
238
|
+
f"Verified: {'✓' if data.verification_complete else '✗'}",
|
|
239
|
+
]
|
|
240
|
+
|
|
241
|
+
return "\n".join(lines)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def hash_policy(policy_text: str) -> str:
|
|
245
|
+
"""Create a hash of policy text for checkpoint matching."""
|
|
246
|
+
import hashlib
|
|
247
|
+
return hashlib.sha256(policy_text.encode()).hexdigest()[:16]
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
__all__ = ["CheckpointManager", "CheckpointData", "hash_policy"]
|
synkro/core/dataset.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""Dataset class for managing generated traces."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterator
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
|
|
11
|
+
from synkro.types.core import Trace
|
|
12
|
+
|
|
13
|
+
console = Console()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Dataset(BaseModel):
|
|
17
|
+
"""
|
|
18
|
+
A collection of generated training traces.
|
|
19
|
+
|
|
20
|
+
Provides methods for filtering, saving, and exporting traces
|
|
21
|
+
in various formats.
|
|
22
|
+
|
|
23
|
+
Examples:
|
|
24
|
+
>>> dataset = generator.generate(policy, traces=100)
|
|
25
|
+
|
|
26
|
+
>>> # Filter to only passing traces
|
|
27
|
+
>>> passing = dataset.filter(passed=True)
|
|
28
|
+
|
|
29
|
+
>>> # Save to JSONL
|
|
30
|
+
>>> dataset.save("training.jsonl")
|
|
31
|
+
|
|
32
|
+
>>> # Push to HuggingFace
|
|
33
|
+
>>> dataset.to_huggingface().push_to_hub("my-org/dataset")
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
traces: list[Trace] = Field(default_factory=list)
|
|
37
|
+
|
|
38
|
+
class Config:
|
|
39
|
+
arbitrary_types_allowed = True
|
|
40
|
+
|
|
41
|
+
def __len__(self) -> int:
|
|
42
|
+
return len(self.traces)
|
|
43
|
+
|
|
44
|
+
def __iter__(self) -> Iterator[Trace]:
|
|
45
|
+
return iter(self.traces)
|
|
46
|
+
|
|
47
|
+
def __getitem__(self, idx: int) -> Trace:
|
|
48
|
+
return self.traces[idx]
|
|
49
|
+
|
|
50
|
+
def filter(
|
|
51
|
+
self,
|
|
52
|
+
passed: bool | None = None,
|
|
53
|
+
category: str | None = None,
|
|
54
|
+
min_length: int | None = None,
|
|
55
|
+
) -> "Dataset":
|
|
56
|
+
"""
|
|
57
|
+
Filter traces by criteria.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
passed: Filter by grade pass/fail status
|
|
61
|
+
category: Filter by scenario category
|
|
62
|
+
min_length: Minimum response length in characters
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
New Dataset with filtered traces
|
|
66
|
+
"""
|
|
67
|
+
filtered = self.traces
|
|
68
|
+
|
|
69
|
+
if passed is not None:
|
|
70
|
+
filtered = [
|
|
71
|
+
t for t in filtered if t.grade and t.grade.passed == passed
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
if category is not None:
|
|
75
|
+
filtered = [
|
|
76
|
+
t for t in filtered if t.scenario.category == category
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
if min_length is not None:
|
|
80
|
+
filtered = [
|
|
81
|
+
t for t in filtered if len(t.assistant_message) >= min_length
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
return Dataset(traces=filtered)
|
|
85
|
+
|
|
86
|
+
def dedupe(
|
|
87
|
+
self,
|
|
88
|
+
threshold: float = 0.85,
|
|
89
|
+
method: str = "semantic",
|
|
90
|
+
field: str = "user",
|
|
91
|
+
) -> "Dataset":
|
|
92
|
+
"""
|
|
93
|
+
Remove duplicate or near-duplicate traces.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
threshold: Similarity threshold (0-1). Higher = stricter dedup.
|
|
97
|
+
Only used for semantic method. (default: 0.85)
|
|
98
|
+
method: Deduplication method:
|
|
99
|
+
- "exact": Remove exact text duplicates (fast)
|
|
100
|
+
- "semantic": Remove semantically similar traces (requires sentence-transformers)
|
|
101
|
+
field: Which field to dedupe on - "user", "assistant", or "both"
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
New Dataset with duplicates removed
|
|
105
|
+
|
|
106
|
+
Examples:
|
|
107
|
+
>>> # Remove exact duplicates (fast)
|
|
108
|
+
>>> deduped = dataset.dedupe(method="exact")
|
|
109
|
+
|
|
110
|
+
>>> # Remove semantically similar (needs sentence-transformers)
|
|
111
|
+
>>> deduped = dataset.dedupe(threshold=0.9, method="semantic")
|
|
112
|
+
|
|
113
|
+
>>> # Dedupe based on assistant responses
|
|
114
|
+
>>> deduped = dataset.dedupe(field="assistant")
|
|
115
|
+
"""
|
|
116
|
+
if not self.traces:
|
|
117
|
+
return Dataset(traces=[])
|
|
118
|
+
|
|
119
|
+
if method == "exact":
|
|
120
|
+
return self._dedupe_exact(field)
|
|
121
|
+
elif method == "semantic":
|
|
122
|
+
return self._dedupe_semantic(threshold, field)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError(f"Unknown method: {method}. Use 'exact' or 'semantic'")
|
|
125
|
+
|
|
126
|
+
def _dedupe_exact(self, field: str) -> "Dataset":
|
|
127
|
+
"""Remove exact text duplicates."""
|
|
128
|
+
seen = set()
|
|
129
|
+
unique_traces = []
|
|
130
|
+
|
|
131
|
+
for trace in self.traces:
|
|
132
|
+
if field == "user":
|
|
133
|
+
key = trace.user_message
|
|
134
|
+
elif field == "assistant":
|
|
135
|
+
key = trace.assistant_message
|
|
136
|
+
else: # both
|
|
137
|
+
key = (trace.user_message, trace.assistant_message)
|
|
138
|
+
|
|
139
|
+
if key not in seen:
|
|
140
|
+
seen.add(key)
|
|
141
|
+
unique_traces.append(trace)
|
|
142
|
+
|
|
143
|
+
removed = len(self.traces) - len(unique_traces)
|
|
144
|
+
if removed > 0:
|
|
145
|
+
console.print(f"[yellow]🔍 Dedupe:[/yellow] Removed {removed} exact duplicates")
|
|
146
|
+
|
|
147
|
+
return Dataset(traces=unique_traces)
|
|
148
|
+
|
|
149
|
+
def _dedupe_semantic(self, threshold: float, field: str) -> "Dataset":
|
|
150
|
+
"""Remove semantically similar traces using embeddings."""
|
|
151
|
+
try:
|
|
152
|
+
from sentence_transformers import SentenceTransformer
|
|
153
|
+
import numpy as np
|
|
154
|
+
except ImportError:
|
|
155
|
+
raise ImportError(
|
|
156
|
+
"sentence-transformers is required for semantic deduplication. "
|
|
157
|
+
"Install with: pip install sentence-transformers"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Get texts to embed
|
|
161
|
+
if field == "user":
|
|
162
|
+
texts = [t.user_message for t in self.traces]
|
|
163
|
+
elif field == "assistant":
|
|
164
|
+
texts = [t.assistant_message for t in self.traces]
|
|
165
|
+
else: # both
|
|
166
|
+
texts = [f"{t.user_message} {t.assistant_message}" for t in self.traces]
|
|
167
|
+
|
|
168
|
+
# Compute embeddings
|
|
169
|
+
console.print("[dim]Computing embeddings for deduplication...[/dim]")
|
|
170
|
+
model = SentenceTransformer("all-MiniLM-L6-v2")
|
|
171
|
+
embeddings = model.encode(texts, show_progress_bar=False)
|
|
172
|
+
embeddings = np.array(embeddings)
|
|
173
|
+
|
|
174
|
+
# Normalize for cosine similarity
|
|
175
|
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
176
|
+
embeddings = embeddings / norms
|
|
177
|
+
|
|
178
|
+
# Find duplicates using cosine similarity
|
|
179
|
+
unique_indices = []
|
|
180
|
+
duplicate_of = {} # Maps duplicate index to original index
|
|
181
|
+
|
|
182
|
+
for i in range(len(embeddings)):
|
|
183
|
+
is_duplicate = False
|
|
184
|
+
for j in unique_indices:
|
|
185
|
+
similarity = np.dot(embeddings[i], embeddings[j])
|
|
186
|
+
if similarity >= threshold:
|
|
187
|
+
is_duplicate = True
|
|
188
|
+
duplicate_of[i] = j
|
|
189
|
+
break
|
|
190
|
+
if not is_duplicate:
|
|
191
|
+
unique_indices.append(i)
|
|
192
|
+
|
|
193
|
+
unique_traces = [self.traces[i] for i in unique_indices]
|
|
194
|
+
removed = len(self.traces) - len(unique_traces)
|
|
195
|
+
|
|
196
|
+
if removed > 0:
|
|
197
|
+
console.print(f"[yellow]🔍 Dedupe:[/yellow] Removed {removed} semantic duplicates (threshold={threshold})")
|
|
198
|
+
|
|
199
|
+
return Dataset(traces=unique_traces)
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def passing_rate(self) -> float:
|
|
203
|
+
"""Get the percentage of traces that passed grading."""
|
|
204
|
+
if not self.traces:
|
|
205
|
+
return 0.0
|
|
206
|
+
|
|
207
|
+
passed = sum(1 for t in self.traces if t.grade and t.grade.passed)
|
|
208
|
+
return passed / len(self.traces)
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def categories(self) -> list[str]:
|
|
212
|
+
"""Get unique categories in the dataset."""
|
|
213
|
+
return list(set(t.scenario.category for t in self.traces if t.scenario.category))
|
|
214
|
+
|
|
215
|
+
def save(self, path: str | Path | None = None, format: str = "sft") -> "Dataset":
|
|
216
|
+
"""
|
|
217
|
+
Save dataset to a JSONL file.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
path: Output file path (auto-generated if not provided)
|
|
221
|
+
format: Output format - "sft", "qa", or "tool_call"
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Self for method chaining
|
|
225
|
+
|
|
226
|
+
Example:
|
|
227
|
+
>>> dataset.save() # Auto-names: synkro_sft_2024-01-15.jsonl
|
|
228
|
+
>>> dataset.save("training.jsonl")
|
|
229
|
+
>>> dataset.save("qa_data.jsonl", format="qa")
|
|
230
|
+
>>> dataset.save("tools.jsonl", format="tool_call")
|
|
231
|
+
"""
|
|
232
|
+
from synkro.formatters import SFTFormatter, QAFormatter, ToolCallFormatter
|
|
233
|
+
|
|
234
|
+
# Auto-generate filename if not provided
|
|
235
|
+
if path is None:
|
|
236
|
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H%M")
|
|
237
|
+
path = f"synkro_{format}_{timestamp}.jsonl"
|
|
238
|
+
|
|
239
|
+
path = Path(path)
|
|
240
|
+
|
|
241
|
+
if format == "sft":
|
|
242
|
+
SFTFormatter().save(self.traces, path)
|
|
243
|
+
elif format == "qa":
|
|
244
|
+
QAFormatter().save(self.traces, path)
|
|
245
|
+
elif format == "tool_call":
|
|
246
|
+
ToolCallFormatter().save(self.traces, path)
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(f"Unknown format: {format}. Use 'sft', 'qa', or 'tool_call'")
|
|
249
|
+
|
|
250
|
+
# Print confirmation
|
|
251
|
+
file_size = path.stat().st_size
|
|
252
|
+
size_str = f"{file_size / 1024:.1f} KB" if file_size < 1024 * 1024 else f"{file_size / 1024 / 1024:.1f} MB"
|
|
253
|
+
console.print(f"[green]📁 Saved:[/green] {path} ({size_str})")
|
|
254
|
+
|
|
255
|
+
return self
|
|
256
|
+
|
|
257
|
+
def to_jsonl(self, format: str = "sft") -> str:
|
|
258
|
+
"""
|
|
259
|
+
Convert dataset to JSONL string.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
format: Output format - "sft", "qa", or "tool_call"
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
JSONL formatted string
|
|
266
|
+
"""
|
|
267
|
+
from synkro.formatters import SFTFormatter, QAFormatter, ToolCallFormatter
|
|
268
|
+
|
|
269
|
+
if format == "sft":
|
|
270
|
+
return SFTFormatter().to_jsonl(self.traces)
|
|
271
|
+
elif format == "qa":
|
|
272
|
+
return QAFormatter().to_jsonl(self.traces)
|
|
273
|
+
elif format == "tool_call":
|
|
274
|
+
return ToolCallFormatter().to_jsonl(self.traces)
|
|
275
|
+
else:
|
|
276
|
+
raise ValueError(f"Unknown format: {format}. Use 'sft', 'qa', or 'tool_call'")
|
|
277
|
+
|
|
278
|
+
def to_hf_dataset(self, format: str = "sft"):
|
|
279
|
+
"""
|
|
280
|
+
Convert to HuggingFace Dataset.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
format: Output format - "sft", "qa", or "tool_call"
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
HuggingFace datasets.Dataset object
|
|
287
|
+
|
|
288
|
+
Example:
|
|
289
|
+
>>> hf_dataset = dataset.to_hf_dataset()
|
|
290
|
+
>>> hf_dataset.push_to_hub("my-org/policy-traces")
|
|
291
|
+
|
|
292
|
+
>>> # With train/test split
|
|
293
|
+
>>> hf_dataset = dataset.to_hf_dataset()
|
|
294
|
+
>>> split = hf_dataset.train_test_split(test_size=0.1)
|
|
295
|
+
>>> split.push_to_hub("my-org/policy-traces")
|
|
296
|
+
"""
|
|
297
|
+
try:
|
|
298
|
+
from datasets import Dataset as HFDataset
|
|
299
|
+
except ImportError:
|
|
300
|
+
raise ImportError(
|
|
301
|
+
"datasets is required for HuggingFace export. "
|
|
302
|
+
"Install with: pip install datasets"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
from synkro.formatters import SFTFormatter, QAFormatter, ToolCallFormatter
|
|
306
|
+
|
|
307
|
+
if format == "sft":
|
|
308
|
+
examples = SFTFormatter(include_metadata=True).format(self.traces)
|
|
309
|
+
elif format == "qa":
|
|
310
|
+
examples = QAFormatter().format(self.traces)
|
|
311
|
+
elif format == "tool_call":
|
|
312
|
+
examples = ToolCallFormatter().format(self.traces)
|
|
313
|
+
else:
|
|
314
|
+
raise ValueError(f"Unknown format: {format}")
|
|
315
|
+
|
|
316
|
+
return HFDataset.from_list(examples)
|
|
317
|
+
|
|
318
|
+
# Alias for backwards compatibility
|
|
319
|
+
to_huggingface = to_hf_dataset
|
|
320
|
+
|
|
321
|
+
def push_to_hub(
|
|
322
|
+
self,
|
|
323
|
+
repo_id: str,
|
|
324
|
+
format: str = "sft",
|
|
325
|
+
private: bool = False,
|
|
326
|
+
split: str = "train",
|
|
327
|
+
token: str | None = None,
|
|
328
|
+
) -> str:
|
|
329
|
+
"""
|
|
330
|
+
Push dataset directly to HuggingFace Hub.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
repo_id: HuggingFace repo ID (e.g., "my-org/policy-sft")
|
|
334
|
+
format: Output format - "sft", "qa", or "tool_call"
|
|
335
|
+
private: Whether the repo should be private
|
|
336
|
+
split: Dataset split name (default: "train")
|
|
337
|
+
token: HuggingFace token (uses cached token if not provided)
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
URL of the uploaded dataset
|
|
341
|
+
|
|
342
|
+
Example:
|
|
343
|
+
>>> dataset.push_to_hub("my-org/policy-sft")
|
|
344
|
+
>>> dataset.push_to_hub("my-org/policy-sft", private=True)
|
|
345
|
+
"""
|
|
346
|
+
hf_dataset = self.to_hf_dataset(format=format)
|
|
347
|
+
hf_dataset.push_to_hub(
|
|
348
|
+
repo_id,
|
|
349
|
+
private=private,
|
|
350
|
+
split=split,
|
|
351
|
+
token=token,
|
|
352
|
+
)
|
|
353
|
+
url = f"https://huggingface.co/datasets/{repo_id}"
|
|
354
|
+
console.print(f"[green]🤗 Pushed to Hub:[/green] {url}")
|
|
355
|
+
return url
|
|
356
|
+
|
|
357
|
+
def to_dict(self) -> dict:
|
|
358
|
+
"""
|
|
359
|
+
Convert dataset to a dictionary.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Dictionary with trace data
|
|
363
|
+
"""
|
|
364
|
+
return {
|
|
365
|
+
"traces": [t.model_dump() for t in self.traces],
|
|
366
|
+
"stats": {
|
|
367
|
+
"total": len(self.traces),
|
|
368
|
+
"passing_rate": self.passing_rate,
|
|
369
|
+
"categories": self.categories,
|
|
370
|
+
},
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
def summary(self) -> str:
|
|
374
|
+
"""
|
|
375
|
+
Get a summary of the dataset.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Human-readable summary string
|
|
379
|
+
"""
|
|
380
|
+
lines = [
|
|
381
|
+
f"Dataset Summary",
|
|
382
|
+
f"===============",
|
|
383
|
+
f"Total traces: {len(self.traces)}",
|
|
384
|
+
f"Passing rate: {self.passing_rate:.1%}",
|
|
385
|
+
f"Categories: {len(self.categories)}",
|
|
386
|
+
]
|
|
387
|
+
|
|
388
|
+
if self.categories:
|
|
389
|
+
lines.append("")
|
|
390
|
+
lines.append("By category:")
|
|
391
|
+
for cat in self.categories:
|
|
392
|
+
count = sum(1 for t in self.traces if t.scenario.category == cat)
|
|
393
|
+
lines.append(f" - {cat}: {count}")
|
|
394
|
+
|
|
395
|
+
return "\n".join(lines)
|
|
396
|
+
|
|
397
|
+
def __str__(self) -> str:
|
|
398
|
+
return f"Dataset(traces={len(self.traces)}, passing={self.passing_rate:.1%})"
|
|
399
|
+
|
|
400
|
+
def __repr__(self) -> str:
|
|
401
|
+
return self.__str__()
|
|
402
|
+
|