synkro 0.4.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. synkro/__init__.py +179 -0
  2. synkro/advanced.py +186 -0
  3. synkro/cli.py +128 -0
  4. synkro/core/__init__.py +7 -0
  5. synkro/core/checkpoint.py +250 -0
  6. synkro/core/dataset.py +402 -0
  7. synkro/core/policy.py +337 -0
  8. synkro/errors.py +178 -0
  9. synkro/examples/__init__.py +148 -0
  10. synkro/factory.py +276 -0
  11. synkro/formatters/__init__.py +12 -0
  12. synkro/formatters/qa.py +98 -0
  13. synkro/formatters/sft.py +90 -0
  14. synkro/formatters/tool_call.py +127 -0
  15. synkro/generation/__init__.py +9 -0
  16. synkro/generation/follow_ups.py +134 -0
  17. synkro/generation/generator.py +220 -0
  18. synkro/generation/golden_responses.py +244 -0
  19. synkro/generation/golden_scenarios.py +276 -0
  20. synkro/generation/golden_tool_responses.py +416 -0
  21. synkro/generation/logic_extractor.py +126 -0
  22. synkro/generation/multiturn_responses.py +177 -0
  23. synkro/generation/planner.py +131 -0
  24. synkro/generation/responses.py +189 -0
  25. synkro/generation/scenarios.py +90 -0
  26. synkro/generation/tool_responses.py +376 -0
  27. synkro/generation/tool_simulator.py +114 -0
  28. synkro/interactive/__init__.py +12 -0
  29. synkro/interactive/hitl_session.py +77 -0
  30. synkro/interactive/logic_map_editor.py +173 -0
  31. synkro/interactive/rich_ui.py +205 -0
  32. synkro/llm/__init__.py +7 -0
  33. synkro/llm/client.py +235 -0
  34. synkro/llm/rate_limits.py +95 -0
  35. synkro/models/__init__.py +43 -0
  36. synkro/models/anthropic.py +26 -0
  37. synkro/models/google.py +19 -0
  38. synkro/models/openai.py +31 -0
  39. synkro/modes/__init__.py +15 -0
  40. synkro/modes/config.py +66 -0
  41. synkro/modes/qa.py +18 -0
  42. synkro/modes/sft.py +18 -0
  43. synkro/modes/tool_call.py +18 -0
  44. synkro/parsers.py +442 -0
  45. synkro/pipeline/__init__.py +20 -0
  46. synkro/pipeline/phases.py +592 -0
  47. synkro/pipeline/runner.py +424 -0
  48. synkro/pipelines.py +123 -0
  49. synkro/prompts/__init__.py +57 -0
  50. synkro/prompts/base.py +167 -0
  51. synkro/prompts/golden_templates.py +474 -0
  52. synkro/prompts/interactive_templates.py +65 -0
  53. synkro/prompts/multiturn_templates.py +156 -0
  54. synkro/prompts/qa_templates.py +97 -0
  55. synkro/prompts/templates.py +281 -0
  56. synkro/prompts/tool_templates.py +201 -0
  57. synkro/quality/__init__.py +14 -0
  58. synkro/quality/golden_refiner.py +163 -0
  59. synkro/quality/grader.py +153 -0
  60. synkro/quality/multiturn_grader.py +150 -0
  61. synkro/quality/refiner.py +137 -0
  62. synkro/quality/tool_grader.py +126 -0
  63. synkro/quality/tool_refiner.py +128 -0
  64. synkro/quality/verifier.py +228 -0
  65. synkro/reporting.py +537 -0
  66. synkro/schemas.py +472 -0
  67. synkro/types/__init__.py +41 -0
  68. synkro/types/core.py +126 -0
  69. synkro/types/dataset_type.py +30 -0
  70. synkro/types/logic_map.py +345 -0
  71. synkro/types/tool.py +94 -0
  72. synkro-0.4.12.data/data/examples/__init__.py +148 -0
  73. synkro-0.4.12.dist-info/METADATA +258 -0
  74. synkro-0.4.12.dist-info/RECORD +77 -0
  75. synkro-0.4.12.dist-info/WHEEL +4 -0
  76. synkro-0.4.12.dist-info/entry_points.txt +2 -0
  77. synkro-0.4.12.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,250 @@
1
+ """Checkpoint manager for resumable generation."""
2
+
3
+ import json
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ from pydantic import BaseModel, Field
9
+ from rich.console import Console
10
+
11
+ if TYPE_CHECKING:
12
+ from synkro.types.core import Trace
13
+ from synkro.types.logic_map import LogicMap, GoldenScenario
14
+
15
+ console = Console()
16
+
17
+
18
+ class CheckpointData(BaseModel):
19
+ """Data stored in a checkpoint file."""
20
+
21
+ # Metadata
22
+ created_at: str = Field(default_factory=lambda: datetime.now().isoformat())
23
+ policy_hash: str = ""
24
+ target_traces: int = 0
25
+ dataset_type: str = ""
26
+
27
+ # Stage 1: Logic Map
28
+ logic_map_data: dict | None = None
29
+
30
+ # Stage 2: Scenarios
31
+ scenarios_data: list[dict] = Field(default_factory=list)
32
+ scenario_distribution: dict[str, int] = Field(default_factory=dict)
33
+
34
+ # Stage 3: Generated traces
35
+ traces_data: list[dict] = Field(default_factory=list)
36
+ completed_scenario_indices: list[int] = Field(default_factory=list)
37
+
38
+ # Stage 4: Verified traces
39
+ verified_traces_data: list[dict] = Field(default_factory=list)
40
+ verification_complete: bool = False
41
+
42
+
43
+ class CheckpointManager:
44
+ """
45
+ Manages checkpoints for resumable generation.
46
+
47
+ Saves progress after each stage and allows resuming from the last
48
+ successful checkpoint.
49
+
50
+ Examples:
51
+ >>> manager = CheckpointManager("./checkpoints")
52
+ >>> manager.save_logic_map(logic_map, policy_hash, 100, "sft")
53
+ >>> manager.save_scenarios(scenarios, distribution)
54
+ >>> manager.save_trace(trace, scenario_index)
55
+
56
+ >>> # Resume from checkpoint
57
+ >>> checkpoint = manager.load()
58
+ >>> if checkpoint:
59
+ ... logic_map = checkpoint.get_logic_map()
60
+ ... completed = checkpoint.completed_scenario_indices
61
+ """
62
+
63
+ def __init__(self, checkpoint_dir: str | Path):
64
+ """
65
+ Initialize the checkpoint manager.
66
+
67
+ Args:
68
+ checkpoint_dir: Directory to store checkpoint files
69
+ """
70
+ self.checkpoint_dir = Path(checkpoint_dir)
71
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
72
+ self.checkpoint_file = self.checkpoint_dir / "checkpoint.json"
73
+ self._data: CheckpointData | None = None
74
+
75
+ def _load_or_create(self) -> CheckpointData:
76
+ """Load existing checkpoint or create new one."""
77
+ if self._data is not None:
78
+ return self._data
79
+
80
+ if self.checkpoint_file.exists():
81
+ with open(self.checkpoint_file) as f:
82
+ data = json.load(f)
83
+ self._data = CheckpointData.model_validate(data)
84
+ else:
85
+ self._data = CheckpointData()
86
+
87
+ return self._data
88
+
89
+ def _save(self) -> None:
90
+ """Save checkpoint to disk."""
91
+ if self._data is None:
92
+ return
93
+
94
+ with open(self.checkpoint_file, "w") as f:
95
+ json.dump(self._data.model_dump(), f, indent=2)
96
+
97
+ def has_checkpoint(self) -> bool:
98
+ """Check if a checkpoint exists."""
99
+ return self.checkpoint_file.exists()
100
+
101
+ def load(self) -> CheckpointData | None:
102
+ """Load checkpoint if it exists."""
103
+ if not self.has_checkpoint():
104
+ return None
105
+ return self._load_or_create()
106
+
107
+ def matches_config(self, policy_hash: str, target_traces: int, dataset_type: str) -> bool:
108
+ """Check if checkpoint matches the current generation config."""
109
+ data = self._load_or_create()
110
+ return (
111
+ data.policy_hash == policy_hash
112
+ and data.target_traces == target_traces
113
+ and data.dataset_type == dataset_type
114
+ )
115
+
116
+ def save_logic_map(
117
+ self,
118
+ logic_map: "LogicMap",
119
+ policy_hash: str,
120
+ target_traces: int,
121
+ dataset_type: str,
122
+ ) -> None:
123
+ """Save Logic Map (Stage 1 complete)."""
124
+ data = self._load_or_create()
125
+ data.policy_hash = policy_hash
126
+ data.target_traces = target_traces
127
+ data.dataset_type = dataset_type
128
+ data.logic_map_data = logic_map.model_dump()
129
+ self._save()
130
+ console.print("[dim]💾 Checkpoint: Logic Map saved[/dim]")
131
+
132
+ def save_scenarios(
133
+ self,
134
+ scenarios: list["GoldenScenario"],
135
+ distribution: dict[str, int],
136
+ ) -> None:
137
+ """Save scenarios (Stage 2 complete)."""
138
+ data = self._load_or_create()
139
+ data.scenarios_data = [s.model_dump() for s in scenarios]
140
+ data.scenario_distribution = distribution
141
+ self._save()
142
+ console.print("[dim]💾 Checkpoint: Scenarios saved[/dim]")
143
+
144
+ def save_trace(self, trace: "Trace", scenario_index: int) -> None:
145
+ """Save a generated trace (incremental Stage 3)."""
146
+ data = self._load_or_create()
147
+ data.traces_data.append(trace.model_dump())
148
+ data.completed_scenario_indices.append(scenario_index)
149
+ self._save()
150
+
151
+ def save_traces_batch(self, traces: list["Trace"], indices: list[int]) -> None:
152
+ """Save a batch of traces at once."""
153
+ data = self._load_or_create()
154
+ for trace, idx in zip(traces, indices):
155
+ data.traces_data.append(trace.model_dump())
156
+ data.completed_scenario_indices.append(idx)
157
+ self._save()
158
+ console.print(f"[dim]💾 Checkpoint: {len(traces)} traces saved[/dim]")
159
+
160
+ def save_verified_traces(self, traces: list["Trace"]) -> None:
161
+ """Save verified traces (Stage 4 complete)."""
162
+ data = self._load_or_create()
163
+ data.verified_traces_data = [t.model_dump() for t in traces]
164
+ data.verification_complete = True
165
+ self._save()
166
+ console.print("[dim]💾 Checkpoint: Verification complete[/dim]")
167
+
168
+ def get_logic_map(self) -> "LogicMap | None":
169
+ """Retrieve Logic Map from checkpoint."""
170
+ from synkro.types.logic_map import LogicMap
171
+
172
+ data = self._load_or_create()
173
+ if data.logic_map_data:
174
+ return LogicMap.model_validate(data.logic_map_data)
175
+ return None
176
+
177
+ def get_scenarios(self) -> list["GoldenScenario"]:
178
+ """Retrieve scenarios from checkpoint."""
179
+ from synkro.types.logic_map import GoldenScenario
180
+
181
+ data = self._load_or_create()
182
+ return [GoldenScenario.model_validate(s) for s in data.scenarios_data]
183
+
184
+ def get_traces(self) -> list["Trace"]:
185
+ """Retrieve traces from checkpoint."""
186
+ from synkro.types.core import Trace
187
+
188
+ data = self._load_or_create()
189
+ return [Trace.model_validate(t) for t in data.traces_data]
190
+
191
+ def get_verified_traces(self) -> list["Trace"]:
192
+ """Retrieve verified traces from checkpoint."""
193
+ from synkro.types.core import Trace
194
+
195
+ data = self._load_or_create()
196
+ return [Trace.model_validate(t) for t in data.verified_traces_data]
197
+
198
+ def get_pending_scenario_indices(self, total: int) -> list[int]:
199
+ """Get indices of scenarios that haven't been processed yet."""
200
+ data = self._load_or_create()
201
+ completed = set(data.completed_scenario_indices)
202
+ return [i for i in range(total) if i not in completed]
203
+
204
+ def clear(self) -> None:
205
+ """Clear the checkpoint."""
206
+ if self.checkpoint_file.exists():
207
+ self.checkpoint_file.unlink()
208
+ self._data = None
209
+ console.print("[dim]🗑️ Checkpoint cleared[/dim]")
210
+
211
+ @property
212
+ def stage(self) -> str:
213
+ """Get the current stage based on checkpoint data."""
214
+ data = self._load_or_create()
215
+
216
+ if data.verification_complete:
217
+ return "complete"
218
+ if data.traces_data:
219
+ return "traces" # In progress or done
220
+ if data.scenarios_data:
221
+ return "scenarios"
222
+ if data.logic_map_data:
223
+ return "logic_map"
224
+ return "start"
225
+
226
+ def summary(self) -> str:
227
+ """Get a summary of the checkpoint status."""
228
+ data = self._load_or_create()
229
+
230
+ lines = [
231
+ f"Checkpoint Status",
232
+ f"=================",
233
+ f"Stage: {self.stage}",
234
+ f"Target traces: {data.target_traces}",
235
+ f"Logic Map: {'✓' if data.logic_map_data else '✗'}",
236
+ f"Scenarios: {len(data.scenarios_data)}",
237
+ f"Traces: {len(data.traces_data)}/{data.target_traces}",
238
+ f"Verified: {'✓' if data.verification_complete else '✗'}",
239
+ ]
240
+
241
+ return "\n".join(lines)
242
+
243
+
244
+ def hash_policy(policy_text: str) -> str:
245
+ """Create a hash of policy text for checkpoint matching."""
246
+ import hashlib
247
+ return hashlib.sha256(policy_text.encode()).hexdigest()[:16]
248
+
249
+
250
+ __all__ = ["CheckpointManager", "CheckpointData", "hash_policy"]
synkro/core/dataset.py ADDED
@@ -0,0 +1,402 @@
1
+ """Dataset class for managing generated traces."""
2
+
3
+ import json
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Iterator
7
+
8
+ from pydantic import BaseModel, Field
9
+ from rich.console import Console
10
+
11
+ from synkro.types.core import Trace
12
+
13
+ console = Console()
14
+
15
+
16
+ class Dataset(BaseModel):
17
+ """
18
+ A collection of generated training traces.
19
+
20
+ Provides methods for filtering, saving, and exporting traces
21
+ in various formats.
22
+
23
+ Examples:
24
+ >>> dataset = generator.generate(policy, traces=100)
25
+
26
+ >>> # Filter to only passing traces
27
+ >>> passing = dataset.filter(passed=True)
28
+
29
+ >>> # Save to JSONL
30
+ >>> dataset.save("training.jsonl")
31
+
32
+ >>> # Push to HuggingFace
33
+ >>> dataset.to_huggingface().push_to_hub("my-org/dataset")
34
+ """
35
+
36
+ traces: list[Trace] = Field(default_factory=list)
37
+
38
+ class Config:
39
+ arbitrary_types_allowed = True
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.traces)
43
+
44
+ def __iter__(self) -> Iterator[Trace]:
45
+ return iter(self.traces)
46
+
47
+ def __getitem__(self, idx: int) -> Trace:
48
+ return self.traces[idx]
49
+
50
+ def filter(
51
+ self,
52
+ passed: bool | None = None,
53
+ category: str | None = None,
54
+ min_length: int | None = None,
55
+ ) -> "Dataset":
56
+ """
57
+ Filter traces by criteria.
58
+
59
+ Args:
60
+ passed: Filter by grade pass/fail status
61
+ category: Filter by scenario category
62
+ min_length: Minimum response length in characters
63
+
64
+ Returns:
65
+ New Dataset with filtered traces
66
+ """
67
+ filtered = self.traces
68
+
69
+ if passed is not None:
70
+ filtered = [
71
+ t for t in filtered if t.grade and t.grade.passed == passed
72
+ ]
73
+
74
+ if category is not None:
75
+ filtered = [
76
+ t for t in filtered if t.scenario.category == category
77
+ ]
78
+
79
+ if min_length is not None:
80
+ filtered = [
81
+ t for t in filtered if len(t.assistant_message) >= min_length
82
+ ]
83
+
84
+ return Dataset(traces=filtered)
85
+
86
+ def dedupe(
87
+ self,
88
+ threshold: float = 0.85,
89
+ method: str = "semantic",
90
+ field: str = "user",
91
+ ) -> "Dataset":
92
+ """
93
+ Remove duplicate or near-duplicate traces.
94
+
95
+ Args:
96
+ threshold: Similarity threshold (0-1). Higher = stricter dedup.
97
+ Only used for semantic method. (default: 0.85)
98
+ method: Deduplication method:
99
+ - "exact": Remove exact text duplicates (fast)
100
+ - "semantic": Remove semantically similar traces (requires sentence-transformers)
101
+ field: Which field to dedupe on - "user", "assistant", or "both"
102
+
103
+ Returns:
104
+ New Dataset with duplicates removed
105
+
106
+ Examples:
107
+ >>> # Remove exact duplicates (fast)
108
+ >>> deduped = dataset.dedupe(method="exact")
109
+
110
+ >>> # Remove semantically similar (needs sentence-transformers)
111
+ >>> deduped = dataset.dedupe(threshold=0.9, method="semantic")
112
+
113
+ >>> # Dedupe based on assistant responses
114
+ >>> deduped = dataset.dedupe(field="assistant")
115
+ """
116
+ if not self.traces:
117
+ return Dataset(traces=[])
118
+
119
+ if method == "exact":
120
+ return self._dedupe_exact(field)
121
+ elif method == "semantic":
122
+ return self._dedupe_semantic(threshold, field)
123
+ else:
124
+ raise ValueError(f"Unknown method: {method}. Use 'exact' or 'semantic'")
125
+
126
+ def _dedupe_exact(self, field: str) -> "Dataset":
127
+ """Remove exact text duplicates."""
128
+ seen = set()
129
+ unique_traces = []
130
+
131
+ for trace in self.traces:
132
+ if field == "user":
133
+ key = trace.user_message
134
+ elif field == "assistant":
135
+ key = trace.assistant_message
136
+ else: # both
137
+ key = (trace.user_message, trace.assistant_message)
138
+
139
+ if key not in seen:
140
+ seen.add(key)
141
+ unique_traces.append(trace)
142
+
143
+ removed = len(self.traces) - len(unique_traces)
144
+ if removed > 0:
145
+ console.print(f"[yellow]🔍 Dedupe:[/yellow] Removed {removed} exact duplicates")
146
+
147
+ return Dataset(traces=unique_traces)
148
+
149
+ def _dedupe_semantic(self, threshold: float, field: str) -> "Dataset":
150
+ """Remove semantically similar traces using embeddings."""
151
+ try:
152
+ from sentence_transformers import SentenceTransformer
153
+ import numpy as np
154
+ except ImportError:
155
+ raise ImportError(
156
+ "sentence-transformers is required for semantic deduplication. "
157
+ "Install with: pip install sentence-transformers"
158
+ )
159
+
160
+ # Get texts to embed
161
+ if field == "user":
162
+ texts = [t.user_message for t in self.traces]
163
+ elif field == "assistant":
164
+ texts = [t.assistant_message for t in self.traces]
165
+ else: # both
166
+ texts = [f"{t.user_message} {t.assistant_message}" for t in self.traces]
167
+
168
+ # Compute embeddings
169
+ console.print("[dim]Computing embeddings for deduplication...[/dim]")
170
+ model = SentenceTransformer("all-MiniLM-L6-v2")
171
+ embeddings = model.encode(texts, show_progress_bar=False)
172
+ embeddings = np.array(embeddings)
173
+
174
+ # Normalize for cosine similarity
175
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
176
+ embeddings = embeddings / norms
177
+
178
+ # Find duplicates using cosine similarity
179
+ unique_indices = []
180
+ duplicate_of = {} # Maps duplicate index to original index
181
+
182
+ for i in range(len(embeddings)):
183
+ is_duplicate = False
184
+ for j in unique_indices:
185
+ similarity = np.dot(embeddings[i], embeddings[j])
186
+ if similarity >= threshold:
187
+ is_duplicate = True
188
+ duplicate_of[i] = j
189
+ break
190
+ if not is_duplicate:
191
+ unique_indices.append(i)
192
+
193
+ unique_traces = [self.traces[i] for i in unique_indices]
194
+ removed = len(self.traces) - len(unique_traces)
195
+
196
+ if removed > 0:
197
+ console.print(f"[yellow]🔍 Dedupe:[/yellow] Removed {removed} semantic duplicates (threshold={threshold})")
198
+
199
+ return Dataset(traces=unique_traces)
200
+
201
+ @property
202
+ def passing_rate(self) -> float:
203
+ """Get the percentage of traces that passed grading."""
204
+ if not self.traces:
205
+ return 0.0
206
+
207
+ passed = sum(1 for t in self.traces if t.grade and t.grade.passed)
208
+ return passed / len(self.traces)
209
+
210
+ @property
211
+ def categories(self) -> list[str]:
212
+ """Get unique categories in the dataset."""
213
+ return list(set(t.scenario.category for t in self.traces if t.scenario.category))
214
+
215
+ def save(self, path: str | Path | None = None, format: str = "sft") -> "Dataset":
216
+ """
217
+ Save dataset to a JSONL file.
218
+
219
+ Args:
220
+ path: Output file path (auto-generated if not provided)
221
+ format: Output format - "sft", "qa", or "tool_call"
222
+
223
+ Returns:
224
+ Self for method chaining
225
+
226
+ Example:
227
+ >>> dataset.save() # Auto-names: synkro_sft_2024-01-15.jsonl
228
+ >>> dataset.save("training.jsonl")
229
+ >>> dataset.save("qa_data.jsonl", format="qa")
230
+ >>> dataset.save("tools.jsonl", format="tool_call")
231
+ """
232
+ from synkro.formatters import SFTFormatter, QAFormatter, ToolCallFormatter
233
+
234
+ # Auto-generate filename if not provided
235
+ if path is None:
236
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H%M")
237
+ path = f"synkro_{format}_{timestamp}.jsonl"
238
+
239
+ path = Path(path)
240
+
241
+ if format == "sft":
242
+ SFTFormatter().save(self.traces, path)
243
+ elif format == "qa":
244
+ QAFormatter().save(self.traces, path)
245
+ elif format == "tool_call":
246
+ ToolCallFormatter().save(self.traces, path)
247
+ else:
248
+ raise ValueError(f"Unknown format: {format}. Use 'sft', 'qa', or 'tool_call'")
249
+
250
+ # Print confirmation
251
+ file_size = path.stat().st_size
252
+ size_str = f"{file_size / 1024:.1f} KB" if file_size < 1024 * 1024 else f"{file_size / 1024 / 1024:.1f} MB"
253
+ console.print(f"[green]📁 Saved:[/green] {path} ({size_str})")
254
+
255
+ return self
256
+
257
+ def to_jsonl(self, format: str = "sft") -> str:
258
+ """
259
+ Convert dataset to JSONL string.
260
+
261
+ Args:
262
+ format: Output format - "sft", "qa", or "tool_call"
263
+
264
+ Returns:
265
+ JSONL formatted string
266
+ """
267
+ from synkro.formatters import SFTFormatter, QAFormatter, ToolCallFormatter
268
+
269
+ if format == "sft":
270
+ return SFTFormatter().to_jsonl(self.traces)
271
+ elif format == "qa":
272
+ return QAFormatter().to_jsonl(self.traces)
273
+ elif format == "tool_call":
274
+ return ToolCallFormatter().to_jsonl(self.traces)
275
+ else:
276
+ raise ValueError(f"Unknown format: {format}. Use 'sft', 'qa', or 'tool_call'")
277
+
278
+ def to_hf_dataset(self, format: str = "sft"):
279
+ """
280
+ Convert to HuggingFace Dataset.
281
+
282
+ Args:
283
+ format: Output format - "sft", "qa", or "tool_call"
284
+
285
+ Returns:
286
+ HuggingFace datasets.Dataset object
287
+
288
+ Example:
289
+ >>> hf_dataset = dataset.to_hf_dataset()
290
+ >>> hf_dataset.push_to_hub("my-org/policy-traces")
291
+
292
+ >>> # With train/test split
293
+ >>> hf_dataset = dataset.to_hf_dataset()
294
+ >>> split = hf_dataset.train_test_split(test_size=0.1)
295
+ >>> split.push_to_hub("my-org/policy-traces")
296
+ """
297
+ try:
298
+ from datasets import Dataset as HFDataset
299
+ except ImportError:
300
+ raise ImportError(
301
+ "datasets is required for HuggingFace export. "
302
+ "Install with: pip install datasets"
303
+ )
304
+
305
+ from synkro.formatters import SFTFormatter, QAFormatter, ToolCallFormatter
306
+
307
+ if format == "sft":
308
+ examples = SFTFormatter(include_metadata=True).format(self.traces)
309
+ elif format == "qa":
310
+ examples = QAFormatter().format(self.traces)
311
+ elif format == "tool_call":
312
+ examples = ToolCallFormatter().format(self.traces)
313
+ else:
314
+ raise ValueError(f"Unknown format: {format}")
315
+
316
+ return HFDataset.from_list(examples)
317
+
318
+ # Alias for backwards compatibility
319
+ to_huggingface = to_hf_dataset
320
+
321
+ def push_to_hub(
322
+ self,
323
+ repo_id: str,
324
+ format: str = "sft",
325
+ private: bool = False,
326
+ split: str = "train",
327
+ token: str | None = None,
328
+ ) -> str:
329
+ """
330
+ Push dataset directly to HuggingFace Hub.
331
+
332
+ Args:
333
+ repo_id: HuggingFace repo ID (e.g., "my-org/policy-sft")
334
+ format: Output format - "sft", "qa", or "tool_call"
335
+ private: Whether the repo should be private
336
+ split: Dataset split name (default: "train")
337
+ token: HuggingFace token (uses cached token if not provided)
338
+
339
+ Returns:
340
+ URL of the uploaded dataset
341
+
342
+ Example:
343
+ >>> dataset.push_to_hub("my-org/policy-sft")
344
+ >>> dataset.push_to_hub("my-org/policy-sft", private=True)
345
+ """
346
+ hf_dataset = self.to_hf_dataset(format=format)
347
+ hf_dataset.push_to_hub(
348
+ repo_id,
349
+ private=private,
350
+ split=split,
351
+ token=token,
352
+ )
353
+ url = f"https://huggingface.co/datasets/{repo_id}"
354
+ console.print(f"[green]🤗 Pushed to Hub:[/green] {url}")
355
+ return url
356
+
357
+ def to_dict(self) -> dict:
358
+ """
359
+ Convert dataset to a dictionary.
360
+
361
+ Returns:
362
+ Dictionary with trace data
363
+ """
364
+ return {
365
+ "traces": [t.model_dump() for t in self.traces],
366
+ "stats": {
367
+ "total": len(self.traces),
368
+ "passing_rate": self.passing_rate,
369
+ "categories": self.categories,
370
+ },
371
+ }
372
+
373
+ def summary(self) -> str:
374
+ """
375
+ Get a summary of the dataset.
376
+
377
+ Returns:
378
+ Human-readable summary string
379
+ """
380
+ lines = [
381
+ f"Dataset Summary",
382
+ f"===============",
383
+ f"Total traces: {len(self.traces)}",
384
+ f"Passing rate: {self.passing_rate:.1%}",
385
+ f"Categories: {len(self.categories)}",
386
+ ]
387
+
388
+ if self.categories:
389
+ lines.append("")
390
+ lines.append("By category:")
391
+ for cat in self.categories:
392
+ count = sum(1 for t in self.traces if t.scenario.category == cat)
393
+ lines.append(f" - {cat}: {count}")
394
+
395
+ return "\n".join(lines)
396
+
397
+ def __str__(self) -> str:
398
+ return f"Dataset(traces={len(self.traces)}, passing={self.passing_rate:.1%})"
399
+
400
+ def __repr__(self) -> str:
401
+ return self.__str__()
402
+