DeepFabric 4.10.1__py3-none-any.whl → 4.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -778,7 +778,7 @@ def handle_cloud_upload( # noqa: PLR0911
778
778
 
779
779
  # Build prompt based on what's available
780
780
  if has_dataset and has_graph:
781
- prompt_text = " Upload to DeepFabric Cloud?"
781
+ prompt_text = " Upload graph and dataset to DeepFabric Cloud?"
782
782
  hint = "[dim](Y=both, n=skip, c=choose)[/dim]"
783
783
  elif has_dataset:
784
784
  prompt_text = " Upload dataset to DeepFabric Cloud?"
deepfabric/config.py CHANGED
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator
8
8
 
9
9
  from .constants import (
10
10
  DEFAULT_MAX_RETRIES,
11
+ DEFAULT_MAX_TOKENS,
11
12
  DEFAULT_MODEL,
12
13
  DEFAULT_PROVIDER,
13
14
  DEFAULT_SAMPLE_RETRIES,
@@ -86,7 +87,7 @@ class TopicsConfig(BaseModel):
86
87
  ..., min_length=1, description="The initial prompt to start topic generation"
87
88
  )
88
89
  mode: Literal["tree", "graph"] = Field(
89
- default="tree", description="Topic generation mode: tree or graph"
90
+ default="graph", description="Topic generation mode: tree or graph"
90
91
  )
91
92
  system_prompt: str = Field(
92
93
  default="", description="System prompt for topic exploration and generation"
@@ -109,6 +110,11 @@ class TopicsConfig(BaseModel):
109
110
  le=20,
110
111
  description="Maximum concurrent LLM calls during graph expansion (helps avoid rate limits)",
111
112
  )
113
+ max_tokens: int = Field(
114
+ default=DEFAULT_MAX_TOKENS,
115
+ ge=1,
116
+ description="Maximum tokens for topic generation LLM calls",
117
+ )
112
118
  save_as: str | None = Field(default=None, description="Where to save the generated topics")
113
119
  prompt_style: Literal["default", "isolated", "anchored"] = Field(
114
120
  default="default",
@@ -589,6 +595,7 @@ See documentation for full examples.
589
595
  "depth": self.topics.depth,
590
596
  "degree": self.topics.degree,
591
597
  "max_concurrent": self.topics.max_concurrent,
598
+ "max_tokens": self.topics.max_tokens,
592
599
  "prompt_style": self.topics.prompt_style,
593
600
  }
594
601
 
@@ -628,11 +635,13 @@ See documentation for full examples.
628
635
  "output_save_as": self.output.save_as,
629
636
  # Checkpoint config (nested inside output)
630
637
  # Note: checkpoint_path can be None, meaning "auto-resolve" at runtime
631
- "checkpoint_interval": self.output.checkpoint.interval if self.output.checkpoint else None,
638
+ "checkpoint_interval": self.output.checkpoint.interval
639
+ if self.output.checkpoint
640
+ else None,
632
641
  "checkpoint_path": self.output.checkpoint.path if self.output.checkpoint else None,
633
- "checkpoint_retry_failed": (
634
- self.output.checkpoint.retry_failed if self.output.checkpoint else False
635
- ),
642
+ "checkpoint_retry_failed": self.output.checkpoint.retry_failed
643
+ if self.output.checkpoint
644
+ else False,
636
645
  }
637
646
 
638
647
  # Tool config
@@ -37,7 +37,7 @@ def load_config( # noqa: PLR0913
37
37
  topics_save_as: str | None = None,
38
38
  output_save_as: str | None = None,
39
39
  include_system_message: bool | None = None,
40
- mode: str = "tree",
40
+ mode: str | None = None,
41
41
  # Modular conversation configuration
42
42
  conversation_type: str | None = None,
43
43
  reasoning_style: str | None = None,
@@ -83,6 +83,8 @@ def load_config( # noqa: PLR0913
83
83
  except Exception as e:
84
84
  raise ConfigurationError(f"Error loading config file: {str(e)}") from e
85
85
  else:
86
+ if mode is not None:
87
+ config.topics.mode = mode
86
88
  return config
87
89
 
88
90
  # No config file provided - create minimal configuration from CLI args
@@ -92,6 +94,9 @@ def load_config( # noqa: PLR0913
92
94
  tui = get_tui()
93
95
  tui.info("No config file provided - using CLI parameters")
94
96
 
97
+ # Default to graph mode when no config file and no explicit mode
98
+ mode = mode or "graph"
99
+
95
100
  # Create minimal config dict with new structure
96
101
  default_prompt = generation_system_prompt or "You are a helpful AI assistant."
97
102
 
deepfabric/constants.py CHANGED
@@ -93,7 +93,7 @@ FAILED_TREE_SUFFIX = "_failed.jsonl"
93
93
  CHECKPOINT_METADATA_SUFFIX = ".checkpoint.json"
94
94
  CHECKPOINT_SAMPLES_SUFFIX = ".checkpoint.jsonl"
95
95
  CHECKPOINT_FAILURES_SUFFIX = ".checkpoint.failures.jsonl"
96
- CHECKPOINT_VERSION = 3 # Increment when checkpoint format changes
96
+ CHECKPOINT_VERSION = 4 # v4: (uuid, cycle) tuple tracking for cycle-based generation
97
97
 
98
98
  # Stream simulation defaults
99
99
  STREAM_SIM_CHUNK_SIZE = 8 # characters per chunk
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import contextlib
3
3
  import json
4
+ import math
4
5
  import os
5
6
  import traceback
6
7
 
@@ -12,6 +13,15 @@ from typing import TYPE_CHECKING, Any
12
13
  from datasets import Dataset as HFDataset
13
14
  from rich.layout import Layout
14
15
  from rich.live import Live
16
+ from rich.progress import (
17
+ BarColumn,
18
+ MofNCompleteColumn,
19
+ Progress,
20
+ SpinnerColumn,
21
+ TextColumn,
22
+ TimeElapsedColumn,
23
+ )
24
+ from rich.table import Column
15
25
 
16
26
  from .config import DeepFabricConfig
17
27
  from .config_manager import DEFAULT_MODEL
@@ -93,7 +103,14 @@ async def handle_dataset_events_async(
93
103
  footer_prog = None
94
104
  task = None
95
105
  live = None
96
- simple_task = None
106
+ simple_progress = None # Progress bar for simple/headless mode
107
+ simple_progress_task = None
108
+ simple_checkpoint_task = None # Checkpoint progress task for simple mode
109
+ checkpoint_interval = 0
110
+ headless_completed = 0 # Counter for non-TTY headless mode
111
+ headless_total = 0
112
+ headless_batch_size = 0 # Print threshold for headless mode
113
+ headless_last_printed = 0 # Last count at which we printed
97
114
 
98
115
  final_result: HFDataset | None = None
99
116
  try:
@@ -101,9 +118,26 @@ async def handle_dataset_events_async(
101
118
  if isinstance(event, dict) and "event" in event:
102
119
  if event["event"] == "generation_start":
103
120
  settings = _get_tui_settings()
121
+ # Handle both cycle-based and step-based event formats
122
+ # Cycle-based: unique_topics, cycles_needed, concurrency
123
+ # Step-based: num_steps, batch_size
124
+ is_cycle_based = "cycles_needed" in event
125
+ if is_cycle_based:
126
+ display_steps = event.get("cycles_needed", 1)
127
+ display_batch_size = event.get("concurrency", 1)
128
+ else:
129
+ display_steps = event.get("num_steps", 1)
130
+ display_batch_size = event.get("batch_size", 1)
104
131
  # Build header and params panels for layout
105
132
  header_panel, params_panel = tui.build_generation_panels(
106
- event["model_name"], event["num_steps"], event["batch_size"]
133
+ event["model_name"],
134
+ display_steps,
135
+ display_batch_size,
136
+ total_samples=event["total_samples"],
137
+ is_cycle_based=is_cycle_based,
138
+ unique_topics=event.get("unique_topics", 0),
139
+ final_cycle_size=event.get("final_cycle_size", 0),
140
+ checkpoint_interval=event.get("checkpoint_interval", 0),
107
141
  )
108
142
  # Capture context for the run
109
143
  tui.root_topic_prompt = event.get("root_topic_prompt")
@@ -112,9 +146,10 @@ async def handle_dataset_events_async(
112
146
  if settings.mode == "rich":
113
147
  # Initialize status tracking
114
148
  tui.init_status(
115
- total_steps=event["num_steps"],
149
+ total_steps=display_steps,
116
150
  total_samples=event["total_samples"],
117
151
  checkpoint_enabled=event.get("checkpoint_enabled", False),
152
+ is_cycle_based=is_cycle_based,
118
153
  )
119
154
 
120
155
  # Build layout with footer card
@@ -144,12 +179,15 @@ async def handle_dataset_events_async(
144
179
  )
145
180
  layout["main"].split_row(left, right)
146
181
 
182
+ prog_total = event["total_samples"]
183
+ resumed_samples = event.get("resumed_samples", 0)
184
+
147
185
  # Footer run status
148
186
  footer_prog = tui.tui.create_footer(layout, title="Run Status")
149
187
  task = footer_prog.add_task(
150
188
  "Generating dataset samples",
151
- total=event["total_samples"],
152
- completed=event.get("resumed_samples", 0),
189
+ total=prog_total,
190
+ completed=min(resumed_samples, prog_total),
153
191
  )
154
192
 
155
193
  # Use alternate screen to avoid scroll trails; leave a clean terminal
@@ -162,15 +200,97 @@ async def handle_dataset_events_async(
162
200
  tui.live_display = live # Give TUI reference to update it
163
201
  tui.live_layout = layout # Allow TUI to update panes
164
202
  live.start()
203
+ if resumed_samples >= prog_total:
204
+ tui.log_event(
205
+ f"Checkpoint already complete: {resumed_samples} samples "
206
+ f"(target: {prog_total})"
207
+ )
165
208
  else:
166
- # Simple/headless mode: print and proceed without Live
167
- tui.show_generation_header(
168
- event["model_name"], event["num_steps"], event["batch_size"]
169
- )
170
- simple_task = {
171
- "count": event.get("resumed_samples", 0),
172
- "total": event["total_samples"],
173
- }
209
+ prog_total = event["total_samples"]
210
+ resumed_samples = event.get("resumed_samples", 0)
211
+
212
+ # Simple/headless mode: runtime summary then progress bar
213
+ tui.console.print("\n[bold cyan]Dataset Generation[/bold cyan]")
214
+ model_line = f"Model: {event['model_name']}"
215
+ if event.get("topic_model_type"):
216
+ topic_type = event["topic_model_type"]
217
+ if is_cycle_based and event.get("unique_topics"):
218
+ model_line += (
219
+ f" ({topic_type}, {event['unique_topics']} unique topics)"
220
+ )
221
+ else:
222
+ model_line += f" ({topic_type})"
223
+ tui.info(model_line)
224
+
225
+ if is_cycle_based:
226
+ output_line = (
227
+ f"Output: num_samples={prog_total}, "
228
+ f"concurrency={display_batch_size}"
229
+ )
230
+ else:
231
+ output_line = (
232
+ f"Output: num_samples={prog_total}, batch_size={display_batch_size}"
233
+ )
234
+ tui.info(output_line)
235
+
236
+ if is_cycle_based:
237
+ cycles = event.get("cycles_needed", 1)
238
+ unique = event.get("unique_topics", 0)
239
+ tui.info(
240
+ f" → Cycles needed: {cycles} "
241
+ f"({prog_total} samples ÷ {unique} unique topics)"
242
+ )
243
+ final_cycle = event.get("final_cycle_size", 0)
244
+ if final_cycle and unique and final_cycle < unique:
245
+ tui.info(f" → Final cycle: {final_cycle} topics (partial)")
246
+
247
+ tui.console.print()
248
+ cp_interval = event.get("checkpoint_interval")
249
+ if cp_interval and cp_interval > 0:
250
+ total_cp = math.ceil(prog_total / cp_interval)
251
+ tui.info(
252
+ f"Checkpoint: every {cp_interval} samples "
253
+ f"({total_cp} total checkpoints)"
254
+ )
255
+ tui.console.print()
256
+
257
+ if resumed_samples >= prog_total:
258
+ # Checkpoint already has enough samples
259
+ tui.success(
260
+ f"Checkpoint already complete: {resumed_samples} samples "
261
+ f"(target: {prog_total})"
262
+ )
263
+ elif tui.console.is_terminal:
264
+ simple_progress = Progress(
265
+ SpinnerColumn(),
266
+ TextColumn("[progress.description]{task.description}"),
267
+ BarColumn(),
268
+ MofNCompleteColumn(table_column=Column(justify="right")),
269
+ TimeElapsedColumn(),
270
+ console=tui.console,
271
+ )
272
+ simple_progress_task = simple_progress.add_task(
273
+ "Generating",
274
+ total=prog_total,
275
+ completed=resumed_samples,
276
+ )
277
+ simple_progress.start()
278
+ tui.simple_progress = simple_progress
279
+ # Add checkpoint progress task if interval is set
280
+ checkpoint_interval = event.get("checkpoint_interval") or 0
281
+ if checkpoint_interval > 0:
282
+ simple_checkpoint_task = simple_progress.add_task(
283
+ "Next checkpoint",
284
+ total=checkpoint_interval,
285
+ completed=0,
286
+ )
287
+ else:
288
+ # Headless (non-TTY): track progress with counters
289
+ headless_total = prog_total
290
+ headless_completed = resumed_samples
291
+ headless_batch_size = display_batch_size
292
+ headless_last_printed = resumed_samples
293
+ checkpoint_interval = event.get("checkpoint_interval") or 0
174
294
  elif event["event"] == "step_complete":
175
295
  samples_generated = event.get("samples_generated", 0)
176
296
  if footer_prog and task is not None:
@@ -182,39 +302,80 @@ async def handle_dataset_events_async(
182
302
  tui.status_step_complete(
183
303
  samples_generated, int(event.get("failed_in_step", 0))
184
304
  )
185
- elif isinstance(simple_task, dict):
186
- simple_task["count"] += samples_generated
187
- failed_in_step = int(event.get("failed_in_step", 0))
188
- retry_summary = tui.get_step_retry_summary()
189
-
190
- # Build step summary message
191
- step_msg = f"Step {event.get('step')}: +{samples_generated}"
192
- if failed_in_step > 0:
193
- step_msg += f" (-{failed_in_step} failed)"
194
- step_msg += f" (total {simple_task['count']}/{simple_task['total']})"
195
-
196
- # Display with appropriate style based on failures
197
- if failed_in_step > 0:
198
- tui.warning(step_msg)
199
- else:
200
- tui.info(step_msg)
201
-
202
- # Show retry summary if there were retries
203
- if retry_summary:
204
- tui.console.print(f" [dim]{retry_summary}[/dim]")
205
-
206
- # Clear retries for next step
305
+ elif simple_progress is not None and simple_progress_task is not None:
306
+ with contextlib.suppress(Exception):
307
+ simple_progress.update(simple_progress_task, advance=samples_generated)
308
+ if simple_checkpoint_task is not None and samples_generated > 0:
309
+ with contextlib.suppress(Exception):
310
+ simple_progress.update(
311
+ simple_checkpoint_task, advance=samples_generated
312
+ )
207
313
  tui.clear_step_retries()
314
+ elif headless_total > 0 and samples_generated > 0:
315
+ headless_completed += samples_generated
316
+ if (
317
+ headless_completed - headless_last_printed >= headless_batch_size
318
+ or headless_completed >= headless_total
319
+ ):
320
+ tui.info(f"Generated {headless_completed}/{headless_total} samples")
321
+ headless_last_printed = headless_completed
208
322
  elif event["event"] == "step_start":
209
323
  # Keep status panel in sync
210
324
  step = int(event.get("step", 0))
211
325
  total = int(event.get("total_steps", 0))
212
326
  tui.status_step_start(step, total)
213
327
 
328
+ elif event["event"] == "cycle_start":
329
+ # Cycle-based generation: keep status panel in sync
330
+ cycle = int(event.get("cycle", 0))
331
+ total_cycles = int(event.get("total_cycles", 0))
332
+ tui.status_step_start(cycle, total_cycles)
333
+
334
+ elif event["event"] == "batch_complete":
335
+ # Per-batch progress: advance bars after each concurrency batch
336
+ batch_generated = event.get("samples_generated", 0)
337
+ batch_failed = event.get("samples_failed", 0)
338
+ if footer_prog and task is not None:
339
+ if batch_generated > 0:
340
+ with contextlib.suppress(Exception):
341
+ footer_prog.update(task, advance=batch_generated)
342
+ tui.status_step_complete(batch_generated, batch_failed)
343
+ elif simple_progress is not None and simple_progress_task is not None:
344
+ with contextlib.suppress(Exception):
345
+ simple_progress.update(simple_progress_task, advance=batch_generated)
346
+ if simple_checkpoint_task is not None and batch_generated > 0:
347
+ with contextlib.suppress(Exception):
348
+ simple_progress.update(
349
+ simple_checkpoint_task, advance=batch_generated
350
+ )
351
+ elif headless_total > 0 and batch_generated > 0:
352
+ headless_completed += batch_generated
353
+ if (
354
+ headless_completed - headless_last_printed >= headless_batch_size
355
+ or headless_completed >= headless_total
356
+ ):
357
+ tui.info(f"Generated {headless_completed}/{headless_total} samples")
358
+ headless_last_printed = headless_completed
359
+
360
+ elif event["event"] == "cycle_complete":
361
+ # Cycle-based generation: log cycle summary (progress already advanced by batch_complete)
362
+ samples_in_cycle = event.get("samples_in_cycle", 0)
363
+ failures_in_cycle = event.get("failures_in_cycle", 0)
364
+ if footer_prog and task is not None:
365
+ tui.log_event(
366
+ f"✓ Cycle {event.get('cycle')}: "
367
+ f"+{samples_in_cycle} samples"
368
+ + (f" (-{failures_in_cycle} failed)" if failures_in_cycle else "")
369
+ )
370
+ elif headless_total > 0:
371
+ msg = f"Cycle {event.get('cycle')}: +{samples_in_cycle} samples"
372
+ if failures_in_cycle:
373
+ msg += f" (-{failures_in_cycle} failed)"
374
+ tui.info(msg)
375
+
214
376
  elif event["event"] == "checkpoint_saved":
215
377
  # Display checkpoint save notification
216
378
  total_samples = event.get("total_samples", 0)
217
- total_failures = event.get("total_failures", 0)
218
379
  is_final = event.get("final", False)
219
380
 
220
381
  if footer_prog and task is not None:
@@ -224,19 +385,26 @@ async def handle_dataset_events_async(
224
385
  else:
225
386
  tui.log_event(f"💾 Checkpoint: {total_samples} samples")
226
387
  tui.status_checkpoint_saved(total_samples)
227
- elif isinstance(simple_task, dict):
228
- # Simple mode: print checkpoint notification
229
- checkpoint_msg = f"Checkpoint saved: {total_samples} samples"
230
- if total_failures > 0:
231
- checkpoint_msg += f" ({total_failures} failures)"
388
+ elif simple_progress is not None:
389
+ # Simple mode: reset checkpoint progress bar instead of stacking print lines
390
+ if simple_checkpoint_task is not None and not is_final:
391
+ with contextlib.suppress(Exception):
392
+ simple_progress.reset(
393
+ simple_checkpoint_task, total=checkpoint_interval
394
+ )
395
+ elif headless_total > 0:
232
396
  if is_final:
233
- checkpoint_msg = "Final " + checkpoint_msg.lower()
234
- tui.info(checkpoint_msg)
397
+ tui.info(f"Checkpoint (final): {total_samples} samples saved")
398
+ else:
399
+ tui.info(f"Checkpoint: {total_samples} samples saved")
235
400
 
236
401
  elif event["event"] == "generation_stopped":
237
402
  # Graceful stop at checkpoint
238
403
  if live:
239
404
  live.stop()
405
+ if simple_progress is not None:
406
+ simple_progress.stop()
407
+ tui.simple_progress = None
240
408
  tui.console.print()
241
409
  tui.success(
242
410
  f"Gracefully stopped: {event['total_samples']} samples saved to checkpoint"
@@ -248,13 +416,33 @@ async def handle_dataset_events_async(
248
416
  elif event["event"] == "generation_complete":
249
417
  if live:
250
418
  live.stop()
419
+ if simple_progress is not None:
420
+ simple_progress.stop()
421
+ tui.simple_progress = None
251
422
  tui.console.print() # Add blank line after live display
252
423
  tui.success(f"Successfully generated {event['total_samples']} samples")
424
+
425
+ # Show accounting summary
426
+ expected = event.get("expected_samples", 0)
427
+ topics_exhausted = event.get("topics_exhausted", 0)
428
+ unaccounted = event.get("unaccounted", 0)
253
429
  tui.log_event(
254
- f"Done • total={event['total_samples']} failed={event['failed_samples']}"
430
+ f"Done • expected={expected} generated={event['total_samples']} "
431
+ f"failed={event['failed_samples']} topics_exhausted={topics_exhausted} "
432
+ f"unaccounted={unaccounted}"
255
433
  )
256
434
  if event["failed_samples"] > 0:
257
435
  tui.warning(f"Failed to generate {event['failed_samples']} samples")
436
+ if topics_exhausted > 0:
437
+ tui.warning(
438
+ f"Topics exhausted: {topics_exhausted} samples could not be generated "
439
+ f"(not enough unique topics for requested sample count)"
440
+ )
441
+ if unaccounted > 0:
442
+ tui.error(
443
+ f"WARNING: {unaccounted} samples unaccounted for "
444
+ f"(neither generated nor recorded as failures)"
445
+ )
258
446
 
259
447
  # Show detailed failure information in debug mode
260
448
  if debug and engine and hasattr(engine, "failed_samples"):
@@ -275,6 +463,9 @@ async def handle_dataset_events_async(
275
463
  except Exception as e:
276
464
  if live:
277
465
  live.stop()
466
+ if simple_progress is not None:
467
+ simple_progress.stop()
468
+ tui.simple_progress = None
278
469
  if debug:
279
470
  get_tui().error(f"🔍 Debug: Full traceback:\n{traceback.format_exc()}")
280
471
  get_tui().error(f"Dataset generation failed: {str(e)}")
@@ -367,23 +558,17 @@ async def create_dataset_async(
367
558
  generation_params = config.get_generation_params(**(generation_overrides or {}))
368
559
  final_model = model or generation_params.get("model_name", DEFAULT_MODEL)
369
560
 
370
- # Convert total samples to number of steps (batches)
371
- # The generator expects num_steps where total_samples = num_steps * batch_size
372
- import math # noqa: PLC0415
373
-
561
+ # Still compute num_steps for backward compat with the generator's step-based path
374
562
  final_num_steps = math.ceil(final_num_samples / final_batch_size)
375
563
 
376
- tui.info(
377
- f"Dataset generation: {final_num_samples} samples in {final_num_steps} steps "
378
- f"(batch_size={final_batch_size})"
379
- )
380
-
381
564
  # Create progress reporter and attach TUI as observer for streaming feedback
382
565
  progress_reporter = ProgressReporter()
383
566
  progress_reporter.attach(tui)
384
567
 
385
- # Attach progress reporter to engine
568
+ # Attach progress reporter to engine and its LLM retry handler
386
569
  engine.progress_reporter = progress_reporter
570
+ if hasattr(engine, "llm_client"):
571
+ engine.llm_client.retry_handler.progress_reporter = progress_reporter
387
572
 
388
573
  try:
389
574
  generator = engine.create_data_with_events_async(
@@ -548,19 +733,30 @@ def _save_jsonl_without_nulls(dataset: HFDataset, save_path: str) -> None:
548
733
  f.write(json.dumps(cleaned, separators=(",", ":")) + "\n")
549
734
 
550
735
 
551
- def _save_failed_samples(save_path: str, failed_samples: list, tui) -> None:
552
- """Save failed samples to a timestamped file alongside the main dataset.
736
+ def _save_failed_samples(
737
+ save_path: str,
738
+ failed_samples: list,
739
+ tui,
740
+ use_path_directly: bool = False,
741
+ ) -> None:
742
+ """Save failed samples to a file.
553
743
 
554
744
  Args:
555
- save_path: Path to the main dataset file (e.g., "my-dataset.jsonl")
745
+ save_path: Path for failures file. If use_path_directly is False, this is treated as the
746
+ main dataset path and a timestamped filename is generated alongside it.
556
747
  failed_samples: List of failed samples - can be dicts with 'error' and 'raw_content' keys,
557
748
  or plain strings/other types for legacy compatibility
558
749
  tui: TUI instance for output
750
+ use_path_directly: If True, use save_path as-is. If False, generate timestamped filename.
559
751
  """
560
- # Generate timestamped filename: my-dataset.jsonl -> my-dataset_failures_20231130_143022.jsonl
561
- base_path = save_path.rsplit(".", 1)[0] if "." in save_path else save_path
562
- timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
563
- failures_path = f"{base_path}_failures_{timestamp}.jsonl"
752
+ if use_path_directly:
753
+ # Use the provided path directly
754
+ failures_path = save_path
755
+ else:
756
+ # Generate timestamped filename: my-dataset.jsonl -> my-dataset_failures_20231130_143022.jsonl
757
+ base_path = save_path.rsplit(".", 1)[0] if "." in save_path else save_path
758
+ timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
759
+ failures_path = f"{base_path}_failures_{timestamp}.jsonl"
564
760
 
565
761
  try:
566
762
  with open(failures_path, "w") as f:
@@ -589,6 +785,7 @@ def save_dataset(
589
785
  save_path: str,
590
786
  config: DeepFabricConfig | None = None,
591
787
  engine: DataSetGenerator | None = None,
788
+ failures_path: str | None = None,
592
789
  ) -> None:
593
790
  """
594
791
  Save dataset to file.
@@ -598,6 +795,7 @@ def save_dataset(
598
795
  save_path: Path where to save the dataset
599
796
  config: Optional configuration for upload settings
600
797
  engine: Optional DataSetGenerator to save failed samples from
798
+ failures_path: Optional explicit path for failures file (overrides auto-generated path)
601
799
 
602
800
  Raises:
603
801
  ConfigurationError: If saving fails
@@ -614,7 +812,11 @@ def save_dataset(
614
812
  if engine:
615
813
  all_failures = engine.get_all_failures()
616
814
  if all_failures:
617
- _save_failed_samples(save_path, all_failures, tui)
815
+ # Use explicit failures_path if provided, otherwise auto-generate from save_path
816
+ actual_failures_path = failures_path or save_path
817
+ _save_failed_samples(
818
+ actual_failures_path, all_failures, tui, use_path_directly=bool(failures_path)
819
+ )
618
820
 
619
821
  # Handle automatic uploads if configured
620
822
  if config: