DeepFabric 4.10.0__py3-none-any.whl → 4.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/cli.py +83 -27
- deepfabric/cloud_upload.py +1 -1
- deepfabric/config.py +6 -4
- deepfabric/constants.py +1 -1
- deepfabric/dataset_manager.py +264 -62
- deepfabric/generator.py +687 -82
- deepfabric/graph.py +25 -1
- deepfabric/llm/retry_handler.py +28 -9
- deepfabric/progress.py +42 -0
- deepfabric/topic_manager.py +22 -2
- deepfabric/topic_model.py +26 -0
- deepfabric/tree.py +41 -16
- deepfabric/tui.py +448 -349
- deepfabric/utils.py +4 -1
- {deepfabric-4.10.0.dist-info → deepfabric-4.11.0.dist-info}/METADATA +4 -2
- {deepfabric-4.10.0.dist-info → deepfabric-4.11.0.dist-info}/RECORD +19 -19
- {deepfabric-4.10.0.dist-info → deepfabric-4.11.0.dist-info}/licenses/LICENSE +1 -1
- {deepfabric-4.10.0.dist-info → deepfabric-4.11.0.dist-info}/WHEEL +0 -0
- {deepfabric-4.10.0.dist-info → deepfabric-4.11.0.dist-info}/entry_points.txt +0 -0
deepfabric/dataset_manager.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import contextlib
|
|
3
3
|
import json
|
|
4
|
+
import math
|
|
4
5
|
import os
|
|
5
6
|
import traceback
|
|
6
7
|
|
|
@@ -12,6 +13,15 @@ from typing import TYPE_CHECKING, Any
|
|
|
12
13
|
from datasets import Dataset as HFDataset
|
|
13
14
|
from rich.layout import Layout
|
|
14
15
|
from rich.live import Live
|
|
16
|
+
from rich.progress import (
|
|
17
|
+
BarColumn,
|
|
18
|
+
MofNCompleteColumn,
|
|
19
|
+
Progress,
|
|
20
|
+
SpinnerColumn,
|
|
21
|
+
TextColumn,
|
|
22
|
+
TimeElapsedColumn,
|
|
23
|
+
)
|
|
24
|
+
from rich.table import Column
|
|
15
25
|
|
|
16
26
|
from .config import DeepFabricConfig
|
|
17
27
|
from .config_manager import DEFAULT_MODEL
|
|
@@ -93,7 +103,14 @@ async def handle_dataset_events_async(
|
|
|
93
103
|
footer_prog = None
|
|
94
104
|
task = None
|
|
95
105
|
live = None
|
|
96
|
-
|
|
106
|
+
simple_progress = None # Progress bar for simple/headless mode
|
|
107
|
+
simple_progress_task = None
|
|
108
|
+
simple_checkpoint_task = None # Checkpoint progress task for simple mode
|
|
109
|
+
checkpoint_interval = 0
|
|
110
|
+
headless_completed = 0 # Counter for non-TTY headless mode
|
|
111
|
+
headless_total = 0
|
|
112
|
+
headless_batch_size = 0 # Print threshold for headless mode
|
|
113
|
+
headless_last_printed = 0 # Last count at which we printed
|
|
97
114
|
|
|
98
115
|
final_result: HFDataset | None = None
|
|
99
116
|
try:
|
|
@@ -101,9 +118,26 @@ async def handle_dataset_events_async(
|
|
|
101
118
|
if isinstance(event, dict) and "event" in event:
|
|
102
119
|
if event["event"] == "generation_start":
|
|
103
120
|
settings = _get_tui_settings()
|
|
121
|
+
# Handle both cycle-based and step-based event formats
|
|
122
|
+
# Cycle-based: unique_topics, cycles_needed, concurrency
|
|
123
|
+
# Step-based: num_steps, batch_size
|
|
124
|
+
is_cycle_based = "cycles_needed" in event
|
|
125
|
+
if is_cycle_based:
|
|
126
|
+
display_steps = event.get("cycles_needed", 1)
|
|
127
|
+
display_batch_size = event.get("concurrency", 1)
|
|
128
|
+
else:
|
|
129
|
+
display_steps = event.get("num_steps", 1)
|
|
130
|
+
display_batch_size = event.get("batch_size", 1)
|
|
104
131
|
# Build header and params panels for layout
|
|
105
132
|
header_panel, params_panel = tui.build_generation_panels(
|
|
106
|
-
event["model_name"],
|
|
133
|
+
event["model_name"],
|
|
134
|
+
display_steps,
|
|
135
|
+
display_batch_size,
|
|
136
|
+
total_samples=event["total_samples"],
|
|
137
|
+
is_cycle_based=is_cycle_based,
|
|
138
|
+
unique_topics=event.get("unique_topics", 0),
|
|
139
|
+
final_cycle_size=event.get("final_cycle_size", 0),
|
|
140
|
+
checkpoint_interval=event.get("checkpoint_interval", 0),
|
|
107
141
|
)
|
|
108
142
|
# Capture context for the run
|
|
109
143
|
tui.root_topic_prompt = event.get("root_topic_prompt")
|
|
@@ -112,9 +146,10 @@ async def handle_dataset_events_async(
|
|
|
112
146
|
if settings.mode == "rich":
|
|
113
147
|
# Initialize status tracking
|
|
114
148
|
tui.init_status(
|
|
115
|
-
total_steps=
|
|
149
|
+
total_steps=display_steps,
|
|
116
150
|
total_samples=event["total_samples"],
|
|
117
151
|
checkpoint_enabled=event.get("checkpoint_enabled", False),
|
|
152
|
+
is_cycle_based=is_cycle_based,
|
|
118
153
|
)
|
|
119
154
|
|
|
120
155
|
# Build layout with footer card
|
|
@@ -144,12 +179,15 @@ async def handle_dataset_events_async(
|
|
|
144
179
|
)
|
|
145
180
|
layout["main"].split_row(left, right)
|
|
146
181
|
|
|
182
|
+
prog_total = event["total_samples"]
|
|
183
|
+
resumed_samples = event.get("resumed_samples", 0)
|
|
184
|
+
|
|
147
185
|
# Footer run status
|
|
148
186
|
footer_prog = tui.tui.create_footer(layout, title="Run Status")
|
|
149
187
|
task = footer_prog.add_task(
|
|
150
188
|
"Generating dataset samples",
|
|
151
|
-
total=
|
|
152
|
-
completed=
|
|
189
|
+
total=prog_total,
|
|
190
|
+
completed=min(resumed_samples, prog_total),
|
|
153
191
|
)
|
|
154
192
|
|
|
155
193
|
# Use alternate screen to avoid scroll trails; leave a clean terminal
|
|
@@ -162,15 +200,97 @@ async def handle_dataset_events_async(
|
|
|
162
200
|
tui.live_display = live # Give TUI reference to update it
|
|
163
201
|
tui.live_layout = layout # Allow TUI to update panes
|
|
164
202
|
live.start()
|
|
203
|
+
if resumed_samples >= prog_total:
|
|
204
|
+
tui.log_event(
|
|
205
|
+
f"Checkpoint already complete: {resumed_samples} samples "
|
|
206
|
+
f"(target: {prog_total})"
|
|
207
|
+
)
|
|
165
208
|
else:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
209
|
+
prog_total = event["total_samples"]
|
|
210
|
+
resumed_samples = event.get("resumed_samples", 0)
|
|
211
|
+
|
|
212
|
+
# Simple/headless mode: runtime summary then progress bar
|
|
213
|
+
tui.console.print("\n[bold cyan]Dataset Generation[/bold cyan]")
|
|
214
|
+
model_line = f"Model: {event['model_name']}"
|
|
215
|
+
if event.get("topic_model_type"):
|
|
216
|
+
topic_type = event["topic_model_type"]
|
|
217
|
+
if is_cycle_based and event.get("unique_topics"):
|
|
218
|
+
model_line += (
|
|
219
|
+
f" ({topic_type}, {event['unique_topics']} unique topics)"
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
model_line += f" ({topic_type})"
|
|
223
|
+
tui.info(model_line)
|
|
224
|
+
|
|
225
|
+
if is_cycle_based:
|
|
226
|
+
output_line = (
|
|
227
|
+
f"Output: num_samples={prog_total}, "
|
|
228
|
+
f"concurrency={display_batch_size}"
|
|
229
|
+
)
|
|
230
|
+
else:
|
|
231
|
+
output_line = (
|
|
232
|
+
f"Output: num_samples={prog_total}, batch_size={display_batch_size}"
|
|
233
|
+
)
|
|
234
|
+
tui.info(output_line)
|
|
235
|
+
|
|
236
|
+
if is_cycle_based:
|
|
237
|
+
cycles = event.get("cycles_needed", 1)
|
|
238
|
+
unique = event.get("unique_topics", 0)
|
|
239
|
+
tui.info(
|
|
240
|
+
f" → Cycles needed: {cycles} "
|
|
241
|
+
f"({prog_total} samples ÷ {unique} unique topics)"
|
|
242
|
+
)
|
|
243
|
+
final_cycle = event.get("final_cycle_size", 0)
|
|
244
|
+
if final_cycle and unique and final_cycle < unique:
|
|
245
|
+
tui.info(f" → Final cycle: {final_cycle} topics (partial)")
|
|
246
|
+
|
|
247
|
+
tui.console.print()
|
|
248
|
+
cp_interval = event.get("checkpoint_interval")
|
|
249
|
+
if cp_interval and cp_interval > 0:
|
|
250
|
+
total_cp = math.ceil(prog_total / cp_interval)
|
|
251
|
+
tui.info(
|
|
252
|
+
f"Checkpoint: every {cp_interval} samples "
|
|
253
|
+
f"({total_cp} total checkpoints)"
|
|
254
|
+
)
|
|
255
|
+
tui.console.print()
|
|
256
|
+
|
|
257
|
+
if resumed_samples >= prog_total:
|
|
258
|
+
# Checkpoint already has enough samples
|
|
259
|
+
tui.success(
|
|
260
|
+
f"Checkpoint already complete: {resumed_samples} samples "
|
|
261
|
+
f"(target: {prog_total})"
|
|
262
|
+
)
|
|
263
|
+
elif tui.console.is_terminal:
|
|
264
|
+
simple_progress = Progress(
|
|
265
|
+
SpinnerColumn(),
|
|
266
|
+
TextColumn("[progress.description]{task.description}"),
|
|
267
|
+
BarColumn(),
|
|
268
|
+
MofNCompleteColumn(table_column=Column(justify="right")),
|
|
269
|
+
TimeElapsedColumn(),
|
|
270
|
+
console=tui.console,
|
|
271
|
+
)
|
|
272
|
+
simple_progress_task = simple_progress.add_task(
|
|
273
|
+
"Generating",
|
|
274
|
+
total=prog_total,
|
|
275
|
+
completed=resumed_samples,
|
|
276
|
+
)
|
|
277
|
+
simple_progress.start()
|
|
278
|
+
tui.simple_progress = simple_progress
|
|
279
|
+
# Add checkpoint progress task if interval is set
|
|
280
|
+
checkpoint_interval = event.get("checkpoint_interval") or 0
|
|
281
|
+
if checkpoint_interval > 0:
|
|
282
|
+
simple_checkpoint_task = simple_progress.add_task(
|
|
283
|
+
"Next checkpoint",
|
|
284
|
+
total=checkpoint_interval,
|
|
285
|
+
completed=0,
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
# Headless (non-TTY): track progress with counters
|
|
289
|
+
headless_total = prog_total
|
|
290
|
+
headless_completed = resumed_samples
|
|
291
|
+
headless_batch_size = display_batch_size
|
|
292
|
+
headless_last_printed = resumed_samples
|
|
293
|
+
checkpoint_interval = event.get("checkpoint_interval") or 0
|
|
174
294
|
elif event["event"] == "step_complete":
|
|
175
295
|
samples_generated = event.get("samples_generated", 0)
|
|
176
296
|
if footer_prog and task is not None:
|
|
@@ -182,39 +302,80 @@ async def handle_dataset_events_async(
|
|
|
182
302
|
tui.status_step_complete(
|
|
183
303
|
samples_generated, int(event.get("failed_in_step", 0))
|
|
184
304
|
)
|
|
185
|
-
elif
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
step_msg += f" (-{failed_in_step} failed)"
|
|
194
|
-
step_msg += f" (total {simple_task['count']}/{simple_task['total']})"
|
|
195
|
-
|
|
196
|
-
# Display with appropriate style based on failures
|
|
197
|
-
if failed_in_step > 0:
|
|
198
|
-
tui.warning(step_msg)
|
|
199
|
-
else:
|
|
200
|
-
tui.info(step_msg)
|
|
201
|
-
|
|
202
|
-
# Show retry summary if there were retries
|
|
203
|
-
if retry_summary:
|
|
204
|
-
tui.console.print(f" [dim]{retry_summary}[/dim]")
|
|
205
|
-
|
|
206
|
-
# Clear retries for next step
|
|
305
|
+
elif simple_progress is not None and simple_progress_task is not None:
|
|
306
|
+
with contextlib.suppress(Exception):
|
|
307
|
+
simple_progress.update(simple_progress_task, advance=samples_generated)
|
|
308
|
+
if simple_checkpoint_task is not None and samples_generated > 0:
|
|
309
|
+
with contextlib.suppress(Exception):
|
|
310
|
+
simple_progress.update(
|
|
311
|
+
simple_checkpoint_task, advance=samples_generated
|
|
312
|
+
)
|
|
207
313
|
tui.clear_step_retries()
|
|
314
|
+
elif headless_total > 0 and samples_generated > 0:
|
|
315
|
+
headless_completed += samples_generated
|
|
316
|
+
if (
|
|
317
|
+
headless_completed - headless_last_printed >= headless_batch_size
|
|
318
|
+
or headless_completed >= headless_total
|
|
319
|
+
):
|
|
320
|
+
tui.info(f"Generated {headless_completed}/{headless_total} samples")
|
|
321
|
+
headless_last_printed = headless_completed
|
|
208
322
|
elif event["event"] == "step_start":
|
|
209
323
|
# Keep status panel in sync
|
|
210
324
|
step = int(event.get("step", 0))
|
|
211
325
|
total = int(event.get("total_steps", 0))
|
|
212
326
|
tui.status_step_start(step, total)
|
|
213
327
|
|
|
328
|
+
elif event["event"] == "cycle_start":
|
|
329
|
+
# Cycle-based generation: keep status panel in sync
|
|
330
|
+
cycle = int(event.get("cycle", 0))
|
|
331
|
+
total_cycles = int(event.get("total_cycles", 0))
|
|
332
|
+
tui.status_step_start(cycle, total_cycles)
|
|
333
|
+
|
|
334
|
+
elif event["event"] == "batch_complete":
|
|
335
|
+
# Per-batch progress: advance bars after each concurrency batch
|
|
336
|
+
batch_generated = event.get("samples_generated", 0)
|
|
337
|
+
batch_failed = event.get("samples_failed", 0)
|
|
338
|
+
if footer_prog and task is not None:
|
|
339
|
+
if batch_generated > 0:
|
|
340
|
+
with contextlib.suppress(Exception):
|
|
341
|
+
footer_prog.update(task, advance=batch_generated)
|
|
342
|
+
tui.status_step_complete(batch_generated, batch_failed)
|
|
343
|
+
elif simple_progress is not None and simple_progress_task is not None:
|
|
344
|
+
with contextlib.suppress(Exception):
|
|
345
|
+
simple_progress.update(simple_progress_task, advance=batch_generated)
|
|
346
|
+
if simple_checkpoint_task is not None and batch_generated > 0:
|
|
347
|
+
with contextlib.suppress(Exception):
|
|
348
|
+
simple_progress.update(
|
|
349
|
+
simple_checkpoint_task, advance=batch_generated
|
|
350
|
+
)
|
|
351
|
+
elif headless_total > 0 and batch_generated > 0:
|
|
352
|
+
headless_completed += batch_generated
|
|
353
|
+
if (
|
|
354
|
+
headless_completed - headless_last_printed >= headless_batch_size
|
|
355
|
+
or headless_completed >= headless_total
|
|
356
|
+
):
|
|
357
|
+
tui.info(f"Generated {headless_completed}/{headless_total} samples")
|
|
358
|
+
headless_last_printed = headless_completed
|
|
359
|
+
|
|
360
|
+
elif event["event"] == "cycle_complete":
|
|
361
|
+
# Cycle-based generation: log cycle summary (progress already advanced by batch_complete)
|
|
362
|
+
samples_in_cycle = event.get("samples_in_cycle", 0)
|
|
363
|
+
failures_in_cycle = event.get("failures_in_cycle", 0)
|
|
364
|
+
if footer_prog and task is not None:
|
|
365
|
+
tui.log_event(
|
|
366
|
+
f"✓ Cycle {event.get('cycle')}: "
|
|
367
|
+
f"+{samples_in_cycle} samples"
|
|
368
|
+
+ (f" (-{failures_in_cycle} failed)" if failures_in_cycle else "")
|
|
369
|
+
)
|
|
370
|
+
elif headless_total > 0:
|
|
371
|
+
msg = f"Cycle {event.get('cycle')}: +{samples_in_cycle} samples"
|
|
372
|
+
if failures_in_cycle:
|
|
373
|
+
msg += f" (-{failures_in_cycle} failed)"
|
|
374
|
+
tui.info(msg)
|
|
375
|
+
|
|
214
376
|
elif event["event"] == "checkpoint_saved":
|
|
215
377
|
# Display checkpoint save notification
|
|
216
378
|
total_samples = event.get("total_samples", 0)
|
|
217
|
-
total_failures = event.get("total_failures", 0)
|
|
218
379
|
is_final = event.get("final", False)
|
|
219
380
|
|
|
220
381
|
if footer_prog and task is not None:
|
|
@@ -224,19 +385,26 @@ async def handle_dataset_events_async(
|
|
|
224
385
|
else:
|
|
225
386
|
tui.log_event(f"💾 Checkpoint: {total_samples} samples")
|
|
226
387
|
tui.status_checkpoint_saved(total_samples)
|
|
227
|
-
elif
|
|
228
|
-
# Simple mode:
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
388
|
+
elif simple_progress is not None:
|
|
389
|
+
# Simple mode: reset checkpoint progress bar instead of stacking print lines
|
|
390
|
+
if simple_checkpoint_task is not None and not is_final:
|
|
391
|
+
with contextlib.suppress(Exception):
|
|
392
|
+
simple_progress.reset(
|
|
393
|
+
simple_checkpoint_task, total=checkpoint_interval
|
|
394
|
+
)
|
|
395
|
+
elif headless_total > 0:
|
|
232
396
|
if is_final:
|
|
233
|
-
|
|
234
|
-
|
|
397
|
+
tui.info(f"Checkpoint (final): {total_samples} samples saved")
|
|
398
|
+
else:
|
|
399
|
+
tui.info(f"Checkpoint: {total_samples} samples saved")
|
|
235
400
|
|
|
236
401
|
elif event["event"] == "generation_stopped":
|
|
237
402
|
# Graceful stop at checkpoint
|
|
238
403
|
if live:
|
|
239
404
|
live.stop()
|
|
405
|
+
if simple_progress is not None:
|
|
406
|
+
simple_progress.stop()
|
|
407
|
+
tui.simple_progress = None
|
|
240
408
|
tui.console.print()
|
|
241
409
|
tui.success(
|
|
242
410
|
f"Gracefully stopped: {event['total_samples']} samples saved to checkpoint"
|
|
@@ -248,13 +416,33 @@ async def handle_dataset_events_async(
|
|
|
248
416
|
elif event["event"] == "generation_complete":
|
|
249
417
|
if live:
|
|
250
418
|
live.stop()
|
|
419
|
+
if simple_progress is not None:
|
|
420
|
+
simple_progress.stop()
|
|
421
|
+
tui.simple_progress = None
|
|
251
422
|
tui.console.print() # Add blank line after live display
|
|
252
423
|
tui.success(f"Successfully generated {event['total_samples']} samples")
|
|
424
|
+
|
|
425
|
+
# Show accounting summary
|
|
426
|
+
expected = event.get("expected_samples", 0)
|
|
427
|
+
topics_exhausted = event.get("topics_exhausted", 0)
|
|
428
|
+
unaccounted = event.get("unaccounted", 0)
|
|
253
429
|
tui.log_event(
|
|
254
|
-
f"Done •
|
|
430
|
+
f"Done • expected={expected} generated={event['total_samples']} "
|
|
431
|
+
f"failed={event['failed_samples']} topics_exhausted={topics_exhausted} "
|
|
432
|
+
f"unaccounted={unaccounted}"
|
|
255
433
|
)
|
|
256
434
|
if event["failed_samples"] > 0:
|
|
257
435
|
tui.warning(f"Failed to generate {event['failed_samples']} samples")
|
|
436
|
+
if topics_exhausted > 0:
|
|
437
|
+
tui.warning(
|
|
438
|
+
f"Topics exhausted: {topics_exhausted} samples could not be generated "
|
|
439
|
+
f"(not enough unique topics for requested sample count)"
|
|
440
|
+
)
|
|
441
|
+
if unaccounted > 0:
|
|
442
|
+
tui.error(
|
|
443
|
+
f"WARNING: {unaccounted} samples unaccounted for "
|
|
444
|
+
f"(neither generated nor recorded as failures)"
|
|
445
|
+
)
|
|
258
446
|
|
|
259
447
|
# Show detailed failure information in debug mode
|
|
260
448
|
if debug and engine and hasattr(engine, "failed_samples"):
|
|
@@ -275,6 +463,9 @@ async def handle_dataset_events_async(
|
|
|
275
463
|
except Exception as e:
|
|
276
464
|
if live:
|
|
277
465
|
live.stop()
|
|
466
|
+
if simple_progress is not None:
|
|
467
|
+
simple_progress.stop()
|
|
468
|
+
tui.simple_progress = None
|
|
278
469
|
if debug:
|
|
279
470
|
get_tui().error(f"🔍 Debug: Full traceback:\n{traceback.format_exc()}")
|
|
280
471
|
get_tui().error(f"Dataset generation failed: {str(e)}")
|
|
@@ -367,23 +558,17 @@ async def create_dataset_async(
|
|
|
367
558
|
generation_params = config.get_generation_params(**(generation_overrides or {}))
|
|
368
559
|
final_model = model or generation_params.get("model_name", DEFAULT_MODEL)
|
|
369
560
|
|
|
370
|
-
#
|
|
371
|
-
# The generator expects num_steps where total_samples = num_steps * batch_size
|
|
372
|
-
import math # noqa: PLC0415
|
|
373
|
-
|
|
561
|
+
# Still compute num_steps for backward compat with the generator's step-based path
|
|
374
562
|
final_num_steps = math.ceil(final_num_samples / final_batch_size)
|
|
375
563
|
|
|
376
|
-
tui.info(
|
|
377
|
-
f"Dataset generation: {final_num_samples} samples in {final_num_steps} steps "
|
|
378
|
-
f"(batch_size={final_batch_size})"
|
|
379
|
-
)
|
|
380
|
-
|
|
381
564
|
# Create progress reporter and attach TUI as observer for streaming feedback
|
|
382
565
|
progress_reporter = ProgressReporter()
|
|
383
566
|
progress_reporter.attach(tui)
|
|
384
567
|
|
|
385
|
-
# Attach progress reporter to engine
|
|
568
|
+
# Attach progress reporter to engine and its LLM retry handler
|
|
386
569
|
engine.progress_reporter = progress_reporter
|
|
570
|
+
if hasattr(engine, "llm_client"):
|
|
571
|
+
engine.llm_client.retry_handler.progress_reporter = progress_reporter
|
|
387
572
|
|
|
388
573
|
try:
|
|
389
574
|
generator = engine.create_data_with_events_async(
|
|
@@ -548,19 +733,30 @@ def _save_jsonl_without_nulls(dataset: HFDataset, save_path: str) -> None:
|
|
|
548
733
|
f.write(json.dumps(cleaned, separators=(",", ":")) + "\n")
|
|
549
734
|
|
|
550
735
|
|
|
551
|
-
def _save_failed_samples(
|
|
552
|
-
|
|
736
|
+
def _save_failed_samples(
|
|
737
|
+
save_path: str,
|
|
738
|
+
failed_samples: list,
|
|
739
|
+
tui,
|
|
740
|
+
use_path_directly: bool = False,
|
|
741
|
+
) -> None:
|
|
742
|
+
"""Save failed samples to a file.
|
|
553
743
|
|
|
554
744
|
Args:
|
|
555
|
-
save_path: Path
|
|
745
|
+
save_path: Path for failures file. If use_path_directly is False, this is treated as the
|
|
746
|
+
main dataset path and a timestamped filename is generated alongside it.
|
|
556
747
|
failed_samples: List of failed samples - can be dicts with 'error' and 'raw_content' keys,
|
|
557
748
|
or plain strings/other types for legacy compatibility
|
|
558
749
|
tui: TUI instance for output
|
|
750
|
+
use_path_directly: If True, use save_path as-is. If False, generate timestamped filename.
|
|
559
751
|
"""
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
752
|
+
if use_path_directly:
|
|
753
|
+
# Use the provided path directly
|
|
754
|
+
failures_path = save_path
|
|
755
|
+
else:
|
|
756
|
+
# Generate timestamped filename: my-dataset.jsonl -> my-dataset_failures_20231130_143022.jsonl
|
|
757
|
+
base_path = save_path.rsplit(".", 1)[0] if "." in save_path else save_path
|
|
758
|
+
timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
|
|
759
|
+
failures_path = f"{base_path}_failures_{timestamp}.jsonl"
|
|
564
760
|
|
|
565
761
|
try:
|
|
566
762
|
with open(failures_path, "w") as f:
|
|
@@ -589,6 +785,7 @@ def save_dataset(
|
|
|
589
785
|
save_path: str,
|
|
590
786
|
config: DeepFabricConfig | None = None,
|
|
591
787
|
engine: DataSetGenerator | None = None,
|
|
788
|
+
failures_path: str | None = None,
|
|
592
789
|
) -> None:
|
|
593
790
|
"""
|
|
594
791
|
Save dataset to file.
|
|
@@ -598,6 +795,7 @@ def save_dataset(
|
|
|
598
795
|
save_path: Path where to save the dataset
|
|
599
796
|
config: Optional configuration for upload settings
|
|
600
797
|
engine: Optional DataSetGenerator to save failed samples from
|
|
798
|
+
failures_path: Optional explicit path for failures file (overrides auto-generated path)
|
|
601
799
|
|
|
602
800
|
Raises:
|
|
603
801
|
ConfigurationError: If saving fails
|
|
@@ -614,7 +812,11 @@ def save_dataset(
|
|
|
614
812
|
if engine:
|
|
615
813
|
all_failures = engine.get_all_failures()
|
|
616
814
|
if all_failures:
|
|
617
|
-
|
|
815
|
+
# Use explicit failures_path if provided, otherwise auto-generate from save_path
|
|
816
|
+
actual_failures_path = failures_path or save_path
|
|
817
|
+
_save_failed_samples(
|
|
818
|
+
actual_failures_path, all_failures, tui, use_path_directly=bool(failures_path)
|
|
819
|
+
)
|
|
618
820
|
|
|
619
821
|
# Handle automatic uploads if configured
|
|
620
822
|
if config:
|