DeepFabric 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. deepfabric/__init__.py +70 -0
  2. deepfabric/__main__.py +6 -0
  3. deepfabric/auth.py +382 -0
  4. deepfabric/builders.py +303 -0
  5. deepfabric/builders_agent.py +1304 -0
  6. deepfabric/cli.py +1288 -0
  7. deepfabric/config.py +899 -0
  8. deepfabric/config_manager.py +251 -0
  9. deepfabric/constants.py +94 -0
  10. deepfabric/dataset_manager.py +534 -0
  11. deepfabric/error_codes.py +581 -0
  12. deepfabric/evaluation/__init__.py +47 -0
  13. deepfabric/evaluation/backends/__init__.py +32 -0
  14. deepfabric/evaluation/backends/ollama_backend.py +137 -0
  15. deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
  16. deepfabric/evaluation/backends/transformers_backend.py +326 -0
  17. deepfabric/evaluation/evaluator.py +845 -0
  18. deepfabric/evaluation/evaluators/__init__.py +13 -0
  19. deepfabric/evaluation/evaluators/base.py +104 -0
  20. deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
  21. deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
  22. deepfabric/evaluation/evaluators/registry.py +66 -0
  23. deepfabric/evaluation/inference.py +155 -0
  24. deepfabric/evaluation/metrics.py +397 -0
  25. deepfabric/evaluation/parser.py +304 -0
  26. deepfabric/evaluation/reporters/__init__.py +13 -0
  27. deepfabric/evaluation/reporters/base.py +56 -0
  28. deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
  29. deepfabric/evaluation/reporters/file_reporter.py +61 -0
  30. deepfabric/evaluation/reporters/multi_reporter.py +56 -0
  31. deepfabric/exceptions.py +67 -0
  32. deepfabric/factory.py +26 -0
  33. deepfabric/generator.py +1084 -0
  34. deepfabric/graph.py +545 -0
  35. deepfabric/hf_hub.py +214 -0
  36. deepfabric/kaggle_hub.py +219 -0
  37. deepfabric/llm/__init__.py +41 -0
  38. deepfabric/llm/api_key_verifier.py +534 -0
  39. deepfabric/llm/client.py +1206 -0
  40. deepfabric/llm/errors.py +105 -0
  41. deepfabric/llm/rate_limit_config.py +262 -0
  42. deepfabric/llm/rate_limit_detector.py +278 -0
  43. deepfabric/llm/retry_handler.py +270 -0
  44. deepfabric/metrics.py +212 -0
  45. deepfabric/progress.py +262 -0
  46. deepfabric/prompts.py +290 -0
  47. deepfabric/schemas.py +1000 -0
  48. deepfabric/spin/__init__.py +6 -0
  49. deepfabric/spin/client.py +263 -0
  50. deepfabric/spin/models.py +26 -0
  51. deepfabric/stream_simulator.py +90 -0
  52. deepfabric/tools/__init__.py +5 -0
  53. deepfabric/tools/defaults.py +85 -0
  54. deepfabric/tools/loader.py +87 -0
  55. deepfabric/tools/mcp_client.py +677 -0
  56. deepfabric/topic_manager.py +303 -0
  57. deepfabric/topic_model.py +20 -0
  58. deepfabric/training/__init__.py +35 -0
  59. deepfabric/training/api_key_prompt.py +302 -0
  60. deepfabric/training/callback.py +363 -0
  61. deepfabric/training/metrics_sender.py +301 -0
  62. deepfabric/tree.py +438 -0
  63. deepfabric/tui.py +1267 -0
  64. deepfabric/update_checker.py +166 -0
  65. deepfabric/utils.py +150 -0
  66. deepfabric/validation.py +143 -0
  67. deepfabric-4.4.0.dist-info/METADATA +702 -0
  68. deepfabric-4.4.0.dist-info/RECORD +71 -0
  69. deepfabric-4.4.0.dist-info/WHEEL +4 -0
  70. deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
  71. deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/config.py ADDED
@@ -0,0 +1,899 @@
1
+ import warnings
2
+
3
+ from typing import Literal
4
+
5
+ import yaml
6
+
7
+ from pydantic import BaseModel, Field, field_validator, model_validator
8
+
9
+ from .constants import (
10
+ DEFAULT_MAX_RETRIES,
11
+ DEFAULT_MODEL,
12
+ DEFAULT_PROVIDER,
13
+ DEFAULT_SAMPLE_RETRIES,
14
+ ENGINE_DEFAULT_BATCH_SIZE,
15
+ ENGINE_DEFAULT_NUM_EXAMPLES,
16
+ ENGINE_DEFAULT_TEMPERATURE,
17
+ TOPIC_TREE_DEFAULT_DEGREE,
18
+ TOPIC_TREE_DEFAULT_DEPTH,
19
+ TOPIC_TREE_DEFAULT_TEMPERATURE,
20
+ )
21
+ from .exceptions import ConfigurationError
22
+ from .metrics import trace
23
+
24
+
25
+ def _normalize_reasoning_style(value: str | None) -> str | None:
26
+ """Normalize reasoning_style with deprecation warnings for old values.
27
+
28
+ Args:
29
+ value: The reasoning_style value to normalize
30
+
31
+ Returns:
32
+ Normalized value ('freetext', 'agent', or None)
33
+ """
34
+ if value is None:
35
+ return None
36
+ if value == "structured":
37
+ warnings.warn(
38
+ "reasoning_style='structured' is deprecated. Use 'agent' instead.",
39
+ DeprecationWarning,
40
+ stacklevel=4,
41
+ )
42
+ return "agent"
43
+ if value == "hybrid":
44
+ warnings.warn(
45
+ "reasoning_style='hybrid' is deprecated and was non-functional. Use 'agent' instead.",
46
+ DeprecationWarning,
47
+ stacklevel=4,
48
+ )
49
+ return "agent"
50
+ return value
51
+
52
+
53
+ # =============================================================================
54
+ # NEW CONFIG STRUCTURE
55
+ # =============================================================================
56
+
57
+
58
+ class LLMConfig(BaseModel):
59
+ """Shared LLM configuration that can be inherited by topics and generation."""
60
+
61
+ provider: str | None = Field(
62
+ default=None,
63
+ description="LLM provider (openai, anthropic, gemini, ollama)",
64
+ )
65
+ model: str | None = Field(
66
+ default=None,
67
+ description="The name of the model to be used",
68
+ )
69
+ temperature: float | None = Field(
70
+ default=None,
71
+ ge=0.0,
72
+ le=2.0,
73
+ description="Temperature for model generation",
74
+ )
75
+ base_url: str | None = Field(
76
+ default=None,
77
+ description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
78
+ )
79
+
80
+
81
+ class TopicsConfig(BaseModel):
82
+ """Configuration for topic generation (tree or graph mode)."""
83
+
84
+ prompt: str = Field(
85
+ ..., min_length=1, description="The initial prompt to start topic generation"
86
+ )
87
+ mode: Literal["tree", "graph"] = Field(
88
+ default="tree", description="Topic generation mode: tree or graph"
89
+ )
90
+ system_prompt: str = Field(
91
+ default="", description="System prompt for topic exploration and generation"
92
+ )
93
+ depth: int = Field(
94
+ default=TOPIC_TREE_DEFAULT_DEPTH,
95
+ ge=1,
96
+ le=10,
97
+ description="Depth of the tree/graph",
98
+ )
99
+ degree: int = Field(
100
+ default=TOPIC_TREE_DEFAULT_DEGREE,
101
+ ge=1,
102
+ le=50,
103
+ description="Number of subtopics per node (branching factor)",
104
+ )
105
+ max_concurrent: int = Field(
106
+ default=4,
107
+ ge=1,
108
+ le=20,
109
+ description="Maximum concurrent LLM calls during graph expansion (helps avoid rate limits)",
110
+ )
111
+ save_as: str | None = Field(default=None, description="Where to save the generated topics")
112
+
113
+ # Optional LLM overrides (inherits from top-level llm if not specified)
114
+ llm: LLMConfig | None = Field(
115
+ default=None, description="Optional LLM configuration overrides for topics"
116
+ )
117
+
118
+
119
+ class ConversationConfig(BaseModel):
120
+ """Configuration for conversation structure in generation."""
121
+
122
+ type: Literal["basic", "chain_of_thought"] = Field(
123
+ default="basic",
124
+ description="Base conversation type: basic (simple chat), chain_of_thought (with reasoning)",
125
+ )
126
+ reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = Field(
127
+ default=None,
128
+ description="Reasoning style for chain_of_thought: freetext or agent. Note: 'structured' and 'hybrid' are deprecated.",
129
+ )
130
+ agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
131
+ default=None,
132
+ description="Agent mode: single_turn (one-shot tool use), multi_turn (extended conversations)",
133
+ )
134
+ min_turns: int = Field(
135
+ default=2,
136
+ ge=1,
137
+ le=10,
138
+ description="Minimum conversation turns for multi_turn agent mode",
139
+ )
140
+ max_turns: int = Field(
141
+ default=4,
142
+ ge=1,
143
+ le=10,
144
+ description="Maximum conversation turns for multi_turn agent mode",
145
+ )
146
+ min_tool_calls: int = Field(
147
+ default=2,
148
+ ge=0,
149
+ le=20,
150
+ description="Minimum tool calls before allowing conversation conclusion",
151
+ )
152
+
153
+ @field_validator("reasoning_style", mode="before")
154
+ @classmethod
155
+ def normalize_reasoning_style(cls, v: str | None) -> str | None:
156
+ """Normalize deprecated reasoning_style values."""
157
+ return _normalize_reasoning_style(v)
158
+
159
+ @model_validator(mode="after")
160
+ def validate_configuration(self):
161
+ """Validate that configuration combinations are consistent."""
162
+ if self.reasoning_style is not None and self.type != "chain_of_thought":
163
+ raise ValueError(
164
+ f"reasoning_style can only be set when type='chain_of_thought', "
165
+ f"got type='{self.type}'"
166
+ )
167
+
168
+ if self.type == "chain_of_thought" and self.reasoning_style is None:
169
+ raise ValueError(
170
+ "reasoning_style must be specified when type='chain_of_thought'. "
171
+ "Choose from: 'freetext' or 'agent'"
172
+ )
173
+
174
+ if self.agent_mode is not None and self.reasoning_style == "freetext":
175
+ raise ValueError(
176
+ "reasoning_style='freetext' is not compatible with agent_mode. "
177
+ "Agent mode requires structured reasoning. Use reasoning_style='agent' instead."
178
+ )
179
+
180
+ return self
181
+
182
+
183
+ class ToolsConfig(BaseModel):
184
+ """Configuration for tool/function calling in generation.
185
+
186
+ Tools are organized by component - each component routes to a different
187
+ Spin endpoint (e.g., /vfs/execute, /github/execute, /slack/execute).
188
+
189
+ Example:
190
+ tools:
191
+ spin_endpoint: "http://localhost:3000"
192
+ components:
193
+ builtin: [read_file, write_file] # Routes to /vfs/execute
194
+ github: [gh_get_file_contents] # Routes to /github/execute
195
+ slack: [send_message] # Routes to /slack/execute
196
+ tools_endpoint: "http://localhost:3000/mock/list-tools" # For non-builtin tools
197
+ """
198
+
199
+ spin_endpoint: str | None = Field(
200
+ default=None,
201
+ description="Spin service URL for real tool execution (e.g., 'http://localhost:3000')",
202
+ )
203
+ components: dict[str, list[str]] = Field(
204
+ default_factory=dict,
205
+ description=(
206
+ "Map of component name to tool names. 'builtin' uses built-in tools "
207
+ "(read_file, write_file, list_files, delete_file) and routes to /vfs/execute. "
208
+ "Other components (github, slack, etc.) load tools from tools_endpoint "
209
+ "and route to /{component}/execute."
210
+ ),
211
+ )
212
+ tools_endpoint: str | None = Field(
213
+ default=None,
214
+ description=(
215
+ "HTTP endpoint to load tool definitions from in MCP format "
216
+ "(e.g., 'http://localhost:3000/mock/list-tools'). "
217
+ "Required for non-builtin components."
218
+ ),
219
+ )
220
+ custom: list[dict] = Field(
221
+ default_factory=list,
222
+ description="Custom tool definitions as dictionaries (for inline tool definitions)",
223
+ )
224
+ max_per_query: int = Field(
225
+ default=3, ge=1, le=10, description="Maximum number of tools per query/turn"
226
+ )
227
+ strict: bool = Field(
228
+ default=True,
229
+ description="If True, discard samples exceeding max_per_query. If False, truncate.",
230
+ )
231
+ scenario_seed: dict | None = Field(
232
+ default=None,
233
+ description="Initial state to seed into Spin VFS before generation starts",
234
+ )
235
+ max_agent_steps: int = Field(
236
+ default=5,
237
+ ge=1,
238
+ le=10,
239
+ description="Maximum ReAct reasoning steps before forcing conclusion",
240
+ )
241
+
242
+ tool_execute_path: str | None = Field(
243
+ default=None,
244
+ description=(
245
+ "Custom path for tool execution (e.g., '/mock/execute'). "
246
+ "If not set, uses component-based routing (/{component}/execute)."
247
+ ),
248
+ )
249
+
250
+
251
+ class GenerationConfig(BaseModel):
252
+ """Configuration for sample/conversation generation."""
253
+
254
+ system_prompt: str = Field(
255
+ ..., min_length=1, description="System prompt for content generation"
256
+ )
257
+ instructions: str = Field(default="", description="Additional instructions for data generation")
258
+ conversation: ConversationConfig = Field(
259
+ default_factory=ConversationConfig,
260
+ description="Conversation structure configuration",
261
+ )
262
+ tools: ToolsConfig | None = Field(
263
+ default=None, description="Tool configuration (required for agent modes)"
264
+ )
265
+ max_retries: int = Field(
266
+ default=DEFAULT_MAX_RETRIES,
267
+ ge=0,
268
+ le=10,
269
+ description="Maximum retries for failed generations",
270
+ )
271
+ sample_retries: int = Field(
272
+ default=DEFAULT_SAMPLE_RETRIES,
273
+ ge=0,
274
+ le=5,
275
+ description="Retries for individual sample validation failures",
276
+ )
277
+ max_tokens: int = Field(default=2000, ge=1, description="Maximum tokens to generate per call")
278
+ rate_limit: dict[str, int | float | str | bool] | None = Field(
279
+ default=None,
280
+ description="Rate limiting and retry configuration",
281
+ )
282
+ save_as: str | None = Field(default=None, description="Where to save the generated samples")
283
+
284
+ # Optional LLM overrides
285
+ llm: LLMConfig | None = Field(
286
+ default=None, description="Optional LLM configuration overrides for generation"
287
+ )
288
+
289
+ @model_validator(mode="after")
290
+ def validate_agent_requires_tools(self):
291
+ """Validate that agent_mode requires tools with Spin endpoint."""
292
+ if self.conversation.agent_mode is not None:
293
+ if self.tools is None:
294
+ raise ValueError(
295
+ "agent_mode requires tools to be configured. "
296
+ "Specify tools.spin_endpoint and optionally tools.available to filter tools."
297
+ )
298
+ if not self.tools.spin_endpoint:
299
+ raise ValueError(
300
+ "agent_mode requires a Spin endpoint for tool execution. "
301
+ "Set tools.spin_endpoint (e.g., 'http://localhost:3000'). "
302
+ "See: cd tools-sdk && spin build && spin up"
303
+ )
304
+ return self
305
+
306
+
307
+ class OutputConfig(BaseModel):
308
+ """Configuration for final dataset output."""
309
+
310
+ system_prompt: str | None = Field(
311
+ None,
312
+ description="System prompt that goes INTO the training data (falls back to generation.system_prompt)",
313
+ )
314
+ include_system_message: bool = Field(
315
+ default=True,
316
+ description="Whether to include system message in output format",
317
+ )
318
+ num_samples: int = Field(
319
+ default=ENGINE_DEFAULT_NUM_EXAMPLES,
320
+ ge=1,
321
+ description="Number of training samples to generate",
322
+ )
323
+ batch_size: int = Field(
324
+ default=ENGINE_DEFAULT_BATCH_SIZE,
325
+ ge=1,
326
+ description="Number of samples to process at a time",
327
+ )
328
+ save_as: str = Field(..., min_length=1, description="Where to save the final dataset")
329
+
330
+
331
+ class HuggingFaceConfig(BaseModel):
332
+ """Configuration for Hugging Face Hub integration."""
333
+
334
+ repository: str = Field(..., min_length=1, description="HuggingFace repository name")
335
+ tags: list[str] = Field(default_factory=list, description="Tags for the dataset")
336
+
337
+
338
+ class KaggleConfig(BaseModel):
339
+ """Configuration for Kaggle integration."""
340
+
341
+ handle: str = Field(
342
+ ..., min_length=1, description="Kaggle dataset handle (username/dataset-name)"
343
+ )
344
+ tags: list[str] = Field(default_factory=list, description="Tags for the dataset")
345
+ description: str | None = Field(None, description="Description for the dataset")
346
+ version_notes: str | None = Field(None, description="Version notes for dataset update")
347
+
348
+
349
+ class EvaluationConfig(BaseModel):
350
+ """Configuration for model evaluation."""
351
+
352
+ conversation_type: Literal["basic", "chain_of_thought"] = Field(
353
+ ...,
354
+ description="Conversation type (must match dataset generation)",
355
+ )
356
+ reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = Field(
357
+ default=None,
358
+ description="Reasoning style for chain_of_thought type",
359
+ )
360
+
361
+ @field_validator("reasoning_style", mode="before")
362
+ @classmethod
363
+ def normalize_reasoning_style(cls, v: str | None) -> str | None:
364
+ """Normalize deprecated reasoning_style values."""
365
+ return _normalize_reasoning_style(v)
366
+
367
+ agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
368
+ default=None,
369
+ description="Agent mode if tools are used",
370
+ )
371
+ metrics: list[str] = Field(
372
+ default_factory=lambda: [
373
+ "tool_selection_accuracy",
374
+ "parameter_accuracy",
375
+ "execution_success_rate",
376
+ "response_quality",
377
+ ],
378
+ description="Metrics to compute during evaluation",
379
+ )
380
+ thresholds: dict[str, float] = Field(
381
+ default_factory=dict,
382
+ description="Pass/fail thresholds for metrics",
383
+ )
384
+ weights: dict[str, float] = Field(
385
+ default_factory=lambda: {
386
+ "tool_selection": 0.40,
387
+ "parameter_accuracy": 0.30,
388
+ "execution_success": 0.20,
389
+ "response_quality": 0.10,
390
+ },
391
+ description="Metric weights for overall score calculation",
392
+ )
393
+ output_dir: str = Field(
394
+ default="./eval_results",
395
+ description="Output directory for evaluation results",
396
+ )
397
+ output_formats: list[Literal["json", "html", "csv"]] = Field(
398
+ default_factory=lambda: ["json", "html", "csv"],
399
+ description="Output formats to generate",
400
+ )
401
+ include_failures: bool = Field(
402
+ default=True,
403
+ description="Include failed examples in output",
404
+ )
405
+ generate_charts: bool = Field(
406
+ default=True,
407
+ description="Generate visualization charts",
408
+ )
409
+ batch_size: int = Field(
410
+ default=1,
411
+ ge=1,
412
+ description="Batch size for model inference",
413
+ )
414
+ max_samples: int | None = Field(
415
+ default=None,
416
+ description="Maximum number of samples to evaluate (None for all)",
417
+ )
418
+
419
+ @model_validator(mode="after")
420
+ def validate_evaluation_config(self) -> "EvaluationConfig":
421
+ """Validate evaluation configuration consistency."""
422
+ if self.reasoning_style is not None and self.conversation_type != "chain_of_thought":
423
+ raise ValueError(
424
+ f"reasoning_style can only be set when conversation_type='chain_of_thought', "
425
+ f"got conversation_type='{self.conversation_type}'"
426
+ )
427
+
428
+ if self.conversation_type == "chain_of_thought" and self.reasoning_style is None:
429
+ raise ValueError(
430
+ "reasoning_style must be specified when conversation_type='chain_of_thought'. "
431
+ "Choose from: 'freetext' or 'agent'"
432
+ )
433
+
434
+ if self.agent_mode is not None and self.reasoning_style == "freetext":
435
+ raise ValueError(
436
+ "reasoning_style='freetext' is not compatible with agent_mode. "
437
+ "Agent mode requires structured reasoning. Use reasoning_style='agent' instead."
438
+ )
439
+
440
+ return self
441
+
442
+
443
+ class DeepFabricConfig(BaseModel):
444
+ """Main configuration for DeepFabric tasks using the new structure."""
445
+
446
+ # Optional shared LLM defaults
447
+ llm: LLMConfig | None = Field(
448
+ None, description="Shared LLM defaults inherited by topics and generation"
449
+ )
450
+
451
+ # Core sections
452
+ topics: TopicsConfig = Field(..., description="Topic generation configuration")
453
+ generation: GenerationConfig = Field(..., description="Sample generation configuration")
454
+ output: OutputConfig = Field(..., description="Output dataset configuration")
455
+
456
+ # Optional integrations
457
+ evaluation: EvaluationConfig | None = Field(None, description="Evaluation configuration")
458
+ huggingface: HuggingFaceConfig | None = Field(None, description="Hugging Face configuration")
459
+ kaggle: KaggleConfig | None = Field(None, description="Kaggle configuration")
460
+
461
+ @classmethod
462
+ def _detect_old_format(cls, config_dict: dict) -> bool:
463
+ """Detect if config uses old format."""
464
+ old_keys = ["topic_tree", "topic_graph", "data_engine", "dataset_system_prompt"]
465
+ return any(key in config_dict for key in old_keys)
466
+
467
+ @classmethod
468
+ def _get_migration_message(cls) -> str:
469
+ """Return migration message for old config format."""
470
+ return """
471
+ Configuration format has changed. Please update your config to the new structure:
472
+
473
+ OLD FORMAT NEW FORMAT
474
+ ----------- ----------
475
+ dataset_system_prompt -> output.system_prompt
476
+ topic_tree/topic_graph -> topics (with mode: tree|graph)
477
+ topic_prompt -> prompt
478
+ topic_system_prompt -> system_prompt
479
+ data_engine -> generation
480
+ generation_system_prompt -> system_prompt
481
+ conversation_type -> conversation.type
482
+ reasoning_style -> conversation.reasoning_style
483
+ agent_mode -> conversation.agent_mode
484
+ available_tools -> tools.available
485
+ custom_tools -> tools.custom
486
+ max_tools_per_query -> tools.max_per_query
487
+ max_tools_strict -> tools.strict
488
+ spin_endpoint -> tools.spin_endpoint
489
+ dataset.creation.num_steps -> output.num_samples
490
+ dataset.creation.batch_size -> output.batch_size
491
+ dataset.creation.sys_msg -> output.include_system_message
492
+ dataset.save_as -> output.save_as
493
+
494
+ See documentation for full examples.
495
+ """
496
+
497
+ @classmethod
498
+ def from_yaml(cls, yaml_path: str) -> "DeepFabricConfig":
499
+ """Load configuration from a YAML file."""
500
+ try:
501
+ with open(yaml_path, encoding="utf-8") as f:
502
+ config_dict = yaml.safe_load(f)
503
+ except FileNotFoundError as e:
504
+ raise ConfigurationError(f"not found: {yaml_path}") from e
505
+ except yaml.YAMLError as e:
506
+ raise ConfigurationError(f"invalid YAML: {str(e)}") from e
507
+ except Exception as e:
508
+ raise ConfigurationError(f"read error: {str(e)}") from e
509
+
510
+ if not isinstance(config_dict, dict):
511
+ raise ConfigurationError("must be dictionary")
512
+
513
+ # Detect and reject old format
514
+ if cls._detect_old_format(config_dict):
515
+ raise ConfigurationError(cls._get_migration_message())
516
+
517
+ try:
518
+ config = cls(**config_dict)
519
+ trace(
520
+ "config_loaded",
521
+ {
522
+ "method": "yaml",
523
+ "topics_mode": config.topics.mode,
524
+ "has_huggingface": config.huggingface is not None,
525
+ "has_kaggle": config.kaggle is not None,
526
+ },
527
+ )
528
+ except Exception as e:
529
+ raise ConfigurationError(f"invalid structure: {str(e)}") from e
530
+ else:
531
+ return config
532
+
533
+ def _resolve_llm_config(self, section_llm: LLMConfig | None) -> LLMConfig:
534
+ """Resolve LLM config with inheritance from top-level.
535
+
536
+ Priority order (highest to lowest):
537
+ 1. Section-specific llm config (e.g., generation.llm)
538
+ 2. Top-level shared llm config
539
+ 3. Built-in defaults (DEFAULT_PROVIDER, DEFAULT_MODEL, etc.)
540
+ """
541
+ # Get values from section-specific config (if any)
542
+ section_provider = section_llm.provider if section_llm else None
543
+ section_model = section_llm.model if section_llm else None
544
+ section_temperature = section_llm.temperature if section_llm else None
545
+ section_base_url = section_llm.base_url if section_llm else None
546
+
547
+ # Get values from top-level shared config (if any)
548
+ shared_provider = self.llm.provider if self.llm else None
549
+ shared_model = self.llm.model if self.llm else None
550
+ shared_temperature = self.llm.temperature if self.llm else None
551
+ shared_base_url = self.llm.base_url if self.llm else None
552
+
553
+ # Resolve with priority: section > shared > defaults
554
+ return LLMConfig(
555
+ provider=section_provider or shared_provider or DEFAULT_PROVIDER,
556
+ model=section_model or shared_model or DEFAULT_MODEL,
557
+ temperature=(
558
+ section_temperature
559
+ if section_temperature is not None
560
+ else (
561
+ shared_temperature
562
+ if shared_temperature is not None
563
+ else TOPIC_TREE_DEFAULT_TEMPERATURE
564
+ )
565
+ ),
566
+ base_url=section_base_url or shared_base_url,
567
+ )
568
+
569
+ def get_topics_params(self, **overrides) -> dict:
570
+ """Get parameters for Tree/Graph instantiation."""
571
+ llm = self._resolve_llm_config(self.topics.llm)
572
+
573
+ params = {
574
+ "topic_prompt": self.topics.prompt,
575
+ "topic_system_prompt": self.topics.system_prompt,
576
+ "provider": llm.provider,
577
+ "model_name": llm.model,
578
+ "temperature": llm.temperature,
579
+ "base_url": llm.base_url,
580
+ "depth": self.topics.depth,
581
+ "degree": self.topics.degree,
582
+ "max_concurrent": self.topics.max_concurrent,
583
+ }
584
+
585
+ # Handle overrides
586
+ override_provider = overrides.pop("provider", None)
587
+ override_model = overrides.pop("model", None)
588
+
589
+ if override_provider:
590
+ params["provider"] = override_provider
591
+ if override_model:
592
+ params["model_name"] = override_model
593
+
594
+ params.update(overrides)
595
+ return params
596
+
597
+ def get_generation_params(self, **overrides) -> dict:
598
+ """Get parameters for DataSetGenerator instantiation."""
599
+ llm = self._resolve_llm_config(self.generation.llm)
600
+
601
+ params = {
602
+ "generation_system_prompt": self.generation.system_prompt,
603
+ "instructions": self.generation.instructions,
604
+ "provider": llm.provider,
605
+ "model_name": llm.model,
606
+ "temperature": llm.temperature,
607
+ "base_url": llm.base_url,
608
+ "max_retries": self.generation.max_retries,
609
+ "sample_retries": self.generation.sample_retries,
610
+ "max_tokens": self.generation.max_tokens,
611
+ "rate_limit": self.generation.rate_limit,
612
+ # Conversation config
613
+ "conversation_type": self.generation.conversation.type,
614
+ "reasoning_style": self.generation.conversation.reasoning_style,
615
+ "agent_mode": self.generation.conversation.agent_mode,
616
+ "min_turns": self.generation.conversation.min_turns,
617
+ "max_turns": self.generation.conversation.max_turns,
618
+ "min_tool_calls": self.generation.conversation.min_tool_calls,
619
+ # Output config
620
+ "sys_msg": self.output.include_system_message,
621
+ "dataset_system_prompt": self.output.system_prompt or self.generation.system_prompt,
622
+ }
623
+
624
+ # Tool config
625
+ if self.generation.tools:
626
+ params["tool_components"] = self.generation.tools.components
627
+ params["tools_endpoint"] = self.generation.tools.tools_endpoint
628
+ params["tool_execute_path"] = self.generation.tools.tool_execute_path
629
+ params["custom_tools"] = self.generation.tools.custom
630
+ params["max_tools_per_query"] = self.generation.tools.max_per_query
631
+ params["max_tools_strict"] = self.generation.tools.strict
632
+ params["spin_endpoint"] = self.generation.tools.spin_endpoint
633
+ params["scenario_seed"] = self.generation.tools.scenario_seed
634
+ params["max_agent_steps"] = self.generation.tools.max_agent_steps
635
+
636
+ # Handle overrides
637
+ override_provider = overrides.pop("provider", None)
638
+ override_model = overrides.pop("model", None)
639
+
640
+ if override_provider:
641
+ params["provider"] = override_provider
642
+ if override_model:
643
+ params["model_name"] = override_model
644
+
645
+ params.update(overrides)
646
+ return params
647
+
648
+ def get_output_config(self) -> dict:
649
+ """Get output configuration."""
650
+ return {
651
+ "system_prompt": self.output.system_prompt,
652
+ "include_system_message": self.output.include_system_message,
653
+ "num_samples": self.output.num_samples,
654
+ "batch_size": self.output.batch_size,
655
+ "save_as": self.output.save_as,
656
+ }
657
+
658
+ def get_huggingface_config(self) -> dict:
659
+ """Get Hugging Face configuration."""
660
+ return self.huggingface.model_dump() if self.huggingface else {}
661
+
662
+ def get_kaggle_config(self) -> dict:
663
+ """Get Kaggle configuration."""
664
+ return self.kaggle.model_dump() if self.kaggle else {}
665
+
666
+ def get_configured_providers(self) -> set[str]:
667
+ """Get the set of LLM providers configured in this config."""
668
+ providers = set()
669
+
670
+ # Get topics provider
671
+ topics_llm = self._resolve_llm_config(self.topics.llm)
672
+ providers.add(topics_llm.provider)
673
+
674
+ # Get generation provider
675
+ gen_llm = self._resolve_llm_config(self.generation.llm)
676
+ providers.add(gen_llm.provider)
677
+
678
+ return providers
679
+
680
+
681
+ # =============================================================================
682
+ # LEGACY CONFIG CLASSES (for reference during migration - can be removed later)
683
+ # =============================================================================
684
+
685
+
686
+ class TopicTreeConfig(BaseModel):
687
+ """DEPRECATED: Configuration for topic tree generation. Use TopicsConfig instead."""
688
+
689
+ topic_prompt: str = Field(
690
+ ..., min_length=1, description="The initial prompt to start the topic tree"
691
+ )
692
+ topic_system_prompt: str = Field(
693
+ default="", description="System prompt for topic exploration and generation"
694
+ )
695
+ provider: str = Field(
696
+ default=DEFAULT_PROVIDER,
697
+ min_length=1,
698
+ description="LLM provider (openai, anthropic, gemini, ollama)",
699
+ )
700
+ model: str = Field(
701
+ default=DEFAULT_MODEL,
702
+ min_length=1,
703
+ description="The name of the model to be used",
704
+ )
705
+ temperature: float = Field(
706
+ default=TOPIC_TREE_DEFAULT_TEMPERATURE,
707
+ ge=0.0,
708
+ le=2.0,
709
+ description="Temperature for model generation",
710
+ )
711
+ degree: int = Field(
712
+ default=TOPIC_TREE_DEFAULT_DEGREE,
713
+ ge=1,
714
+ le=50,
715
+ description="Number of subtopics per node",
716
+ )
717
+ depth: int = Field(
718
+ default=TOPIC_TREE_DEFAULT_DEPTH,
719
+ ge=1,
720
+ le=10,
721
+ description="Depth of the tree",
722
+ )
723
+ base_url: str | None = Field(
724
+ default=None,
725
+ description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
726
+ )
727
+ save_as: str | None = Field(default=None, description="Where to save the generated topic tree")
728
+
729
+
730
+ class TopicGraphConfig(BaseModel):
731
+ """DEPRECATED: Configuration for topic graph generation. Use TopicsConfig instead."""
732
+
733
+ topic_prompt: str = Field(
734
+ ..., min_length=1, description="The initial prompt to start the topic graph"
735
+ )
736
+ topic_system_prompt: str = Field(
737
+ default="", description="System prompt for topic exploration and generation"
738
+ )
739
+ provider: str = Field(
740
+ default=DEFAULT_PROVIDER,
741
+ min_length=1,
742
+ description="LLM provider (openai, anthropic, gemini, ollama)",
743
+ )
744
+ model: str = Field(
745
+ default=DEFAULT_MODEL,
746
+ min_length=1,
747
+ description="The name of the model to be used",
748
+ )
749
+ temperature: float = Field(
750
+ default=0.6,
751
+ ge=0.0,
752
+ le=2.0,
753
+ description="Temperature for model generation",
754
+ )
755
+ degree: int = Field(default=3, ge=1, le=10, description="The branching factor of the graph")
756
+ depth: int = Field(default=2, ge=1, le=5, description="The depth of the graph")
757
+ base_url: str | None = Field(
758
+ default=None,
759
+ description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
760
+ )
761
+ save_as: str | None = Field(default=None, description="Where to save the generated topic graph")
762
+
763
+
764
+ class DataEngineConfig(BaseModel):
765
+ """DEPRECATED: Configuration for data engine generation. Use GenerationConfig instead."""
766
+
767
+ instructions: str = Field(default="", description="Additional instructions for data generation")
768
+ generation_system_prompt: str = Field(
769
+ ..., min_length=1, description="System prompt for content generation"
770
+ )
771
+ provider: str = Field(
772
+ default=DEFAULT_PROVIDER,
773
+ min_length=1,
774
+ description="LLM provider (openai, anthropic, gemini, ollama)",
775
+ )
776
+ model: str = Field(
777
+ default=DEFAULT_MODEL,
778
+ min_length=1,
779
+ description="The name of the model to be used",
780
+ )
781
+ temperature: float = Field(
782
+ default=ENGINE_DEFAULT_TEMPERATURE,
783
+ ge=0.0,
784
+ le=2.0,
785
+ description="Temperature for model generation",
786
+ )
787
+ max_retries: int = Field(
788
+ default=DEFAULT_MAX_RETRIES,
789
+ ge=0,
790
+ le=10,
791
+ description="Maximum number of retries for failed generations",
792
+ )
793
+ sample_retries: int = Field(
794
+ default=DEFAULT_SAMPLE_RETRIES,
795
+ ge=0,
796
+ le=5,
797
+ description="Number of retries for individual sample validation failures",
798
+ )
799
+ max_tokens: int = Field(
800
+ default=2000, ge=1, description="Maximum tokens to generate in a single call to the llm"
801
+ )
802
+ base_url: str | None = Field(
803
+ default=None,
804
+ description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
805
+ )
806
+ save_as: str | None = Field(default=None, description="Where to save the generated data")
807
+ rate_limit: dict[str, int | float | str | bool] | None = Field(
808
+ default=None,
809
+ description="Rate limiting and retry configuration",
810
+ )
811
+ conversation_type: Literal["basic", "chain_of_thought"] = Field(
812
+ default="basic",
813
+ description="Base conversation type",
814
+ )
815
+ reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = Field(
816
+ default=None,
817
+ description="Reasoning style for chain_of_thought type",
818
+ )
819
+
820
+ @field_validator("reasoning_style", mode="before")
821
+ @classmethod
822
+ def normalize_reasoning_style(cls, v: str | None) -> str | None:
823
+ return _normalize_reasoning_style(v)
824
+
825
+ agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
826
+ default=None,
827
+ description="Agent mode for tool use",
828
+ )
829
+ available_tools: list[str] = Field(
830
+ default_factory=list,
831
+ description="List of tool names available",
832
+ )
833
+ custom_tools: list[dict] = Field(default_factory=list, description="Custom tool definitions")
834
+ max_tools_per_query: int = Field(default=3, ge=1, le=10, description="Maximum tools per query")
835
+ max_tools_strict: bool = Field(
836
+ default=True,
837
+ description="Strict mode for tool limits",
838
+ )
839
+
840
+ @model_validator(mode="after")
841
+ def validate_configuration(self):
842
+ if self.reasoning_style is not None and self.conversation_type != "chain_of_thought":
843
+ raise ValueError(
844
+ f"reasoning_style can only be set when conversation_type='chain_of_thought', "
845
+ f"got conversation_type='{self.conversation_type}'"
846
+ )
847
+
848
+ if self.conversation_type == "chain_of_thought" and self.reasoning_style is None:
849
+ raise ValueError(
850
+ "reasoning_style must be specified when conversation_type='chain_of_thought'. "
851
+ "Choose from: 'freetext' or 'agent'"
852
+ )
853
+
854
+ if self.agent_mode is not None:
855
+ has_tools = bool(self.available_tools or self.custom_tools)
856
+ if not has_tools:
857
+ raise ValueError("agent_mode requires tools to be configured.")
858
+
859
+ if self.agent_mode is not None and self.reasoning_style == "freetext":
860
+ raise ValueError("reasoning_style='freetext' is not compatible with agent_mode.")
861
+
862
+ return self
863
+
864
+
865
+ class DatasetCreationConfig(BaseModel):
866
+ """DEPRECATED: Configuration for dataset creation. Use OutputConfig instead."""
867
+
868
+ num_steps: int = Field(
869
+ default=ENGINE_DEFAULT_NUM_EXAMPLES,
870
+ ge=1,
871
+ description="Number of training examples to generate",
872
+ )
873
+ batch_size: int = Field(
874
+ default=ENGINE_DEFAULT_BATCH_SIZE,
875
+ ge=1,
876
+ description="Number of examples to process at a time",
877
+ )
878
+ sys_msg: bool | None = Field(
879
+ default=None,
880
+ description="Include system messages in output format",
881
+ )
882
+ provider: str | None = Field(
883
+ default=None,
884
+ description="Optional provider override",
885
+ )
886
+ model: str | None = Field(
887
+ default=None,
888
+ description="Optional model override",
889
+ )
890
+
891
+
892
+ class DatasetConfig(BaseModel):
893
+ """DEPRECATED: Configuration for dataset assembly. Use OutputConfig instead."""
894
+
895
+ creation: DatasetCreationConfig = Field(
896
+ default_factory=DatasetCreationConfig,
897
+ description="Dataset creation parameters",
898
+ )
899
+ save_as: str = Field(..., min_length=1, description="Where to save the final dataset")