synkro 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synkro might be problematic. Click here for more details.

Files changed (58) hide show
  1. synkro/__init__.py +165 -0
  2. synkro/cli.py +120 -0
  3. synkro/core/__init__.py +7 -0
  4. synkro/core/dataset.py +233 -0
  5. synkro/core/policy.py +337 -0
  6. synkro/errors.py +178 -0
  7. synkro/examples/__init__.py +148 -0
  8. synkro/factory.py +160 -0
  9. synkro/formatters/__init__.py +12 -0
  10. synkro/formatters/qa.py +85 -0
  11. synkro/formatters/sft.py +90 -0
  12. synkro/formatters/tool_call.py +127 -0
  13. synkro/generation/__init__.py +9 -0
  14. synkro/generation/generator.py +163 -0
  15. synkro/generation/planner.py +87 -0
  16. synkro/generation/responses.py +160 -0
  17. synkro/generation/scenarios.py +90 -0
  18. synkro/generation/tool_responses.py +370 -0
  19. synkro/generation/tool_simulator.py +114 -0
  20. synkro/llm/__init__.py +7 -0
  21. synkro/llm/client.py +235 -0
  22. synkro/llm/rate_limits.py +95 -0
  23. synkro/models/__init__.py +43 -0
  24. synkro/models/anthropic.py +26 -0
  25. synkro/models/google.py +19 -0
  26. synkro/models/openai.py +31 -0
  27. synkro/modes/__init__.py +15 -0
  28. synkro/modes/config.py +66 -0
  29. synkro/modes/qa.py +18 -0
  30. synkro/modes/sft.py +18 -0
  31. synkro/modes/tool_call.py +18 -0
  32. synkro/parsers.py +442 -0
  33. synkro/pipeline/__init__.py +20 -0
  34. synkro/pipeline/phases.py +237 -0
  35. synkro/pipeline/runner.py +198 -0
  36. synkro/pipelines.py +105 -0
  37. synkro/prompts/__init__.py +44 -0
  38. synkro/prompts/base.py +167 -0
  39. synkro/prompts/qa_templates.py +97 -0
  40. synkro/prompts/templates.py +281 -0
  41. synkro/prompts/tool_templates.py +201 -0
  42. synkro/quality/__init__.py +14 -0
  43. synkro/quality/grader.py +130 -0
  44. synkro/quality/refiner.py +137 -0
  45. synkro/quality/tool_grader.py +126 -0
  46. synkro/quality/tool_refiner.py +128 -0
  47. synkro/reporting.py +213 -0
  48. synkro/schemas.py +325 -0
  49. synkro/types/__init__.py +41 -0
  50. synkro/types/core.py +113 -0
  51. synkro/types/dataset_type.py +30 -0
  52. synkro/types/tool.py +94 -0
  53. synkro-0.4.5.data/data/examples/__init__.py +148 -0
  54. synkro-0.4.5.dist-info/METADATA +221 -0
  55. synkro-0.4.5.dist-info/RECORD +58 -0
  56. synkro-0.4.5.dist-info/WHEEL +4 -0
  57. synkro-0.4.5.dist-info/entry_points.txt +2 -0
  58. synkro-0.4.5.dist-info/licenses/LICENSE +21 -0
synkro/factory.py ADDED
@@ -0,0 +1,160 @@
1
+ """Component factory for dependency injection.
2
+
3
+ This module provides a factory for creating pipeline components,
4
+ enabling testability and flexible configuration.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ from synkro.llm.client import LLM
10
+ from synkro.modes.config import ModeConfig
11
+ from synkro.generation.planner import Planner
12
+ from synkro.generation.scenarios import ScenarioGenerator
13
+ from synkro.generation.responses import ResponseGenerator
14
+ from synkro.quality.grader import Grader
15
+ from synkro.quality.refiner import Refiner
16
+
17
+ if TYPE_CHECKING:
18
+ from synkro.types.tool import ToolDefinition
19
+ from synkro.generation.tool_simulator import ToolSimulator
20
+ from synkro.generation.tool_responses import ToolCallResponseGenerator
21
+ from synkro.quality.tool_grader import ToolCallGrader
22
+ from synkro.quality.tool_refiner import ToolCallRefiner
23
+
24
+
25
+ class ComponentFactory:
26
+ """
27
+ Factory for creating pipeline components with shared LLM clients.
28
+
29
+ This centralizes component creation and ensures consistent configuration
30
+ across the pipeline.
31
+
32
+ Examples:
33
+ >>> factory = ComponentFactory(gen_llm, grade_llm, mode_config)
34
+ >>> planner = factory.create_planner()
35
+ >>> grader = factory.create_grader()
36
+
37
+ >>> # With tools for tool_call dataset type
38
+ >>> factory = ComponentFactory(gen_llm, grade_llm, mode_config, tools=[...])
39
+ >>> simulator = factory.create_tool_simulator()
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ generation_llm: LLM,
45
+ grading_llm: LLM,
46
+ mode_config: ModeConfig,
47
+ tools: list["ToolDefinition"] | None = None,
48
+ ):
49
+ """
50
+ Initialize the factory.
51
+
52
+ Args:
53
+ generation_llm: LLM client for generation tasks (scenarios, responses, refinement)
54
+ grading_llm: LLM client for grading and planning (typically stronger model)
55
+ mode_config: Configuration for the dataset type (prompts, etc.)
56
+ tools: Optional list of tool definitions for tool_call dataset type
57
+ """
58
+ self.generation_llm = generation_llm
59
+ self.grading_llm = grading_llm
60
+ self.mode_config = mode_config
61
+ self.tools = tools or []
62
+
63
+ def create_planner(self) -> Planner:
64
+ """Create a Planner instance."""
65
+ return Planner(llm=self.grading_llm)
66
+
67
+ def create_scenario_generator(self) -> ScenarioGenerator:
68
+ """Create a ScenarioGenerator with mode-specific prompts."""
69
+ gen = ScenarioGenerator(llm=self.generation_llm)
70
+ gen.prompt_template = self.mode_config.scenario_prompt
71
+ return gen
72
+
73
+ def create_response_generator(self) -> ResponseGenerator:
74
+ """Create a ResponseGenerator with mode-specific prompts."""
75
+ gen = ResponseGenerator(llm=self.generation_llm)
76
+ gen.prompt_template = self.mode_config.response_prompt
77
+ return gen
78
+
79
+ def create_grader(self) -> "Grader | ToolCallGrader":
80
+ """
81
+ Create a Grader with mode-specific prompts.
82
+
83
+ Auto-selects ToolCallGrader when tools are configured.
84
+ """
85
+ if self.has_tools:
86
+ from synkro.quality.tool_grader import ToolCallGrader
87
+ return ToolCallGrader(llm=self.grading_llm, tools=self.tools)
88
+
89
+ grader = Grader(llm=self.grading_llm)
90
+ grader.prompt_template = self.mode_config.grade_prompt
91
+ return grader
92
+
93
+ def create_refiner(self) -> "Refiner | ToolCallRefiner":
94
+ """
95
+ Create a Refiner with mode-specific prompts.
96
+
97
+ Auto-selects ToolCallRefiner when tools are configured.
98
+ This ensures tool_calls format is preserved during refinement.
99
+ """
100
+ if self.has_tools:
101
+ from synkro.quality.tool_refiner import ToolCallRefiner
102
+ simulator = self.create_tool_simulator()
103
+ return ToolCallRefiner(
104
+ llm=self.generation_llm,
105
+ tools=self.tools,
106
+ simulator=simulator,
107
+ )
108
+
109
+ refiner = Refiner(llm=self.generation_llm)
110
+ refiner.prompt_template = self.mode_config.refine_prompt
111
+ return refiner
112
+
113
+ def create_tool_simulator(self) -> "ToolSimulator":
114
+ """Create a ToolSimulator instance for tool_call dataset type."""
115
+ from synkro.generation.tool_simulator import ToolSimulator
116
+
117
+ if not self.tools:
118
+ raise ValueError("Cannot create ToolSimulator without tools")
119
+
120
+ return ToolSimulator(tools=self.tools, llm=self.generation_llm)
121
+
122
+ def create_tool_call_response_generator(self) -> "ToolCallResponseGenerator":
123
+ """
124
+ Create a ToolCallResponseGenerator for generating proper tool call traces.
125
+
126
+ This generator uses JSON mode to produce structured tool calls in
127
+ OpenAI function calling format.
128
+ """
129
+ from synkro.generation.tool_responses import ToolCallResponseGenerator
130
+
131
+ if not self.tools:
132
+ raise ValueError("Cannot create ToolCallResponseGenerator without tools")
133
+
134
+ # Create simulator for generating tool responses
135
+ simulator = self.create_tool_simulator()
136
+
137
+ return ToolCallResponseGenerator(
138
+ tools=self.tools,
139
+ llm=self.generation_llm,
140
+ simulator=simulator,
141
+ )
142
+
143
+ def get_tools_description(self) -> str:
144
+ """Get formatted description of all available tools."""
145
+ if not self.tools:
146
+ return "No tools available"
147
+
148
+ descriptions = []
149
+ for tool in self.tools:
150
+ descriptions.append(tool.to_system_prompt())
151
+ return "\n\n".join(descriptions)
152
+
153
+ @property
154
+ def has_tools(self) -> bool:
155
+ """Check if tools are configured."""
156
+ return bool(self.tools)
157
+
158
+
159
+ __all__ = ["ComponentFactory"]
160
+
@@ -0,0 +1,12 @@
1
+ """Output formatters for different training data formats."""
2
+
3
+ from synkro.formatters.sft import SFTFormatter
4
+ from synkro.formatters.qa import QAFormatter
5
+ from synkro.formatters.tool_call import ToolCallFormatter
6
+
7
+ __all__ = [
8
+ "SFTFormatter",
9
+ "QAFormatter",
10
+ "ToolCallFormatter",
11
+ ]
12
+
@@ -0,0 +1,85 @@
1
+ """QA (Question-Answer) formatter."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from synkro.types.core import Trace
9
+
10
+
11
+ class QAFormatter:
12
+ """
13
+ Format traces for Question-Answer datasets.
14
+
15
+ QA format is simple question/answer pairs with optional context,
16
+ suitable for RAG training and knowledge extraction.
17
+
18
+ Example output:
19
+ {"question": "...", "answer": "...", "context": "..."}
20
+ {"question": "...", "answer": "...", "context": "..."}
21
+ """
22
+
23
+ def __init__(self, include_context: bool = True):
24
+ """
25
+ Initialize the QA formatter.
26
+
27
+ Args:
28
+ include_context: If True, include source context in output
29
+ """
30
+ self.include_context = include_context
31
+
32
+ def format(self, traces: list["Trace"]) -> list[dict]:
33
+ """
34
+ Format traces as QA pairs.
35
+
36
+ Args:
37
+ traces: List of traces to format
38
+
39
+ Returns:
40
+ List of QA examples (dicts with 'question', 'answer', optionally 'context')
41
+ """
42
+ examples = []
43
+
44
+ for trace in traces:
45
+ example = {
46
+ "question": trace.user_message,
47
+ "answer": trace.assistant_message,
48
+ }
49
+
50
+ if self.include_context:
51
+ # Use scenario context or the source section if available
52
+ example["context"] = trace.scenario.context or ""
53
+
54
+ examples.append(example)
55
+
56
+ return examples
57
+
58
+ def save(self, traces: list["Trace"], path: str | Path) -> None:
59
+ """
60
+ Save formatted traces to a JSONL file.
61
+
62
+ Args:
63
+ traces: List of traces to save
64
+ path: Output file path
65
+ """
66
+ path = Path(path)
67
+ examples = self.format(traces)
68
+
69
+ with open(path, "w") as f:
70
+ for example in examples:
71
+ f.write(json.dumps(example) + "\n")
72
+
73
+ def to_jsonl(self, traces: list["Trace"]) -> str:
74
+ """
75
+ Convert traces to JSONL string.
76
+
77
+ Args:
78
+ traces: List of traces to convert
79
+
80
+ Returns:
81
+ JSONL formatted string
82
+ """
83
+ examples = self.format(traces)
84
+ return "\n".join(json.dumps(e) for e in examples)
85
+
@@ -0,0 +1,90 @@
1
+ """SFT (Supervised Fine-Tuning) formatter."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from synkro.types.core import Trace
9
+
10
+
11
+ class SFTFormatter:
12
+ """
13
+ Format traces for Supervised Fine-Tuning (SFT).
14
+
15
+ SFT format is a simple array of conversations, each with messages.
16
+ This is the standard format used by OpenAI, HuggingFace, and most
17
+ fine-tuning platforms.
18
+
19
+ Example output:
20
+ {"messages": [{"role": "system", "content": "..."}, ...]}
21
+ {"messages": [{"role": "system", "content": "..."}, ...]}
22
+ """
23
+
24
+ def __init__(self, include_metadata: bool = False):
25
+ """
26
+ Initialize the SFT formatter.
27
+
28
+ Args:
29
+ include_metadata: If True, include trace metadata in output
30
+ """
31
+ self.include_metadata = include_metadata
32
+
33
+ def format(self, traces: list["Trace"]) -> list[dict]:
34
+ """
35
+ Format traces as SFT training examples.
36
+
37
+ Args:
38
+ traces: List of traces to format
39
+
40
+ Returns:
41
+ List of SFT examples (dicts with 'messages' key)
42
+ """
43
+ examples = []
44
+
45
+ for trace in traces:
46
+ example = {
47
+ "messages": [
48
+ {"role": m.role, "content": m.content} for m in trace.messages
49
+ ]
50
+ }
51
+
52
+ if self.include_metadata:
53
+ example["metadata"] = {
54
+ "scenario": trace.scenario.description,
55
+ "category": trace.scenario.category,
56
+ "grade": trace.grade.model_dump() if trace.grade else None,
57
+ }
58
+
59
+ examples.append(example)
60
+
61
+ return examples
62
+
63
+ def save(self, traces: list["Trace"], path: str | Path) -> None:
64
+ """
65
+ Save formatted traces to a JSONL file.
66
+
67
+ Args:
68
+ traces: List of traces to save
69
+ path: Output file path (should end in .jsonl)
70
+ """
71
+ path = Path(path)
72
+ examples = self.format(traces)
73
+
74
+ with open(path, "w") as f:
75
+ for example in examples:
76
+ f.write(json.dumps(example) + "\n")
77
+
78
+ def to_jsonl(self, traces: list["Trace"]) -> str:
79
+ """
80
+ Convert traces to JSONL string.
81
+
82
+ Args:
83
+ traces: List of traces to convert
84
+
85
+ Returns:
86
+ JSONL formatted string
87
+ """
88
+ examples = self.format(traces)
89
+ return "\n".join(json.dumps(e) for e in examples)
90
+
@@ -0,0 +1,127 @@
1
+ """Tool Call formatter for training data."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from synkro.types.core import Trace
9
+
10
+
11
+ class ToolCallFormatter:
12
+ """
13
+ Format traces with tool calls for fine-tuning.
14
+
15
+ Outputs OpenAI function calling format compatible with most fine-tuning platforms.
16
+
17
+ Example output:
18
+ {
19
+ "messages": [
20
+ {"role": "system", "content": "You have access to: web_search(query)"},
21
+ {"role": "user", "content": "What's the weather in NYC?"},
22
+ {"role": "assistant", "content": null, "tool_calls": [
23
+ {"id": "call_1", "type": "function", "function": {"name": "web_search", "arguments": "{\\"query\\": \\"weather NYC\\"}"}}
24
+ ]},
25
+ {"role": "tool", "tool_call_id": "call_1", "content": "NYC: 72°F, sunny"},
26
+ {"role": "assistant", "content": "The weather in NYC is currently 72°F and sunny."}
27
+ ]
28
+ }
29
+ """
30
+
31
+ def __init__(self, include_metadata: bool = False):
32
+ """
33
+ Initialize the ToolCallFormatter.
34
+
35
+ Args:
36
+ include_metadata: If True, include trace metadata in output
37
+ """
38
+ self.include_metadata = include_metadata
39
+
40
+ def format(self, traces: list["Trace"]) -> list[dict]:
41
+ """
42
+ Format traces as tool-calling training examples.
43
+
44
+ Args:
45
+ traces: List of traces to format
46
+
47
+ Returns:
48
+ List of formatted examples with tool calls
49
+ """
50
+ examples = []
51
+
52
+ for trace in traces:
53
+ messages = []
54
+
55
+ for m in trace.messages:
56
+ msg = {"role": m.role}
57
+
58
+ # Handle content (can be None for tool-calling assistant messages)
59
+ if m.content is not None:
60
+ msg["content"] = m.content
61
+ elif m.role == "assistant" and m.tool_calls:
62
+ msg["content"] = None
63
+ else:
64
+ msg["content"] = ""
65
+
66
+ # Handle tool calls
67
+ if m.tool_calls:
68
+ msg["tool_calls"] = [
69
+ {
70
+ "id": tc.id,
71
+ "type": tc.type,
72
+ "function": {
73
+ "name": tc.function.name,
74
+ "arguments": tc.function.arguments,
75
+ }
76
+ }
77
+ for tc in m.tool_calls
78
+ ]
79
+
80
+ # Handle tool response
81
+ if m.tool_call_id:
82
+ msg["tool_call_id"] = m.tool_call_id
83
+
84
+ messages.append(msg)
85
+
86
+ example = {"messages": messages}
87
+
88
+ if self.include_metadata:
89
+ example["metadata"] = {
90
+ "scenario": trace.scenario.description,
91
+ "category": trace.scenario.category,
92
+ "grade": trace.grade.model_dump() if trace.grade else None,
93
+ "has_tool_calls": trace.has_tool_calls,
94
+ }
95
+
96
+ examples.append(example)
97
+
98
+ return examples
99
+
100
+ def save(self, traces: list["Trace"], path: str | Path) -> None:
101
+ """
102
+ Save formatted traces to a JSONL file.
103
+
104
+ Args:
105
+ traces: List of traces to save
106
+ path: Output file path (should end in .jsonl)
107
+ """
108
+ path = Path(path)
109
+ examples = self.format(traces)
110
+
111
+ with open(path, "w") as f:
112
+ for example in examples:
113
+ f.write(json.dumps(example) + "\n")
114
+
115
+ def to_jsonl(self, traces: list["Trace"]) -> str:
116
+ """
117
+ Convert traces to JSONL string.
118
+
119
+ Args:
120
+ traces: List of traces to convert
121
+
122
+ Returns:
123
+ JSONL formatted string
124
+ """
125
+ examples = self.format(traces)
126
+ return "\n".join(json.dumps(e) for e in examples)
127
+
@@ -0,0 +1,9 @@
1
+ """Generation components for creating training data."""
2
+
3
+ from synkro.generation.generator import Generator
4
+ from synkro.generation.scenarios import ScenarioGenerator
5
+ from synkro.generation.responses import ResponseGenerator
6
+ from synkro.generation.planner import Planner
7
+
8
+ __all__ = ["Generator", "ScenarioGenerator", "ResponseGenerator", "Planner"]
9
+
@@ -0,0 +1,163 @@
1
+ """Main Generator class orchestrating the full trace generation pipeline."""
2
+
3
+ import asyncio
4
+ from enum import Enum
5
+ from typing import TYPE_CHECKING
6
+
7
+ from synkro.llm.client import LLM
8
+ from synkro.llm.rate_limits import auto_workers
9
+ from synkro.models import Model, OpenAI
10
+ from synkro.types.dataset_type import DatasetType
11
+ from synkro.core.policy import Policy
12
+ from synkro.core.dataset import Dataset
13
+ from synkro.modes.config import get_mode_config
14
+ from synkro.errors import handle_error
15
+ from synkro.factory import ComponentFactory
16
+ from synkro.reporting import ProgressReporter, RichReporter
17
+ from synkro.pipeline.runner import GenerationPipeline
18
+
19
+ if TYPE_CHECKING:
20
+ from synkro.types.tool import ToolDefinition
21
+
22
+
23
+ class Generator:
24
+ """
25
+ Main orchestrator for generating training datasets.
26
+
27
+ The Generator handles the full pipeline:
28
+ 1. Plan: Analyze policy and create category distribution
29
+ 2. Generate: Create scenarios and responses
30
+ 3. Grade: Evaluate response quality
31
+ 4. Refine: Fix failed responses
32
+ 5. Return: Dataset of passing traces
33
+
34
+ Examples:
35
+ >>> generator = Generator()
36
+ >>> dataset = generator.generate(policy, traces=20)
37
+
38
+ >>> # QA dataset
39
+ >>> generator = Generator(dataset_type=DatasetType.QA)
40
+ >>> dataset = generator.generate(policy)
41
+
42
+ >>> # Silent mode (no console output)
43
+ >>> from synkro.reporting import SilentReporter
44
+ >>> generator = Generator(reporter=SilentReporter())
45
+ >>> dataset = generator.generate(policy)
46
+
47
+ >>> # Tool call dataset
48
+ >>> from synkro import ToolDefinition
49
+ >>> tools = [ToolDefinition(name="search", description="...", parameters={})]
50
+ >>> generator = Generator(dataset_type=DatasetType.TOOL_CALL, tools=tools)
51
+ >>> dataset = generator.generate("Usage guidelines", traces=20)
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dataset_type: DatasetType = DatasetType.SFT,
57
+ generation_model: Model = OpenAI.GPT_4O_MINI,
58
+ grading_model: Model = OpenAI.GPT_4O,
59
+ max_iterations: int = 1,
60
+ skip_grading: bool = False,
61
+ reporter: ProgressReporter | None = None,
62
+ tools: list["ToolDefinition"] | None = None,
63
+ ):
64
+ """
65
+ Initialize the Generator.
66
+
67
+ Args:
68
+ dataset_type: Type of dataset to generate (QA, SFT, or TOOL_CALL)
69
+ generation_model: Model for scenarios/responses (default: gpt-4o-mini)
70
+ grading_model: Model for grading (default: gpt-4o, recommend stronger)
71
+ max_iterations: Max refinement iterations per trace (default: 1, no retries)
72
+ skip_grading: Skip grading phase for faster generation (default: False)
73
+ reporter: Progress reporter (default: RichReporter for console output)
74
+ tools: List of ToolDefinition for TOOL_CALL dataset type
75
+ """
76
+ self.dataset_type = dataset_type
77
+ self.mode_config = get_mode_config(dataset_type)
78
+ self.max_iterations = max_iterations
79
+ self.skip_grading = skip_grading
80
+ self.tools = tools
81
+
82
+ # Validate tools for TOOL_CALL dataset type
83
+ if dataset_type == DatasetType.TOOL_CALL and not tools:
84
+ raise ValueError("TOOL_CALL dataset type requires tools parameter")
85
+
86
+ # Store model info for reporting
87
+ self.generation_model = generation_model
88
+ self.grading_model = grading_model
89
+
90
+ # Create LLM clients
91
+ self.generation_llm = LLM(model=generation_model)
92
+ self.grading_llm = LLM(model=grading_model)
93
+
94
+ # Create factory for component creation
95
+ self.factory = ComponentFactory(
96
+ generation_llm=self.generation_llm,
97
+ grading_llm=self.grading_llm,
98
+ mode_config=self.mode_config,
99
+ tools=tools,
100
+ )
101
+
102
+ # Reporter for progress output
103
+ self.reporter = reporter or RichReporter()
104
+
105
+ # Auto-scale workers based on provider
106
+ model_str = generation_model.value if isinstance(generation_model, Enum) else str(generation_model)
107
+ self.workers = auto_workers(model_str)
108
+
109
+ # Create pipeline
110
+ self.pipeline = GenerationPipeline(
111
+ factory=self.factory,
112
+ reporter=self.reporter,
113
+ workers=self.workers,
114
+ max_iterations=max_iterations,
115
+ skip_grading=skip_grading,
116
+ )
117
+
118
+ @handle_error
119
+ def generate(self, policy: Policy | str, traces: int = 20) -> Dataset:
120
+ """
121
+ Generate a training dataset from a policy.
122
+
123
+ Args:
124
+ policy: Policy object or text string
125
+ traces: Target number of traces to generate (default: 20)
126
+
127
+ Returns:
128
+ Dataset with generated traces
129
+ """
130
+ if isinstance(policy, str):
131
+ policy = Policy(text=policy)
132
+
133
+ # Validate policy has enough content
134
+ policy.validate_length()
135
+
136
+ return asyncio.run(self._generate_async(policy, traces))
137
+
138
+ async def _generate_async(self, policy: Policy, traces: int) -> Dataset:
139
+ """Async implementation of generation pipeline."""
140
+ model_str = self.generation_model.value if isinstance(self.generation_model, Enum) else str(self.generation_model)
141
+
142
+ return await self.pipeline.run(
143
+ policy=policy,
144
+ traces=traces,
145
+ model=model_str,
146
+ dataset_type=self.dataset_type.value,
147
+ )
148
+
149
+ async def generate_async(self, policy: Policy | str, traces: int = 20) -> Dataset:
150
+ """
151
+ Async version of generate for use in async contexts.
152
+
153
+ Args:
154
+ policy: Policy object or text string
155
+ traces: Target number of traces to generate (default: 20)
156
+
157
+ Returns:
158
+ Dataset with generated traces
159
+ """
160
+ if isinstance(policy, str):
161
+ policy = Policy(text=policy)
162
+
163
+ return await self._generate_async(policy, traces)