DeepFabric 4.9.0__py3-none-any.whl → 4.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/builders.py +7 -21
- deepfabric/builders_agent.py +0 -542
- deepfabric/cli.py +505 -74
- deepfabric/config.py +57 -73
- deepfabric/config_manager.py +8 -6
- deepfabric/constants.py +6 -0
- deepfabric/dataset_manager.py +107 -11
- deepfabric/evaluation/parser.py +7 -7
- deepfabric/generator.py +656 -103
- deepfabric/graph.py +46 -1
- deepfabric/prompts.py +0 -39
- deepfabric/schemas.py +4 -3
- deepfabric/topic_model.py +32 -0
- deepfabric/tree.py +23 -1
- deepfabric/tui.py +66 -21
- deepfabric/utils.py +184 -0
- deepfabric/validation.py +47 -77
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/METADATA +5 -6
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/RECORD +22 -22
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/WHEEL +0 -0
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/entry_points.txt +0 -0
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/licenses/LICENSE +0 -0
deepfabric/config.py
CHANGED
|
@@ -20,6 +20,7 @@ from .constants import (
|
|
|
20
20
|
)
|
|
21
21
|
from .exceptions import ConfigurationError
|
|
22
22
|
from .metrics import trace
|
|
23
|
+
from .utils import parse_num_samples
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
def _normalize_reasoning_style(value: str | None) -> str | None:
|
|
@@ -131,28 +132,6 @@ class ConversationConfig(BaseModel):
|
|
|
131
132
|
default=None,
|
|
132
133
|
description="Reasoning style for cot: freetext or agent. Note: 'structured' and 'hybrid' are deprecated.",
|
|
133
134
|
)
|
|
134
|
-
agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
|
|
135
|
-
default=None,
|
|
136
|
-
description="Agent mode: single_turn (one-shot tool use), multi_turn (extended conversations)",
|
|
137
|
-
)
|
|
138
|
-
min_turns: int = Field(
|
|
139
|
-
default=2,
|
|
140
|
-
ge=1,
|
|
141
|
-
le=10,
|
|
142
|
-
description="Minimum conversation turns for multi_turn agent mode",
|
|
143
|
-
)
|
|
144
|
-
max_turns: int = Field(
|
|
145
|
-
default=4,
|
|
146
|
-
ge=1,
|
|
147
|
-
le=10,
|
|
148
|
-
description="Maximum conversation turns for multi_turn agent mode",
|
|
149
|
-
)
|
|
150
|
-
min_tool_calls: int = Field(
|
|
151
|
-
default=2,
|
|
152
|
-
ge=0,
|
|
153
|
-
le=20,
|
|
154
|
-
description="Minimum tool calls before allowing conversation conclusion",
|
|
155
|
-
)
|
|
156
135
|
|
|
157
136
|
@field_validator("reasoning_style", mode="before")
|
|
158
137
|
@classmethod
|
|
@@ -174,12 +153,6 @@ class ConversationConfig(BaseModel):
|
|
|
174
153
|
"Choose from: 'freetext' or 'agent'"
|
|
175
154
|
)
|
|
176
155
|
|
|
177
|
-
if self.agent_mode is not None and self.reasoning_style == "freetext":
|
|
178
|
-
raise ValueError(
|
|
179
|
-
"reasoning_style='freetext' is not compatible with agent_mode. "
|
|
180
|
-
"Agent mode requires structured reasoning. Use reasoning_style='agent' instead."
|
|
181
|
-
)
|
|
182
|
-
|
|
183
156
|
return self
|
|
184
157
|
|
|
185
158
|
|
|
@@ -289,22 +262,28 @@ class GenerationConfig(BaseModel):
|
|
|
289
262
|
default=None, description="Optional LLM configuration overrides for generation"
|
|
290
263
|
)
|
|
291
264
|
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
265
|
+
|
|
266
|
+
class CheckpointConfig(BaseModel):
|
|
267
|
+
"""Configuration for checkpoint-based resume capability.
|
|
268
|
+
|
|
269
|
+
Checkpoints allow pausing and resuming long-running dataset generation
|
|
270
|
+
without losing progress. When enabled, samples are periodically saved
|
|
271
|
+
to disk and can be resumed if generation is interrupted.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
interval: int = Field(
|
|
275
|
+
...,
|
|
276
|
+
ge=1,
|
|
277
|
+
description="Save checkpoint every N samples",
|
|
278
|
+
)
|
|
279
|
+
path: str | None = Field(
|
|
280
|
+
default=None,
|
|
281
|
+
description="Directory to store checkpoint files. If not specified, uses ~/.deepfabric/checkpoints/{config_hash}/",
|
|
282
|
+
)
|
|
283
|
+
retry_failed: bool = Field(
|
|
284
|
+
default=False,
|
|
285
|
+
description="When resuming, retry previously failed samples",
|
|
286
|
+
)
|
|
308
287
|
|
|
309
288
|
|
|
310
289
|
class OutputConfig(BaseModel):
|
|
@@ -318,10 +297,9 @@ class OutputConfig(BaseModel):
|
|
|
318
297
|
default=True,
|
|
319
298
|
description="Whether to include system message in output format",
|
|
320
299
|
)
|
|
321
|
-
num_samples: int = Field(
|
|
300
|
+
num_samples: int | str = Field(
|
|
322
301
|
default=ENGINE_DEFAULT_NUM_EXAMPLES,
|
|
323
|
-
|
|
324
|
-
description="Number of training samples to generate",
|
|
302
|
+
description="Number of samples: integer, 'auto' (100% of topics), or percentage like '50%'",
|
|
325
303
|
)
|
|
326
304
|
batch_size: int = Field(
|
|
327
305
|
default=ENGINE_DEFAULT_BATCH_SIZE,
|
|
@@ -330,6 +308,20 @@ class OutputConfig(BaseModel):
|
|
|
330
308
|
)
|
|
331
309
|
save_as: str = Field(..., min_length=1, description="Where to save the final dataset")
|
|
332
310
|
|
|
311
|
+
# Optional checkpoint configuration (nested inside output)
|
|
312
|
+
checkpoint: CheckpointConfig | None = Field(
|
|
313
|
+
None, description="Checkpoint configuration for resumable generation"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
@field_validator("num_samples", mode="before")
|
|
317
|
+
@classmethod
|
|
318
|
+
def validate_num_samples(cls, v: int | str) -> int | str:
|
|
319
|
+
"""Validate num_samples: integer, 'auto', or percentage like '50%'."""
|
|
320
|
+
result = parse_num_samples(v)
|
|
321
|
+
if result is None:
|
|
322
|
+
raise ValueError("num_samples cannot be None")
|
|
323
|
+
return result
|
|
324
|
+
|
|
333
325
|
|
|
334
326
|
class HuggingFaceConfig(BaseModel):
|
|
335
327
|
"""Configuration for Hugging Face Hub integration."""
|
|
@@ -388,10 +380,6 @@ class EvaluationConfig(BaseModel):
|
|
|
388
380
|
"""Normalize deprecated reasoning_style values."""
|
|
389
381
|
return _normalize_reasoning_style(v)
|
|
390
382
|
|
|
391
|
-
agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
|
|
392
|
-
default=None,
|
|
393
|
-
description="Agent mode if tools are used",
|
|
394
|
-
)
|
|
395
383
|
metrics: list[str] = Field(
|
|
396
384
|
default_factory=lambda: [
|
|
397
385
|
"tool_selection_accuracy",
|
|
@@ -455,12 +443,6 @@ class EvaluationConfig(BaseModel):
|
|
|
455
443
|
"Choose from: 'freetext' or 'agent'"
|
|
456
444
|
)
|
|
457
445
|
|
|
458
|
-
if self.agent_mode is not None and self.reasoning_style == "freetext":
|
|
459
|
-
raise ValueError(
|
|
460
|
-
"reasoning_style='freetext' is not compatible with agent_mode. "
|
|
461
|
-
"Agent mode requires structured reasoning. Use reasoning_style='agent' instead."
|
|
462
|
-
)
|
|
463
|
-
|
|
464
446
|
return self
|
|
465
447
|
|
|
466
448
|
|
|
@@ -640,13 +622,17 @@ See documentation for full examples.
|
|
|
640
622
|
# Conversation config
|
|
641
623
|
"conversation_type": self.generation.conversation.type,
|
|
642
624
|
"reasoning_style": self.generation.conversation.reasoning_style,
|
|
643
|
-
"agent_mode": self.generation.conversation.agent_mode,
|
|
644
|
-
"min_turns": self.generation.conversation.min_turns,
|
|
645
|
-
"max_turns": self.generation.conversation.max_turns,
|
|
646
|
-
"min_tool_calls": self.generation.conversation.min_tool_calls,
|
|
647
625
|
# Output config
|
|
648
626
|
"sys_msg": self.output.include_system_message,
|
|
649
627
|
"dataset_system_prompt": self.output.system_prompt or self.generation.system_prompt,
|
|
628
|
+
"output_save_as": self.output.save_as,
|
|
629
|
+
# Checkpoint config (nested inside output)
|
|
630
|
+
# Note: checkpoint_path can be None, meaning "auto-resolve" at runtime
|
|
631
|
+
"checkpoint_interval": self.output.checkpoint.interval if self.output.checkpoint else None,
|
|
632
|
+
"checkpoint_path": self.output.checkpoint.path if self.output.checkpoint else None,
|
|
633
|
+
"checkpoint_retry_failed": (
|
|
634
|
+
self.output.checkpoint.retry_failed if self.output.checkpoint else False
|
|
635
|
+
),
|
|
650
636
|
}
|
|
651
637
|
|
|
652
638
|
# Tool config
|
|
@@ -683,6 +669,16 @@ See documentation for full examples.
|
|
|
683
669
|
"save_as": self.output.save_as,
|
|
684
670
|
}
|
|
685
671
|
|
|
672
|
+
def get_checkpoint_config(self) -> dict:
|
|
673
|
+
"""Get checkpoint configuration."""
|
|
674
|
+
if self.output.checkpoint is None:
|
|
675
|
+
return {
|
|
676
|
+
"interval": None,
|
|
677
|
+
"path": None, # None means "auto-resolve" at runtime
|
|
678
|
+
"retry_failed": False,
|
|
679
|
+
}
|
|
680
|
+
return self.output.checkpoint.model_dump()
|
|
681
|
+
|
|
686
682
|
def get_huggingface_config(self) -> dict:
|
|
687
683
|
"""Get Hugging Face configuration."""
|
|
688
684
|
return self.huggingface.model_dump() if self.huggingface else {}
|
|
@@ -854,10 +850,6 @@ class DataEngineConfig(BaseModel):
|
|
|
854
850
|
def normalize_reasoning_style(cls, v: str | None) -> str | None:
|
|
855
851
|
return _normalize_reasoning_style(v)
|
|
856
852
|
|
|
857
|
-
agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
|
|
858
|
-
default=None,
|
|
859
|
-
description="Agent mode for tool use",
|
|
860
|
-
)
|
|
861
853
|
available_tools: list[str] = Field(
|
|
862
854
|
default_factory=list,
|
|
863
855
|
description="List of tool names available",
|
|
@@ -883,14 +875,6 @@ class DataEngineConfig(BaseModel):
|
|
|
883
875
|
"Choose from: 'freetext' or 'agent'"
|
|
884
876
|
)
|
|
885
877
|
|
|
886
|
-
if self.agent_mode is not None:
|
|
887
|
-
has_tools = bool(self.available_tools or self.custom_tools)
|
|
888
|
-
if not has_tools:
|
|
889
|
-
raise ValueError("agent_mode requires tools to be configured.")
|
|
890
|
-
|
|
891
|
-
if self.agent_mode is not None and self.reasoning_style == "freetext":
|
|
892
|
-
raise ValueError("reasoning_style='freetext' is not compatible with agent_mode.")
|
|
893
|
-
|
|
894
878
|
return self
|
|
895
879
|
|
|
896
880
|
|
deepfabric/config_manager.py
CHANGED
|
@@ -65,7 +65,7 @@ def load_config( # noqa: PLR0913
|
|
|
65
65
|
mode: Topic generation mode (tree or graph)
|
|
66
66
|
conversation_type: Base conversation type (basic, cot)
|
|
67
67
|
reasoning_style: Reasoning style for cot (freetext, agent)
|
|
68
|
-
agent_mode: Agent mode (single_turn, multi_turn)
|
|
68
|
+
agent_mode: [Deprecated] Agent mode (single_turn only, multi_turn no longer supported)
|
|
69
69
|
|
|
70
70
|
Returns:
|
|
71
71
|
DeepFabricConfig object
|
|
@@ -140,7 +140,7 @@ def load_config( # noqa: PLR0913
|
|
|
140
140
|
"include_system_message": include_system_message
|
|
141
141
|
if include_system_message is not None
|
|
142
142
|
else True,
|
|
143
|
-
"num_samples": num_samples
|
|
143
|
+
"num_samples": num_samples if num_samples is not None else ENGINE_DEFAULT_NUM_EXAMPLES,
|
|
144
144
|
"batch_size": batch_size or ENGINE_DEFAULT_BATCH_SIZE,
|
|
145
145
|
"save_as": output_save_as or "dataset.jsonl",
|
|
146
146
|
},
|
|
@@ -221,27 +221,29 @@ def apply_cli_overrides(
|
|
|
221
221
|
|
|
222
222
|
def get_final_parameters(
|
|
223
223
|
config: DeepFabricConfig,
|
|
224
|
-
num_samples: int | None = None,
|
|
224
|
+
num_samples: int | str | None = None,
|
|
225
225
|
batch_size: int | None = None,
|
|
226
226
|
depth: int | None = None,
|
|
227
227
|
degree: int | None = None,
|
|
228
|
-
) -> tuple[int, int, int, int]:
|
|
228
|
+
) -> tuple[int | str, int, int, int]:
|
|
229
229
|
"""
|
|
230
230
|
Get final parameters from config and CLI overrides.
|
|
231
231
|
|
|
232
232
|
Args:
|
|
233
233
|
config: DeepFabricConfig object
|
|
234
|
-
num_samples: CLI override for num_samples
|
|
234
|
+
num_samples: CLI override for num_samples (int, "auto", or percentage like "50%")
|
|
235
235
|
batch_size: CLI override for batch_size
|
|
236
236
|
depth: CLI override for depth
|
|
237
237
|
degree: CLI override for degree
|
|
238
238
|
|
|
239
239
|
Returns:
|
|
240
240
|
Tuple of (num_samples, batch_size, depth, degree)
|
|
241
|
+
Note: num_samples may be int, "auto", or percentage string
|
|
241
242
|
"""
|
|
242
243
|
output_config = config.get_output_config()
|
|
243
244
|
|
|
244
|
-
|
|
245
|
+
# Use 'is not None' to allow passing through "auto" or percentage strings
|
|
246
|
+
final_num_samples = num_samples if num_samples is not None else output_config["num_samples"]
|
|
245
247
|
final_batch_size = batch_size or output_config["batch_size"]
|
|
246
248
|
|
|
247
249
|
# Get depth and degree from topics config
|
deepfabric/constants.py
CHANGED
|
@@ -89,6 +89,12 @@ ERROR_DATASET_FILENAME = "error_dataset.jsonl"
|
|
|
89
89
|
PARTIAL_TREE_FILENAME = "partial_tree.jsonl"
|
|
90
90
|
FAILED_TREE_SUFFIX = "_failed.jsonl"
|
|
91
91
|
|
|
92
|
+
# Checkpoint file patterns
|
|
93
|
+
CHECKPOINT_METADATA_SUFFIX = ".checkpoint.json"
|
|
94
|
+
CHECKPOINT_SAMPLES_SUFFIX = ".checkpoint.jsonl"
|
|
95
|
+
CHECKPOINT_FAILURES_SUFFIX = ".checkpoint.failures.jsonl"
|
|
96
|
+
CHECKPOINT_VERSION = 3 # Increment when checkpoint format changes
|
|
97
|
+
|
|
92
98
|
# Stream simulation defaults
|
|
93
99
|
STREAM_SIM_CHUNK_SIZE = 8 # characters per chunk
|
|
94
100
|
STREAM_SIM_CHUNK_DELAY_MS = 10.0 # milliseconds between chunks
|
deepfabric/dataset_manager.py
CHANGED
|
@@ -6,6 +6,7 @@ import traceback
|
|
|
6
6
|
|
|
7
7
|
from collections.abc import AsyncIterator
|
|
8
8
|
from datetime import datetime, timezone
|
|
9
|
+
from pathlib import Path
|
|
9
10
|
from typing import TYPE_CHECKING, Any
|
|
10
11
|
|
|
11
12
|
from datasets import Dataset as HFDataset
|
|
@@ -51,6 +52,39 @@ if TYPE_CHECKING:
|
|
|
51
52
|
DEBUG_MAX_FAILURES_TO_SHOW = 10
|
|
52
53
|
|
|
53
54
|
|
|
55
|
+
def resolve_num_samples(num_samples: int | str, topic_count: int) -> int:
|
|
56
|
+
"""Resolve num_samples to an integer based on topic count.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
num_samples: Integer, "auto", or percentage string like "50%"
|
|
60
|
+
topic_count: Number of available topic paths
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Resolved integer sample count
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ConfigurationError: If topic_count is 0 and dynamic sampling is requested
|
|
67
|
+
"""
|
|
68
|
+
if isinstance(num_samples, int):
|
|
69
|
+
return num_samples
|
|
70
|
+
|
|
71
|
+
if topic_count == 0:
|
|
72
|
+
raise ConfigurationError(
|
|
73
|
+
"Cannot use 'auto' or percentage num_samples with empty topic model. "
|
|
74
|
+
"Ensure topic generation produced paths."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
if num_samples == "auto":
|
|
78
|
+
return topic_count
|
|
79
|
+
|
|
80
|
+
if isinstance(num_samples, str) and num_samples.endswith("%"):
|
|
81
|
+
percentage = float(num_samples[:-1]) / 100.0
|
|
82
|
+
return max(1, int(topic_count * percentage))
|
|
83
|
+
|
|
84
|
+
# Fallback - try to parse as int (shouldn't reach here if validated properly)
|
|
85
|
+
return int(num_samples)
|
|
86
|
+
|
|
87
|
+
|
|
54
88
|
async def handle_dataset_events_async(
|
|
55
89
|
generator: AsyncIterator[dict | HFDataset], engine=None, debug: bool = False
|
|
56
90
|
) -> HFDataset | None:
|
|
@@ -80,6 +114,7 @@ async def handle_dataset_events_async(
|
|
|
80
114
|
tui.init_status(
|
|
81
115
|
total_steps=event["num_steps"],
|
|
82
116
|
total_samples=event["total_samples"],
|
|
117
|
+
checkpoint_enabled=event.get("checkpoint_enabled", False),
|
|
83
118
|
)
|
|
84
119
|
|
|
85
120
|
# Build layout with footer card
|
|
@@ -112,7 +147,9 @@ async def handle_dataset_events_async(
|
|
|
112
147
|
# Footer run status
|
|
113
148
|
footer_prog = tui.tui.create_footer(layout, title="Run Status")
|
|
114
149
|
task = footer_prog.add_task(
|
|
115
|
-
"Generating dataset samples",
|
|
150
|
+
"Generating dataset samples",
|
|
151
|
+
total=event["total_samples"],
|
|
152
|
+
completed=event.get("resumed_samples", 0),
|
|
116
153
|
)
|
|
117
154
|
|
|
118
155
|
# Use alternate screen to avoid scroll trails; leave a clean terminal
|
|
@@ -130,7 +167,10 @@ async def handle_dataset_events_async(
|
|
|
130
167
|
tui.show_generation_header(
|
|
131
168
|
event["model_name"], event["num_steps"], event["batch_size"]
|
|
132
169
|
)
|
|
133
|
-
simple_task = {
|
|
170
|
+
simple_task = {
|
|
171
|
+
"count": event.get("resumed_samples", 0),
|
|
172
|
+
"total": event["total_samples"],
|
|
173
|
+
}
|
|
134
174
|
elif event["event"] == "step_complete":
|
|
135
175
|
samples_generated = event.get("samples_generated", 0)
|
|
136
176
|
if footer_prog and task is not None:
|
|
@@ -171,6 +211,40 @@ async def handle_dataset_events_async(
|
|
|
171
211
|
total = int(event.get("total_steps", 0))
|
|
172
212
|
tui.status_step_start(step, total)
|
|
173
213
|
|
|
214
|
+
elif event["event"] == "checkpoint_saved":
|
|
215
|
+
# Display checkpoint save notification
|
|
216
|
+
total_samples = event.get("total_samples", 0)
|
|
217
|
+
total_failures = event.get("total_failures", 0)
|
|
218
|
+
is_final = event.get("final", False)
|
|
219
|
+
|
|
220
|
+
if footer_prog and task is not None:
|
|
221
|
+
# Rich mode: log to events panel and update status
|
|
222
|
+
if is_final:
|
|
223
|
+
tui.log_event(f"💾 Final checkpoint: {total_samples} samples")
|
|
224
|
+
else:
|
|
225
|
+
tui.log_event(f"💾 Checkpoint: {total_samples} samples")
|
|
226
|
+
tui.status_checkpoint_saved(total_samples)
|
|
227
|
+
elif isinstance(simple_task, dict):
|
|
228
|
+
# Simple mode: print checkpoint notification
|
|
229
|
+
checkpoint_msg = f"Checkpoint saved: {total_samples} samples"
|
|
230
|
+
if total_failures > 0:
|
|
231
|
+
checkpoint_msg += f" ({total_failures} failures)"
|
|
232
|
+
if is_final:
|
|
233
|
+
checkpoint_msg = "Final " + checkpoint_msg.lower()
|
|
234
|
+
tui.info(checkpoint_msg)
|
|
235
|
+
|
|
236
|
+
elif event["event"] == "generation_stopped":
|
|
237
|
+
# Graceful stop at checkpoint
|
|
238
|
+
if live:
|
|
239
|
+
live.stop()
|
|
240
|
+
tui.console.print()
|
|
241
|
+
tui.success(
|
|
242
|
+
f"Gracefully stopped: {event['total_samples']} samples saved to checkpoint"
|
|
243
|
+
)
|
|
244
|
+
if event.get("total_failures", 0) > 0:
|
|
245
|
+
tui.info(f"({event['total_failures']} failures recorded)")
|
|
246
|
+
tui.info("Resume with: --resume flag")
|
|
247
|
+
|
|
174
248
|
elif event["event"] == "generation_complete":
|
|
175
249
|
if live:
|
|
176
250
|
live.stop()
|
|
@@ -219,7 +293,7 @@ def create_dataset(
|
|
|
219
293
|
engine: DataSetGenerator,
|
|
220
294
|
topic_model: "TopicModel",
|
|
221
295
|
config: DeepFabricConfig,
|
|
222
|
-
num_samples: int | None = None,
|
|
296
|
+
num_samples: int | str | None = None,
|
|
223
297
|
batch_size: int | None = None,
|
|
224
298
|
include_system_message: bool | None = None,
|
|
225
299
|
provider: str | None = None, # noqa: ARG001
|
|
@@ -234,7 +308,7 @@ def create_dataset(
|
|
|
234
308
|
engine: DataSetGenerator instance
|
|
235
309
|
topic_model: TopicModel (Tree or Graph) to use for generation
|
|
236
310
|
config: DeepFabricConfig object
|
|
237
|
-
num_samples: Override for number of samples
|
|
311
|
+
num_samples: Override for number of samples (int, "auto", or percentage like "50%")
|
|
238
312
|
batch_size: Override for batch size
|
|
239
313
|
include_system_message: Override for including system message
|
|
240
314
|
provider: Override for LLM provider
|
|
@@ -268,7 +342,7 @@ async def create_dataset_async(
|
|
|
268
342
|
engine: DataSetGenerator,
|
|
269
343
|
topic_model: "TopicModel",
|
|
270
344
|
config: DeepFabricConfig,
|
|
271
|
-
num_samples: int | None = None,
|
|
345
|
+
num_samples: int | str | None = None,
|
|
272
346
|
batch_size: int | None = None,
|
|
273
347
|
include_system_message: bool | None = None,
|
|
274
348
|
provider: str | None = None, # noqa: ARG001
|
|
@@ -278,15 +352,34 @@ async def create_dataset_async(
|
|
|
278
352
|
) -> HFDataset:
|
|
279
353
|
output_config = config.get_output_config()
|
|
280
354
|
|
|
281
|
-
|
|
355
|
+
raw_num_samples = num_samples if num_samples is not None else output_config["num_samples"]
|
|
282
356
|
final_batch_size = batch_size or output_config["batch_size"]
|
|
283
357
|
|
|
358
|
+
# Resolve "auto" or percentage to actual count based on topic paths
|
|
359
|
+
topic_count = len(topic_model.get_all_paths())
|
|
360
|
+
final_num_samples = resolve_num_samples(raw_num_samples, topic_count)
|
|
361
|
+
|
|
362
|
+
# Log resolution for dynamic values
|
|
363
|
+
tui = get_dataset_tui()
|
|
364
|
+
if isinstance(raw_num_samples, str):
|
|
365
|
+
tui.info(f"Resolved num_samples: {raw_num_samples} → {final_num_samples} samples")
|
|
366
|
+
|
|
284
367
|
generation_params = config.get_generation_params(**(generation_overrides or {}))
|
|
285
368
|
final_model = model or generation_params.get("model_name", DEFAULT_MODEL)
|
|
286
369
|
|
|
370
|
+
# Convert total samples to number of steps (batches)
|
|
371
|
+
# The generator expects num_steps where total_samples = num_steps * batch_size
|
|
372
|
+
import math # noqa: PLC0415
|
|
373
|
+
|
|
374
|
+
final_num_steps = math.ceil(final_num_samples / final_batch_size)
|
|
375
|
+
|
|
376
|
+
tui.info(
|
|
377
|
+
f"Dataset generation: {final_num_samples} samples in {final_num_steps} steps "
|
|
378
|
+
f"(batch_size={final_batch_size})"
|
|
379
|
+
)
|
|
380
|
+
|
|
287
381
|
# Create progress reporter and attach TUI as observer for streaming feedback
|
|
288
382
|
progress_reporter = ProgressReporter()
|
|
289
|
-
tui = get_dataset_tui()
|
|
290
383
|
progress_reporter.attach(tui)
|
|
291
384
|
|
|
292
385
|
# Attach progress reporter to engine
|
|
@@ -294,7 +387,7 @@ async def create_dataset_async(
|
|
|
294
387
|
|
|
295
388
|
try:
|
|
296
389
|
generator = engine.create_data_with_events_async(
|
|
297
|
-
num_steps=
|
|
390
|
+
num_steps=final_num_steps,
|
|
298
391
|
batch_size=final_batch_size,
|
|
299
392
|
topic_model=topic_model,
|
|
300
393
|
model_name=final_model,
|
|
@@ -448,6 +541,7 @@ def _strip_nulls(obj: Any) -> Any:
|
|
|
448
541
|
|
|
449
542
|
def _save_jsonl_without_nulls(dataset: HFDataset, save_path: str) -> None:
|
|
450
543
|
"""Save HF Dataset to JSONL, stripping null values injected by Arrow schema."""
|
|
544
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
451
545
|
with open(save_path, "w") as f:
|
|
452
546
|
for row in dataset:
|
|
453
547
|
cleaned = _strip_nulls(dict(row))
|
|
@@ -516,9 +610,11 @@ def save_dataset(
|
|
|
516
610
|
_save_jsonl_without_nulls(dataset, save_path)
|
|
517
611
|
tui.success(f"Dataset saved to: {save_path}")
|
|
518
612
|
|
|
519
|
-
# Save failed samples if engine has any
|
|
520
|
-
if engine
|
|
521
|
-
|
|
613
|
+
# Save failed samples if engine has any (including flushed to checkpoint)
|
|
614
|
+
if engine:
|
|
615
|
+
all_failures = engine.get_all_failures()
|
|
616
|
+
if all_failures:
|
|
617
|
+
_save_failed_samples(save_path, all_failures, tui)
|
|
522
618
|
|
|
523
619
|
# Handle automatic uploads if configured
|
|
524
620
|
if config:
|
deepfabric/evaluation/parser.py
CHANGED
|
@@ -56,9 +56,9 @@ class GroundTruth(BaseModel):
|
|
|
56
56
|
default=None,
|
|
57
57
|
description="Reasoning style if cot",
|
|
58
58
|
)
|
|
59
|
-
agent_mode: Literal["single_turn"
|
|
59
|
+
agent_mode: Literal["single_turn"] | None = Field(
|
|
60
60
|
default=None,
|
|
61
|
-
description="Agent mode if tools are used",
|
|
61
|
+
description="Agent mode if tools are used (single_turn only)",
|
|
62
62
|
)
|
|
63
63
|
metadata: dict[str, str | int | float | bool] = Field(
|
|
64
64
|
default_factory=dict,
|
|
@@ -77,20 +77,20 @@ class GroundTruthParser:
|
|
|
77
77
|
self,
|
|
78
78
|
conversation_type: Literal["basic", "cot"],
|
|
79
79
|
reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = None,
|
|
80
|
-
agent_mode: Literal["single_turn"
|
|
80
|
+
agent_mode: Literal["single_turn"] | None = None,
|
|
81
81
|
):
|
|
82
82
|
"""Initialize parser with conversation configuration.
|
|
83
83
|
|
|
84
84
|
Args:
|
|
85
85
|
conversation_type: Type of conversation (basic, cot)
|
|
86
86
|
reasoning_style: Reasoning style for cot
|
|
87
|
-
agent_mode: Agent mode if tools are used
|
|
87
|
+
agent_mode: Agent mode if tools are used (single_turn only)
|
|
88
88
|
"""
|
|
89
89
|
self.conversation_type: Literal["basic", "cot"] = conversation_type
|
|
90
90
|
self.reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = (
|
|
91
91
|
reasoning_style
|
|
92
92
|
)
|
|
93
|
-
self.agent_mode: Literal["single_turn"
|
|
93
|
+
self.agent_mode: Literal["single_turn"] | None = agent_mode
|
|
94
94
|
|
|
95
95
|
def parse(self, conversation: Conversation) -> GroundTruth:
|
|
96
96
|
"""Extract ground truth from a conversation sample.
|
|
@@ -272,7 +272,7 @@ def parse_batch(
|
|
|
272
272
|
conversations: list[Conversation],
|
|
273
273
|
conversation_type: Literal["basic", "cot"],
|
|
274
274
|
reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = None,
|
|
275
|
-
agent_mode: Literal["single_turn"
|
|
275
|
+
agent_mode: Literal["single_turn"] | None = None,
|
|
276
276
|
) -> list[GroundTruth]:
|
|
277
277
|
"""Parse a batch of conversations to extract ground truth.
|
|
278
278
|
|
|
@@ -280,7 +280,7 @@ def parse_batch(
|
|
|
280
280
|
conversations: List of Conversation objects
|
|
281
281
|
conversation_type: Type of conversation
|
|
282
282
|
reasoning_style: Reasoning style if cot
|
|
283
|
-
agent_mode: Agent mode if tools are used
|
|
283
|
+
agent_mode: Agent mode if tools are used (single_turn only)
|
|
284
284
|
|
|
285
285
|
Returns:
|
|
286
286
|
List of GroundTruth objects
|