DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/generator.py
ADDED
|
@@ -0,0 +1,1084 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import math
|
|
5
|
+
import random
|
|
6
|
+
|
|
7
|
+
from collections.abc import AsyncGenerator
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
9
|
+
|
|
10
|
+
from datasets import Dataset as HFDataset
|
|
11
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
12
|
+
|
|
13
|
+
from .builders import ConversationBuilderFactory
|
|
14
|
+
from .config import _normalize_reasoning_style
|
|
15
|
+
from .constants import (
|
|
16
|
+
API_ERROR_INDICATORS,
|
|
17
|
+
DEFAULT_MAX_RETRIES,
|
|
18
|
+
DEFAULT_REQUEST_TIMEOUT,
|
|
19
|
+
DEFAULT_SAMPLE_RETRIES,
|
|
20
|
+
ENGINE_DEFAULT_BATCH_SIZE,
|
|
21
|
+
ENGINE_DEFAULT_NUM_EXAMPLES,
|
|
22
|
+
ENGINE_DEFAULT_TEMPERATURE,
|
|
23
|
+
ERROR_CATEGORIES,
|
|
24
|
+
ERROR_DATASET_FILENAME,
|
|
25
|
+
INTERRUPTED_DATASET_FILENAME,
|
|
26
|
+
)
|
|
27
|
+
from .error_codes import classify_error
|
|
28
|
+
from .exceptions import DataSetGeneratorError
|
|
29
|
+
from .llm import LLMClient
|
|
30
|
+
from .metrics import trace
|
|
31
|
+
from .progress import ProgressReporter
|
|
32
|
+
from .prompts import (
|
|
33
|
+
AGENT_COT_MULTI_TURN_PROMPT,
|
|
34
|
+
AGENT_COT_TOOLS_PROMPT,
|
|
35
|
+
CONVERSATION_GENERATION_PROMPT,
|
|
36
|
+
FREETEXT_COT_PROMPT,
|
|
37
|
+
STRUCTURED_COT_PROMPT,
|
|
38
|
+
AgentPromptBuilder,
|
|
39
|
+
)
|
|
40
|
+
from .schemas import Conversation, ToolRegistry, get_conversation_schema
|
|
41
|
+
from .tools import BUILTIN_TOOL_REGISTRY
|
|
42
|
+
from .tools.loader import load_tools_from_dict, load_tools_from_endpoint
|
|
43
|
+
from .topic_model import TopicModel
|
|
44
|
+
from .utils import ensure_not_running_loop, is_validation_error
|
|
45
|
+
|
|
46
|
+
# Handle circular import for type hints
|
|
47
|
+
if TYPE_CHECKING:
|
|
48
|
+
from .topic_model import TopicModel
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class DataSetGeneratorConfig(BaseModel):
|
|
54
|
+
"""Configuration for the data engine."""
|
|
55
|
+
|
|
56
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
57
|
+
|
|
58
|
+
instructions: str = Field(default="", description="Additional instructions for data generation")
|
|
59
|
+
generation_system_prompt: str = Field(
|
|
60
|
+
..., min_length=1, description="System prompt for content generation"
|
|
61
|
+
)
|
|
62
|
+
dataset_system_prompt: str | None = Field(
|
|
63
|
+
None,
|
|
64
|
+
description="System prompt that goes into the final dataset (falls back to generation_system_prompt if not provided)",
|
|
65
|
+
)
|
|
66
|
+
provider: str = Field(
|
|
67
|
+
...,
|
|
68
|
+
min_length=1,
|
|
69
|
+
description="LLM provider (openai, anthropic, gemini, ollama)",
|
|
70
|
+
)
|
|
71
|
+
model_name: str = Field(..., min_length=1, description="Name of the model to use")
|
|
72
|
+
prompt_template: str | None = Field(default=None, description="Custom prompt template")
|
|
73
|
+
example_data: HFDataset | None = Field(
|
|
74
|
+
default=None, description="Example dataset for few-shot learning"
|
|
75
|
+
)
|
|
76
|
+
temperature: float = Field(
|
|
77
|
+
default=ENGINE_DEFAULT_TEMPERATURE,
|
|
78
|
+
ge=0.0,
|
|
79
|
+
le=2.0,
|
|
80
|
+
description="Temperature for model generation",
|
|
81
|
+
)
|
|
82
|
+
max_retries: int = Field(
|
|
83
|
+
default=DEFAULT_MAX_RETRIES,
|
|
84
|
+
ge=1,
|
|
85
|
+
le=10,
|
|
86
|
+
description="Maximum number of retries for failed requests (deprecated, use rate_limit config)",
|
|
87
|
+
)
|
|
88
|
+
max_tokens: int = Field(
|
|
89
|
+
default=2000,
|
|
90
|
+
ge=1,
|
|
91
|
+
description="Maximum tokens to generate in a single call to the llm",
|
|
92
|
+
)
|
|
93
|
+
default_batch_size: int = Field(
|
|
94
|
+
default=ENGINE_DEFAULT_BATCH_SIZE,
|
|
95
|
+
ge=1,
|
|
96
|
+
le=100,
|
|
97
|
+
description="Default batch size for generation",
|
|
98
|
+
)
|
|
99
|
+
default_num_examples: int = Field(
|
|
100
|
+
default=ENGINE_DEFAULT_NUM_EXAMPLES,
|
|
101
|
+
ge=0,
|
|
102
|
+
le=10,
|
|
103
|
+
description="Default number of examples to include",
|
|
104
|
+
)
|
|
105
|
+
request_timeout: int = Field(
|
|
106
|
+
default=DEFAULT_REQUEST_TIMEOUT,
|
|
107
|
+
ge=5,
|
|
108
|
+
le=300,
|
|
109
|
+
description="Request timeout in seconds",
|
|
110
|
+
)
|
|
111
|
+
sample_retries: int = Field(
|
|
112
|
+
default=DEFAULT_SAMPLE_RETRIES,
|
|
113
|
+
ge=0,
|
|
114
|
+
le=5,
|
|
115
|
+
description="Number of retries for individual sample validation failures",
|
|
116
|
+
)
|
|
117
|
+
sys_msg: bool = Field(default=True, description="Whether to include system message in dataset")
|
|
118
|
+
base_url: str | None = Field(
|
|
119
|
+
default=None,
|
|
120
|
+
description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Rate limiting configuration
|
|
124
|
+
rate_limit: dict[str, int | float | str | bool] | None = Field(
|
|
125
|
+
default=None,
|
|
126
|
+
description="Rate limiting and retry configuration (uses provider defaults if not specified)",
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Modular conversation configuration
|
|
130
|
+
conversation_type: Literal["basic", "chain_of_thought"] = Field(
|
|
131
|
+
default="basic",
|
|
132
|
+
description="Base conversation type: basic (simple chat), chain_of_thought (with reasoning traces)",
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = Field(
|
|
136
|
+
default=None,
|
|
137
|
+
description="Reasoning style for chain_of_thought type: freetext (natural language) or agent (structured step-by-step for tool-calling). Note: 'structured' and 'hybrid' are deprecated.",
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
@field_validator("reasoning_style", mode="before")
|
|
141
|
+
@classmethod
|
|
142
|
+
def normalize_reasoning_style(cls, v: str | None) -> str | None:
|
|
143
|
+
"""Normalize deprecated reasoning_style values."""
|
|
144
|
+
return _normalize_reasoning_style(v)
|
|
145
|
+
|
|
146
|
+
agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
|
|
147
|
+
default=None,
|
|
148
|
+
description="Agent mode: single_turn (one-shot tool use), multi_turn (extended agent conversations). Requires tools to be configured.",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Tool configuration (used when agent_mode is enabled or for tool_calling)
|
|
152
|
+
tool_components: dict[str, list[str]] = Field(
|
|
153
|
+
default_factory=dict,
|
|
154
|
+
description=(
|
|
155
|
+
"Map of component name to tool names. 'builtin' uses built-in tools "
|
|
156
|
+
"and routes to /vfs/execute. Other components load from tools_endpoint "
|
|
157
|
+
"and route to /{component}/execute."
|
|
158
|
+
),
|
|
159
|
+
)
|
|
160
|
+
custom_tools: list[dict] = Field(
|
|
161
|
+
default_factory=list, description="Custom tool definitions as dictionaries"
|
|
162
|
+
)
|
|
163
|
+
max_tools_per_query: int = Field(
|
|
164
|
+
default=3, ge=1, le=10, description="Maximum number of tools per query/turn"
|
|
165
|
+
)
|
|
166
|
+
max_tools_strict: bool = Field(
|
|
167
|
+
default=True,
|
|
168
|
+
description="If True, discard samples exceeding max_tools_per_query. If False, keep sample but truncate executions to limit.",
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Spin integration for real tool execution
|
|
172
|
+
spin_endpoint: str | None = Field(
|
|
173
|
+
default=None,
|
|
174
|
+
description="Spin service URL for real tool execution (e.g., 'http://localhost:3000')",
|
|
175
|
+
)
|
|
176
|
+
scenario_seed: dict | None = Field(
|
|
177
|
+
default=None,
|
|
178
|
+
description="Initial state to seed into Spin VFS before generation (e.g., {'files': {'main.py': '...'}})",
|
|
179
|
+
)
|
|
180
|
+
max_agent_steps: int = Field(
|
|
181
|
+
default=5,
|
|
182
|
+
ge=1,
|
|
183
|
+
le=10,
|
|
184
|
+
description="Maximum ReAct reasoning steps per sample before forcing conclusion",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# MCP/Mock tool integration - load tools from HTTP endpoint instead of code
|
|
188
|
+
tools_endpoint: str | None = Field(
|
|
189
|
+
default=None,
|
|
190
|
+
description="HTTP endpoint to load tool definitions from (e.g., 'http://localhost:3000/mock/list-tools'). Tools are loaded in MCP format.",
|
|
191
|
+
)
|
|
192
|
+
tool_execute_path: str | None = Field(
|
|
193
|
+
default=None,
|
|
194
|
+
description="Path for tool execution when using tools_endpoint (e.g., '/mock/execute'). Combined with spin_endpoint.",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Multi-turn configuration (used when agent_mode="multi_turn")
|
|
198
|
+
min_turns: int = Field(
|
|
199
|
+
default=2,
|
|
200
|
+
ge=1,
|
|
201
|
+
le=10,
|
|
202
|
+
description="Minimum number of conversation turns for multi-turn agent mode",
|
|
203
|
+
)
|
|
204
|
+
max_turns: int = Field(
|
|
205
|
+
default=4,
|
|
206
|
+
ge=1,
|
|
207
|
+
le=10,
|
|
208
|
+
description="Maximum number of conversation turns for multi-turn agent mode",
|
|
209
|
+
)
|
|
210
|
+
min_tool_calls: int = Field(
|
|
211
|
+
default=2,
|
|
212
|
+
ge=0,
|
|
213
|
+
le=20,
|
|
214
|
+
description="Minimum number of tool calls required before allowing early conversation conclusion",
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class DataSetGenerator:
|
|
219
|
+
def __init__(self, **kwargs):
|
|
220
|
+
"""Initialize DataSetGenerator with parameters."""
|
|
221
|
+
try:
|
|
222
|
+
self.config = DataSetGeneratorConfig.model_validate(kwargs)
|
|
223
|
+
except Exception as e: # noqa: TRY003
|
|
224
|
+
raise DataSetGeneratorError(f"Invalid generator configuration: {str(e)}") from e
|
|
225
|
+
|
|
226
|
+
# Initialize from config
|
|
227
|
+
self.provider = self.config.provider
|
|
228
|
+
self.model_name = self.config.model_name
|
|
229
|
+
self._samples: list[dict] = []
|
|
230
|
+
self.failed_samples = []
|
|
231
|
+
self.failure_analysis = {category: [] for category in ERROR_CATEGORIES}
|
|
232
|
+
|
|
233
|
+
# Initialize LLM client with rate limiting configuration
|
|
234
|
+
llm_kwargs: dict[str, Any] = {"rate_limit_config": self.config.rate_limit}
|
|
235
|
+
if self.config.base_url:
|
|
236
|
+
llm_kwargs["base_url"] = self.config.base_url
|
|
237
|
+
|
|
238
|
+
self.llm_client = LLMClient(
|
|
239
|
+
provider=self.provider,
|
|
240
|
+
model_name=self.model_name,
|
|
241
|
+
**llm_kwargs,
|
|
242
|
+
)
|
|
243
|
+
trace(
|
|
244
|
+
"generator_created",
|
|
245
|
+
{
|
|
246
|
+
"provider": self.provider,
|
|
247
|
+
"model_name": self.model_name,
|
|
248
|
+
"conversation_type": self.config.conversation_type,
|
|
249
|
+
},
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Store dataset system prompt for dataset inclusion (with fallback)
|
|
253
|
+
self.dataset_system_prompt = (
|
|
254
|
+
self.config.dataset_system_prompt or self.config.generation_system_prompt
|
|
255
|
+
)
|
|
256
|
+
# Store generation prompt for content generation
|
|
257
|
+
self.generation_prompt = self.config.generation_system_prompt
|
|
258
|
+
|
|
259
|
+
# Initialize tool registry when agent_mode is enabled or tools are configured
|
|
260
|
+
self.tool_registry = None
|
|
261
|
+
if (
|
|
262
|
+
self.config.agent_mode is not None
|
|
263
|
+
or self.config.tool_components
|
|
264
|
+
or self.config.custom_tools
|
|
265
|
+
):
|
|
266
|
+
self._initialize_tool_registry()
|
|
267
|
+
|
|
268
|
+
# Progress reporter for streaming feedback (set by external callers)
|
|
269
|
+
self.progress_reporter: ProgressReporter | None = None
|
|
270
|
+
|
|
271
|
+
def _initialize_tool_registry(self):
|
|
272
|
+
"""Initialize tool registry from component configuration.
|
|
273
|
+
|
|
274
|
+
Tools are loaded based on the tool_components mapping:
|
|
275
|
+
- 'builtin': Uses BUILTIN_TOOL_REGISTRY (read_file, write_file, etc.)
|
|
276
|
+
- Other components: Loads from tools_endpoint and sets component field
|
|
277
|
+
|
|
278
|
+
Each tool's component field determines routing (/{component}/execute).
|
|
279
|
+
"""
|
|
280
|
+
try:
|
|
281
|
+
all_tools = []
|
|
282
|
+
endpoint_registry = None
|
|
283
|
+
|
|
284
|
+
# Load tools from endpoint if needed for non-builtin components
|
|
285
|
+
non_builtin_components = {
|
|
286
|
+
k: v for k, v in self.config.tool_components.items() if k != "builtin"
|
|
287
|
+
}
|
|
288
|
+
if non_builtin_components:
|
|
289
|
+
if not self.config.tools_endpoint:
|
|
290
|
+
raise DataSetGeneratorError(
|
|
291
|
+
f"Non-builtin components {list(non_builtin_components.keys())} require "
|
|
292
|
+
"'tools_endpoint' to load tool definitions."
|
|
293
|
+
)
|
|
294
|
+
endpoint_registry = load_tools_from_endpoint(self.config.tools_endpoint)
|
|
295
|
+
logger.info(
|
|
296
|
+
"Loaded %d tools from endpoint: %s",
|
|
297
|
+
len(endpoint_registry.tools),
|
|
298
|
+
self.config.tools_endpoint,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Process each component
|
|
302
|
+
for component_name, tool_names in self.config.tool_components.items():
|
|
303
|
+
if component_name == "builtin":
|
|
304
|
+
# Filter from builtin registry
|
|
305
|
+
for tool in BUILTIN_TOOL_REGISTRY.tools:
|
|
306
|
+
if tool.name in tool_names:
|
|
307
|
+
all_tools.append(tool)
|
|
308
|
+
elif endpoint_registry:
|
|
309
|
+
# Filter from endpoint registry and set component
|
|
310
|
+
for tool in endpoint_registry.tools:
|
|
311
|
+
if tool.name in tool_names:
|
|
312
|
+
# Create copy with component set
|
|
313
|
+
tool_copy = tool.model_copy(update={"component": component_name})
|
|
314
|
+
all_tools.append(tool_copy)
|
|
315
|
+
|
|
316
|
+
# Add custom tools if provided
|
|
317
|
+
if self.config.custom_tools:
|
|
318
|
+
custom_registry = load_tools_from_dict(self.config.custom_tools)
|
|
319
|
+
all_tools.extend(custom_registry.tools)
|
|
320
|
+
|
|
321
|
+
self.tool_registry = ToolRegistry(tools=all_tools)
|
|
322
|
+
logger.info("Initialized tool registry with %d tools", len(all_tools))
|
|
323
|
+
|
|
324
|
+
except Exception as e: # noqa: BLE001
|
|
325
|
+
raise DataSetGeneratorError(f"Failed to initialize tool registry: {str(e)}") from e
|
|
326
|
+
|
|
327
|
+
def _validate_create_data_params(
|
|
328
|
+
self,
|
|
329
|
+
num_steps: int,
|
|
330
|
+
batch_size: int,
|
|
331
|
+
topic_model: "TopicModel | None" = None,
|
|
332
|
+
) -> None:
|
|
333
|
+
"""Validate parameters for data creation."""
|
|
334
|
+
if num_steps is None or num_steps <= 0:
|
|
335
|
+
raise DataSetGeneratorError("num_steps must be a positive integer")
|
|
336
|
+
|
|
337
|
+
if batch_size <= 0:
|
|
338
|
+
raise DataSetGeneratorError("batch_size must be a positive integer")
|
|
339
|
+
|
|
340
|
+
if topic_model and len(topic_model.get_all_paths()) == 0:
|
|
341
|
+
raise DataSetGeneratorError(
|
|
342
|
+
"Topic model has no paths. Ensure the topic tree was built successfully."
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
def _prepare_topic_paths(
|
|
346
|
+
self,
|
|
347
|
+
num_steps: int,
|
|
348
|
+
batch_size: int,
|
|
349
|
+
topic_model: "TopicModel | None" = None,
|
|
350
|
+
) -> tuple[list | None, int]:
|
|
351
|
+
"""Prepare and validate topic paths for data generation."""
|
|
352
|
+
topic_paths = None
|
|
353
|
+
if topic_model is not None:
|
|
354
|
+
topic_paths = topic_model.get_all_paths()
|
|
355
|
+
total_paths = len(topic_paths)
|
|
356
|
+
required_samples = num_steps * batch_size
|
|
357
|
+
|
|
358
|
+
if required_samples > total_paths:
|
|
359
|
+
# Provide detailed error with recommendations
|
|
360
|
+
max_steps_for_batch = total_paths // batch_size
|
|
361
|
+
max_batch_for_steps = total_paths // num_steps if num_steps > 0 else total_paths
|
|
362
|
+
|
|
363
|
+
error_msg = (
|
|
364
|
+
f"Insufficient topic paths for dataset generation:\n"
|
|
365
|
+
f" • Available paths: {total_paths}\n"
|
|
366
|
+
f" • Requested samples: {required_samples} ({num_steps} steps × {batch_size} batch size)\n"
|
|
367
|
+
f" • Shortfall: {required_samples - total_paths} samples\n\n"
|
|
368
|
+
f"Recommendations:\n"
|
|
369
|
+
f" • Reduce --num-steps to {max_steps_for_batch} (with current batch size {batch_size})\n"
|
|
370
|
+
f" • Reduce --batch-size to {max_batch_for_steps} (with current {num_steps} steps)\n"
|
|
371
|
+
f" • Increase topic tree/graph depth or degree to generate more paths"
|
|
372
|
+
)
|
|
373
|
+
raise DataSetGeneratorError(error_msg)
|
|
374
|
+
|
|
375
|
+
# Bandit: not a security function
|
|
376
|
+
topic_paths = random.sample(topic_paths, required_samples) # nosec
|
|
377
|
+
num_steps = math.ceil(len(topic_paths) / batch_size)
|
|
378
|
+
|
|
379
|
+
return topic_paths, num_steps
|
|
380
|
+
|
|
381
|
+
def _generate_batch_prompts(
|
|
382
|
+
self,
|
|
383
|
+
batch_size: int,
|
|
384
|
+
start_idx: int,
|
|
385
|
+
topic_paths: list,
|
|
386
|
+
data_creation_prompt: str,
|
|
387
|
+
num_example_demonstrations: int,
|
|
388
|
+
) -> tuple[list[str], list[list[str] | None]]:
|
|
389
|
+
"""Generate prompts for a batch and return the associated paths used.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
(prompts, used_paths) where used_paths aligns with prompts order.
|
|
393
|
+
"""
|
|
394
|
+
prompts: list[str] = []
|
|
395
|
+
used_paths: list[list[str] | None] = []
|
|
396
|
+
for i in range(batch_size):
|
|
397
|
+
path = None
|
|
398
|
+
if topic_paths:
|
|
399
|
+
current_idx = start_idx + i
|
|
400
|
+
if current_idx < len(topic_paths):
|
|
401
|
+
path = topic_paths[current_idx]
|
|
402
|
+
else:
|
|
403
|
+
break
|
|
404
|
+
|
|
405
|
+
sample_prompt = self.build_prompt(
|
|
406
|
+
data_creation_prompt=data_creation_prompt,
|
|
407
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
408
|
+
subtopics_list=path,
|
|
409
|
+
)
|
|
410
|
+
prompts.append(sample_prompt)
|
|
411
|
+
used_paths.append(path)
|
|
412
|
+
return prompts, used_paths
|
|
413
|
+
|
|
414
|
+
def _get_minimal_schema(self) -> type:
|
|
415
|
+
"""Get the conversation schema for the current config."""
|
|
416
|
+
return get_conversation_schema(self.config.conversation_type)
|
|
417
|
+
|
|
418
|
+
def _emit_retry(
|
|
419
|
+
self,
|
|
420
|
+
sample_idx: int,
|
|
421
|
+
attempt: int,
|
|
422
|
+
max_attempts: int,
|
|
423
|
+
error: Exception | str,
|
|
424
|
+
) -> None:
|
|
425
|
+
"""Emit a retry event if a progress reporter is attached.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
sample_idx: 0-based sample index (will be converted to 1-based)
|
|
429
|
+
attempt: 0-based attempt number (will be converted to 1-based)
|
|
430
|
+
max_attempts: Total number of attempts allowed
|
|
431
|
+
error: The error that triggered the retry
|
|
432
|
+
"""
|
|
433
|
+
if self.progress_reporter:
|
|
434
|
+
self.progress_reporter.emit_retry(
|
|
435
|
+
sample_idx=sample_idx + 1,
|
|
436
|
+
attempt=attempt + 1,
|
|
437
|
+
max_attempts=max_attempts,
|
|
438
|
+
error_summary=str(error)[:100],
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
async def _generate_structured_samples_async(
|
|
442
|
+
self,
|
|
443
|
+
prompts: list[str],
|
|
444
|
+
include_sys_msg: bool,
|
|
445
|
+
start_sample_idx: int = 0,
|
|
446
|
+
paths_for_batch: list[list[str] | None] | None = None,
|
|
447
|
+
) -> tuple[list, list]:
|
|
448
|
+
"""Generate structured samples using builder pattern.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
prompts: List of topic prompts to generate samples for
|
|
452
|
+
include_sys_msg: Whether to include system message in output
|
|
453
|
+
start_sample_idx: Starting sample index for progress reporting
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Tuple of (successful samples, failed responses)
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
samples = []
|
|
460
|
+
failed_responses = []
|
|
461
|
+
|
|
462
|
+
# Create config with overridden sys_msg if needed
|
|
463
|
+
config = self.config
|
|
464
|
+
if include_sys_msg != self.config.sys_msg:
|
|
465
|
+
# Create a copy of config with sys_msg overridden
|
|
466
|
+
config = self.config.model_copy(update={"sys_msg": include_sys_msg})
|
|
467
|
+
|
|
468
|
+
async def _generate_with_retry(
|
|
469
|
+
prompt: str, sample_idx: int, path_info: list[str] | None
|
|
470
|
+
) -> tuple[bool, Exception | Conversation]:
|
|
471
|
+
"""Generate a single sample with per-sample retry for validation errors.
|
|
472
|
+
|
|
473
|
+
Each parallel task gets its own builder instance to avoid Spin session
|
|
474
|
+
conflicts when running samples concurrently (batch_size > 1).
|
|
475
|
+
"""
|
|
476
|
+
# Create a fresh builder for this sample to avoid session conflicts
|
|
477
|
+
# when running in parallel batches
|
|
478
|
+
builder = ConversationBuilderFactory.create(
|
|
479
|
+
config=config,
|
|
480
|
+
llm=self.llm_client,
|
|
481
|
+
tool_registry=self.tool_registry,
|
|
482
|
+
progress_reporter=self.progress_reporter,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
last_error: Exception | None = None
|
|
486
|
+
error_feedback: str | None = None
|
|
487
|
+
max_attempts = self.config.sample_retries + 1
|
|
488
|
+
logger.debug(
|
|
489
|
+
"Sample %d: max_attempts=%d (sample_retries=%d)",
|
|
490
|
+
sample_idx + 1,
|
|
491
|
+
max_attempts,
|
|
492
|
+
self.config.sample_retries,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
for attempt in range(max_attempts):
|
|
496
|
+
# Notify progress reporter about which sample we're working on
|
|
497
|
+
if self.progress_reporter:
|
|
498
|
+
retry_suffix = f" (retry {attempt})" if attempt > 0 else ""
|
|
499
|
+
self.progress_reporter.emit_step_start(
|
|
500
|
+
f"Generating sample {sample_idx + 1}{retry_suffix}",
|
|
501
|
+
sample_idx=sample_idx + 1,
|
|
502
|
+
topic_path=path_info,
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
try:
|
|
506
|
+
# Builder handles all generation complexity
|
|
507
|
+
# Pass error feedback from previous attempt if this is a retry
|
|
508
|
+
conversation = await builder.generate(prompt, error_feedback)
|
|
509
|
+
except Exception as e: # noqa: BLE001
|
|
510
|
+
last_error = e
|
|
511
|
+
is_validation = is_validation_error(e)
|
|
512
|
+
can_retry = attempt < self.config.sample_retries
|
|
513
|
+
logger.debug(
|
|
514
|
+
"Sample %d error: is_validation=%s, can_retry=%s, attempt=%d/%d, error=%s",
|
|
515
|
+
sample_idx + 1,
|
|
516
|
+
is_validation,
|
|
517
|
+
can_retry,
|
|
518
|
+
attempt + 1,
|
|
519
|
+
self.config.sample_retries + 1,
|
|
520
|
+
str(e)[:200],
|
|
521
|
+
)
|
|
522
|
+
# Only retry validation errors, not API/network errors
|
|
523
|
+
if is_validation and can_retry:
|
|
524
|
+
# Extract error message for feedback to the model
|
|
525
|
+
error_feedback = str(e)
|
|
526
|
+
self._emit_retry(sample_idx, attempt, max_attempts, e)
|
|
527
|
+
continue
|
|
528
|
+
# Non-retryable error or exhausted retries
|
|
529
|
+
return False, last_error or Exception("Sample generation failed")
|
|
530
|
+
|
|
531
|
+
else:
|
|
532
|
+
# Validate tool execution count for agent modes
|
|
533
|
+
if self.config.agent_mode is not None:
|
|
534
|
+
if (
|
|
535
|
+
not conversation.tool_context
|
|
536
|
+
or not conversation.tool_context.executions
|
|
537
|
+
):
|
|
538
|
+
last_error = ValueError(
|
|
539
|
+
"Agent mode requires at least one tool execution"
|
|
540
|
+
)
|
|
541
|
+
if attempt < self.config.sample_retries:
|
|
542
|
+
self._emit_retry(sample_idx, attempt, max_attempts, last_error)
|
|
543
|
+
continue
|
|
544
|
+
return False, last_error or Exception("Sample generation failed")
|
|
545
|
+
|
|
546
|
+
num_executions = len(conversation.tool_context.executions)
|
|
547
|
+
if num_executions > self.config.max_tools_per_query:
|
|
548
|
+
if self.config.max_tools_strict:
|
|
549
|
+
# Strict mode: discard entire sample
|
|
550
|
+
last_error = ValueError(
|
|
551
|
+
f"Sample has {num_executions} tool executions, "
|
|
552
|
+
f"exceeds limit of {self.config.max_tools_per_query}"
|
|
553
|
+
)
|
|
554
|
+
if attempt < self.config.sample_retries:
|
|
555
|
+
self._emit_retry(sample_idx, attempt, max_attempts, last_error)
|
|
556
|
+
continue
|
|
557
|
+
return False, last_error or Exception("Sample generation failed")
|
|
558
|
+
# Non-strict mode: truncate to limit and keep sample
|
|
559
|
+
conversation.tool_context.executions = (
|
|
560
|
+
conversation.tool_context.executions[
|
|
561
|
+
: self.config.max_tools_per_query
|
|
562
|
+
]
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
return True, conversation
|
|
566
|
+
|
|
567
|
+
return False, last_error or Exception("Sample generation failed")
|
|
568
|
+
|
|
569
|
+
# Generate all samples concurrently with sample indices
|
|
570
|
+
tasks = []
|
|
571
|
+
for idx, prompt in enumerate(prompts):
|
|
572
|
+
path_info = None
|
|
573
|
+
if paths_for_batch and idx < len(paths_for_batch):
|
|
574
|
+
path_info = paths_for_batch[idx]
|
|
575
|
+
tasks.append(
|
|
576
|
+
asyncio.create_task(_generate_with_retry(prompt, start_sample_idx + idx, path_info))
|
|
577
|
+
)
|
|
578
|
+
results = await asyncio.gather(*tasks)
|
|
579
|
+
|
|
580
|
+
for idx, (success, payload) in enumerate(results):
|
|
581
|
+
if success:
|
|
582
|
+
samples.append(payload)
|
|
583
|
+
else:
|
|
584
|
+
error = payload
|
|
585
|
+
error_msg = f"Generation failed: {error}"
|
|
586
|
+
# Build failure record with raw content if available
|
|
587
|
+
failure_record = {"error": error_msg}
|
|
588
|
+
if isinstance(error, Exception):
|
|
589
|
+
context = getattr(error, "context", None)
|
|
590
|
+
if isinstance(context, dict) and "raw_content" in context:
|
|
591
|
+
failure_record["raw_content"] = context["raw_content"]
|
|
592
|
+
failed_responses.append(failure_record)
|
|
593
|
+
failure_type = self.analyze_failure(
|
|
594
|
+
str(error), error=error if isinstance(error, Exception) else None
|
|
595
|
+
)
|
|
596
|
+
self.failure_analysis[failure_type].append(error_msg)
|
|
597
|
+
|
|
598
|
+
# Classify and emit error to progress reporter
|
|
599
|
+
classified = classify_error(
|
|
600
|
+
error if isinstance(error, Exception) else str(error),
|
|
601
|
+
provider=self.provider,
|
|
602
|
+
context={"error_type": failure_type},
|
|
603
|
+
)
|
|
604
|
+
if self.progress_reporter:
|
|
605
|
+
self.progress_reporter.emit_error(
|
|
606
|
+
classified,
|
|
607
|
+
sample_idx=start_sample_idx + idx,
|
|
608
|
+
)
|
|
609
|
+
|
|
610
|
+
return samples, failed_responses
|
|
611
|
+
|
|
612
|
+
def analyze_failure(self, response_content: str, error: Exception | None = None) -> str:
|
|
613
|
+
"""Analyze the failure reason for a sample."""
|
|
614
|
+
if error:
|
|
615
|
+
error_str = str(error)
|
|
616
|
+
if "schema" in error_str.lower():
|
|
617
|
+
return "invalid_schema"
|
|
618
|
+
if any(api_err in error_str.lower() for api_err in API_ERROR_INDICATORS):
|
|
619
|
+
return "api_errors"
|
|
620
|
+
return "other_errors"
|
|
621
|
+
|
|
622
|
+
if not response_content or response_content.isspace():
|
|
623
|
+
return "empty_responses"
|
|
624
|
+
|
|
625
|
+
# Check if response seems to be attempting JSON but failing
|
|
626
|
+
if any(char in response_content for char in "{}[]"):
|
|
627
|
+
return "json_parsing_errors"
|
|
628
|
+
return "malformed_responses"
|
|
629
|
+
|
|
630
|
+
def summarize_failures(self) -> dict:
|
|
631
|
+
"""Generate a summary of all failures."""
|
|
632
|
+
summary = {
|
|
633
|
+
"total_failures": len(self.failed_samples),
|
|
634
|
+
"failure_types": {k: len(v) for k, v in self.failure_analysis.items()},
|
|
635
|
+
"failure_examples": {},
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
# Add example failures for each category
|
|
639
|
+
for _category, failures in self.failure_analysis.items():
|
|
640
|
+
if failures:
|
|
641
|
+
# Get up to 3 examples for each category
|
|
642
|
+
examples = failures[:3]
|
|
643
|
+
summary["failure_examples"].append(
|
|
644
|
+
(
|
|
645
|
+
str(ex)[:200] + "..."
|
|
646
|
+
if len(str(ex)) > 200 # noqa: PLR2004
|
|
647
|
+
else str(ex)
|
|
648
|
+
)
|
|
649
|
+
for ex in examples
|
|
650
|
+
)
|
|
651
|
+
return summary
|
|
652
|
+
|
|
653
|
+
def create_data(
|
|
654
|
+
self,
|
|
655
|
+
num_steps: int | None = None,
|
|
656
|
+
num_example_demonstrations: int = 3,
|
|
657
|
+
batch_size: int = 10,
|
|
658
|
+
topic_model: TopicModel | None = None,
|
|
659
|
+
model_name: str | None = None,
|
|
660
|
+
sys_msg: bool | None = None,
|
|
661
|
+
):
|
|
662
|
+
ensure_not_running_loop("DataSetGenerator.create_data")
|
|
663
|
+
return asyncio.run(
|
|
664
|
+
self.create_data_async(
|
|
665
|
+
num_steps=num_steps,
|
|
666
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
667
|
+
batch_size=batch_size,
|
|
668
|
+
topic_model=topic_model,
|
|
669
|
+
model_name=model_name,
|
|
670
|
+
sys_msg=sys_msg,
|
|
671
|
+
)
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
def create_data_with_events(
|
|
675
|
+
self,
|
|
676
|
+
num_steps: int | None = None,
|
|
677
|
+
num_example_demonstrations: int = 3,
|
|
678
|
+
batch_size: int = 10,
|
|
679
|
+
topic_model: TopicModel | None = None,
|
|
680
|
+
model_name: str | None = None,
|
|
681
|
+
sys_msg: bool | None = None,
|
|
682
|
+
):
|
|
683
|
+
ensure_not_running_loop("DataSetGenerator.create_data_with_events")
|
|
684
|
+
|
|
685
|
+
async def _async_generator() -> AsyncGenerator[dict | HFDataset, None]:
|
|
686
|
+
async for event in self.create_data_with_events_async(
|
|
687
|
+
num_steps=num_steps,
|
|
688
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
689
|
+
batch_size=batch_size,
|
|
690
|
+
topic_model=topic_model,
|
|
691
|
+
model_name=model_name,
|
|
692
|
+
sys_msg=sys_msg,
|
|
693
|
+
):
|
|
694
|
+
yield event
|
|
695
|
+
|
|
696
|
+
agen = _async_generator()
|
|
697
|
+
|
|
698
|
+
def _sync_generator():
|
|
699
|
+
loop = asyncio.new_event_loop()
|
|
700
|
+
try:
|
|
701
|
+
while True:
|
|
702
|
+
try:
|
|
703
|
+
event = loop.run_until_complete(agen.__anext__())
|
|
704
|
+
except StopAsyncIteration:
|
|
705
|
+
break
|
|
706
|
+
else:
|
|
707
|
+
yield event
|
|
708
|
+
finally:
|
|
709
|
+
loop.run_until_complete(agen.aclose())
|
|
710
|
+
loop.close()
|
|
711
|
+
|
|
712
|
+
return _sync_generator()
|
|
713
|
+
|
|
714
|
+
async def create_data_async(
|
|
715
|
+
self,
|
|
716
|
+
num_steps: int | None = None,
|
|
717
|
+
num_example_demonstrations: int = 3,
|
|
718
|
+
batch_size: int = 10,
|
|
719
|
+
topic_model: TopicModel | None = None,
|
|
720
|
+
model_name: str | None = None,
|
|
721
|
+
sys_msg: bool | None = None,
|
|
722
|
+
) -> HFDataset:
|
|
723
|
+
if num_steps is None:
|
|
724
|
+
num_steps = 1
|
|
725
|
+
|
|
726
|
+
self._validate_create_data_params(num_steps, batch_size, topic_model)
|
|
727
|
+
|
|
728
|
+
if model_name:
|
|
729
|
+
self.model_name = model_name.strip()
|
|
730
|
+
|
|
731
|
+
if not self.model_name:
|
|
732
|
+
raise DataSetGeneratorError("")
|
|
733
|
+
|
|
734
|
+
include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
|
|
735
|
+
|
|
736
|
+
topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
|
|
737
|
+
|
|
738
|
+
total_samples = num_steps * batch_size
|
|
739
|
+
data_creation_prompt = self._get_cot_prompt_template()
|
|
740
|
+
|
|
741
|
+
final_result: HFDataset | dict | None = None
|
|
742
|
+
async for event in self._run_generation_loop_async(
|
|
743
|
+
num_steps=num_steps,
|
|
744
|
+
batch_size=batch_size,
|
|
745
|
+
total_samples=total_samples,
|
|
746
|
+
topic_paths=topic_paths or [],
|
|
747
|
+
data_creation_prompt=data_creation_prompt,
|
|
748
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
749
|
+
include_sys_msg=include_sys_msg,
|
|
750
|
+
):
|
|
751
|
+
final_result = event
|
|
752
|
+
|
|
753
|
+
if isinstance(final_result, HFDataset):
|
|
754
|
+
trace(
|
|
755
|
+
"dataset_created",
|
|
756
|
+
{
|
|
757
|
+
"provider": self.provider,
|
|
758
|
+
"model_name": self.model_name,
|
|
759
|
+
"conversation_type": self.config.conversation_type,
|
|
760
|
+
"samples_count": len(final_result),
|
|
761
|
+
"failed_samples": len(self.failed_samples),
|
|
762
|
+
"success": len(final_result) > 0,
|
|
763
|
+
},
|
|
764
|
+
)
|
|
765
|
+
return final_result
|
|
766
|
+
|
|
767
|
+
msg = "Dataset generation failed"
|
|
768
|
+
raise DataSetGeneratorError(msg)
|
|
769
|
+
|
|
770
|
+
async def create_data_with_events_async(
|
|
771
|
+
self,
|
|
772
|
+
num_steps: int | None = None,
|
|
773
|
+
num_example_demonstrations: int = 3,
|
|
774
|
+
batch_size: int = 10,
|
|
775
|
+
topic_model: TopicModel | None = None,
|
|
776
|
+
model_name: str | None = None,
|
|
777
|
+
sys_msg: bool | None = None,
|
|
778
|
+
) -> AsyncGenerator[dict | HFDataset, None]:
|
|
779
|
+
if num_steps is None:
|
|
780
|
+
num_steps = 1
|
|
781
|
+
|
|
782
|
+
self._validate_create_data_params(num_steps, batch_size, topic_model)
|
|
783
|
+
|
|
784
|
+
if model_name:
|
|
785
|
+
self.model_name = model_name.strip()
|
|
786
|
+
|
|
787
|
+
if not self.model_name:
|
|
788
|
+
raise DataSetGeneratorError("")
|
|
789
|
+
|
|
790
|
+
include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
|
|
791
|
+
|
|
792
|
+
topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
|
|
793
|
+
|
|
794
|
+
total_samples = num_steps * batch_size
|
|
795
|
+
data_creation_prompt = self._get_cot_prompt_template()
|
|
796
|
+
|
|
797
|
+
root_topic_prompt = None
|
|
798
|
+
topic_model_type = None
|
|
799
|
+
if topic_model is not None:
|
|
800
|
+
root_topic_prompt = getattr(topic_model, "topic_prompt", None)
|
|
801
|
+
topic_model_type = type(topic_model).__name__.lower()
|
|
802
|
+
|
|
803
|
+
async for event in self._run_generation_loop_async(
|
|
804
|
+
num_steps=num_steps,
|
|
805
|
+
batch_size=batch_size,
|
|
806
|
+
total_samples=total_samples,
|
|
807
|
+
topic_paths=topic_paths or [],
|
|
808
|
+
data_creation_prompt=data_creation_prompt,
|
|
809
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
810
|
+
include_sys_msg=include_sys_msg,
|
|
811
|
+
root_topic_prompt=root_topic_prompt,
|
|
812
|
+
topic_model_type=topic_model_type,
|
|
813
|
+
):
|
|
814
|
+
yield event
|
|
815
|
+
|
|
816
|
+
async def _run_generation_loop_async( # noqa: PLR0912
|
|
817
|
+
self,
|
|
818
|
+
num_steps: int,
|
|
819
|
+
batch_size: int,
|
|
820
|
+
total_samples: int,
|
|
821
|
+
topic_paths: list,
|
|
822
|
+
data_creation_prompt: str,
|
|
823
|
+
num_example_demonstrations: int,
|
|
824
|
+
include_sys_msg: bool,
|
|
825
|
+
root_topic_prompt: str | None = None,
|
|
826
|
+
topic_model_type: str | None = None,
|
|
827
|
+
) -> AsyncGenerator[dict | HFDataset, None]:
|
|
828
|
+
"""Run the main generation loop yielding progress events."""
|
|
829
|
+
try:
|
|
830
|
+
yield {
|
|
831
|
+
"event": "generation_start",
|
|
832
|
+
"model_name": self.model_name,
|
|
833
|
+
"num_steps": num_steps,
|
|
834
|
+
"batch_size": batch_size,
|
|
835
|
+
"total_samples": total_samples,
|
|
836
|
+
"root_topic_prompt": root_topic_prompt,
|
|
837
|
+
"topic_model_type": topic_model_type,
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
for step in range(num_steps):
|
|
841
|
+
yield {
|
|
842
|
+
"event": "step_start",
|
|
843
|
+
"step": step + 1,
|
|
844
|
+
"total_steps": num_steps,
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
start_idx = step * batch_size
|
|
848
|
+
prompts, used_paths = self._generate_batch_prompts(
|
|
849
|
+
batch_size,
|
|
850
|
+
start_idx,
|
|
851
|
+
topic_paths,
|
|
852
|
+
data_creation_prompt,
|
|
853
|
+
num_example_demonstrations,
|
|
854
|
+
)
|
|
855
|
+
|
|
856
|
+
failed_before = len(self.failed_samples)
|
|
857
|
+
|
|
858
|
+
success, samples_generated = await self._process_batch_with_retries_async(
|
|
859
|
+
prompts, include_sys_msg, start_idx, used_paths
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
failed_in_batch = len(self.failed_samples) - failed_before
|
|
863
|
+
failure_reasons = []
|
|
864
|
+
if failed_in_batch > 0 and self.failed_samples:
|
|
865
|
+
recent_failures = self.failed_samples[-failed_in_batch:]
|
|
866
|
+
failure_reasons = recent_failures[:3]
|
|
867
|
+
|
|
868
|
+
yield {
|
|
869
|
+
"event": "step_complete",
|
|
870
|
+
"step": step + 1,
|
|
871
|
+
"samples_generated": samples_generated,
|
|
872
|
+
"success": success,
|
|
873
|
+
"failed_in_step": failed_in_batch,
|
|
874
|
+
"failure_reasons": failure_reasons,
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
if not success:
|
|
878
|
+
yield {
|
|
879
|
+
"event": "step_failed",
|
|
880
|
+
"step": step + 1,
|
|
881
|
+
"message": f"Failed to process batch {step + 1} after all retries",
|
|
882
|
+
}
|
|
883
|
+
|
|
884
|
+
yield {
|
|
885
|
+
"event": "generation_complete",
|
|
886
|
+
"total_samples": len(self._samples),
|
|
887
|
+
"failed_samples": len(self.failed_samples),
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
except KeyboardInterrupt:
|
|
891
|
+
yield {
|
|
892
|
+
"event": "generation_interrupted",
|
|
893
|
+
"message": "Generation interrupted by user.",
|
|
894
|
+
}
|
|
895
|
+
self.print_failure_summary()
|
|
896
|
+
self._save_samples_to_file(INTERRUPTED_DATASET_FILENAME)
|
|
897
|
+
|
|
898
|
+
except Exception as e: # noqa: BLE001
|
|
899
|
+
yield {"event": "generation_error", "error": str(e)}
|
|
900
|
+
self.print_failure_summary()
|
|
901
|
+
self._save_samples_to_file(ERROR_DATASET_FILENAME)
|
|
902
|
+
raise DataSetGeneratorError("failed") from e
|
|
903
|
+
|
|
904
|
+
yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
|
|
905
|
+
|
|
906
|
+
async def _process_batch_with_retries_async(
|
|
907
|
+
self,
|
|
908
|
+
prompts: list[str],
|
|
909
|
+
include_sys_msg: bool,
|
|
910
|
+
start_sample_idx: int = 0,
|
|
911
|
+
paths_for_batch: list[list[str] | None] | None = None,
|
|
912
|
+
) -> tuple[bool, int]:
|
|
913
|
+
"""Process a batch with retry logic."""
|
|
914
|
+
for attempt in range(self.config.max_retries):
|
|
915
|
+
try:
|
|
916
|
+
samples, failed_responses = await self._generate_structured_samples_async(
|
|
917
|
+
prompts, include_sys_msg, start_sample_idx, paths_for_batch
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
# Update failed samples
|
|
921
|
+
self.failed_samples.extend(failed_responses)
|
|
922
|
+
|
|
923
|
+
if samples:
|
|
924
|
+
# Convert Pydantic models to dicts and add to samples list
|
|
925
|
+
sample_dicts = [s.model_dump(exclude_none=True) for s in samples]
|
|
926
|
+
self._samples.extend(sample_dicts)
|
|
927
|
+
return True, len(samples) # Success - exit retry loop
|
|
928
|
+
|
|
929
|
+
except DataSetGeneratorError as e:
|
|
930
|
+
# Authentication and API errors are now wrapped in DataSetGeneratorError
|
|
931
|
+
error_str = str(e).lower()
|
|
932
|
+
if any(
|
|
933
|
+
keyword in error_str
|
|
934
|
+
for keyword in [
|
|
935
|
+
"api_key",
|
|
936
|
+
"api key",
|
|
937
|
+
"authentication",
|
|
938
|
+
"unauthorized",
|
|
939
|
+
]
|
|
940
|
+
):
|
|
941
|
+
error_msg = f"Authentication failed for provider '{self.provider}'. Please set the required API key environment variable."
|
|
942
|
+
self.failure_analysis["authentication_error"].append(error_msg)
|
|
943
|
+
else:
|
|
944
|
+
error_msg = f"API error for provider '{self.provider}': {str(e)[:100]}..."
|
|
945
|
+
self.failure_analysis["api_errors"].append(error_msg)
|
|
946
|
+
|
|
947
|
+
self.failed_samples.append(error_msg)
|
|
948
|
+
logger.exception("API error: %s", error_msg)
|
|
949
|
+
return False, 0 # Don't retry authentication/API errors
|
|
950
|
+
except Exception as e:
|
|
951
|
+
if attempt == self.config.max_retries - 1:
|
|
952
|
+
self.failed_samples.append(str(e))
|
|
953
|
+
failure_type = self.analyze_failure(str(e), error=e)
|
|
954
|
+
self.failure_analysis[failure_type].append(str(e))
|
|
955
|
+
return False, 0
|
|
956
|
+
else:
|
|
957
|
+
# If no exception and no samples, return False, 0
|
|
958
|
+
return False, 0
|
|
959
|
+
|
|
960
|
+
return False, 0
|
|
961
|
+
|
|
962
|
+
def print_failure_summary(self):
|
|
963
|
+
"""Print a detailed summary of all failures."""
|
|
964
|
+
summary = self.summarize_failures()
|
|
965
|
+
|
|
966
|
+
print("\n=== Failure Analysis Summary ===")
|
|
967
|
+
print(f"Total Failed Samples: {summary['total_failures']}")
|
|
968
|
+
print("\nFailure Types Breakdown:")
|
|
969
|
+
for failure_type, count in summary["failure_types"].items():
|
|
970
|
+
if count > 0:
|
|
971
|
+
print(f"\n{failure_type.replace('_', ' ').title()}: {count}")
|
|
972
|
+
if failure_type in summary["failure_examples"]:
|
|
973
|
+
print("Example failures:")
|
|
974
|
+
for i, example in enumerate(
|
|
975
|
+
summary["failure_examples"].get(failure_type, []), 1
|
|
976
|
+
):
|
|
977
|
+
print(f" {i}. {example}")
|
|
978
|
+
print("\n=============================")
|
|
979
|
+
|
|
980
|
+
def build_prompt(
|
|
981
|
+
self,
|
|
982
|
+
data_creation_prompt: str,
|
|
983
|
+
num_example_demonstrations: int,
|
|
984
|
+
subtopics_list: list[str] | None = None,
|
|
985
|
+
) -> str:
|
|
986
|
+
prompt = data_creation_prompt.replace("{{{{system_prompt}}}}", self.generation_prompt)
|
|
987
|
+
prompt = prompt.replace("{{{{instructions}}}}", self.build_custom_instructions_text())
|
|
988
|
+
prompt = prompt.replace(
|
|
989
|
+
"{{{{examples}}}}", self.build_examples_text(num_example_demonstrations)
|
|
990
|
+
)
|
|
991
|
+
return prompt.replace("{{{{subtopics}}}}", self.build_subtopics_text(subtopics_list))
|
|
992
|
+
|
|
993
|
+
def build_system_prompt(self):
|
|
994
|
+
"""Return the original system prompt for dataset inclusion."""
|
|
995
|
+
return self.dataset_system_prompt
|
|
996
|
+
|
|
997
|
+
def build_custom_instructions_text(self) -> str:
|
|
998
|
+
if self.config.instructions is None or self.config.instructions == "":
|
|
999
|
+
return ""
|
|
1000
|
+
return f"\nHere are additional instructions:\n<instructions>\n{self.config.instructions}\n</instructions>\n"
|
|
1001
|
+
|
|
1002
|
+
def build_examples_text(self, num_example_demonstrations: int):
|
|
1003
|
+
if self.config.example_data is None or num_example_demonstrations == 0:
|
|
1004
|
+
return ""
|
|
1005
|
+
# Bandit: not a security function
|
|
1006
|
+
# HF Dataset supports len() and indexing, convert to list for sampling
|
|
1007
|
+
example_list = list(self.config.example_data)
|
|
1008
|
+
examples = random.sample(example_list, min(num_example_demonstrations, len(example_list))) # nosec
|
|
1009
|
+
examples_text = "Here are output examples:\n\n"
|
|
1010
|
+
examples_text += "\n".join(f"Example {i + 1}: \n\n{ex}\n" for i, ex in enumerate(examples))
|
|
1011
|
+
return f"\nHere are output examples:\n<examples>\n{examples_text}\n</examples>\n"
|
|
1012
|
+
|
|
1013
|
+
def build_tools_text(self) -> str:
|
|
1014
|
+
"""Build formatted tools text for XLAM multi-turn prompts."""
|
|
1015
|
+
if not self.tool_registry:
|
|
1016
|
+
return "No tools available"
|
|
1017
|
+
|
|
1018
|
+
tools_text = []
|
|
1019
|
+
for tool in self.tool_registry.tools:
|
|
1020
|
+
params_text = []
|
|
1021
|
+
for param in tool.parameters:
|
|
1022
|
+
req = " (required)" if param.required else " (optional)"
|
|
1023
|
+
params_text.append(f" - {param.name} ({param.type}){req}: {param.description}")
|
|
1024
|
+
|
|
1025
|
+
tool_text = f"• {tool.name}: {tool.description}\n Parameters:\n" + "\n".join(
|
|
1026
|
+
params_text
|
|
1027
|
+
)
|
|
1028
|
+
tools_text.append(tool_text)
|
|
1029
|
+
|
|
1030
|
+
return "\n\n".join(tools_text)
|
|
1031
|
+
|
|
1032
|
+
def build_subtopics_text(self, subtopic_list: list[str] | None):
|
|
1033
|
+
if subtopic_list is None:
|
|
1034
|
+
return ""
|
|
1035
|
+
return f"\nLastly, the topic of the training data should be related to the following subtopics: {' -> '.join(subtopic_list)}"
|
|
1036
|
+
|
|
1037
|
+
def _get_cot_prompt_template(self) -> str: # noqa: PLR0911
|
|
1038
|
+
"""Get the appropriate prompt template based on modular configuration."""
|
|
1039
|
+
# Handle basic conversations
|
|
1040
|
+
if self.config.conversation_type == "basic":
|
|
1041
|
+
return CONVERSATION_GENERATION_PROMPT
|
|
1042
|
+
|
|
1043
|
+
# Handle chain of thought conversations
|
|
1044
|
+
if self.config.conversation_type == "chain_of_thought":
|
|
1045
|
+
# Agent mode with tools - use agent prompts
|
|
1046
|
+
if self.config.agent_mode == "single_turn" and self.tool_registry:
|
|
1047
|
+
# Use agent prompt for single-turn tool calling
|
|
1048
|
+
return (
|
|
1049
|
+
AgentPromptBuilder.build_tool_context_prompt(
|
|
1050
|
+
self.tool_registry,
|
|
1051
|
+
max_tools_per_query=self.config.max_tools_per_query,
|
|
1052
|
+
)
|
|
1053
|
+
or AGENT_COT_TOOLS_PROMPT
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
if self.config.agent_mode == "multi_turn" and self.tool_registry:
|
|
1057
|
+
# Standard multi-turn agent
|
|
1058
|
+
return (
|
|
1059
|
+
AgentPromptBuilder.build_multi_turn_context_prompt(
|
|
1060
|
+
self.tool_registry,
|
|
1061
|
+
max_tools_per_query=self.config.max_tools_per_query,
|
|
1062
|
+
)
|
|
1063
|
+
or AGENT_COT_MULTI_TURN_PROMPT
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
# Non-agent CoT - select based on reasoning style
|
|
1067
|
+
if self.config.reasoning_style == "freetext":
|
|
1068
|
+
return FREETEXT_COT_PROMPT
|
|
1069
|
+
if self.config.reasoning_style == "agent":
|
|
1070
|
+
return STRUCTURED_COT_PROMPT
|
|
1071
|
+
|
|
1072
|
+
# Fallback to basic conversation prompt
|
|
1073
|
+
return CONVERSATION_GENERATION_PROMPT
|
|
1074
|
+
|
|
1075
|
+
def _save_samples_to_file(self, save_path: str):
|
|
1076
|
+
"""Save the current samples to a JSONL file."""
|
|
1077
|
+
|
|
1078
|
+
with open(save_path, "w") as f:
|
|
1079
|
+
for sample in self._samples:
|
|
1080
|
+
f.write(json.dumps(sample, separators=(",", ":")) + "\n")
|
|
1081
|
+
|
|
1082
|
+
def save_dataset(self, save_path: str):
|
|
1083
|
+
"""Save the dataset to a JSONL file."""
|
|
1084
|
+
self._save_samples_to_file(save_path)
|