DeepFabric 4.8.3__py3-none-any.whl → 4.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepfabric/generator.py CHANGED
@@ -2,9 +2,12 @@ import asyncio
2
2
  import json
3
3
  import logging
4
4
  import math
5
+ import os
5
6
  import random
6
7
 
7
8
  from collections.abc import AsyncGenerator
9
+ from datetime import datetime, timezone
10
+ from pathlib import Path
8
11
  from typing import TYPE_CHECKING, Any, Literal
9
12
 
10
13
  from datasets import Dataset as HFDataset
@@ -14,6 +17,10 @@ from .builders import ConversationBuilderFactory
14
17
  from .config import _normalize_reasoning_style
15
18
  from .constants import (
16
19
  API_ERROR_INDICATORS,
20
+ CHECKPOINT_FAILURES_SUFFIX,
21
+ CHECKPOINT_METADATA_SUFFIX,
22
+ CHECKPOINT_SAMPLES_SUFFIX,
23
+ CHECKPOINT_VERSION,
17
24
  DEFAULT_MAX_RETRIES,
18
25
  DEFAULT_REQUEST_TIMEOUT,
19
26
  DEFAULT_SAMPLE_RETRIES,
@@ -30,7 +37,6 @@ from .llm import LLMClient
30
37
  from .metrics import trace
31
38
  from .progress import ProgressReporter
32
39
  from .prompts import (
33
- AGENT_COT_MULTI_TURN_PROMPT,
34
40
  AGENT_COT_TOOLS_PROMPT,
35
41
  CONVERSATION_GENERATION_PROMPT,
36
42
  FREETEXT_COT_PROMPT,
@@ -40,8 +46,8 @@ from .prompts import (
40
46
  from .schemas import Conversation, ToolRegistry, get_conversation_schema
41
47
  from .tools import BUILTIN_TOOL_REGISTRY
42
48
  from .tools.loader import load_tools_from_dict, load_tools_from_endpoint
43
- from .topic_model import TopicModel
44
- from .utils import ensure_not_running_loop, is_validation_error
49
+ from .topic_model import TopicModel, TopicPath
50
+ from .utils import ensure_not_running_loop, get_checkpoint_dir, is_validation_error
45
51
 
46
52
  # Handle circular import for type hints
47
53
  if TYPE_CHECKING:
@@ -143,12 +149,7 @@ class DataSetGeneratorConfig(BaseModel):
143
149
  """Normalize deprecated reasoning_style values."""
144
150
  return _normalize_reasoning_style(v)
145
151
 
146
- agent_mode: Literal["single_turn", "multi_turn"] | None = Field(
147
- default=None,
148
- description="Agent mode: single_turn (one-shot tool use), multi_turn (extended agent conversations). Requires tools to be configured.",
149
- )
150
-
151
- # Tool configuration (used when agent_mode is enabled or for tool_calling)
152
+ # Tool configuration (used when tools are configured for agent mode)
152
153
  tool_components: dict[str, list[str]] = Field(
153
154
  default_factory=dict,
154
155
  description=(
@@ -194,28 +195,32 @@ class DataSetGeneratorConfig(BaseModel):
194
195
  description="Path for tool execution when using tools_endpoint (e.g., '/mock/execute'). Combined with spin_endpoint.",
195
196
  )
196
197
 
197
- # Multi-turn configuration (used when agent_mode="multi_turn")
198
- min_turns: int = Field(
199
- default=2,
200
- ge=1,
201
- le=10,
202
- description="Minimum number of conversation turns for multi-turn agent mode",
198
+ tool_inclusion_strategy: Literal["all", "used_only"] = Field(
199
+ default="used_only",
200
+ description="Which tools to include in each sample: 'all' includes full catalog, 'used_only' includes only tools actually called (recommended for training)",
203
201
  )
204
- max_turns: int = Field(
205
- default=4,
202
+
203
+ # Checkpoint configuration
204
+ checkpoint_interval: int | None = Field(
205
+ default=None,
206
206
  ge=1,
207
- le=10,
208
- description="Maximum number of conversation turns for multi-turn agent mode",
207
+ description="Save checkpoint every N samples. None disables checkpointing.",
209
208
  )
210
- min_tool_calls: int = Field(
211
- default=2,
212
- ge=0,
213
- le=20,
214
- description="Minimum number of tool calls required before allowing early conversation conclusion",
209
+ checkpoint_path: str | None = Field(
210
+ default=None,
211
+ description="Directory to store checkpoint files. None uses fallback '.checkpoints'",
215
212
  )
216
- tool_inclusion_strategy: Literal["all", "used_only"] = Field(
217
- default="used_only",
218
- description="Which tools to include in each sample: 'all' includes full catalog, 'used_only' includes only tools actually called (recommended for training)",
213
+ checkpoint_retry_failed: bool = Field(
214
+ default=False,
215
+ description="When resuming, retry previously failed samples",
216
+ )
217
+ output_save_as: str | None = Field(
218
+ default=None,
219
+ description="Output file path (used to derive checkpoint file names)",
220
+ )
221
+ topics_file: str | None = Field(
222
+ default=None,
223
+ description="Topics file path (stored in checkpoint metadata for auto-resume)",
219
224
  )
220
225
 
221
226
 
@@ -260,18 +265,27 @@ class DataSetGenerator:
260
265
  # Store generation prompt for content generation
261
266
  self.generation_prompt = self.config.generation_system_prompt
262
267
 
263
- # Initialize tool registry when agent_mode is enabled or tools are configured
268
+ # Initialize tool registry when tools are configured (enables agent mode)
264
269
  self.tool_registry = None
265
- if (
266
- self.config.agent_mode is not None
267
- or self.config.tool_components
268
- or self.config.custom_tools
269
- ):
270
+ if self.config.tool_components or self.config.custom_tools:
270
271
  self._initialize_tool_registry()
271
272
 
272
273
  # Progress reporter for streaming feedback (set by external callers)
273
274
  self.progress_reporter: ProgressReporter | None = None
274
275
 
276
+ # Checkpoint state
277
+ self._checkpoint_samples_since_save = 0
278
+ self._processed_ids: set[str] = set() # Track processed topic IDs (UUIDs)
279
+ self._checkpoint_metadata_path: Path | None = None
280
+ self._checkpoint_samples_path: Path | None = None
281
+ self._checkpoint_failures_path: Path | None = None
282
+ # Memory optimization: track flushed counts for checkpoint mode
283
+ self._flushed_samples_count = 0
284
+ self._flushed_failures_count = 0
285
+
286
+ # Graceful stop flag - set by signal handler to stop at next checkpoint
287
+ self.stop_requested = False
288
+
275
289
  def _initialize_tool_registry(self):
276
290
  """Initialize tool registry from component configuration.
277
291
 
@@ -328,6 +342,400 @@ class DataSetGenerator:
328
342
  except Exception as e: # noqa: BLE001
329
343
  raise DataSetGeneratorError(f"Failed to initialize tool registry: {str(e)}") from e
330
344
 
345
+ def _get_checkpoint_paths(self) -> tuple[Path, Path, Path]:
346
+ """Get checkpoint file paths based on output_save_as.
347
+
348
+ Returns:
349
+ Tuple of (metadata_path, samples_path, failures_path)
350
+ """
351
+ if not self.config.output_save_as:
352
+ raise DataSetGeneratorError(
353
+ "Cannot create checkpoint paths: output_save_as not configured"
354
+ )
355
+
356
+ # Create checkpoint directory if needed
357
+ # Use XDG-compliant fallback if checkpoint_path not resolved by CLI
358
+ checkpoint_dir = Path(self.config.checkpoint_path or get_checkpoint_dir(config_path=None))
359
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
360
+
361
+ # Derive checkpoint filenames from output filename
362
+ output_stem = Path(self.config.output_save_as).stem
363
+ metadata_path = checkpoint_dir / f"{output_stem}{CHECKPOINT_METADATA_SUFFIX}"
364
+ samples_path = checkpoint_dir / f"{output_stem}{CHECKPOINT_SAMPLES_SUFFIX}"
365
+ failures_path = checkpoint_dir / f"{output_stem}{CHECKPOINT_FAILURES_SUFFIX}"
366
+
367
+ return metadata_path, samples_path, failures_path
368
+
369
+ def _initialize_checkpoint_paths(self) -> None:
370
+ """Initialize checkpoint file paths if checkpointing is enabled."""
371
+ if self.config.checkpoint_interval is not None:
372
+ paths = self._get_checkpoint_paths()
373
+ self._checkpoint_metadata_path = paths[0]
374
+ self._checkpoint_samples_path = paths[1]
375
+ self._checkpoint_failures_path = paths[2]
376
+ logger.info(
377
+ "Checkpointing enabled: saving every %d samples to %s",
378
+ self.config.checkpoint_interval,
379
+ self._checkpoint_samples_path,
380
+ )
381
+
382
+ def _save_checkpoint(
383
+ self,
384
+ new_samples: list[dict],
385
+ new_failures: list[dict],
386
+ processed_topic_paths: list[TopicPath | None],
387
+ flush_memory: bool = True,
388
+ ) -> None:
389
+ """Save checkpoint data incrementally.
390
+
391
+ Args:
392
+ new_samples: New successful samples to append
393
+ new_failures: New failed samples to append
394
+ processed_topic_paths: TopicPath objects that were processed in this batch
395
+ flush_memory: If True, clear flushed samples from memory (memory optimization)
396
+ """
397
+ if self._checkpoint_samples_path is None:
398
+ return
399
+
400
+ # Append new samples to checkpoint file
401
+ if new_samples:
402
+ with open(self._checkpoint_samples_path, "a", encoding="utf-8") as f:
403
+ for sample in new_samples:
404
+ f.write(json.dumps(sample, separators=(",", ":")) + "\n")
405
+
406
+ # Append new failures to failures file
407
+ if new_failures and self._checkpoint_failures_path:
408
+ with open(self._checkpoint_failures_path, "a", encoding="utf-8") as f:
409
+ for failure in new_failures:
410
+ f.write(json.dumps(failure, separators=(",", ":")) + "\n")
411
+
412
+ # Track processed topic IDs
413
+ for topic_path in processed_topic_paths:
414
+ if topic_path is not None:
415
+ self._processed_ids.add(topic_path.topic_id)
416
+
417
+ # Memory optimization: track flushed counts and clear in-memory lists
418
+ # Must happen BEFORE saving metadata so counts are accurate
419
+ if flush_memory:
420
+ self._flushed_samples_count += len(new_samples)
421
+ self._flushed_failures_count += len(new_failures)
422
+ # Clear the in-memory lists since data is now on disk
423
+ self._samples.clear()
424
+ self.failed_samples.clear()
425
+
426
+ # Update metadata (after flush counts are updated)
427
+ self._save_checkpoint_metadata()
428
+
429
+ logger.debug(
430
+ "Checkpoint saved: %d samples, %d failures, %d total IDs processed (flushed=%s)",
431
+ len(new_samples),
432
+ len(new_failures),
433
+ len(self._processed_ids),
434
+ flush_memory,
435
+ )
436
+
437
+ def _save_checkpoint_metadata(self) -> None:
438
+ """Save checkpoint metadata file."""
439
+ if self._checkpoint_metadata_path is None:
440
+ return
441
+
442
+ # Total counts include both flushed (on disk) and in-memory samples
443
+ total_samples = self._flushed_samples_count + len(self._samples)
444
+ total_failures = self._flushed_failures_count + len(self.failed_samples)
445
+
446
+ metadata = {
447
+ "version": CHECKPOINT_VERSION,
448
+ "created_at": datetime.now(timezone.utc).isoformat(),
449
+ "provider": self.provider,
450
+ "model_name": self.model_name,
451
+ "conversation_type": self.config.conversation_type,
452
+ "reasoning_style": self.config.reasoning_style,
453
+ "total_samples": total_samples,
454
+ "total_failures": total_failures,
455
+ "processed_ids": list(self._processed_ids),
456
+ "checkpoint_interval": self.config.checkpoint_interval,
457
+ "topics_file": self.config.topics_file,
458
+ }
459
+
460
+ with open(self._checkpoint_metadata_path, "w", encoding="utf-8") as f:
461
+ json.dump(metadata, f, indent=2)
462
+
463
+ def _validate_checkpoint_compatibility(self, metadata: dict) -> None:
464
+ """Validate that current config is compatible with checkpoint.
465
+
466
+ Logs warnings for config mismatches but allows resumption.
467
+
468
+ Args:
469
+ metadata: Checkpoint metadata dictionary
470
+ """
471
+ mismatches: list[str] = []
472
+
473
+ # Check provider
474
+ checkpoint_provider = metadata.get("provider")
475
+ if checkpoint_provider and checkpoint_provider != self.provider:
476
+ mismatches.append(
477
+ f"provider: checkpoint={checkpoint_provider}, current={self.provider}"
478
+ )
479
+
480
+ # Check model
481
+ checkpoint_model = metadata.get("model_name")
482
+ if checkpoint_model and checkpoint_model != self.model_name:
483
+ mismatches.append(
484
+ f"model_name: checkpoint={checkpoint_model}, current={self.model_name}"
485
+ )
486
+
487
+ # Check conversation type
488
+ checkpoint_conv_type = metadata.get("conversation_type")
489
+ if checkpoint_conv_type and checkpoint_conv_type != self.config.conversation_type:
490
+ mismatches.append(
491
+ f"conversation_type: checkpoint={checkpoint_conv_type}, "
492
+ f"current={self.config.conversation_type}"
493
+ )
494
+
495
+ # Check reasoning style
496
+ checkpoint_reasoning = metadata.get("reasoning_style")
497
+ if checkpoint_reasoning and checkpoint_reasoning != self.config.reasoning_style:
498
+ mismatches.append(
499
+ f"reasoning_style: checkpoint={checkpoint_reasoning}, "
500
+ f"current={self.config.reasoning_style}"
501
+ )
502
+
503
+ if mismatches:
504
+ logger.warning(
505
+ "Config mismatch with checkpoint. Resuming may produce inconsistent results. "
506
+ "Differences: %s",
507
+ "; ".join(mismatches),
508
+ )
509
+
510
+ def _validate_checkpoint_integrity(self, metadata: dict) -> tuple[bool, str | None]:
511
+ """Validate checkpoint file integrity.
512
+
513
+ Checks that:
514
+ 1. Metadata version is supported
515
+ 2. Required metadata fields are present
516
+ 3. Sample count in metadata matches actual file line count
517
+ 4. Sample file contains valid JSON on each line
518
+
519
+ Args:
520
+ metadata: Checkpoint metadata dictionary
521
+
522
+ Returns:
523
+ Tuple of (is_valid, error_message). error_message is None if valid.
524
+ """
525
+ error_msg: str | None = None
526
+
527
+ # Check version
528
+ version = metadata.get("version")
529
+ if version is None:
530
+ error_msg = "Missing 'version' field in checkpoint metadata"
531
+ elif version != CHECKPOINT_VERSION:
532
+ error_msg = f"Unsupported checkpoint version: {version} (expected {CHECKPOINT_VERSION})"
533
+
534
+ # Check required fields
535
+ if error_msg is None:
536
+ required_fields = ["created_at", "total_samples", "processed_ids"]
537
+ for field in required_fields:
538
+ if field not in metadata:
539
+ error_msg = f"Missing required field in checkpoint metadata: {field}"
540
+ break
541
+
542
+ # Validate sample count matches file
543
+ if error_msg is None:
544
+ expected_samples = metadata.get("total_samples", 0)
545
+ if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
546
+ actual_count = 0
547
+ try:
548
+ with open(self._checkpoint_samples_path, encoding="utf-8") as f:
549
+ for line_num, raw_line in enumerate(f, 1):
550
+ stripped = raw_line.strip()
551
+ if stripped:
552
+ try:
553
+ json.loads(stripped)
554
+ actual_count += 1
555
+ except json.JSONDecodeError as e:
556
+ error_msg = f"Invalid JSON on line {line_num} of checkpoint samples: {e}"
557
+ break
558
+ except OSError as e:
559
+ error_msg = f"Failed to read checkpoint samples file: {e}"
560
+
561
+ if error_msg is None and actual_count != expected_samples:
562
+ error_msg = (
563
+ f"Sample count mismatch: metadata says {expected_samples}, "
564
+ f"file has {actual_count} samples"
565
+ )
566
+ elif expected_samples > 0:
567
+ error_msg = f"Checkpoint metadata expects {expected_samples} samples but samples file missing"
568
+
569
+ return (error_msg is None, error_msg)
570
+
571
+ def has_checkpoint(self) -> bool:
572
+ """Check if a checkpoint exists without loading it.
573
+
574
+ Returns:
575
+ True if checkpoint metadata file exists, False otherwise
576
+ """
577
+ if self.config.checkpoint_interval is None:
578
+ return False
579
+
580
+ self._initialize_checkpoint_paths()
581
+ return (
582
+ self._checkpoint_metadata_path is not None
583
+ and self._checkpoint_metadata_path.exists()
584
+ )
585
+
586
+ def load_checkpoint(self, retry_failed: bool = False) -> bool:
587
+ """Load checkpoint data if it exists.
588
+
589
+ Args:
590
+ retry_failed: If True, remove failed IDs from processed set to retry them
591
+
592
+ Returns:
593
+ True if checkpoint was loaded, False if no checkpoint exists
594
+ """
595
+ if self.config.checkpoint_interval is None:
596
+ return False
597
+
598
+ self._initialize_checkpoint_paths()
599
+
600
+ if self._checkpoint_metadata_path is None or not self._checkpoint_metadata_path.exists():
601
+ return False
602
+
603
+ try:
604
+ # Load metadata
605
+ with open(self._checkpoint_metadata_path, encoding="utf-8") as f:
606
+ metadata = json.load(f)
607
+
608
+ # Validate checkpoint integrity
609
+ is_valid, error_msg = self._validate_checkpoint_integrity(metadata)
610
+ if not is_valid:
611
+ logger.error("Checkpoint integrity check failed: %s", error_msg)
612
+ return False
613
+
614
+ # Validate config compatibility
615
+ self._validate_checkpoint_compatibility(metadata)
616
+
617
+ # Restore processed IDs
618
+ self._processed_ids = set(metadata.get("processed_ids", []))
619
+
620
+ # Count existing samples (don't load into memory - they're already on disk)
621
+ # Memory optimization: track as flushed counts instead of loading into RAM
622
+ if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
623
+ sample_count = 0
624
+ with open(self._checkpoint_samples_path, encoding="utf-8") as f:
625
+ for raw_line in f:
626
+ if raw_line.strip():
627
+ sample_count += 1
628
+ self._flushed_samples_count = sample_count
629
+
630
+ # Load failure IDs for retry logic (these are small)
631
+ failed_ids: set[str] = set()
632
+ if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
633
+ failure_count = 0
634
+ with open(self._checkpoint_failures_path, encoding="utf-8") as f:
635
+ for raw_line in f:
636
+ stripped = raw_line.strip()
637
+ if stripped:
638
+ failure = json.loads(stripped)
639
+ failure_count += 1
640
+ # Track the topic_id that failed for potential retry
641
+ if "topic_id" in failure:
642
+ failed_ids.add(failure["topic_id"])
643
+ self._flushed_failures_count = failure_count
644
+
645
+ # If retry_failed is True, remove failed IDs from processed set
646
+ # so they will be retried during generation
647
+ if retry_failed and failed_ids:
648
+ ids_to_retry = self._processed_ids & failed_ids
649
+ self._processed_ids -= ids_to_retry
650
+ # Clear failures file since we're retrying
651
+ if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
652
+ os.remove(self._checkpoint_failures_path)
653
+ self._flushed_failures_count = 0
654
+ logger.info(
655
+ "Retry mode: %d failed IDs will be retried",
656
+ len(ids_to_retry),
657
+ )
658
+
659
+ logger.info(
660
+ "Loaded checkpoint: %d samples, %d failures, %d IDs processed",
661
+ self._flushed_samples_count,
662
+ self._flushed_failures_count,
663
+ len(self._processed_ids),
664
+ )
665
+ except Exception as e: # noqa: BLE001
666
+ logger.warning("Failed to load checkpoint: %s", e)
667
+ return False
668
+ else:
669
+ return True
670
+
671
+ def clear_checkpoint(self) -> None:
672
+ """Remove checkpoint files."""
673
+ if self._checkpoint_metadata_path and self._checkpoint_metadata_path.exists():
674
+ os.remove(self._checkpoint_metadata_path)
675
+ if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
676
+ os.remove(self._checkpoint_samples_path)
677
+ if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
678
+ os.remove(self._checkpoint_failures_path)
679
+ self._processed_ids.clear()
680
+ self._flushed_samples_count = 0
681
+ self._flushed_failures_count = 0
682
+ logger.info("Checkpoint files cleared")
683
+
684
+ def _load_all_samples_from_checkpoint(self) -> list[dict]:
685
+ """Load all samples from checkpoint file.
686
+
687
+ Used at end of generation to build final dataset when memory
688
+ optimization has flushed samples to disk.
689
+
690
+ Returns:
691
+ List of all sample dictionaries from checkpoint file
692
+ """
693
+ all_samples: list[dict] = []
694
+ if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
695
+ with open(self._checkpoint_samples_path, encoding="utf-8") as f:
696
+ for raw_line in f:
697
+ stripped = raw_line.strip()
698
+ if stripped:
699
+ all_samples.append(json.loads(stripped))
700
+ return all_samples
701
+
702
+ def get_all_failures(self) -> list[dict]:
703
+ """Get all failures including those flushed to checkpoint.
704
+
705
+ This combines in-memory failures with any that were flushed to the
706
+ checkpoint failures file during memory optimization.
707
+
708
+ Returns:
709
+ List of all failure dictionaries
710
+ """
711
+ all_failures: list[dict] = []
712
+
713
+ # First load from checkpoint file if it exists
714
+ if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
715
+ with open(self._checkpoint_failures_path, encoding="utf-8") as f:
716
+ for raw_line in f:
717
+ stripped = raw_line.strip()
718
+ if stripped:
719
+ all_failures.append(json.loads(stripped))
720
+
721
+ # Then add any in-memory failures (if not yet flushed)
722
+ all_failures.extend(self.failed_samples)
723
+
724
+ return all_failures
725
+
726
+ def _is_topic_processed(self, topic_path: TopicPath | None) -> bool:
727
+ """Check if a topic has already been processed.
728
+
729
+ Args:
730
+ topic_path: TopicPath to check
731
+
732
+ Returns:
733
+ True if topic was already processed in a previous run
734
+ """
735
+ if topic_path is None:
736
+ return False
737
+ return topic_path.topic_id in self._processed_ids
738
+
331
739
  def _validate_create_data_params(
332
740
  self,
333
741
  num_steps: int,
@@ -351,34 +759,24 @@ class DataSetGenerator:
351
759
  num_steps: int,
352
760
  batch_size: int,
353
761
  topic_model: "TopicModel | None" = None,
354
- ) -> tuple[list | None, int]:
762
+ ) -> tuple[list[TopicPath] | None, int]:
355
763
  """Prepare and validate topic paths for data generation."""
356
- topic_paths = None
764
+ topic_paths: list[TopicPath] | None = None
357
765
  if topic_model is not None:
358
- topic_paths = topic_model.get_all_paths()
766
+ topic_paths = topic_model.get_all_paths_with_ids()
359
767
  total_paths = len(topic_paths)
360
768
  required_samples = num_steps * batch_size
361
769
 
362
770
  if required_samples > total_paths:
363
- # Provide detailed error with recommendations
364
- max_steps_for_batch = total_paths // batch_size
365
- max_batch_for_steps = total_paths // num_steps if num_steps > 0 else total_paths
366
-
367
- error_msg = (
368
- f"Insufficient topic paths for dataset generation:\n"
369
- f" • Available paths: {total_paths}\n"
370
- f" • Requested samples: {required_samples} ({num_steps} steps × {batch_size} batch size)\n"
371
- f" • Shortfall: {required_samples - total_paths} samples\n\n"
372
- f"Recommendations:\n"
373
- f" • Reduce --num-steps to {max_steps_for_batch} (with current batch size {batch_size})\n"
374
- f" • Reduce --batch-size to {max_batch_for_steps} (with current {num_steps} steps)\n"
375
- f" • Increase topic tree/graph depth or degree to generate more paths"
376
- )
377
- raise DataSetGeneratorError(error_msg)
378
-
379
- # Bandit: not a security function
380
- topic_paths = random.sample(topic_paths, required_samples) # nosec
381
- num_steps = math.ceil(len(topic_paths) / batch_size)
771
+ # Cycle through topics to generate more samples than paths
772
+ # Each topic will be used multiple times for even coverage
773
+ multiplier = math.ceil(required_samples / total_paths)
774
+ topic_paths = (topic_paths * multiplier)[:required_samples]
775
+ elif required_samples < total_paths:
776
+ # Sample subset (percentage case or explicit count < total)
777
+ # Bandit: not a security function
778
+ topic_paths = random.sample(topic_paths, required_samples) # nosec
779
+ # else: required_samples == total_paths - use all paths as-is
382
780
 
383
781
  return topic_paths, num_steps
384
782
 
@@ -386,23 +784,25 @@ class DataSetGenerator:
386
784
  self,
387
785
  batch_size: int,
388
786
  start_idx: int,
389
- topic_paths: list,
787
+ topic_paths: list[TopicPath],
390
788
  data_creation_prompt: str,
391
789
  num_example_demonstrations: int,
392
- ) -> tuple[list[str], list[list[str] | None]]:
393
- """Generate prompts for a batch and return the associated paths used.
790
+ ) -> tuple[list[str], list[TopicPath | None]]:
791
+ """Generate prompts for a batch and return the associated TopicPaths used.
394
792
 
395
793
  Returns:
396
- (prompts, used_paths) where used_paths aligns with prompts order.
794
+ (prompts, used_topic_paths) where used_topic_paths aligns with prompts order.
397
795
  """
398
796
  prompts: list[str] = []
399
- used_paths: list[list[str] | None] = []
797
+ used_topic_paths: list[TopicPath | None] = []
400
798
  for i in range(batch_size):
401
- path = None
799
+ topic_path: TopicPath | None = None
800
+ path: list[str] | None = None
402
801
  if topic_paths:
403
802
  current_idx = start_idx + i
404
803
  if current_idx < len(topic_paths):
405
- path = topic_paths[current_idx]
804
+ topic_path = topic_paths[current_idx]
805
+ path = topic_path.path
406
806
  else:
407
807
  break
408
808
 
@@ -412,8 +812,8 @@ class DataSetGenerator:
412
812
  subtopics_list=path,
413
813
  )
414
814
  prompts.append(sample_prompt)
415
- used_paths.append(path)
416
- return prompts, used_paths
815
+ used_topic_paths.append(topic_path)
816
+ return prompts, used_topic_paths
417
817
 
418
818
  def _get_minimal_schema(self) -> type:
419
819
  """Get the conversation schema for the current config."""
@@ -447,7 +847,7 @@ class DataSetGenerator:
447
847
  prompts: list[str],
448
848
  include_sys_msg: bool,
449
849
  start_sample_idx: int = 0,
450
- paths_for_batch: list[list[str] | None] | None = None,
850
+ topic_paths_for_batch: list[TopicPath | None] | None = None,
451
851
  ) -> tuple[list, list]:
452
852
  """Generate structured samples using builder pattern.
453
853
 
@@ -455,6 +855,7 @@ class DataSetGenerator:
455
855
  prompts: List of topic prompts to generate samples for
456
856
  include_sys_msg: Whether to include system message in output
457
857
  start_sample_idx: Starting sample index for progress reporting
858
+ topic_paths_for_batch: TopicPath objects for each sample (includes topic_id)
458
859
 
459
860
  Returns:
460
861
  Tuple of (successful samples, failed responses)
@@ -470,7 +871,7 @@ class DataSetGenerator:
470
871
  config = self.config.model_copy(update={"sys_msg": include_sys_msg})
471
872
 
472
873
  async def _generate_with_retry(
473
- prompt: str, sample_idx: int, path_info: list[str] | None
874
+ prompt: str, sample_idx: int, topic_path_info: TopicPath | None
474
875
  ) -> tuple[bool, Exception | Conversation]:
475
876
  """Generate a single sample with per-sample retry for validation errors.
476
877
 
@@ -496,6 +897,9 @@ class DataSetGenerator:
496
897
  self.config.sample_retries,
497
898
  )
498
899
 
900
+ # Extract path for progress reporting
901
+ path_info = topic_path_info.path if topic_path_info else None
902
+
499
903
  for attempt in range(max_attempts):
500
904
  # Notify progress reporter about which sample we're working on
501
905
  if self.progress_reporter:
@@ -533,8 +937,8 @@ class DataSetGenerator:
533
937
  return False, last_error or Exception("Sample generation failed")
534
938
 
535
939
  else:
536
- # Validate tool execution count for agent modes
537
- if self.config.agent_mode is not None:
940
+ # Validate tool execution count for agent mode (when tools configured)
941
+ if self.tool_registry is not None:
538
942
  if (
539
943
  not conversation.tool_context
540
944
  or not conversation.tool_context.executions
@@ -566,6 +970,12 @@ class DataSetGenerator:
566
970
  ]
567
971
  )
568
972
 
973
+ # Add topic_id to conversation metadata for traceability
974
+ if topic_path_info and hasattr(conversation, "metadata"):
975
+ if conversation.metadata is None:
976
+ conversation.metadata = {}
977
+ conversation.metadata["topic_id"] = topic_path_info.topic_id
978
+
569
979
  return True, conversation
570
980
 
571
981
  return False, last_error or Exception("Sample generation failed")
@@ -573,11 +983,13 @@ class DataSetGenerator:
573
983
  # Generate all samples concurrently with sample indices
574
984
  tasks = []
575
985
  for idx, prompt in enumerate(prompts):
576
- path_info = None
577
- if paths_for_batch and idx < len(paths_for_batch):
578
- path_info = paths_for_batch[idx]
986
+ topic_path_info = None
987
+ if topic_paths_for_batch and idx < len(topic_paths_for_batch):
988
+ topic_path_info = topic_paths_for_batch[idx]
579
989
  tasks.append(
580
- asyncio.create_task(_generate_with_retry(prompt, start_sample_idx + idx, path_info))
990
+ asyncio.create_task(
991
+ _generate_with_retry(prompt, start_sample_idx + idx, topic_path_info)
992
+ )
581
993
  )
582
994
  results = await asyncio.gather(*tasks)
583
995
 
@@ -587,12 +999,18 @@ class DataSetGenerator:
587
999
  else:
588
1000
  error = payload
589
1001
  error_msg = f"Generation failed: {error}"
590
- # Build failure record with raw content if available
591
- failure_record = {"error": error_msg}
1002
+ # Build failure record with raw content and topic_id if available
1003
+ failure_record: dict[str, str | None] = {"error": error_msg}
592
1004
  if isinstance(error, Exception):
593
1005
  context = getattr(error, "context", None)
594
1006
  if isinstance(context, dict) and "raw_content" in context:
595
1007
  failure_record["raw_content"] = context["raw_content"]
1008
+ # Include topic_id and path for checkpoint retry functionality
1009
+ if topic_paths_for_batch and idx < len(topic_paths_for_batch):
1010
+ tp = topic_paths_for_batch[idx]
1011
+ if tp:
1012
+ failure_record["topic_id"] = tp.topic_id
1013
+ failure_record["path"] = " -> ".join(tp.path)
596
1014
  failed_responses.append(failure_record)
597
1015
  failure_type = self.analyze_failure(
598
1016
  str(error), error=error if isinstance(error, Exception) else None
@@ -817,12 +1235,12 @@ class DataSetGenerator:
817
1235
  ):
818
1236
  yield event
819
1237
 
820
- async def _run_generation_loop_async( # noqa: PLR0912
1238
+ async def _run_generation_loop_async( # noqa: PLR0912, PLR0915
821
1239
  self,
822
1240
  num_steps: int,
823
1241
  batch_size: int,
824
1242
  total_samples: int,
825
- topic_paths: list,
1243
+ topic_paths: list[TopicPath],
826
1244
  data_creation_prompt: str,
827
1245
  num_example_demonstrations: int,
828
1246
  include_sys_msg: bool,
@@ -830,6 +1248,16 @@ class DataSetGenerator:
830
1248
  topic_model_type: str | None = None,
831
1249
  ) -> AsyncGenerator[dict | HFDataset, None]:
832
1250
  """Run the main generation loop yielding progress events."""
1251
+ # Initialize checkpoint paths if checkpointing is enabled
1252
+ if self.config.checkpoint_interval is not None:
1253
+ self._initialize_checkpoint_paths()
1254
+
1255
+ # Track samples added in this run for checkpointing
1256
+ samples_since_checkpoint = 0
1257
+ samples_in_current_batch: list[dict] = []
1258
+ failures_in_current_batch: list[dict] = []
1259
+ topic_paths_in_current_batch: list[TopicPath | None] = []
1260
+
833
1261
  try:
834
1262
  yield {
835
1263
  "event": "generation_start",
@@ -839,6 +1267,11 @@ class DataSetGenerator:
839
1267
  "total_samples": total_samples,
840
1268
  "root_topic_prompt": root_topic_prompt,
841
1269
  "topic_model_type": topic_model_type,
1270
+ "resumed_from_checkpoint": len(self._processed_ids) > 0,
1271
+ "previously_processed": len(self._processed_ids),
1272
+ "resumed_samples": self._flushed_samples_count,
1273
+ "resumed_failures": self._flushed_failures_count,
1274
+ "checkpoint_enabled": self.config.checkpoint_interval is not None,
842
1275
  }
843
1276
 
844
1277
  for step in range(num_steps):
@@ -849,7 +1282,7 @@ class DataSetGenerator:
849
1282
  }
850
1283
 
851
1284
  start_idx = step * batch_size
852
- prompts, used_paths = self._generate_batch_prompts(
1285
+ prompts, used_topic_paths = self._generate_batch_prompts(
853
1286
  batch_size,
854
1287
  start_idx,
855
1288
  topic_paths,
@@ -857,17 +1290,85 @@ class DataSetGenerator:
857
1290
  num_example_demonstrations,
858
1291
  )
859
1292
 
1293
+ # Filter out already-processed topics when resuming
1294
+ if self._processed_ids:
1295
+ filtered_prompts = []
1296
+ filtered_topic_paths: list[TopicPath | None] = []
1297
+ for prompt, tp in zip(prompts, used_topic_paths, strict=False):
1298
+ if not self._is_topic_processed(tp):
1299
+ filtered_prompts.append(prompt)
1300
+ filtered_topic_paths.append(tp)
1301
+
1302
+ if not filtered_prompts:
1303
+ # All topics in this batch were already processed
1304
+ yield {
1305
+ "event": "step_complete",
1306
+ "step": step + 1,
1307
+ "samples_generated": 0,
1308
+ "success": True,
1309
+ "failed_in_step": 0,
1310
+ "failure_reasons": [],
1311
+ "skipped": len(prompts),
1312
+ }
1313
+ continue
1314
+
1315
+ prompts = filtered_prompts
1316
+ used_topic_paths = filtered_topic_paths
1317
+
860
1318
  failed_before = len(self.failed_samples)
1319
+ samples_before = len(self._samples)
861
1320
 
862
1321
  success, samples_generated = await self._process_batch_with_retries_async(
863
- prompts, include_sys_msg, start_idx, used_paths
1322
+ prompts, include_sys_msg, start_idx, used_topic_paths
864
1323
  )
865
1324
 
1325
+ # Track new samples and failures for checkpointing
1326
+ new_samples = self._samples[samples_before:]
1327
+ new_failures = self.failed_samples[failed_before:]
1328
+ samples_in_current_batch.extend(new_samples)
1329
+ failures_in_current_batch.extend(new_failures)
1330
+ topic_paths_in_current_batch.extend(used_topic_paths)
1331
+ samples_since_checkpoint += samples_generated
1332
+
1333
+ # Save checkpoint if we've reached the interval
1334
+ if (
1335
+ self.config.checkpoint_interval is not None
1336
+ and samples_since_checkpoint >= self.config.checkpoint_interval
1337
+ ):
1338
+ self._save_checkpoint(
1339
+ samples_in_current_batch,
1340
+ failures_in_current_batch,
1341
+ topic_paths_in_current_batch,
1342
+ )
1343
+ samples_in_current_batch = []
1344
+ failures_in_current_batch = []
1345
+ topic_paths_in_current_batch = []
1346
+ samples_since_checkpoint = 0
1347
+ yield {
1348
+ "event": "checkpoint_saved",
1349
+ "total_samples": self._flushed_samples_count,
1350
+ "total_failures": self._flushed_failures_count,
1351
+ }
1352
+
1353
+ # Check for graceful stop request after checkpoint save
1354
+ if self.stop_requested:
1355
+ yield {
1356
+ "event": "generation_stopped",
1357
+ "message": "Stopped at checkpoint as requested",
1358
+ "total_samples": self._flushed_samples_count,
1359
+ "total_failures": self._flushed_failures_count,
1360
+ }
1361
+ return # Exit generator cleanly
1362
+
866
1363
  failed_in_batch = len(self.failed_samples) - failed_before
867
- failure_reasons = []
1364
+ failure_reasons: list[str] = []
868
1365
  if failed_in_batch > 0 and self.failed_samples:
869
1366
  recent_failures = self.failed_samples[-failed_in_batch:]
870
- failure_reasons = recent_failures[:3]
1367
+ for f in recent_failures[:3]:
1368
+ if isinstance(f, dict):
1369
+ failure_reasons.append(f.get("error", str(f)))
1370
+ else:
1371
+ failure_reasons.append(str(f))
871
1372
 
872
1373
  yield {
873
1374
  "event": "step_complete",
@@ -885,13 +1386,42 @@ class DataSetGenerator:
885
1386
  "message": f"Failed to process batch {step + 1} after all retries",
886
1387
  }
887
1388
 
1389
+ # Save final checkpoint with any remaining samples
1390
+ if self.config.checkpoint_interval is not None and (
1391
+ samples_in_current_batch or failures_in_current_batch
1392
+ ):
1393
+ self._save_checkpoint(
1394
+ samples_in_current_batch,
1395
+ failures_in_current_batch,
1396
+ topic_paths_in_current_batch,
1397
+ )
1398
+ yield {
1399
+ "event": "checkpoint_saved",
1400
+ "total_samples": self._flushed_samples_count,
1401
+ "total_failures": self._flushed_failures_count,
1402
+ "final": True,
1403
+ }
1404
+
1405
+ # Calculate total counts including flushed data
1406
+ total_samples = self._flushed_samples_count + len(self._samples)
1407
+ total_failures = self._flushed_failures_count + len(self.failed_samples)
1408
+
888
1409
  yield {
889
1410
  "event": "generation_complete",
890
- "total_samples": len(self._samples),
891
- "failed_samples": len(self.failed_samples),
1411
+ "total_samples": total_samples,
1412
+ "failed_samples": total_failures,
892
1413
  }
893
1414
 
894
1415
  except KeyboardInterrupt:
1416
+ # Save checkpoint on interrupt
1417
+ if self.config.checkpoint_interval is not None and (
1418
+ samples_in_current_batch or failures_in_current_batch
1419
+ ):
1420
+ self._save_checkpoint(
1421
+ samples_in_current_batch,
1422
+ failures_in_current_batch,
1423
+ topic_paths_in_current_batch,
1424
+ )
895
1425
  yield {
896
1426
  "event": "generation_interrupted",
897
1427
  "message": "Generation interrupted by user.",
@@ -900,25 +1430,39 @@ class DataSetGenerator:
900
1430
  self._save_samples_to_file(INTERRUPTED_DATASET_FILENAME)
901
1431
 
902
1432
  except Exception as e: # noqa: BLE001
1433
+ # Save checkpoint on error
1434
+ if self.config.checkpoint_interval is not None and (
1435
+ samples_in_current_batch or failures_in_current_batch
1436
+ ):
1437
+ self._save_checkpoint(
1438
+ samples_in_current_batch,
1439
+ failures_in_current_batch,
1440
+ topic_paths_in_current_batch,
1441
+ )
903
1442
  yield {"event": "generation_error", "error": str(e)}
904
1443
  self.print_failure_summary()
905
1444
  self._save_samples_to_file(ERROR_DATASET_FILENAME)
906
1445
  raise DataSetGeneratorError("failed") from e
907
1446
 
908
- yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
1447
+ # Build final dataset: if samples were flushed to disk, load them from checkpoint
1448
+ if self._flushed_samples_count > 0:
1449
+ all_samples = self._load_all_samples_from_checkpoint()
1450
+ yield HFDataset.from_list(all_samples) if all_samples else HFDataset.from_list([])
1451
+ else:
1452
+ yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
909
1453
 
910
1454
  async def _process_batch_with_retries_async(
911
1455
  self,
912
1456
  prompts: list[str],
913
1457
  include_sys_msg: bool,
914
1458
  start_sample_idx: int = 0,
915
- paths_for_batch: list[list[str] | None] | None = None,
1459
+ topic_paths_for_batch: list[TopicPath | None] | None = None,
916
1460
  ) -> tuple[bool, int]:
917
1461
  """Process a batch with retry logic."""
918
1462
  for attempt in range(self.config.max_retries):
919
1463
  try:
920
1464
  samples, failed_responses = await self._generate_structured_samples_async(
921
- prompts, include_sys_msg, start_sample_idx, paths_for_batch
1465
+ prompts, include_sys_msg, start_sample_idx, topic_paths_for_batch
922
1466
  )
923
1467
 
924
1468
  # Update failed samples
@@ -948,14 +1492,33 @@ class DataSetGenerator:
948
1492
  error_msg = f"API error for provider '{self.provider}': {str(e)[:100]}..."
949
1493
  self.failure_analysis["api_errors"].append(error_msg)
950
1494
 
951
- self.failed_samples.append(error_msg)
1495
+ # Build failure records for each topic path in the batch
1496
+ if topic_paths_for_batch:
1497
+ for tp in topic_paths_for_batch:
1498
+ failure_record: dict[str, str | None] = {"error": error_msg}
1499
+ if tp:
1500
+ failure_record["topic_id"] = tp.topic_id
1501
+ failure_record["path"] = " -> ".join(tp.path)
1502
+ self.failed_samples.append(failure_record)
1503
+ else:
1504
+ self.failed_samples.append({"error": error_msg})
952
1505
  logger.exception("API error: %s", error_msg)
953
1506
  return False, 0 # Don't retry authentication/API errors
954
1507
  except Exception as e:
955
1508
  if attempt == self.config.max_retries - 1:
956
- self.failed_samples.append(str(e))
957
- failure_type = self.analyze_failure(str(e), error=e)
958
- self.failure_analysis[failure_type].append(str(e))
1509
+ error_msg = str(e)
1510
+ # Build failure records for each topic path in the batch
1511
+ if topic_paths_for_batch:
1512
+ for tp in topic_paths_for_batch:
1513
+ failure_record_exc: dict[str, str | None] = {"error": error_msg}
1514
+ if tp:
1515
+ failure_record_exc["topic_id"] = tp.topic_id
1516
+ failure_record_exc["path"] = " -> ".join(tp.path)
1517
+ self.failed_samples.append(failure_record_exc)
1518
+ else:
1519
+ self.failed_samples.append({"error": error_msg})
1520
+ failure_type = self.analyze_failure(error_msg, error=e)
1521
+ self.failure_analysis[failure_type].append(error_msg)
959
1522
  return False, 0
960
1523
  else:
961
1524
  # If no exception and no samples, return False, 0
@@ -1015,7 +1578,7 @@ class DataSetGenerator:
1015
1578
  return f"\nHere are output examples:\n<examples>\n{examples_text}\n</examples>\n"
1016
1579
 
1017
1580
  def build_tools_text(self) -> str:
1018
- """Build formatted tools text for XLAM multi-turn prompts."""
1581
+ """Build formatted tools text for XLAM prompts."""
1019
1582
  if not self.tool_registry:
1020
1583
  return "No tools available"
1021
1584
 
@@ -1046,8 +1609,8 @@ class DataSetGenerator:
1046
1609
 
1047
1610
  # Handle chain of thought conversations
1048
1611
  if self.config.conversation_type == "cot":
1049
- # Agent mode with tools - use agent prompts
1050
- if self.config.agent_mode == "single_turn" and self.tool_registry:
1612
+ # Agent mode with tools - use agent prompts (implicit when tools configured)
1613
+ if self.tool_registry:
1051
1614
  # Use agent prompt for single-turn tool calling
1052
1615
  return (
1053
1616
  AgentPromptBuilder.build_tool_context_prompt(
@@ -1057,16 +1620,6 @@ class DataSetGenerator:
1057
1620
  or AGENT_COT_TOOLS_PROMPT
1058
1621
  )
1059
1622
 
1060
- if self.config.agent_mode == "multi_turn" and self.tool_registry:
1061
- # Standard multi-turn agent
1062
- return (
1063
- AgentPromptBuilder.build_multi_turn_context_prompt(
1064
- self.tool_registry,
1065
- max_tools_per_query=self.config.max_tools_per_query,
1066
- )
1067
- or AGENT_COT_MULTI_TURN_PROMPT
1068
- )
1069
-
1070
1623
  # Non-agent CoT - select based on reasoning style
1071
1624
  if self.config.reasoning_style == "freetext":
1072
1625
  return FREETEXT_COT_PROMPT
@@ -1079,7 +1632,7 @@ class DataSetGenerator:
1079
1632
  def _save_samples_to_file(self, save_path: str):
1080
1633
  """Save the current samples to a JSONL file."""
1081
1634
 
1082
- with open(save_path, "w") as f:
1635
+ with open(save_path, "w", encoding="utf-8") as f:
1083
1636
  for sample in self._samples:
1084
1637
  f.write(json.dumps(sample, separators=(",", ":")) + "\n")
1085
1638