DeepFabric 4.10.1__py3-none-any.whl → 4.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepfabric/generator.py CHANGED
@@ -46,7 +46,7 @@ from .prompts import (
46
46
  from .schemas import Conversation, ToolRegistry, get_conversation_schema
47
47
  from .tools import BUILTIN_TOOL_REGISTRY
48
48
  from .tools.loader import load_tools_from_dict, load_tools_from_endpoint
49
- from .topic_model import TopicModel, TopicPath
49
+ from .topic_model import Topic, TopicModel, TopicPath
50
50
  from .utils import ensure_not_running_loop, get_checkpoint_dir, is_validation_error
51
51
 
52
52
  # Handle circular import for type hints
@@ -275,13 +275,16 @@ class DataSetGenerator:
275
275
 
276
276
  # Checkpoint state
277
277
  self._checkpoint_samples_since_save = 0
278
- self._processed_ids: set[str] = set() # Track processed topic IDs (UUIDs)
278
+ self._completed: set[tuple[str, int]] = set() # Track (uuid, cycle) tuples
279
279
  self._checkpoint_metadata_path: Path | None = None
280
280
  self._checkpoint_samples_path: Path | None = None
281
281
  self._checkpoint_failures_path: Path | None = None
282
282
  # Memory optimization: track flushed counts for checkpoint mode
283
283
  self._flushed_samples_count = 0
284
284
  self._flushed_failures_count = 0
285
+ # Generation state for cycle-based iteration
286
+ self._unique_topics: int = 0 # Number of unique topics
287
+ self._cycles_needed: int = 0 # Total cycles for requested samples
285
288
 
286
289
  # Graceful stop flag - set by signal handler to stop at next checkpoint
287
290
  self.stop_requested = False
@@ -383,7 +386,7 @@ class DataSetGenerator:
383
386
  self,
384
387
  new_samples: list[dict],
385
388
  new_failures: list[dict],
386
- processed_topic_paths: list[TopicPath | None],
389
+ completed_items: list[tuple[str, int]],
387
390
  flush_memory: bool = True,
388
391
  ) -> None:
389
392
  """Save checkpoint data incrementally.
@@ -391,7 +394,7 @@ class DataSetGenerator:
391
394
  Args:
392
395
  new_samples: New successful samples to append
393
396
  new_failures: New failed samples to append
394
- processed_topic_paths: TopicPath objects that were processed in this batch
397
+ completed_items: List of (uuid, cycle) tuples that were completed
395
398
  flush_memory: If True, clear flushed samples from memory (memory optimization)
396
399
  """
397
400
  if self._checkpoint_samples_path is None:
@@ -409,10 +412,9 @@ class DataSetGenerator:
409
412
  for failure in new_failures:
410
413
  f.write(json.dumps(failure, separators=(",", ":")) + "\n")
411
414
 
412
- # Track processed topic IDs
413
- for topic_path in processed_topic_paths:
414
- if topic_path is not None:
415
- self._processed_ids.add(topic_path.topic_id)
415
+ # Track completed (uuid, cycle) tuples
416
+ for item in completed_items:
417
+ self._completed.add(item)
416
418
 
417
419
  # Memory optimization: track flushed counts and clear in-memory lists
418
420
  # Must happen BEFORE saving metadata so counts are accurate
@@ -427,10 +429,10 @@ class DataSetGenerator:
427
429
  self._save_checkpoint_metadata()
428
430
 
429
431
  logger.debug(
430
- "Checkpoint saved: %d samples, %d failures, %d total IDs processed (flushed=%s)",
432
+ "Checkpoint saved: %d samples, %d failures, %d completed tuples (flushed=%s)",
431
433
  len(new_samples),
432
434
  len(new_failures),
433
- len(self._processed_ids),
435
+ len(self._completed),
434
436
  flush_memory,
435
437
  )
436
438
 
@@ -443,6 +445,9 @@ class DataSetGenerator:
443
445
  total_samples = self._flushed_samples_count + len(self._samples)
444
446
  total_failures = self._flushed_failures_count + len(self.failed_samples)
445
447
 
448
+ # Convert completed set of (uuid, cycle) tuples to list of [uuid, cycle] arrays
449
+ completed_list = [[uuid, cycle] for uuid, cycle in self._completed]
450
+
446
451
  metadata = {
447
452
  "version": CHECKPOINT_VERSION,
448
453
  "created_at": datetime.now(timezone.utc).isoformat(),
@@ -452,7 +457,9 @@ class DataSetGenerator:
452
457
  "reasoning_style": self.config.reasoning_style,
453
458
  "total_samples": total_samples,
454
459
  "total_failures": total_failures,
455
- "processed_ids": list(self._processed_ids),
460
+ "completed": completed_list, # List of [uuid, cycle] arrays
461
+ "unique_topics": self._unique_topics,
462
+ "cycles_needed": self._cycles_needed,
456
463
  "checkpoint_interval": self.config.checkpoint_interval,
457
464
  "topics_file": self.config.topics_file,
458
465
  }
@@ -528,12 +535,26 @@ class DataSetGenerator:
528
535
  version = metadata.get("version")
529
536
  if version is None:
530
537
  error_msg = "Missing 'version' field in checkpoint metadata"
531
- elif version != CHECKPOINT_VERSION:
532
- error_msg = f"Unsupported checkpoint version: {version} (expected {CHECKPOINT_VERSION})"
538
+ elif version < CHECKPOINT_VERSION:
539
+ # Old checkpoint format - require fresh start
540
+ error_msg = (
541
+ f"Checkpoint format v{version} is incompatible with current version v{CHECKPOINT_VERSION}. "
542
+ "Please delete the checkpoint and restart: rm -rf .checkpoints/"
543
+ )
544
+ elif version > CHECKPOINT_VERSION:
545
+ error_msg = (
546
+ f"Checkpoint version {version} is newer than supported version {CHECKPOINT_VERSION}"
547
+ )
533
548
 
534
- # Check required fields
549
+ # Check required fields for v4 format
535
550
  if error_msg is None:
536
- required_fields = ["created_at", "total_samples", "processed_ids"]
551
+ required_fields = [
552
+ "created_at",
553
+ "total_samples",
554
+ "completed",
555
+ "unique_topics",
556
+ "cycles_needed",
557
+ ]
537
558
  for field in required_fields:
538
559
  if field not in metadata:
539
560
  error_msg = f"Missing required field in checkpoint metadata: {field}"
@@ -579,8 +600,7 @@ class DataSetGenerator:
579
600
 
580
601
  self._initialize_checkpoint_paths()
581
602
  return (
582
- self._checkpoint_metadata_path is not None
583
- and self._checkpoint_metadata_path.exists()
603
+ self._checkpoint_metadata_path is not None and self._checkpoint_metadata_path.exists()
584
604
  )
585
605
 
586
606
  def load_checkpoint(self, retry_failed: bool = False) -> bool:
@@ -614,8 +634,11 @@ class DataSetGenerator:
614
634
  # Validate config compatibility
615
635
  self._validate_checkpoint_compatibility(metadata)
616
636
 
617
- # Restore processed IDs
618
- self._processed_ids = set(metadata.get("processed_ids", []))
637
+ # Restore completed (uuid, cycle) tuples
638
+ completed_list = metadata.get("completed", [])
639
+ self._completed = {(item[0], item[1]) for item in completed_list}
640
+ self._unique_topics = metadata.get("unique_topics", 0)
641
+ self._cycles_needed = metadata.get("cycles_needed", 0)
619
642
 
620
643
  # Count existing samples (don't load into memory - they're already on disk)
621
644
  # Memory optimization: track as flushed counts instead of loading into RAM
@@ -631,36 +654,47 @@ class DataSetGenerator:
631
654
  failed_ids: set[str] = set()
632
655
  if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
633
656
  failure_count = 0
657
+ failed_tuples: set[tuple[str, int]] = set()
634
658
  with open(self._checkpoint_failures_path, encoding="utf-8") as f:
635
659
  for raw_line in f:
636
660
  stripped = raw_line.strip()
637
661
  if stripped:
638
662
  failure = json.loads(stripped)
639
663
  failure_count += 1
640
- # Track the topic_id that failed for potential retry
664
+ # Track failed (topic_id, cycle) for targeted retry
641
665
  if "topic_id" in failure:
642
666
  failed_ids.add(failure["topic_id"])
667
+ if "cycle" in failure:
668
+ failed_tuples.add((failure["topic_id"], failure["cycle"]))
643
669
  self._flushed_failures_count = failure_count
644
670
 
645
- # If retry_failed is True, remove failed IDs from processed set
671
+ # If retry_failed is True, remove failed entries from completed set
646
672
  # so they will be retried during generation
647
673
  if retry_failed and failed_ids:
648
- ids_to_retry = self._processed_ids & failed_ids
649
- self._processed_ids -= ids_to_retry
674
+ if failed_tuples:
675
+ # Targeted retry: only remove the specific (uuid, cycle) that failed
676
+ tuples_to_remove = self._completed & failed_tuples
677
+ else:
678
+ # Legacy fallback: no cycle info, remove all cycles for failed UUIDs
679
+ tuples_to_remove = {
680
+ (uuid, cycle) for uuid, cycle in self._completed if uuid in failed_ids
681
+ }
682
+ self._completed -= tuples_to_remove
650
683
  # Clear failures file since we're retrying
651
684
  if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
652
685
  os.remove(self._checkpoint_failures_path)
653
686
  self._flushed_failures_count = 0
654
687
  logger.info(
655
- "Retry mode: %d failed IDs will be retried",
656
- len(ids_to_retry),
688
+ "Retry mode: %d failed UUIDs (%d tuples) will be retried",
689
+ len(failed_ids),
690
+ len(tuples_to_remove),
657
691
  )
658
692
 
659
693
  logger.info(
660
- "Loaded checkpoint: %d samples, %d failures, %d IDs processed",
694
+ "Loaded checkpoint: %d samples, %d failures, %d completed (uuid, cycle) tuples",
661
695
  self._flushed_samples_count,
662
696
  self._flushed_failures_count,
663
- len(self._processed_ids),
697
+ len(self._completed),
664
698
  )
665
699
  except Exception as e: # noqa: BLE001
666
700
  logger.warning("Failed to load checkpoint: %s", e)
@@ -676,7 +710,7 @@ class DataSetGenerator:
676
710
  os.remove(self._checkpoint_samples_path)
677
711
  if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
678
712
  os.remove(self._checkpoint_failures_path)
679
- self._processed_ids.clear()
713
+ self._completed.clear()
680
714
  self._flushed_samples_count = 0
681
715
  self._flushed_failures_count = 0
682
716
  logger.info("Checkpoint files cleared")
@@ -723,18 +757,31 @@ class DataSetGenerator:
723
757
 
724
758
  return all_failures
725
759
 
726
- def _is_topic_processed(self, topic_path: TopicPath | None) -> bool:
727
- """Check if a topic has already been processed.
760
+ def _is_completed(self, uuid: str, cycle: int) -> bool:
761
+ """Check if a (uuid, cycle) combination has been completed.
728
762
 
729
763
  Args:
730
- topic_path: TopicPath to check
764
+ uuid: Topic UUID
765
+ cycle: Cycle number (0-indexed)
731
766
 
732
767
  Returns:
733
- True if topic was already processed in a previous run
768
+ True if this (uuid, cycle) was already completed
734
769
  """
735
- if topic_path is None:
736
- return False
737
- return topic_path.topic_id in self._processed_ids
770
+ return (uuid, cycle) in self._completed
771
+
772
+ def _is_uuid_completed_any_cycle(self, uuid: str) -> bool:
773
+ """Check if a UUID has been completed in any cycle.
774
+
775
+ Used during transition period when step-based generation
776
+ still needs to check for completed topics.
777
+
778
+ Args:
779
+ uuid: Topic UUID to check
780
+
781
+ Returns:
782
+ True if this UUID has been completed in at least one cycle
783
+ """
784
+ return any(u == uuid for u, _cycle in self._completed)
738
785
 
739
786
  def _validate_create_data_params(
740
787
  self,
@@ -766,20 +813,109 @@ class DataSetGenerator:
766
813
  topic_paths = topic_model.get_all_paths_with_ids()
767
814
  total_paths = len(topic_paths)
768
815
  required_samples = num_steps * batch_size
816
+ logger.info(
817
+ "Topic preparation: total_paths=%d, required_samples=%d, num_steps=%d, batch_size=%d",
818
+ total_paths,
819
+ required_samples,
820
+ num_steps,
821
+ batch_size,
822
+ )
769
823
 
770
824
  if required_samples > total_paths:
771
825
  # Cycle through topics to generate more samples than paths
772
826
  # Each topic will be used multiple times for even coverage
827
+ # IMPORTANT: Create unique topic_ids using global index position.
828
+ # This handles two cases:
829
+ # 1. Cycling: same path appears multiple times across cycles
830
+ # 2. Graph duplicates: multiple paths can share the same node topic_id
831
+ # Using index ensures every position in the list has a unique ID.
773
832
  multiplier = math.ceil(required_samples / total_paths)
774
- topic_paths = (topic_paths * multiplier)[:required_samples]
833
+ cycled_paths: list[TopicPath] = []
834
+ for _cycle in range(multiplier):
835
+ for _path_idx, tp in enumerate(topic_paths):
836
+ if len(cycled_paths) >= required_samples:
837
+ break
838
+ # Use global index for uniqueness: handles both cycling and graph duplicates
839
+ global_idx = len(cycled_paths)
840
+ unique_id = f"{tp.topic_id}_idx_{global_idx}"
841
+ cycled_paths.append(TopicPath(path=tp.path, topic_id=unique_id))
842
+ if len(cycled_paths) >= required_samples:
843
+ break
844
+ topic_paths = cycled_paths
845
+ logger.info(
846
+ "Topics cycled: %d original paths × %d cycles = %d total (trimmed to %d)",
847
+ total_paths,
848
+ multiplier,
849
+ total_paths * multiplier,
850
+ len(topic_paths),
851
+ )
775
852
  elif required_samples < total_paths:
776
853
  # Sample subset (percentage case or explicit count < total)
777
854
  # Bandit: not a security function
778
855
  topic_paths = random.sample(topic_paths, required_samples) # nosec
779
- # else: required_samples == total_paths - use all paths as-is
856
+ # Assign unique IDs to handle graphs with duplicate node IDs
857
+ topic_paths = [
858
+ TopicPath(path=tp.path, topic_id=f"{tp.topic_id}_idx_{idx}")
859
+ for idx, tp in enumerate(topic_paths)
860
+ ]
861
+ else:
862
+ # required_samples == total_paths - use all paths but with unique IDs
863
+ # Graphs can have duplicate topic_ids (multiple paths to same node)
864
+ topic_paths = [
865
+ TopicPath(path=tp.path, topic_id=f"{tp.topic_id}_idx_{idx}")
866
+ for idx, tp in enumerate(topic_paths)
867
+ ]
868
+
869
+ logger.info("Topic paths after preparation: %d paths", len(topic_paths))
780
870
 
781
871
  return topic_paths, num_steps
782
872
 
873
+ def _prepare_unique_topics(
874
+ self,
875
+ total_samples: int,
876
+ topic_model: "TopicModel",
877
+ ) -> tuple[list[Topic], int]:
878
+ """Prepare unique topics and calculate cycles needed for generation.
879
+
880
+ This method supports the new cycle-based generation model where we iterate
881
+ over unique topics (by UUID) multiple times (cycles) to generate the
882
+ requested number of samples.
883
+
884
+ Args:
885
+ total_samples: Total number of samples to generate.
886
+ topic_model: The topic model (Tree or Graph) to extract topics from.
887
+
888
+ Returns:
889
+ Tuple of (unique_topics, cycles_needed):
890
+ - unique_topics: List of Topic namedtuples with (uuid, topic)
891
+ - cycles_needed: Number of times to iterate through all topics
892
+ """
893
+ unique_topics = topic_model.get_unique_topics()
894
+ unique_count = len(unique_topics)
895
+
896
+ if unique_count == 0:
897
+ raise DataSetGeneratorError(
898
+ "Topic model has no unique topics. Ensure the topic model was built successfully."
899
+ )
900
+
901
+ # Calculate cycles needed to cover the requested samples
902
+ # e.g., 5000 samples from 1875 topics = ceil(5000/1875) = 3 cycles
903
+ cycles_needed = math.ceil(total_samples / unique_count)
904
+
905
+ # Calculate how many samples we'll generate in the final (partial) cycle
906
+ final_cycle_size = total_samples - (cycles_needed - 1) * unique_count
907
+
908
+ logger.info(
909
+ "Topic preparation: unique_topics=%d, requested_samples=%d, cycles_needed=%d, "
910
+ "final_cycle_size=%d",
911
+ unique_count,
912
+ total_samples,
913
+ cycles_needed,
914
+ final_cycle_size,
915
+ )
916
+
917
+ return unique_topics, cycles_needed
918
+
783
919
  def _generate_batch_prompts(
784
920
  self,
785
921
  batch_size: int,
@@ -1050,9 +1186,11 @@ class DataSetGenerator:
1050
1186
  return "malformed_responses"
1051
1187
 
1052
1188
  def summarize_failures(self) -> dict:
1053
- """Generate a summary of all failures."""
1189
+ """Generate a summary of all failures including those flushed to checkpoint."""
1190
+ # Include both in-memory and flushed failures for accurate total
1191
+ total_failures = self._flushed_failures_count + len(self.failed_samples)
1054
1192
  summary = {
1055
- "total_failures": len(self.failed_samples),
1193
+ "total_failures": total_failures,
1056
1194
  "failure_types": {k: len(v) for k, v in self.failure_analysis.items()},
1057
1195
  "failure_examples": {},
1058
1196
  }
@@ -1155,22 +1293,59 @@ class DataSetGenerator:
1155
1293
 
1156
1294
  include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
1157
1295
 
1158
- topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
1159
-
1296
+ # Calculate total samples requested
1160
1297
  total_samples = num_steps * batch_size
1161
1298
  data_creation_prompt = self._get_cot_prompt_template()
1162
1299
 
1163
- final_result: HFDataset | dict | None = None
1164
- async for event in self._run_generation_loop_async(
1165
- num_steps=num_steps,
1166
- batch_size=batch_size,
1167
- total_samples=total_samples,
1168
- topic_paths=topic_paths or [],
1169
- data_creation_prompt=data_creation_prompt,
1170
- num_example_demonstrations=num_example_demonstrations,
1171
- include_sys_msg=include_sys_msg,
1300
+ # Ensure checkpoint_interval is at least as large as concurrency/batch_size
1301
+ # so checkpoints align with batch boundaries
1302
+ if (
1303
+ self.config.checkpoint_interval is not None
1304
+ and self.config.checkpoint_interval < batch_size
1172
1305
  ):
1173
- final_result = event
1306
+ logger.warning(
1307
+ "checkpoint_interval (%d) is less than concurrency/batch_size (%d), "
1308
+ "adjusting to %d to align with batch boundaries",
1309
+ self.config.checkpoint_interval,
1310
+ batch_size,
1311
+ batch_size,
1312
+ )
1313
+ self.config.checkpoint_interval = batch_size
1314
+
1315
+ final_result: HFDataset | dict | None = None
1316
+
1317
+ # Use cycle-based generation when a topic model is provided
1318
+ if topic_model is not None:
1319
+ unique_topics, cycles_needed = self._prepare_unique_topics(total_samples, topic_model)
1320
+
1321
+ # batch_size becomes concurrency in the new model
1322
+ concurrency = batch_size
1323
+
1324
+ async for event in self._run_cycle_based_generation_async(
1325
+ unique_topics=unique_topics,
1326
+ cycles_needed=cycles_needed,
1327
+ total_samples=total_samples,
1328
+ concurrency=concurrency,
1329
+ data_creation_prompt=data_creation_prompt,
1330
+ num_example_demonstrations=num_example_demonstrations,
1331
+ include_sys_msg=include_sys_msg,
1332
+ topic_model=topic_model,
1333
+ ):
1334
+ final_result = event
1335
+ else:
1336
+ # Fall back to step-based generation when no topic model
1337
+ topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
1338
+
1339
+ async for event in self._run_generation_loop_async(
1340
+ num_steps=num_steps,
1341
+ batch_size=batch_size,
1342
+ total_samples=total_samples,
1343
+ topic_paths=topic_paths or [],
1344
+ data_creation_prompt=data_creation_prompt,
1345
+ num_example_demonstrations=num_example_demonstrations,
1346
+ include_sys_msg=include_sys_msg,
1347
+ ):
1348
+ final_result = event
1174
1349
 
1175
1350
  if isinstance(final_result, HFDataset):
1176
1351
  trace(
@@ -1211,29 +1386,67 @@ class DataSetGenerator:
1211
1386
 
1212
1387
  include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
1213
1388
 
1214
- topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
1215
-
1389
+ # Calculate total samples requested
1216
1390
  total_samples = num_steps * batch_size
1217
1391
  data_creation_prompt = self._get_cot_prompt_template()
1218
1392
 
1393
+ # Ensure checkpoint_interval is at least as large as concurrency/batch_size
1394
+ # so checkpoints align with batch boundaries
1395
+ if (
1396
+ self.config.checkpoint_interval is not None
1397
+ and self.config.checkpoint_interval < batch_size
1398
+ ):
1399
+ logger.warning(
1400
+ "checkpoint_interval (%d) is less than concurrency/batch_size (%d), "
1401
+ "adjusting to %d to align with batch boundaries",
1402
+ self.config.checkpoint_interval,
1403
+ batch_size,
1404
+ batch_size,
1405
+ )
1406
+ self.config.checkpoint_interval = batch_size
1407
+
1219
1408
  root_topic_prompt = None
1220
1409
  topic_model_type = None
1221
1410
  if topic_model is not None:
1222
1411
  root_topic_prompt = getattr(topic_model, "topic_prompt", None)
1223
1412
  topic_model_type = type(topic_model).__name__.lower()
1224
1413
 
1225
- async for event in self._run_generation_loop_async(
1226
- num_steps=num_steps,
1227
- batch_size=batch_size,
1228
- total_samples=total_samples,
1229
- topic_paths=topic_paths or [],
1230
- data_creation_prompt=data_creation_prompt,
1231
- num_example_demonstrations=num_example_demonstrations,
1232
- include_sys_msg=include_sys_msg,
1233
- root_topic_prompt=root_topic_prompt,
1234
- topic_model_type=topic_model_type,
1235
- ):
1236
- yield event
1414
+ # Use cycle-based generation when a topic model is provided
1415
+ if topic_model is not None:
1416
+ unique_topics, cycles_needed = self._prepare_unique_topics(total_samples, topic_model)
1417
+
1418
+ # batch_size becomes concurrency in the new model
1419
+ concurrency = batch_size
1420
+
1421
+ async for event in self._run_cycle_based_generation_async(
1422
+ unique_topics=unique_topics,
1423
+ cycles_needed=cycles_needed,
1424
+ total_samples=total_samples,
1425
+ concurrency=concurrency,
1426
+ data_creation_prompt=data_creation_prompt,
1427
+ num_example_demonstrations=num_example_demonstrations,
1428
+ include_sys_msg=include_sys_msg,
1429
+ root_topic_prompt=root_topic_prompt,
1430
+ topic_model_type=topic_model_type,
1431
+ topic_model=topic_model,
1432
+ ):
1433
+ yield event
1434
+ else:
1435
+ # Fall back to step-based generation when no topic model
1436
+ topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
1437
+
1438
+ async for event in self._run_generation_loop_async(
1439
+ num_steps=num_steps,
1440
+ batch_size=batch_size,
1441
+ total_samples=total_samples,
1442
+ topic_paths=topic_paths or [],
1443
+ data_creation_prompt=data_creation_prompt,
1444
+ num_example_demonstrations=num_example_demonstrations,
1445
+ include_sys_msg=include_sys_msg,
1446
+ root_topic_prompt=root_topic_prompt,
1447
+ topic_model_type=topic_model_type,
1448
+ ):
1449
+ yield event
1237
1450
 
1238
1451
  async def _run_generation_loop_async( # noqa: PLR0912, PLR0915
1239
1452
  self,
@@ -1248,6 +1461,17 @@ class DataSetGenerator:
1248
1461
  topic_model_type: str | None = None,
1249
1462
  ) -> AsyncGenerator[dict | HFDataset, None]:
1250
1463
  """Run the main generation loop yielding progress events."""
1464
+ # Verify topic paths cover all expected samples (only when a topic model is used)
1465
+ expected_prompts = num_steps * batch_size
1466
+ actual_paths = len(topic_paths) if topic_paths else 0
1467
+ if topic_paths and actual_paths < expected_prompts:
1468
+ logger.warning(
1469
+ "Topic paths (%d) < expected samples (%d). Steps beyond path %d will produce 0 samples.",
1470
+ actual_paths,
1471
+ expected_prompts,
1472
+ actual_paths // batch_size,
1473
+ )
1474
+
1251
1475
  # Initialize checkpoint paths if checkpointing is enabled
1252
1476
  if self.config.checkpoint_interval is not None:
1253
1477
  self._initialize_checkpoint_paths()
@@ -1257,6 +1481,7 @@ class DataSetGenerator:
1257
1481
  samples_in_current_batch: list[dict] = []
1258
1482
  failures_in_current_batch: list[dict] = []
1259
1483
  topic_paths_in_current_batch: list[TopicPath | None] = []
1484
+ topics_exhausted_count = 0 # Samples not generated due to topic exhaustion
1260
1485
 
1261
1486
  try:
1262
1487
  yield {
@@ -1267,11 +1492,12 @@ class DataSetGenerator:
1267
1492
  "total_samples": total_samples,
1268
1493
  "root_topic_prompt": root_topic_prompt,
1269
1494
  "topic_model_type": topic_model_type,
1270
- "resumed_from_checkpoint": len(self._processed_ids) > 0,
1271
- "previously_processed": len(self._processed_ids),
1495
+ "resumed_from_checkpoint": len(self._completed) > 0,
1496
+ "previously_processed": len(self._completed),
1272
1497
  "resumed_samples": self._flushed_samples_count,
1273
1498
  "resumed_failures": self._flushed_failures_count,
1274
1499
  "checkpoint_enabled": self.config.checkpoint_interval is not None,
1500
+ "checkpoint_interval": self.config.checkpoint_interval,
1275
1501
  }
1276
1502
 
1277
1503
  for step in range(num_steps):
@@ -1290,12 +1516,36 @@ class DataSetGenerator:
1290
1516
  num_example_demonstrations,
1291
1517
  )
1292
1518
 
1293
- # Filter out already-processed topics when resuming
1294
- if self._processed_ids:
1519
+ # Handle topic exhaustion - when we've run out of topics
1520
+ if not prompts and topic_paths:
1521
+ # Topics exhausted - remaining steps will produce nothing
1522
+ exhausted_in_step = batch_size # All samples in this batch had no topics
1523
+ topics_exhausted_count += exhausted_in_step
1524
+ logger.warning(
1525
+ "Step %d: Topics exhausted at index %d (only %d topics available for %d expected)",
1526
+ step + 1,
1527
+ start_idx,
1528
+ len(topic_paths),
1529
+ total_samples,
1530
+ )
1531
+ yield {
1532
+ "event": "step_complete",
1533
+ "step": step + 1,
1534
+ "samples_generated": 0,
1535
+ "success": True,
1536
+ "failed_in_step": 0,
1537
+ "failure_reasons": [],
1538
+ "topics_exhausted": exhausted_in_step,
1539
+ }
1540
+ continue
1541
+
1542
+ # Filter out already-completed topics when resuming
1543
+ if self._completed:
1295
1544
  filtered_prompts = []
1296
1545
  filtered_topic_paths: list[TopicPath | None] = []
1297
1546
  for prompt, tp in zip(prompts, used_topic_paths, strict=False):
1298
- if not self._is_topic_processed(tp):
1547
+ # Check if this topic_id has been completed in any cycle
1548
+ if tp is None or not self._is_uuid_completed_any_cycle(tp.topic_id):
1299
1549
  filtered_prompts.append(prompt)
1300
1550
  filtered_topic_paths.append(tp)
1301
1551
 
@@ -1335,10 +1585,14 @@ class DataSetGenerator:
1335
1585
  self.config.checkpoint_interval is not None
1336
1586
  and samples_since_checkpoint >= self.config.checkpoint_interval
1337
1587
  ):
1588
+ # Convert topic_paths to (uuid, cycle) tuples (cycle=0 for step-based)
1589
+ completed_items = [
1590
+ (tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
1591
+ ]
1338
1592
  self._save_checkpoint(
1339
1593
  samples_in_current_batch,
1340
1594
  failures_in_current_batch,
1341
- topic_paths_in_current_batch,
1595
+ completed_items,
1342
1596
  )
1343
1597
  samples_in_current_batch = []
1344
1598
  failures_in_current_batch = []
@@ -1360,11 +1614,11 @@ class DataSetGenerator:
1360
1614
  }
1361
1615
  return # Exit generator cleanly
1362
1616
 
1363
- failed_in_batch = len(self.failed_samples) - failed_before
1617
+ # Use new_failures captured BEFORE checkpoint clear, not recalculated after
1618
+ failed_in_batch = len(new_failures)
1364
1619
  failure_reasons: list[str] = []
1365
- if failed_in_batch > 0 and self.failed_samples:
1366
- recent_failures = self.failed_samples[-failed_in_batch:]
1367
- for f in recent_failures[:3]:
1620
+ if failed_in_batch > 0 and new_failures:
1621
+ for f in new_failures[:3]:
1368
1622
  if isinstance(f, dict):
1369
1623
  failure_reasons.append(f.get("error", str(f)))
1370
1624
  else:
@@ -1390,10 +1644,13 @@ class DataSetGenerator:
1390
1644
  if self.config.checkpoint_interval is not None and (
1391
1645
  samples_in_current_batch or failures_in_current_batch
1392
1646
  ):
1647
+ completed_items = [
1648
+ (tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
1649
+ ]
1393
1650
  self._save_checkpoint(
1394
1651
  samples_in_current_batch,
1395
1652
  failures_in_current_batch,
1396
- topic_paths_in_current_batch,
1653
+ completed_items,
1397
1654
  )
1398
1655
  yield {
1399
1656
  "event": "checkpoint_saved",
@@ -1403,13 +1660,30 @@ class DataSetGenerator:
1403
1660
  }
1404
1661
 
1405
1662
  # Calculate total counts including flushed data
1406
- total_samples = self._flushed_samples_count + len(self._samples)
1407
- total_failures = self._flushed_failures_count + len(self.failed_samples)
1663
+ actual_samples = self._flushed_samples_count + len(self._samples)
1664
+ actual_failures = self._flushed_failures_count + len(self.failed_samples)
1665
+ total_accounted = actual_samples + actual_failures + topics_exhausted_count
1666
+ true_unaccounted = total_samples - total_accounted
1667
+
1668
+ # Log accounting summary for debugging
1669
+ logger.info(
1670
+ "Generation complete: expected=%d, generated=%d, failed=%d, topics_exhausted=%d, "
1671
+ "accounted=%d, unaccounted=%d",
1672
+ total_samples, # This is the parameter passed in (expected count)
1673
+ actual_samples,
1674
+ actual_failures,
1675
+ topics_exhausted_count,
1676
+ total_accounted,
1677
+ true_unaccounted,
1678
+ )
1408
1679
 
1409
1680
  yield {
1410
1681
  "event": "generation_complete",
1411
- "total_samples": total_samples,
1412
- "failed_samples": total_failures,
1682
+ "total_samples": actual_samples,
1683
+ "failed_samples": actual_failures,
1684
+ "expected_samples": total_samples,
1685
+ "topics_exhausted": topics_exhausted_count,
1686
+ "unaccounted": true_unaccounted,
1413
1687
  }
1414
1688
 
1415
1689
  except KeyboardInterrupt:
@@ -1417,10 +1691,13 @@ class DataSetGenerator:
1417
1691
  if self.config.checkpoint_interval is not None and (
1418
1692
  samples_in_current_batch or failures_in_current_batch
1419
1693
  ):
1694
+ completed_items = [
1695
+ (tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
1696
+ ]
1420
1697
  self._save_checkpoint(
1421
1698
  samples_in_current_batch,
1422
1699
  failures_in_current_batch,
1423
- topic_paths_in_current_batch,
1700
+ completed_items,
1424
1701
  )
1425
1702
  yield {
1426
1703
  "event": "generation_interrupted",
@@ -1434,10 +1711,327 @@ class DataSetGenerator:
1434
1711
  if self.config.checkpoint_interval is not None and (
1435
1712
  samples_in_current_batch or failures_in_current_batch
1436
1713
  ):
1714
+ completed_items = [
1715
+ (tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
1716
+ ]
1437
1717
  self._save_checkpoint(
1438
1718
  samples_in_current_batch,
1439
1719
  failures_in_current_batch,
1440
- topic_paths_in_current_batch,
1720
+ completed_items,
1721
+ )
1722
+ yield {"event": "generation_error", "error": str(e)}
1723
+ self.print_failure_summary()
1724
+ self._save_samples_to_file(ERROR_DATASET_FILENAME)
1725
+ raise DataSetGeneratorError("failed") from e
1726
+
1727
+ # Build final dataset: if samples were flushed to disk, load them from checkpoint
1728
+ if self._flushed_samples_count > 0:
1729
+ all_samples = self._load_all_samples_from_checkpoint()
1730
+ yield HFDataset.from_list(all_samples) if all_samples else HFDataset.from_list([])
1731
+ else:
1732
+ yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
1733
+
1734
+ async def _run_cycle_based_generation_async( # noqa: PLR0912, PLR0915
1735
+ self,
1736
+ unique_topics: list[Topic],
1737
+ cycles_needed: int,
1738
+ total_samples: int,
1739
+ concurrency: int,
1740
+ data_creation_prompt: str,
1741
+ num_example_demonstrations: int,
1742
+ include_sys_msg: bool,
1743
+ root_topic_prompt: str | None = None,
1744
+ topic_model_type: str | None = None,
1745
+ topic_model: "TopicModel | None" = None,
1746
+ ) -> AsyncGenerator[dict | HFDataset, None]:
1747
+ """Run cycle-based generation loop yielding progress events.
1748
+
1749
+ This is the new generation model that iterates over unique topics (by UUID)
1750
+ for multiple cycles, rather than the old step-based batching approach.
1751
+
1752
+ Args:
1753
+ unique_topics: List of Topic namedtuples with (uuid, topic).
1754
+ cycles_needed: Number of cycles to iterate through topics.
1755
+ total_samples: Total number of samples to generate.
1756
+ concurrency: Maximum parallel LLM calls (semaphore limit).
1757
+ data_creation_prompt: The prompt template for data creation.
1758
+ num_example_demonstrations: Number of example demonstrations to include.
1759
+ include_sys_msg: Whether to include system message in output.
1760
+ root_topic_prompt: Original topic prompt for display.
1761
+ topic_model_type: Type of topic model (tree, graph) for display.
1762
+ topic_model: Topic model for path lookup (recovers full path context).
1763
+
1764
+ Yields:
1765
+ Progress event dicts and final HFDataset.
1766
+ """
1767
+ unique_topic_count = len(unique_topics)
1768
+ final_cycle_size = total_samples - (cycles_needed - 1) * unique_topic_count
1769
+
1770
+ # Initialize checkpoint paths if checkpointing is enabled
1771
+ if self.config.checkpoint_interval is not None:
1772
+ self._initialize_checkpoint_paths()
1773
+
1774
+ # Track samples for checkpointing
1775
+ samples_since_checkpoint = 0
1776
+ pending_samples: list[dict] = []
1777
+ pending_failures: list[dict] = []
1778
+ pending_completed: list[tuple[str, int]] = []
1779
+ samples_generated = 0
1780
+
1781
+ # Create semaphore for concurrency control
1782
+ semaphore = asyncio.Semaphore(concurrency)
1783
+
1784
+ try:
1785
+ yield {
1786
+ "event": "generation_start",
1787
+ "model_name": self.model_name,
1788
+ "unique_topics": unique_topic_count,
1789
+ "cycles_needed": cycles_needed,
1790
+ "final_cycle_size": final_cycle_size,
1791
+ "concurrency": concurrency,
1792
+ "total_samples": total_samples,
1793
+ "root_topic_prompt": root_topic_prompt,
1794
+ "topic_model_type": topic_model_type,
1795
+ "resumed_from_checkpoint": len(self._completed) > 0,
1796
+ "previously_completed": len(self._completed),
1797
+ "resumed_samples": self._flushed_samples_count,
1798
+ "resumed_failures": self._flushed_failures_count,
1799
+ "checkpoint_enabled": self.config.checkpoint_interval is not None,
1800
+ "checkpoint_interval": self.config.checkpoint_interval,
1801
+ }
1802
+
1803
+ for cycle in range(cycles_needed):
1804
+ # Determine how many topics to process in this cycle
1805
+ if cycle == cycles_needed - 1:
1806
+ # Final cycle may be partial
1807
+ topics_in_cycle = min(final_cycle_size, unique_topic_count)
1808
+ else:
1809
+ topics_in_cycle = unique_topic_count
1810
+
1811
+ yield {
1812
+ "event": "cycle_start",
1813
+ "cycle": cycle + 1,
1814
+ "total_cycles": cycles_needed,
1815
+ "topics_in_cycle": topics_in_cycle,
1816
+ }
1817
+
1818
+ # Collect topics to process in this cycle
1819
+ topics_to_process: list[Topic] = []
1820
+ for topic_idx, topic in enumerate(unique_topics):
1821
+ if cycle == cycles_needed - 1 and topic_idx >= final_cycle_size:
1822
+ break # Partial final cycle - stop early
1823
+
1824
+ if not self._is_completed(topic.uuid, cycle):
1825
+ topics_to_process.append(topic)
1826
+
1827
+ if not topics_to_process:
1828
+ # All topics in this cycle already completed (resume scenario)
1829
+ yield {
1830
+ "event": "cycle_complete",
1831
+ "cycle": cycle + 1,
1832
+ "samples_in_cycle": 0,
1833
+ "skipped": topics_in_cycle,
1834
+ }
1835
+ continue
1836
+
1837
+ # Process topics with concurrency control
1838
+ cycle_samples = 0
1839
+ cycle_failures = 0
1840
+
1841
+ async def process_topic(
1842
+ topic: Topic, cycle_num: int, sample_idx: int
1843
+ ) -> tuple[tuple[str, int], bool, int]:
1844
+ """Process a single topic with semaphore-controlled concurrency.
1845
+
1846
+ Note: Samples/failures are NOT extracted here because concurrent
1847
+ tasks share self._samples. Extracting per-task would cause
1848
+ duplicates since each task's slice overlaps with others.
1849
+ Instead, samples are captured in bulk after asyncio.gather.
1850
+
1851
+ Returns:
1852
+ Tuple of (completed_item, success, count)
1853
+ """
1854
+ async with semaphore:
1855
+ # Recover full path context (root -> ... -> leaf) for prompt quality
1856
+ full_path = topic_model.get_path_by_id(topic.uuid) if topic_model else None
1857
+ subtopics = full_path if full_path else [topic.topic]
1858
+
1859
+ # Build prompt for this topic
1860
+ sample_prompt = self.build_prompt(
1861
+ data_creation_prompt=data_creation_prompt,
1862
+ num_example_demonstrations=num_example_demonstrations,
1863
+ subtopics_list=subtopics,
1864
+ )
1865
+
1866
+ # Use existing batch processing for a single sample
1867
+ topic_path = TopicPath(path=subtopics, topic_id=topic.uuid)
1868
+ success, count = await self._process_batch_with_retries_async(
1869
+ prompts=[sample_prompt],
1870
+ include_sys_msg=include_sys_msg,
1871
+ start_sample_idx=sample_idx,
1872
+ topic_paths_for_batch=[topic_path],
1873
+ )
1874
+
1875
+ completed_item = (topic.uuid, cycle_num)
1876
+ return completed_item, success, count
1877
+
1878
+ # Process topics in batches for checkpoint saving
1879
+ for batch_start in range(0, len(topics_to_process), concurrency):
1880
+ batch_end = min(batch_start + concurrency, len(topics_to_process))
1881
+ batch_topics = topics_to_process[batch_start:batch_end]
1882
+
1883
+ # Snapshot list lengths before tasks start so we can
1884
+ # capture new items in one safe slice after all complete.
1885
+ samples_before_gather = len(self._samples)
1886
+ failures_before_gather = len(self.failed_samples)
1887
+
1888
+ # Create concurrent tasks
1889
+ # Pass current samples_generated as starting index for each task
1890
+ tasks = [
1891
+ asyncio.create_task(process_topic(topic, cycle, samples_generated + i))
1892
+ for i, topic in enumerate(batch_topics)
1893
+ ]
1894
+
1895
+ # Process results as each task finishes (not waiting for
1896
+ # the whole batch) so the progress bar advances per-sample.
1897
+ batch_samples = 0
1898
+ batch_failures = 0
1899
+ for future in asyncio.as_completed(tasks):
1900
+ completed_item, success, count = await future
1901
+ pending_completed.append(completed_item)
1902
+
1903
+ if success:
1904
+ cycle_samples += count
1905
+ samples_generated += count
1906
+ batch_samples += count
1907
+ else:
1908
+ cycle_failures += 1
1909
+ batch_failures += 1
1910
+
1911
+ samples_since_checkpoint += 1
1912
+
1913
+ # Emit per-sample progress so progress bars advance immediately
1914
+ yield {
1915
+ "event": "batch_complete",
1916
+ "samples_generated": count if success else 0,
1917
+ "samples_failed": 1 if not success else 0,
1918
+ }
1919
+
1920
+ # After all tasks complete, capture samples/failures in
1921
+ # one safe slice (no concurrent tasks are running now).
1922
+ batch_new_samples = list(self._samples[samples_before_gather:])
1923
+ batch_new_failures = list(self.failed_samples[failures_before_gather:])
1924
+ # Tag each failure with the cycle number so retry logic can
1925
+ # target only the specific (uuid, cycle) that failed
1926
+ for failure in batch_new_failures:
1927
+ failure["cycle"] = cycle
1928
+ pending_samples.extend(batch_new_samples)
1929
+ pending_failures.extend(batch_new_failures)
1930
+
1931
+ # Save checkpoint if we've reached the interval
1932
+ if (
1933
+ self.config.checkpoint_interval is not None
1934
+ and samples_since_checkpoint >= self.config.checkpoint_interval
1935
+ ):
1936
+ self._save_checkpoint(
1937
+ pending_samples,
1938
+ pending_failures,
1939
+ pending_completed,
1940
+ )
1941
+ pending_samples = []
1942
+ pending_failures = []
1943
+ pending_completed = []
1944
+ samples_since_checkpoint = 0
1945
+ yield {
1946
+ "event": "checkpoint_saved",
1947
+ "total_samples": self._flushed_samples_count,
1948
+ "total_failures": self._flushed_failures_count,
1949
+ }
1950
+
1951
+ # Check for graceful stop request
1952
+ if self.stop_requested:
1953
+ yield {
1954
+ "event": "generation_stopped",
1955
+ "message": "Stopped at checkpoint as requested",
1956
+ "total_samples": self._flushed_samples_count,
1957
+ "total_failures": self._flushed_failures_count,
1958
+ }
1959
+ return
1960
+
1961
+ yield {
1962
+ "event": "cycle_complete",
1963
+ "cycle": cycle + 1,
1964
+ "samples_in_cycle": cycle_samples,
1965
+ "failures_in_cycle": cycle_failures,
1966
+ }
1967
+
1968
+ # Save final checkpoint with any remaining samples
1969
+ if self.config.checkpoint_interval is not None and (
1970
+ pending_samples or pending_failures
1971
+ ):
1972
+ self._save_checkpoint(
1973
+ pending_samples,
1974
+ pending_failures,
1975
+ pending_completed,
1976
+ )
1977
+ yield {
1978
+ "event": "checkpoint_saved",
1979
+ "total_samples": self._flushed_samples_count,
1980
+ "total_failures": self._flushed_failures_count,
1981
+ "final": True,
1982
+ }
1983
+
1984
+ # Calculate total counts including flushed data
1985
+ actual_samples = self._flushed_samples_count + len(self._samples)
1986
+ actual_failures = self._flushed_failures_count + len(self.failed_samples)
1987
+ unaccounted = total_samples - actual_samples - actual_failures
1988
+
1989
+ logger.info(
1990
+ "Generation complete: expected=%d, generated=%d, failed=%d, "
1991
+ "accounted=%d, unaccounted=%d",
1992
+ total_samples,
1993
+ actual_samples,
1994
+ actual_failures,
1995
+ actual_samples + actual_failures,
1996
+ unaccounted,
1997
+ )
1998
+
1999
+ yield {
2000
+ "event": "generation_complete",
2001
+ "total_samples": actual_samples,
2002
+ "failed_samples": actual_failures,
2003
+ "expected_samples": total_samples,
2004
+ "unaccounted": unaccounted,
2005
+ "cycles_completed": cycles_needed,
2006
+ "unique_topics": unique_topic_count,
2007
+ }
2008
+
2009
+ except KeyboardInterrupt:
2010
+ # Save checkpoint on interrupt
2011
+ if self.config.checkpoint_interval is not None and (
2012
+ pending_samples or pending_failures
2013
+ ):
2014
+ self._save_checkpoint(
2015
+ pending_samples,
2016
+ pending_failures,
2017
+ pending_completed,
2018
+ )
2019
+ yield {
2020
+ "event": "generation_interrupted",
2021
+ "message": "Generation interrupted by user.",
2022
+ }
2023
+ self.print_failure_summary()
2024
+ self._save_samples_to_file(INTERRUPTED_DATASET_FILENAME)
2025
+
2026
+ except Exception as e: # noqa: BLE001
2027
+ # Save checkpoint on error
2028
+ if self.config.checkpoint_interval is not None and (
2029
+ pending_samples or pending_failures
2030
+ ):
2031
+ self._save_checkpoint(
2032
+ pending_samples,
2033
+ pending_failures,
2034
+ pending_completed,
1441
2035
  )
1442
2036
  yield {"event": "generation_error", "error": str(e)}
1443
2037
  self.print_failure_summary()
@@ -1465,6 +2059,17 @@ class DataSetGenerator:
1465
2059
  prompts, include_sys_msg, start_sample_idx, topic_paths_for_batch
1466
2060
  )
1467
2061
 
2062
+ # Verify all prompts are accounted for (no silent drops)
2063
+ accounted = len(samples) + len(failed_responses)
2064
+ if accounted != len(prompts):
2065
+ logger.warning(
2066
+ "Sample accounting mismatch: %d prompts, %d samples, %d failures (missing %d)",
2067
+ len(prompts),
2068
+ len(samples),
2069
+ len(failed_responses),
2070
+ len(prompts) - accounted,
2071
+ )
2072
+
1468
2073
  # Update failed samples
1469
2074
  self.failed_samples.extend(failed_responses)
1470
2075