DeepFabric 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. deepfabric/__init__.py +70 -0
  2. deepfabric/__main__.py +6 -0
  3. deepfabric/auth.py +382 -0
  4. deepfabric/builders.py +303 -0
  5. deepfabric/builders_agent.py +1304 -0
  6. deepfabric/cli.py +1288 -0
  7. deepfabric/config.py +899 -0
  8. deepfabric/config_manager.py +251 -0
  9. deepfabric/constants.py +94 -0
  10. deepfabric/dataset_manager.py +534 -0
  11. deepfabric/error_codes.py +581 -0
  12. deepfabric/evaluation/__init__.py +47 -0
  13. deepfabric/evaluation/backends/__init__.py +32 -0
  14. deepfabric/evaluation/backends/ollama_backend.py +137 -0
  15. deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
  16. deepfabric/evaluation/backends/transformers_backend.py +326 -0
  17. deepfabric/evaluation/evaluator.py +845 -0
  18. deepfabric/evaluation/evaluators/__init__.py +13 -0
  19. deepfabric/evaluation/evaluators/base.py +104 -0
  20. deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
  21. deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
  22. deepfabric/evaluation/evaluators/registry.py +66 -0
  23. deepfabric/evaluation/inference.py +155 -0
  24. deepfabric/evaluation/metrics.py +397 -0
  25. deepfabric/evaluation/parser.py +304 -0
  26. deepfabric/evaluation/reporters/__init__.py +13 -0
  27. deepfabric/evaluation/reporters/base.py +56 -0
  28. deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
  29. deepfabric/evaluation/reporters/file_reporter.py +61 -0
  30. deepfabric/evaluation/reporters/multi_reporter.py +56 -0
  31. deepfabric/exceptions.py +67 -0
  32. deepfabric/factory.py +26 -0
  33. deepfabric/generator.py +1084 -0
  34. deepfabric/graph.py +545 -0
  35. deepfabric/hf_hub.py +214 -0
  36. deepfabric/kaggle_hub.py +219 -0
  37. deepfabric/llm/__init__.py +41 -0
  38. deepfabric/llm/api_key_verifier.py +534 -0
  39. deepfabric/llm/client.py +1206 -0
  40. deepfabric/llm/errors.py +105 -0
  41. deepfabric/llm/rate_limit_config.py +262 -0
  42. deepfabric/llm/rate_limit_detector.py +278 -0
  43. deepfabric/llm/retry_handler.py +270 -0
  44. deepfabric/metrics.py +212 -0
  45. deepfabric/progress.py +262 -0
  46. deepfabric/prompts.py +290 -0
  47. deepfabric/schemas.py +1000 -0
  48. deepfabric/spin/__init__.py +6 -0
  49. deepfabric/spin/client.py +263 -0
  50. deepfabric/spin/models.py +26 -0
  51. deepfabric/stream_simulator.py +90 -0
  52. deepfabric/tools/__init__.py +5 -0
  53. deepfabric/tools/defaults.py +85 -0
  54. deepfabric/tools/loader.py +87 -0
  55. deepfabric/tools/mcp_client.py +677 -0
  56. deepfabric/topic_manager.py +303 -0
  57. deepfabric/topic_model.py +20 -0
  58. deepfabric/training/__init__.py +35 -0
  59. deepfabric/training/api_key_prompt.py +302 -0
  60. deepfabric/training/callback.py +363 -0
  61. deepfabric/training/metrics_sender.py +301 -0
  62. deepfabric/tree.py +438 -0
  63. deepfabric/tui.py +1267 -0
  64. deepfabric/update_checker.py +166 -0
  65. deepfabric/utils.py +150 -0
  66. deepfabric/validation.py +143 -0
  67. deepfabric-4.4.0.dist-info/METADATA +702 -0
  68. deepfabric-4.4.0.dist-info/RECORD +71 -0
  69. deepfabric-4.4.0.dist-info/WHEEL +4 -0
  70. deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
  71. deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1084 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import math
5
+ import random
6
+
7
+ from collections.abc import AsyncGenerator
8
+ from typing import TYPE_CHECKING, Any, Literal
9
+
10
+ from datasets import Dataset as HFDataset
11
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
12
+
13
+ from .builders import ConversationBuilderFactory
14
+ from .config import _normalize_reasoning_style
15
+ from .constants import (
16
+ API_ERROR_INDICATORS,
17
+ DEFAULT_MAX_RETRIES,
18
+ DEFAULT_REQUEST_TIMEOUT,
19
+ DEFAULT_SAMPLE_RETRIES,
20
+ ENGINE_DEFAULT_BATCH_SIZE,
21
+ ENGINE_DEFAULT_NUM_EXAMPLES,
22
+ ENGINE_DEFAULT_TEMPERATURE,
23
+ ERROR_CATEGORIES,
24
+ ERROR_DATASET_FILENAME,
25
+ INTERRUPTED_DATASET_FILENAME,
26
+ )
27
+ from .error_codes import classify_error
28
+ from .exceptions import DataSetGeneratorError
29
+ from .llm import LLMClient
30
+ from .metrics import trace
31
+ from .progress import ProgressReporter
32
+ from .prompts import (
33
+ AGENT_COT_MULTI_TURN_PROMPT,
34
+ AGENT_COT_TOOLS_PROMPT,
35
+ CONVERSATION_GENERATION_PROMPT,
36
+ FREETEXT_COT_PROMPT,
37
+ STRUCTURED_COT_PROMPT,
38
+ AgentPromptBuilder,
39
+ )
40
+ from .schemas import Conversation, ToolRegistry, get_conversation_schema
41
+ from .tools import BUILTIN_TOOL_REGISTRY
42
+ from .tools.loader import load_tools_from_dict, load_tools_from_endpoint
43
+ from .topic_model import TopicModel
44
+ from .utils import ensure_not_running_loop, is_validation_error
45
+
46
+ # Handle circular import for type hints
47
+ if TYPE_CHECKING:
48
+ from .topic_model import TopicModel
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ class DataSetGeneratorConfig(BaseModel):
54
+ """Configuration for the data engine."""
55
+
56
+ model_config = ConfigDict(arbitrary_types_allowed=True)
57
+
58
+ instructions: str = Field(default="", description="Additional instructions for data generation")
59
+ generation_system_prompt: str = Field(
60
+ ..., min_length=1, description="System prompt for content generation"
61
+ )
62
+ dataset_system_prompt: str | None = Field(
63
+ None,
64
+ description="System prompt that goes into the final dataset (falls back to generation_system_prompt if not provided)",
65
+ )
66
+ provider: str = Field(
67
+ ...,
68
+ min_length=1,
69
+ description="LLM provider (openai, anthropic, gemini, ollama)",
70
+ )
71
+ model_name: str = Field(..., min_length=1, description="Name of the model to use")
72
+ prompt_template: str | None = Field(default=None, description="Custom prompt template")
73
+ example_data: HFDataset | None = Field(
74
+ default=None, description="Example dataset for few-shot learning"
75
+ )
76
+ temperature: float = Field(
77
+ default=ENGINE_DEFAULT_TEMPERATURE,
78
+ ge=0.0,
79
+ le=2.0,
80
+ description="Temperature for model generation",
81
+ )
82
+ max_retries: int = Field(
83
+ default=DEFAULT_MAX_RETRIES,
84
+ ge=1,
85
+ le=10,
86
+ description="Maximum number of retries for failed requests (deprecated, use rate_limit config)",
87
+ )
88
+ max_tokens: int = Field(
89
+ default=2000,
90
+ ge=1,
91
+ description="Maximum tokens to generate in a single call to the llm",
92
+ )
93
+ default_batch_size: int = Field(
94
+ default=ENGINE_DEFAULT_BATCH_SIZE,
95
+ ge=1,
96
+ le=100,
97
+ description="Default batch size for generation",
98
+ )
99
+ default_num_examples: int = Field(
100
+ default=ENGINE_DEFAULT_NUM_EXAMPLES,
101
+ ge=0,
102
+ le=10,
103
+ description="Default number of examples to include",
104
+ )
105
+ request_timeout: int = Field(
106
+ default=DEFAULT_REQUEST_TIMEOUT,
107
+ ge=5,
108
+ le=300,
109
+ description="Request timeout in seconds",
110
+ )
111
+ sample_retries: int = Field(
112
+ default=DEFAULT_SAMPLE_RETRIES,
113
+ ge=0,
114
+ le=5,
115
+ description="Number of retries for individual sample validation failures",
116
+ )
117
+ sys_msg: bool = Field(default=True, description="Whether to include system message in dataset")
118
+ base_url: str | None = Field(
119
+ default=None,
120
+ description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
121
+ )
122
+
123
+ # Rate limiting configuration
124
+ rate_limit: dict[str, int | float | str | bool] | None = Field(
125
+ default=None,
126
+ description="Rate limiting and retry configuration (uses provider defaults if not specified)",
127
+ )
128
+
129
+ # Modular conversation configuration
130
+ conversation_type: Literal["basic", "chain_of_thought"] = Field(
131
+ default="basic",
132
+ description="Base conversation type: basic (simple chat), chain_of_thought (with reasoning traces)",
133
+ )
134
+
135
+ reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = Field(
136
+ default=None,
137
+ description="Reasoning style for chain_of_thought type: freetext (natural language) or agent (structured step-by-step for tool-calling). Note: 'structured' and 'hybrid' are deprecated.",
138
+ )
139
+
140
+ @field_validator("reasoning_style", mode="before")
141
+ @classmethod
142
+ def normalize_reasoning_style(cls, v: str | None) -> str | None:
143
+ """Normalize deprecated reasoning_style values."""
144
+ return _normalize_reasoning_style(v)
145
+
146
+ agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
147
+ default=None,
148
+ description="Agent mode: single_turn (one-shot tool use), multi_turn (extended agent conversations). Requires tools to be configured.",
149
+ )
150
+
151
+ # Tool configuration (used when agent_mode is enabled or for tool_calling)
152
+ tool_components: dict[str, list[str]] = Field(
153
+ default_factory=dict,
154
+ description=(
155
+ "Map of component name to tool names. 'builtin' uses built-in tools "
156
+ "and routes to /vfs/execute. Other components load from tools_endpoint "
157
+ "and route to /{component}/execute."
158
+ ),
159
+ )
160
+ custom_tools: list[dict] = Field(
161
+ default_factory=list, description="Custom tool definitions as dictionaries"
162
+ )
163
+ max_tools_per_query: int = Field(
164
+ default=3, ge=1, le=10, description="Maximum number of tools per query/turn"
165
+ )
166
+ max_tools_strict: bool = Field(
167
+ default=True,
168
+ description="If True, discard samples exceeding max_tools_per_query. If False, keep sample but truncate executions to limit.",
169
+ )
170
+
171
+ # Spin integration for real tool execution
172
+ spin_endpoint: str | None = Field(
173
+ default=None,
174
+ description="Spin service URL for real tool execution (e.g., 'http://localhost:3000')",
175
+ )
176
+ scenario_seed: dict | None = Field(
177
+ default=None,
178
+ description="Initial state to seed into Spin VFS before generation (e.g., {'files': {'main.py': '...'}})",
179
+ )
180
+ max_agent_steps: int = Field(
181
+ default=5,
182
+ ge=1,
183
+ le=10,
184
+ description="Maximum ReAct reasoning steps per sample before forcing conclusion",
185
+ )
186
+
187
+ # MCP/Mock tool integration - load tools from HTTP endpoint instead of code
188
+ tools_endpoint: str | None = Field(
189
+ default=None,
190
+ description="HTTP endpoint to load tool definitions from (e.g., 'http://localhost:3000/mock/list-tools'). Tools are loaded in MCP format.",
191
+ )
192
+ tool_execute_path: str | None = Field(
193
+ default=None,
194
+ description="Path for tool execution when using tools_endpoint (e.g., '/mock/execute'). Combined with spin_endpoint.",
195
+ )
196
+
197
+ # Multi-turn configuration (used when agent_mode="multi_turn")
198
+ min_turns: int = Field(
199
+ default=2,
200
+ ge=1,
201
+ le=10,
202
+ description="Minimum number of conversation turns for multi-turn agent mode",
203
+ )
204
+ max_turns: int = Field(
205
+ default=4,
206
+ ge=1,
207
+ le=10,
208
+ description="Maximum number of conversation turns for multi-turn agent mode",
209
+ )
210
+ min_tool_calls: int = Field(
211
+ default=2,
212
+ ge=0,
213
+ le=20,
214
+ description="Minimum number of tool calls required before allowing early conversation conclusion",
215
+ )
216
+
217
+
218
+ class DataSetGenerator:
219
+ def __init__(self, **kwargs):
220
+ """Initialize DataSetGenerator with parameters."""
221
+ try:
222
+ self.config = DataSetGeneratorConfig.model_validate(kwargs)
223
+ except Exception as e: # noqa: TRY003
224
+ raise DataSetGeneratorError(f"Invalid generator configuration: {str(e)}") from e
225
+
226
+ # Initialize from config
227
+ self.provider = self.config.provider
228
+ self.model_name = self.config.model_name
229
+ self._samples: list[dict] = []
230
+ self.failed_samples = []
231
+ self.failure_analysis = {category: [] for category in ERROR_CATEGORIES}
232
+
233
+ # Initialize LLM client with rate limiting configuration
234
+ llm_kwargs: dict[str, Any] = {"rate_limit_config": self.config.rate_limit}
235
+ if self.config.base_url:
236
+ llm_kwargs["base_url"] = self.config.base_url
237
+
238
+ self.llm_client = LLMClient(
239
+ provider=self.provider,
240
+ model_name=self.model_name,
241
+ **llm_kwargs,
242
+ )
243
+ trace(
244
+ "generator_created",
245
+ {
246
+ "provider": self.provider,
247
+ "model_name": self.model_name,
248
+ "conversation_type": self.config.conversation_type,
249
+ },
250
+ )
251
+
252
+ # Store dataset system prompt for dataset inclusion (with fallback)
253
+ self.dataset_system_prompt = (
254
+ self.config.dataset_system_prompt or self.config.generation_system_prompt
255
+ )
256
+ # Store generation prompt for content generation
257
+ self.generation_prompt = self.config.generation_system_prompt
258
+
259
+ # Initialize tool registry when agent_mode is enabled or tools are configured
260
+ self.tool_registry = None
261
+ if (
262
+ self.config.agent_mode is not None
263
+ or self.config.tool_components
264
+ or self.config.custom_tools
265
+ ):
266
+ self._initialize_tool_registry()
267
+
268
+ # Progress reporter for streaming feedback (set by external callers)
269
+ self.progress_reporter: ProgressReporter | None = None
270
+
271
+ def _initialize_tool_registry(self):
272
+ """Initialize tool registry from component configuration.
273
+
274
+ Tools are loaded based on the tool_components mapping:
275
+ - 'builtin': Uses BUILTIN_TOOL_REGISTRY (read_file, write_file, etc.)
276
+ - Other components: Loads from tools_endpoint and sets component field
277
+
278
+ Each tool's component field determines routing (/{component}/execute).
279
+ """
280
+ try:
281
+ all_tools = []
282
+ endpoint_registry = None
283
+
284
+ # Load tools from endpoint if needed for non-builtin components
285
+ non_builtin_components = {
286
+ k: v for k, v in self.config.tool_components.items() if k != "builtin"
287
+ }
288
+ if non_builtin_components:
289
+ if not self.config.tools_endpoint:
290
+ raise DataSetGeneratorError(
291
+ f"Non-builtin components {list(non_builtin_components.keys())} require "
292
+ "'tools_endpoint' to load tool definitions."
293
+ )
294
+ endpoint_registry = load_tools_from_endpoint(self.config.tools_endpoint)
295
+ logger.info(
296
+ "Loaded %d tools from endpoint: %s",
297
+ len(endpoint_registry.tools),
298
+ self.config.tools_endpoint,
299
+ )
300
+
301
+ # Process each component
302
+ for component_name, tool_names in self.config.tool_components.items():
303
+ if component_name == "builtin":
304
+ # Filter from builtin registry
305
+ for tool in BUILTIN_TOOL_REGISTRY.tools:
306
+ if tool.name in tool_names:
307
+ all_tools.append(tool)
308
+ elif endpoint_registry:
309
+ # Filter from endpoint registry and set component
310
+ for tool in endpoint_registry.tools:
311
+ if tool.name in tool_names:
312
+ # Create copy with component set
313
+ tool_copy = tool.model_copy(update={"component": component_name})
314
+ all_tools.append(tool_copy)
315
+
316
+ # Add custom tools if provided
317
+ if self.config.custom_tools:
318
+ custom_registry = load_tools_from_dict(self.config.custom_tools)
319
+ all_tools.extend(custom_registry.tools)
320
+
321
+ self.tool_registry = ToolRegistry(tools=all_tools)
322
+ logger.info("Initialized tool registry with %d tools", len(all_tools))
323
+
324
+ except Exception as e: # noqa: BLE001
325
+ raise DataSetGeneratorError(f"Failed to initialize tool registry: {str(e)}") from e
326
+
327
+ def _validate_create_data_params(
328
+ self,
329
+ num_steps: int,
330
+ batch_size: int,
331
+ topic_model: "TopicModel | None" = None,
332
+ ) -> None:
333
+ """Validate parameters for data creation."""
334
+ if num_steps is None or num_steps <= 0:
335
+ raise DataSetGeneratorError("num_steps must be a positive integer")
336
+
337
+ if batch_size <= 0:
338
+ raise DataSetGeneratorError("batch_size must be a positive integer")
339
+
340
+ if topic_model and len(topic_model.get_all_paths()) == 0:
341
+ raise DataSetGeneratorError(
342
+ "Topic model has no paths. Ensure the topic tree was built successfully."
343
+ )
344
+
345
+ def _prepare_topic_paths(
346
+ self,
347
+ num_steps: int,
348
+ batch_size: int,
349
+ topic_model: "TopicModel | None" = None,
350
+ ) -> tuple[list | None, int]:
351
+ """Prepare and validate topic paths for data generation."""
352
+ topic_paths = None
353
+ if topic_model is not None:
354
+ topic_paths = topic_model.get_all_paths()
355
+ total_paths = len(topic_paths)
356
+ required_samples = num_steps * batch_size
357
+
358
+ if required_samples > total_paths:
359
+ # Provide detailed error with recommendations
360
+ max_steps_for_batch = total_paths // batch_size
361
+ max_batch_for_steps = total_paths // num_steps if num_steps > 0 else total_paths
362
+
363
+ error_msg = (
364
+ f"Insufficient topic paths for dataset generation:\n"
365
+ f" • Available paths: {total_paths}\n"
366
+ f" • Requested samples: {required_samples} ({num_steps} steps × {batch_size} batch size)\n"
367
+ f" • Shortfall: {required_samples - total_paths} samples\n\n"
368
+ f"Recommendations:\n"
369
+ f" • Reduce --num-steps to {max_steps_for_batch} (with current batch size {batch_size})\n"
370
+ f" • Reduce --batch-size to {max_batch_for_steps} (with current {num_steps} steps)\n"
371
+ f" • Increase topic tree/graph depth or degree to generate more paths"
372
+ )
373
+ raise DataSetGeneratorError(error_msg)
374
+
375
+ # Bandit: not a security function
376
+ topic_paths = random.sample(topic_paths, required_samples) # nosec
377
+ num_steps = math.ceil(len(topic_paths) / batch_size)
378
+
379
+ return topic_paths, num_steps
380
+
381
+ def _generate_batch_prompts(
382
+ self,
383
+ batch_size: int,
384
+ start_idx: int,
385
+ topic_paths: list,
386
+ data_creation_prompt: str,
387
+ num_example_demonstrations: int,
388
+ ) -> tuple[list[str], list[list[str] | None]]:
389
+ """Generate prompts for a batch and return the associated paths used.
390
+
391
+ Returns:
392
+ (prompts, used_paths) where used_paths aligns with prompts order.
393
+ """
394
+ prompts: list[str] = []
395
+ used_paths: list[list[str] | None] = []
396
+ for i in range(batch_size):
397
+ path = None
398
+ if topic_paths:
399
+ current_idx = start_idx + i
400
+ if current_idx < len(topic_paths):
401
+ path = topic_paths[current_idx]
402
+ else:
403
+ break
404
+
405
+ sample_prompt = self.build_prompt(
406
+ data_creation_prompt=data_creation_prompt,
407
+ num_example_demonstrations=num_example_demonstrations,
408
+ subtopics_list=path,
409
+ )
410
+ prompts.append(sample_prompt)
411
+ used_paths.append(path)
412
+ return prompts, used_paths
413
+
414
+ def _get_minimal_schema(self) -> type:
415
+ """Get the conversation schema for the current config."""
416
+ return get_conversation_schema(self.config.conversation_type)
417
+
418
+ def _emit_retry(
419
+ self,
420
+ sample_idx: int,
421
+ attempt: int,
422
+ max_attempts: int,
423
+ error: Exception | str,
424
+ ) -> None:
425
+ """Emit a retry event if a progress reporter is attached.
426
+
427
+ Args:
428
+ sample_idx: 0-based sample index (will be converted to 1-based)
429
+ attempt: 0-based attempt number (will be converted to 1-based)
430
+ max_attempts: Total number of attempts allowed
431
+ error: The error that triggered the retry
432
+ """
433
+ if self.progress_reporter:
434
+ self.progress_reporter.emit_retry(
435
+ sample_idx=sample_idx + 1,
436
+ attempt=attempt + 1,
437
+ max_attempts=max_attempts,
438
+ error_summary=str(error)[:100],
439
+ )
440
+
441
+ async def _generate_structured_samples_async(
442
+ self,
443
+ prompts: list[str],
444
+ include_sys_msg: bool,
445
+ start_sample_idx: int = 0,
446
+ paths_for_batch: list[list[str] | None] | None = None,
447
+ ) -> tuple[list, list]:
448
+ """Generate structured samples using builder pattern.
449
+
450
+ Args:
451
+ prompts: List of topic prompts to generate samples for
452
+ include_sys_msg: Whether to include system message in output
453
+ start_sample_idx: Starting sample index for progress reporting
454
+
455
+ Returns:
456
+ Tuple of (successful samples, failed responses)
457
+ """
458
+
459
+ samples = []
460
+ failed_responses = []
461
+
462
+ # Create config with overridden sys_msg if needed
463
+ config = self.config
464
+ if include_sys_msg != self.config.sys_msg:
465
+ # Create a copy of config with sys_msg overridden
466
+ config = self.config.model_copy(update={"sys_msg": include_sys_msg})
467
+
468
+ async def _generate_with_retry(
469
+ prompt: str, sample_idx: int, path_info: list[str] | None
470
+ ) -> tuple[bool, Exception | Conversation]:
471
+ """Generate a single sample with per-sample retry for validation errors.
472
+
473
+ Each parallel task gets its own builder instance to avoid Spin session
474
+ conflicts when running samples concurrently (batch_size > 1).
475
+ """
476
+ # Create a fresh builder for this sample to avoid session conflicts
477
+ # when running in parallel batches
478
+ builder = ConversationBuilderFactory.create(
479
+ config=config,
480
+ llm=self.llm_client,
481
+ tool_registry=self.tool_registry,
482
+ progress_reporter=self.progress_reporter,
483
+ )
484
+
485
+ last_error: Exception | None = None
486
+ error_feedback: str | None = None
487
+ max_attempts = self.config.sample_retries + 1
488
+ logger.debug(
489
+ "Sample %d: max_attempts=%d (sample_retries=%d)",
490
+ sample_idx + 1,
491
+ max_attempts,
492
+ self.config.sample_retries,
493
+ )
494
+
495
+ for attempt in range(max_attempts):
496
+ # Notify progress reporter about which sample we're working on
497
+ if self.progress_reporter:
498
+ retry_suffix = f" (retry {attempt})" if attempt > 0 else ""
499
+ self.progress_reporter.emit_step_start(
500
+ f"Generating sample {sample_idx + 1}{retry_suffix}",
501
+ sample_idx=sample_idx + 1,
502
+ topic_path=path_info,
503
+ )
504
+
505
+ try:
506
+ # Builder handles all generation complexity
507
+ # Pass error feedback from previous attempt if this is a retry
508
+ conversation = await builder.generate(prompt, error_feedback)
509
+ except Exception as e: # noqa: BLE001
510
+ last_error = e
511
+ is_validation = is_validation_error(e)
512
+ can_retry = attempt < self.config.sample_retries
513
+ logger.debug(
514
+ "Sample %d error: is_validation=%s, can_retry=%s, attempt=%d/%d, error=%s",
515
+ sample_idx + 1,
516
+ is_validation,
517
+ can_retry,
518
+ attempt + 1,
519
+ self.config.sample_retries + 1,
520
+ str(e)[:200],
521
+ )
522
+ # Only retry validation errors, not API/network errors
523
+ if is_validation and can_retry:
524
+ # Extract error message for feedback to the model
525
+ error_feedback = str(e)
526
+ self._emit_retry(sample_idx, attempt, max_attempts, e)
527
+ continue
528
+ # Non-retryable error or exhausted retries
529
+ return False, last_error or Exception("Sample generation failed")
530
+
531
+ else:
532
+ # Validate tool execution count for agent modes
533
+ if self.config.agent_mode is not None:
534
+ if (
535
+ not conversation.tool_context
536
+ or not conversation.tool_context.executions
537
+ ):
538
+ last_error = ValueError(
539
+ "Agent mode requires at least one tool execution"
540
+ )
541
+ if attempt < self.config.sample_retries:
542
+ self._emit_retry(sample_idx, attempt, max_attempts, last_error)
543
+ continue
544
+ return False, last_error or Exception("Sample generation failed")
545
+
546
+ num_executions = len(conversation.tool_context.executions)
547
+ if num_executions > self.config.max_tools_per_query:
548
+ if self.config.max_tools_strict:
549
+ # Strict mode: discard entire sample
550
+ last_error = ValueError(
551
+ f"Sample has {num_executions} tool executions, "
552
+ f"exceeds limit of {self.config.max_tools_per_query}"
553
+ )
554
+ if attempt < self.config.sample_retries:
555
+ self._emit_retry(sample_idx, attempt, max_attempts, last_error)
556
+ continue
557
+ return False, last_error or Exception("Sample generation failed")
558
+ # Non-strict mode: truncate to limit and keep sample
559
+ conversation.tool_context.executions = (
560
+ conversation.tool_context.executions[
561
+ : self.config.max_tools_per_query
562
+ ]
563
+ )
564
+
565
+ return True, conversation
566
+
567
+ return False, last_error or Exception("Sample generation failed")
568
+
569
+ # Generate all samples concurrently with sample indices
570
+ tasks = []
571
+ for idx, prompt in enumerate(prompts):
572
+ path_info = None
573
+ if paths_for_batch and idx < len(paths_for_batch):
574
+ path_info = paths_for_batch[idx]
575
+ tasks.append(
576
+ asyncio.create_task(_generate_with_retry(prompt, start_sample_idx + idx, path_info))
577
+ )
578
+ results = await asyncio.gather(*tasks)
579
+
580
+ for idx, (success, payload) in enumerate(results):
581
+ if success:
582
+ samples.append(payload)
583
+ else:
584
+ error = payload
585
+ error_msg = f"Generation failed: {error}"
586
+ # Build failure record with raw content if available
587
+ failure_record = {"error": error_msg}
588
+ if isinstance(error, Exception):
589
+ context = getattr(error, "context", None)
590
+ if isinstance(context, dict) and "raw_content" in context:
591
+ failure_record["raw_content"] = context["raw_content"]
592
+ failed_responses.append(failure_record)
593
+ failure_type = self.analyze_failure(
594
+ str(error), error=error if isinstance(error, Exception) else None
595
+ )
596
+ self.failure_analysis[failure_type].append(error_msg)
597
+
598
+ # Classify and emit error to progress reporter
599
+ classified = classify_error(
600
+ error if isinstance(error, Exception) else str(error),
601
+ provider=self.provider,
602
+ context={"error_type": failure_type},
603
+ )
604
+ if self.progress_reporter:
605
+ self.progress_reporter.emit_error(
606
+ classified,
607
+ sample_idx=start_sample_idx + idx,
608
+ )
609
+
610
+ return samples, failed_responses
611
+
612
+ def analyze_failure(self, response_content: str, error: Exception | None = None) -> str:
613
+ """Analyze the failure reason for a sample."""
614
+ if error:
615
+ error_str = str(error)
616
+ if "schema" in error_str.lower():
617
+ return "invalid_schema"
618
+ if any(api_err in error_str.lower() for api_err in API_ERROR_INDICATORS):
619
+ return "api_errors"
620
+ return "other_errors"
621
+
622
+ if not response_content or response_content.isspace():
623
+ return "empty_responses"
624
+
625
+ # Check if response seems to be attempting JSON but failing
626
+ if any(char in response_content for char in "{}[]"):
627
+ return "json_parsing_errors"
628
+ return "malformed_responses"
629
+
630
+ def summarize_failures(self) -> dict:
631
+ """Generate a summary of all failures."""
632
+ summary = {
633
+ "total_failures": len(self.failed_samples),
634
+ "failure_types": {k: len(v) for k, v in self.failure_analysis.items()},
635
+ "failure_examples": {},
636
+ }
637
+
638
+ # Add example failures for each category
639
+ for _category, failures in self.failure_analysis.items():
640
+ if failures:
641
+ # Get up to 3 examples for each category
642
+ examples = failures[:3]
643
+ summary["failure_examples"].append(
644
+ (
645
+ str(ex)[:200] + "..."
646
+ if len(str(ex)) > 200 # noqa: PLR2004
647
+ else str(ex)
648
+ )
649
+ for ex in examples
650
+ )
651
+ return summary
652
+
653
+ def create_data(
654
+ self,
655
+ num_steps: int | None = None,
656
+ num_example_demonstrations: int = 3,
657
+ batch_size: int = 10,
658
+ topic_model: TopicModel | None = None,
659
+ model_name: str | None = None,
660
+ sys_msg: bool | None = None,
661
+ ):
662
+ ensure_not_running_loop("DataSetGenerator.create_data")
663
+ return asyncio.run(
664
+ self.create_data_async(
665
+ num_steps=num_steps,
666
+ num_example_demonstrations=num_example_demonstrations,
667
+ batch_size=batch_size,
668
+ topic_model=topic_model,
669
+ model_name=model_name,
670
+ sys_msg=sys_msg,
671
+ )
672
+ )
673
+
674
+ def create_data_with_events(
675
+ self,
676
+ num_steps: int | None = None,
677
+ num_example_demonstrations: int = 3,
678
+ batch_size: int = 10,
679
+ topic_model: TopicModel | None = None,
680
+ model_name: str | None = None,
681
+ sys_msg: bool | None = None,
682
+ ):
683
+ ensure_not_running_loop("DataSetGenerator.create_data_with_events")
684
+
685
+ async def _async_generator() -> AsyncGenerator[dict | HFDataset, None]:
686
+ async for event in self.create_data_with_events_async(
687
+ num_steps=num_steps,
688
+ num_example_demonstrations=num_example_demonstrations,
689
+ batch_size=batch_size,
690
+ topic_model=topic_model,
691
+ model_name=model_name,
692
+ sys_msg=sys_msg,
693
+ ):
694
+ yield event
695
+
696
+ agen = _async_generator()
697
+
698
+ def _sync_generator():
699
+ loop = asyncio.new_event_loop()
700
+ try:
701
+ while True:
702
+ try:
703
+ event = loop.run_until_complete(agen.__anext__())
704
+ except StopAsyncIteration:
705
+ break
706
+ else:
707
+ yield event
708
+ finally:
709
+ loop.run_until_complete(agen.aclose())
710
+ loop.close()
711
+
712
+ return _sync_generator()
713
+
714
+ async def create_data_async(
715
+ self,
716
+ num_steps: int | None = None,
717
+ num_example_demonstrations: int = 3,
718
+ batch_size: int = 10,
719
+ topic_model: TopicModel | None = None,
720
+ model_name: str | None = None,
721
+ sys_msg: bool | None = None,
722
+ ) -> HFDataset:
723
+ if num_steps is None:
724
+ num_steps = 1
725
+
726
+ self._validate_create_data_params(num_steps, batch_size, topic_model)
727
+
728
+ if model_name:
729
+ self.model_name = model_name.strip()
730
+
731
+ if not self.model_name:
732
+ raise DataSetGeneratorError("")
733
+
734
+ include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
735
+
736
+ topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
737
+
738
+ total_samples = num_steps * batch_size
739
+ data_creation_prompt = self._get_cot_prompt_template()
740
+
741
+ final_result: HFDataset | dict | None = None
742
+ async for event in self._run_generation_loop_async(
743
+ num_steps=num_steps,
744
+ batch_size=batch_size,
745
+ total_samples=total_samples,
746
+ topic_paths=topic_paths or [],
747
+ data_creation_prompt=data_creation_prompt,
748
+ num_example_demonstrations=num_example_demonstrations,
749
+ include_sys_msg=include_sys_msg,
750
+ ):
751
+ final_result = event
752
+
753
+ if isinstance(final_result, HFDataset):
754
+ trace(
755
+ "dataset_created",
756
+ {
757
+ "provider": self.provider,
758
+ "model_name": self.model_name,
759
+ "conversation_type": self.config.conversation_type,
760
+ "samples_count": len(final_result),
761
+ "failed_samples": len(self.failed_samples),
762
+ "success": len(final_result) > 0,
763
+ },
764
+ )
765
+ return final_result
766
+
767
+ msg = "Dataset generation failed"
768
+ raise DataSetGeneratorError(msg)
769
+
770
+ async def create_data_with_events_async(
771
+ self,
772
+ num_steps: int | None = None,
773
+ num_example_demonstrations: int = 3,
774
+ batch_size: int = 10,
775
+ topic_model: TopicModel | None = None,
776
+ model_name: str | None = None,
777
+ sys_msg: bool | None = None,
778
+ ) -> AsyncGenerator[dict | HFDataset, None]:
779
+ if num_steps is None:
780
+ num_steps = 1
781
+
782
+ self._validate_create_data_params(num_steps, batch_size, topic_model)
783
+
784
+ if model_name:
785
+ self.model_name = model_name.strip()
786
+
787
+ if not self.model_name:
788
+ raise DataSetGeneratorError("")
789
+
790
+ include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
791
+
792
+ topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
793
+
794
+ total_samples = num_steps * batch_size
795
+ data_creation_prompt = self._get_cot_prompt_template()
796
+
797
+ root_topic_prompt = None
798
+ topic_model_type = None
799
+ if topic_model is not None:
800
+ root_topic_prompt = getattr(topic_model, "topic_prompt", None)
801
+ topic_model_type = type(topic_model).__name__.lower()
802
+
803
+ async for event in self._run_generation_loop_async(
804
+ num_steps=num_steps,
805
+ batch_size=batch_size,
806
+ total_samples=total_samples,
807
+ topic_paths=topic_paths or [],
808
+ data_creation_prompt=data_creation_prompt,
809
+ num_example_demonstrations=num_example_demonstrations,
810
+ include_sys_msg=include_sys_msg,
811
+ root_topic_prompt=root_topic_prompt,
812
+ topic_model_type=topic_model_type,
813
+ ):
814
+ yield event
815
+
816
+ async def _run_generation_loop_async( # noqa: PLR0912
817
+ self,
818
+ num_steps: int,
819
+ batch_size: int,
820
+ total_samples: int,
821
+ topic_paths: list,
822
+ data_creation_prompt: str,
823
+ num_example_demonstrations: int,
824
+ include_sys_msg: bool,
825
+ root_topic_prompt: str | None = None,
826
+ topic_model_type: str | None = None,
827
+ ) -> AsyncGenerator[dict | HFDataset, None]:
828
+ """Run the main generation loop yielding progress events."""
829
+ try:
830
+ yield {
831
+ "event": "generation_start",
832
+ "model_name": self.model_name,
833
+ "num_steps": num_steps,
834
+ "batch_size": batch_size,
835
+ "total_samples": total_samples,
836
+ "root_topic_prompt": root_topic_prompt,
837
+ "topic_model_type": topic_model_type,
838
+ }
839
+
840
+ for step in range(num_steps):
841
+ yield {
842
+ "event": "step_start",
843
+ "step": step + 1,
844
+ "total_steps": num_steps,
845
+ }
846
+
847
+ start_idx = step * batch_size
848
+ prompts, used_paths = self._generate_batch_prompts(
849
+ batch_size,
850
+ start_idx,
851
+ topic_paths,
852
+ data_creation_prompt,
853
+ num_example_demonstrations,
854
+ )
855
+
856
+ failed_before = len(self.failed_samples)
857
+
858
+ success, samples_generated = await self._process_batch_with_retries_async(
859
+ prompts, include_sys_msg, start_idx, used_paths
860
+ )
861
+
862
+ failed_in_batch = len(self.failed_samples) - failed_before
863
+ failure_reasons = []
864
+ if failed_in_batch > 0 and self.failed_samples:
865
+ recent_failures = self.failed_samples[-failed_in_batch:]
866
+ failure_reasons = recent_failures[:3]
867
+
868
+ yield {
869
+ "event": "step_complete",
870
+ "step": step + 1,
871
+ "samples_generated": samples_generated,
872
+ "success": success,
873
+ "failed_in_step": failed_in_batch,
874
+ "failure_reasons": failure_reasons,
875
+ }
876
+
877
+ if not success:
878
+ yield {
879
+ "event": "step_failed",
880
+ "step": step + 1,
881
+ "message": f"Failed to process batch {step + 1} after all retries",
882
+ }
883
+
884
+ yield {
885
+ "event": "generation_complete",
886
+ "total_samples": len(self._samples),
887
+ "failed_samples": len(self.failed_samples),
888
+ }
889
+
890
+ except KeyboardInterrupt:
891
+ yield {
892
+ "event": "generation_interrupted",
893
+ "message": "Generation interrupted by user.",
894
+ }
895
+ self.print_failure_summary()
896
+ self._save_samples_to_file(INTERRUPTED_DATASET_FILENAME)
897
+
898
+ except Exception as e: # noqa: BLE001
899
+ yield {"event": "generation_error", "error": str(e)}
900
+ self.print_failure_summary()
901
+ self._save_samples_to_file(ERROR_DATASET_FILENAME)
902
+ raise DataSetGeneratorError("failed") from e
903
+
904
+ yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
905
+
906
+ async def _process_batch_with_retries_async(
907
+ self,
908
+ prompts: list[str],
909
+ include_sys_msg: bool,
910
+ start_sample_idx: int = 0,
911
+ paths_for_batch: list[list[str] | None] | None = None,
912
+ ) -> tuple[bool, int]:
913
+ """Process a batch with retry logic."""
914
+ for attempt in range(self.config.max_retries):
915
+ try:
916
+ samples, failed_responses = await self._generate_structured_samples_async(
917
+ prompts, include_sys_msg, start_sample_idx, paths_for_batch
918
+ )
919
+
920
+ # Update failed samples
921
+ self.failed_samples.extend(failed_responses)
922
+
923
+ if samples:
924
+ # Convert Pydantic models to dicts and add to samples list
925
+ sample_dicts = [s.model_dump(exclude_none=True) for s in samples]
926
+ self._samples.extend(sample_dicts)
927
+ return True, len(samples) # Success - exit retry loop
928
+
929
+ except DataSetGeneratorError as e:
930
+ # Authentication and API errors are now wrapped in DataSetGeneratorError
931
+ error_str = str(e).lower()
932
+ if any(
933
+ keyword in error_str
934
+ for keyword in [
935
+ "api_key",
936
+ "api key",
937
+ "authentication",
938
+ "unauthorized",
939
+ ]
940
+ ):
941
+ error_msg = f"Authentication failed for provider '{self.provider}'. Please set the required API key environment variable."
942
+ self.failure_analysis["authentication_error"].append(error_msg)
943
+ else:
944
+ error_msg = f"API error for provider '{self.provider}': {str(e)[:100]}..."
945
+ self.failure_analysis["api_errors"].append(error_msg)
946
+
947
+ self.failed_samples.append(error_msg)
948
+ logger.exception("API error: %s", error_msg)
949
+ return False, 0 # Don't retry authentication/API errors
950
+ except Exception as e:
951
+ if attempt == self.config.max_retries - 1:
952
+ self.failed_samples.append(str(e))
953
+ failure_type = self.analyze_failure(str(e), error=e)
954
+ self.failure_analysis[failure_type].append(str(e))
955
+ return False, 0
956
+ else:
957
+ # If no exception and no samples, return False, 0
958
+ return False, 0
959
+
960
+ return False, 0
961
+
962
+ def print_failure_summary(self):
963
+ """Print a detailed summary of all failures."""
964
+ summary = self.summarize_failures()
965
+
966
+ print("\n=== Failure Analysis Summary ===")
967
+ print(f"Total Failed Samples: {summary['total_failures']}")
968
+ print("\nFailure Types Breakdown:")
969
+ for failure_type, count in summary["failure_types"].items():
970
+ if count > 0:
971
+ print(f"\n{failure_type.replace('_', ' ').title()}: {count}")
972
+ if failure_type in summary["failure_examples"]:
973
+ print("Example failures:")
974
+ for i, example in enumerate(
975
+ summary["failure_examples"].get(failure_type, []), 1
976
+ ):
977
+ print(f" {i}. {example}")
978
+ print("\n=============================")
979
+
980
+ def build_prompt(
981
+ self,
982
+ data_creation_prompt: str,
983
+ num_example_demonstrations: int,
984
+ subtopics_list: list[str] | None = None,
985
+ ) -> str:
986
+ prompt = data_creation_prompt.replace("{{{{system_prompt}}}}", self.generation_prompt)
987
+ prompt = prompt.replace("{{{{instructions}}}}", self.build_custom_instructions_text())
988
+ prompt = prompt.replace(
989
+ "{{{{examples}}}}", self.build_examples_text(num_example_demonstrations)
990
+ )
991
+ return prompt.replace("{{{{subtopics}}}}", self.build_subtopics_text(subtopics_list))
992
+
993
+ def build_system_prompt(self):
994
+ """Return the original system prompt for dataset inclusion."""
995
+ return self.dataset_system_prompt
996
+
997
+ def build_custom_instructions_text(self) -> str:
998
+ if self.config.instructions is None or self.config.instructions == "":
999
+ return ""
1000
+ return f"\nHere are additional instructions:\n<instructions>\n{self.config.instructions}\n</instructions>\n"
1001
+
1002
+ def build_examples_text(self, num_example_demonstrations: int):
1003
+ if self.config.example_data is None or num_example_demonstrations == 0:
1004
+ return ""
1005
+ # Bandit: not a security function
1006
+ # HF Dataset supports len() and indexing, convert to list for sampling
1007
+ example_list = list(self.config.example_data)
1008
+ examples = random.sample(example_list, min(num_example_demonstrations, len(example_list))) # nosec
1009
+ examples_text = "Here are output examples:\n\n"
1010
+ examples_text += "\n".join(f"Example {i + 1}: \n\n{ex}\n" for i, ex in enumerate(examples))
1011
+ return f"\nHere are output examples:\n<examples>\n{examples_text}\n</examples>\n"
1012
+
1013
+ def build_tools_text(self) -> str:
1014
+ """Build formatted tools text for XLAM multi-turn prompts."""
1015
+ if not self.tool_registry:
1016
+ return "No tools available"
1017
+
1018
+ tools_text = []
1019
+ for tool in self.tool_registry.tools:
1020
+ params_text = []
1021
+ for param in tool.parameters:
1022
+ req = " (required)" if param.required else " (optional)"
1023
+ params_text.append(f" - {param.name} ({param.type}){req}: {param.description}")
1024
+
1025
+ tool_text = f"• {tool.name}: {tool.description}\n Parameters:\n" + "\n".join(
1026
+ params_text
1027
+ )
1028
+ tools_text.append(tool_text)
1029
+
1030
+ return "\n\n".join(tools_text)
1031
+
1032
+ def build_subtopics_text(self, subtopic_list: list[str] | None):
1033
+ if subtopic_list is None:
1034
+ return ""
1035
+ return f"\nLastly, the topic of the training data should be related to the following subtopics: {' -> '.join(subtopic_list)}"
1036
+
1037
+ def _get_cot_prompt_template(self) -> str: # noqa: PLR0911
1038
+ """Get the appropriate prompt template based on modular configuration."""
1039
+ # Handle basic conversations
1040
+ if self.config.conversation_type == "basic":
1041
+ return CONVERSATION_GENERATION_PROMPT
1042
+
1043
+ # Handle chain of thought conversations
1044
+ if self.config.conversation_type == "chain_of_thought":
1045
+ # Agent mode with tools - use agent prompts
1046
+ if self.config.agent_mode == "single_turn" and self.tool_registry:
1047
+ # Use agent prompt for single-turn tool calling
1048
+ return (
1049
+ AgentPromptBuilder.build_tool_context_prompt(
1050
+ self.tool_registry,
1051
+ max_tools_per_query=self.config.max_tools_per_query,
1052
+ )
1053
+ or AGENT_COT_TOOLS_PROMPT
1054
+ )
1055
+
1056
+ if self.config.agent_mode == "multi_turn" and self.tool_registry:
1057
+ # Standard multi-turn agent
1058
+ return (
1059
+ AgentPromptBuilder.build_multi_turn_context_prompt(
1060
+ self.tool_registry,
1061
+ max_tools_per_query=self.config.max_tools_per_query,
1062
+ )
1063
+ or AGENT_COT_MULTI_TURN_PROMPT
1064
+ )
1065
+
1066
+ # Non-agent CoT - select based on reasoning style
1067
+ if self.config.reasoning_style == "freetext":
1068
+ return FREETEXT_COT_PROMPT
1069
+ if self.config.reasoning_style == "agent":
1070
+ return STRUCTURED_COT_PROMPT
1071
+
1072
+ # Fallback to basic conversation prompt
1073
+ return CONVERSATION_GENERATION_PROMPT
1074
+
1075
+ def _save_samples_to_file(self, save_path: str):
1076
+ """Save the current samples to a JSONL file."""
1077
+
1078
+ with open(save_path, "w") as f:
1079
+ for sample in self._samples:
1080
+ f.write(json.dumps(sample, separators=(",", ":")) + "\n")
1081
+
1082
+ def save_dataset(self, save_path: str):
1083
+ """Save the dataset to a JSONL file."""
1084
+ self._save_samples_to_file(save_path)