DeepFabric 4.10.1__py3-none-any.whl → 4.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/cli.py +83 -27
- deepfabric/cloud_upload.py +1 -1
- deepfabric/config.py +6 -4
- deepfabric/constants.py +1 -1
- deepfabric/dataset_manager.py +264 -62
- deepfabric/generator.py +687 -82
- deepfabric/graph.py +25 -1
- deepfabric/llm/retry_handler.py +28 -9
- deepfabric/progress.py +42 -0
- deepfabric/topic_manager.py +22 -2
- deepfabric/topic_model.py +26 -0
- deepfabric/tree.py +41 -16
- deepfabric/tui.py +448 -349
- deepfabric/utils.py +4 -1
- {deepfabric-4.10.1.dist-info → deepfabric-4.11.0.dist-info}/METADATA +3 -1
- {deepfabric-4.10.1.dist-info → deepfabric-4.11.0.dist-info}/RECORD +19 -19
- {deepfabric-4.10.1.dist-info → deepfabric-4.11.0.dist-info}/licenses/LICENSE +1 -1
- {deepfabric-4.10.1.dist-info → deepfabric-4.11.0.dist-info}/WHEEL +0 -0
- {deepfabric-4.10.1.dist-info → deepfabric-4.11.0.dist-info}/entry_points.txt +0 -0
deepfabric/generator.py
CHANGED
|
@@ -46,7 +46,7 @@ from .prompts import (
|
|
|
46
46
|
from .schemas import Conversation, ToolRegistry, get_conversation_schema
|
|
47
47
|
from .tools import BUILTIN_TOOL_REGISTRY
|
|
48
48
|
from .tools.loader import load_tools_from_dict, load_tools_from_endpoint
|
|
49
|
-
from .topic_model import TopicModel, TopicPath
|
|
49
|
+
from .topic_model import Topic, TopicModel, TopicPath
|
|
50
50
|
from .utils import ensure_not_running_loop, get_checkpoint_dir, is_validation_error
|
|
51
51
|
|
|
52
52
|
# Handle circular import for type hints
|
|
@@ -275,13 +275,16 @@ class DataSetGenerator:
|
|
|
275
275
|
|
|
276
276
|
# Checkpoint state
|
|
277
277
|
self._checkpoint_samples_since_save = 0
|
|
278
|
-
self.
|
|
278
|
+
self._completed: set[tuple[str, int]] = set() # Track (uuid, cycle) tuples
|
|
279
279
|
self._checkpoint_metadata_path: Path | None = None
|
|
280
280
|
self._checkpoint_samples_path: Path | None = None
|
|
281
281
|
self._checkpoint_failures_path: Path | None = None
|
|
282
282
|
# Memory optimization: track flushed counts for checkpoint mode
|
|
283
283
|
self._flushed_samples_count = 0
|
|
284
284
|
self._flushed_failures_count = 0
|
|
285
|
+
# Generation state for cycle-based iteration
|
|
286
|
+
self._unique_topics: int = 0 # Number of unique topics
|
|
287
|
+
self._cycles_needed: int = 0 # Total cycles for requested samples
|
|
285
288
|
|
|
286
289
|
# Graceful stop flag - set by signal handler to stop at next checkpoint
|
|
287
290
|
self.stop_requested = False
|
|
@@ -383,7 +386,7 @@ class DataSetGenerator:
|
|
|
383
386
|
self,
|
|
384
387
|
new_samples: list[dict],
|
|
385
388
|
new_failures: list[dict],
|
|
386
|
-
|
|
389
|
+
completed_items: list[tuple[str, int]],
|
|
387
390
|
flush_memory: bool = True,
|
|
388
391
|
) -> None:
|
|
389
392
|
"""Save checkpoint data incrementally.
|
|
@@ -391,7 +394,7 @@ class DataSetGenerator:
|
|
|
391
394
|
Args:
|
|
392
395
|
new_samples: New successful samples to append
|
|
393
396
|
new_failures: New failed samples to append
|
|
394
|
-
|
|
397
|
+
completed_items: List of (uuid, cycle) tuples that were completed
|
|
395
398
|
flush_memory: If True, clear flushed samples from memory (memory optimization)
|
|
396
399
|
"""
|
|
397
400
|
if self._checkpoint_samples_path is None:
|
|
@@ -409,10 +412,9 @@ class DataSetGenerator:
|
|
|
409
412
|
for failure in new_failures:
|
|
410
413
|
f.write(json.dumps(failure, separators=(",", ":")) + "\n")
|
|
411
414
|
|
|
412
|
-
# Track
|
|
413
|
-
for
|
|
414
|
-
|
|
415
|
-
self._processed_ids.add(topic_path.topic_id)
|
|
415
|
+
# Track completed (uuid, cycle) tuples
|
|
416
|
+
for item in completed_items:
|
|
417
|
+
self._completed.add(item)
|
|
416
418
|
|
|
417
419
|
# Memory optimization: track flushed counts and clear in-memory lists
|
|
418
420
|
# Must happen BEFORE saving metadata so counts are accurate
|
|
@@ -427,10 +429,10 @@ class DataSetGenerator:
|
|
|
427
429
|
self._save_checkpoint_metadata()
|
|
428
430
|
|
|
429
431
|
logger.debug(
|
|
430
|
-
"Checkpoint saved: %d samples, %d failures, %d
|
|
432
|
+
"Checkpoint saved: %d samples, %d failures, %d completed tuples (flushed=%s)",
|
|
431
433
|
len(new_samples),
|
|
432
434
|
len(new_failures),
|
|
433
|
-
len(self.
|
|
435
|
+
len(self._completed),
|
|
434
436
|
flush_memory,
|
|
435
437
|
)
|
|
436
438
|
|
|
@@ -443,6 +445,9 @@ class DataSetGenerator:
|
|
|
443
445
|
total_samples = self._flushed_samples_count + len(self._samples)
|
|
444
446
|
total_failures = self._flushed_failures_count + len(self.failed_samples)
|
|
445
447
|
|
|
448
|
+
# Convert completed set of (uuid, cycle) tuples to list of [uuid, cycle] arrays
|
|
449
|
+
completed_list = [[uuid, cycle] for uuid, cycle in self._completed]
|
|
450
|
+
|
|
446
451
|
metadata = {
|
|
447
452
|
"version": CHECKPOINT_VERSION,
|
|
448
453
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
@@ -452,7 +457,9 @@ class DataSetGenerator:
|
|
|
452
457
|
"reasoning_style": self.config.reasoning_style,
|
|
453
458
|
"total_samples": total_samples,
|
|
454
459
|
"total_failures": total_failures,
|
|
455
|
-
"
|
|
460
|
+
"completed": completed_list, # List of [uuid, cycle] arrays
|
|
461
|
+
"unique_topics": self._unique_topics,
|
|
462
|
+
"cycles_needed": self._cycles_needed,
|
|
456
463
|
"checkpoint_interval": self.config.checkpoint_interval,
|
|
457
464
|
"topics_file": self.config.topics_file,
|
|
458
465
|
}
|
|
@@ -528,12 +535,26 @@ class DataSetGenerator:
|
|
|
528
535
|
version = metadata.get("version")
|
|
529
536
|
if version is None:
|
|
530
537
|
error_msg = "Missing 'version' field in checkpoint metadata"
|
|
531
|
-
elif version
|
|
532
|
-
|
|
538
|
+
elif version < CHECKPOINT_VERSION:
|
|
539
|
+
# Old checkpoint format - require fresh start
|
|
540
|
+
error_msg = (
|
|
541
|
+
f"Checkpoint format v{version} is incompatible with current version v{CHECKPOINT_VERSION}. "
|
|
542
|
+
"Please delete the checkpoint and restart: rm -rf .checkpoints/"
|
|
543
|
+
)
|
|
544
|
+
elif version > CHECKPOINT_VERSION:
|
|
545
|
+
error_msg = (
|
|
546
|
+
f"Checkpoint version {version} is newer than supported version {CHECKPOINT_VERSION}"
|
|
547
|
+
)
|
|
533
548
|
|
|
534
|
-
# Check required fields
|
|
549
|
+
# Check required fields for v4 format
|
|
535
550
|
if error_msg is None:
|
|
536
|
-
required_fields = [
|
|
551
|
+
required_fields = [
|
|
552
|
+
"created_at",
|
|
553
|
+
"total_samples",
|
|
554
|
+
"completed",
|
|
555
|
+
"unique_topics",
|
|
556
|
+
"cycles_needed",
|
|
557
|
+
]
|
|
537
558
|
for field in required_fields:
|
|
538
559
|
if field not in metadata:
|
|
539
560
|
error_msg = f"Missing required field in checkpoint metadata: {field}"
|
|
@@ -579,8 +600,7 @@ class DataSetGenerator:
|
|
|
579
600
|
|
|
580
601
|
self._initialize_checkpoint_paths()
|
|
581
602
|
return (
|
|
582
|
-
self._checkpoint_metadata_path is not None
|
|
583
|
-
and self._checkpoint_metadata_path.exists()
|
|
603
|
+
self._checkpoint_metadata_path is not None and self._checkpoint_metadata_path.exists()
|
|
584
604
|
)
|
|
585
605
|
|
|
586
606
|
def load_checkpoint(self, retry_failed: bool = False) -> bool:
|
|
@@ -614,8 +634,11 @@ class DataSetGenerator:
|
|
|
614
634
|
# Validate config compatibility
|
|
615
635
|
self._validate_checkpoint_compatibility(metadata)
|
|
616
636
|
|
|
617
|
-
# Restore
|
|
618
|
-
|
|
637
|
+
# Restore completed (uuid, cycle) tuples
|
|
638
|
+
completed_list = metadata.get("completed", [])
|
|
639
|
+
self._completed = {(item[0], item[1]) for item in completed_list}
|
|
640
|
+
self._unique_topics = metadata.get("unique_topics", 0)
|
|
641
|
+
self._cycles_needed = metadata.get("cycles_needed", 0)
|
|
619
642
|
|
|
620
643
|
# Count existing samples (don't load into memory - they're already on disk)
|
|
621
644
|
# Memory optimization: track as flushed counts instead of loading into RAM
|
|
@@ -631,36 +654,47 @@ class DataSetGenerator:
|
|
|
631
654
|
failed_ids: set[str] = set()
|
|
632
655
|
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
633
656
|
failure_count = 0
|
|
657
|
+
failed_tuples: set[tuple[str, int]] = set()
|
|
634
658
|
with open(self._checkpoint_failures_path, encoding="utf-8") as f:
|
|
635
659
|
for raw_line in f:
|
|
636
660
|
stripped = raw_line.strip()
|
|
637
661
|
if stripped:
|
|
638
662
|
failure = json.loads(stripped)
|
|
639
663
|
failure_count += 1
|
|
640
|
-
# Track
|
|
664
|
+
# Track failed (topic_id, cycle) for targeted retry
|
|
641
665
|
if "topic_id" in failure:
|
|
642
666
|
failed_ids.add(failure["topic_id"])
|
|
667
|
+
if "cycle" in failure:
|
|
668
|
+
failed_tuples.add((failure["topic_id"], failure["cycle"]))
|
|
643
669
|
self._flushed_failures_count = failure_count
|
|
644
670
|
|
|
645
|
-
# If retry_failed is True, remove failed
|
|
671
|
+
# If retry_failed is True, remove failed entries from completed set
|
|
646
672
|
# so they will be retried during generation
|
|
647
673
|
if retry_failed and failed_ids:
|
|
648
|
-
|
|
649
|
-
|
|
674
|
+
if failed_tuples:
|
|
675
|
+
# Targeted retry: only remove the specific (uuid, cycle) that failed
|
|
676
|
+
tuples_to_remove = self._completed & failed_tuples
|
|
677
|
+
else:
|
|
678
|
+
# Legacy fallback: no cycle info, remove all cycles for failed UUIDs
|
|
679
|
+
tuples_to_remove = {
|
|
680
|
+
(uuid, cycle) for uuid, cycle in self._completed if uuid in failed_ids
|
|
681
|
+
}
|
|
682
|
+
self._completed -= tuples_to_remove
|
|
650
683
|
# Clear failures file since we're retrying
|
|
651
684
|
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
652
685
|
os.remove(self._checkpoint_failures_path)
|
|
653
686
|
self._flushed_failures_count = 0
|
|
654
687
|
logger.info(
|
|
655
|
-
"Retry mode: %d failed
|
|
656
|
-
len(
|
|
688
|
+
"Retry mode: %d failed UUIDs (%d tuples) will be retried",
|
|
689
|
+
len(failed_ids),
|
|
690
|
+
len(tuples_to_remove),
|
|
657
691
|
)
|
|
658
692
|
|
|
659
693
|
logger.info(
|
|
660
|
-
"Loaded checkpoint: %d samples, %d failures, %d
|
|
694
|
+
"Loaded checkpoint: %d samples, %d failures, %d completed (uuid, cycle) tuples",
|
|
661
695
|
self._flushed_samples_count,
|
|
662
696
|
self._flushed_failures_count,
|
|
663
|
-
len(self.
|
|
697
|
+
len(self._completed),
|
|
664
698
|
)
|
|
665
699
|
except Exception as e: # noqa: BLE001
|
|
666
700
|
logger.warning("Failed to load checkpoint: %s", e)
|
|
@@ -676,7 +710,7 @@ class DataSetGenerator:
|
|
|
676
710
|
os.remove(self._checkpoint_samples_path)
|
|
677
711
|
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
678
712
|
os.remove(self._checkpoint_failures_path)
|
|
679
|
-
self.
|
|
713
|
+
self._completed.clear()
|
|
680
714
|
self._flushed_samples_count = 0
|
|
681
715
|
self._flushed_failures_count = 0
|
|
682
716
|
logger.info("Checkpoint files cleared")
|
|
@@ -723,18 +757,31 @@ class DataSetGenerator:
|
|
|
723
757
|
|
|
724
758
|
return all_failures
|
|
725
759
|
|
|
726
|
-
def
|
|
727
|
-
"""Check if a
|
|
760
|
+
def _is_completed(self, uuid: str, cycle: int) -> bool:
|
|
761
|
+
"""Check if a (uuid, cycle) combination has been completed.
|
|
728
762
|
|
|
729
763
|
Args:
|
|
730
|
-
|
|
764
|
+
uuid: Topic UUID
|
|
765
|
+
cycle: Cycle number (0-indexed)
|
|
731
766
|
|
|
732
767
|
Returns:
|
|
733
|
-
True if
|
|
768
|
+
True if this (uuid, cycle) was already completed
|
|
734
769
|
"""
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
770
|
+
return (uuid, cycle) in self._completed
|
|
771
|
+
|
|
772
|
+
def _is_uuid_completed_any_cycle(self, uuid: str) -> bool:
|
|
773
|
+
"""Check if a UUID has been completed in any cycle.
|
|
774
|
+
|
|
775
|
+
Used during transition period when step-based generation
|
|
776
|
+
still needs to check for completed topics.
|
|
777
|
+
|
|
778
|
+
Args:
|
|
779
|
+
uuid: Topic UUID to check
|
|
780
|
+
|
|
781
|
+
Returns:
|
|
782
|
+
True if this UUID has been completed in at least one cycle
|
|
783
|
+
"""
|
|
784
|
+
return any(u == uuid for u, _cycle in self._completed)
|
|
738
785
|
|
|
739
786
|
def _validate_create_data_params(
|
|
740
787
|
self,
|
|
@@ -766,20 +813,109 @@ class DataSetGenerator:
|
|
|
766
813
|
topic_paths = topic_model.get_all_paths_with_ids()
|
|
767
814
|
total_paths = len(topic_paths)
|
|
768
815
|
required_samples = num_steps * batch_size
|
|
816
|
+
logger.info(
|
|
817
|
+
"Topic preparation: total_paths=%d, required_samples=%d, num_steps=%d, batch_size=%d",
|
|
818
|
+
total_paths,
|
|
819
|
+
required_samples,
|
|
820
|
+
num_steps,
|
|
821
|
+
batch_size,
|
|
822
|
+
)
|
|
769
823
|
|
|
770
824
|
if required_samples > total_paths:
|
|
771
825
|
# Cycle through topics to generate more samples than paths
|
|
772
826
|
# Each topic will be used multiple times for even coverage
|
|
827
|
+
# IMPORTANT: Create unique topic_ids using global index position.
|
|
828
|
+
# This handles two cases:
|
|
829
|
+
# 1. Cycling: same path appears multiple times across cycles
|
|
830
|
+
# 2. Graph duplicates: multiple paths can share the same node topic_id
|
|
831
|
+
# Using index ensures every position in the list has a unique ID.
|
|
773
832
|
multiplier = math.ceil(required_samples / total_paths)
|
|
774
|
-
|
|
833
|
+
cycled_paths: list[TopicPath] = []
|
|
834
|
+
for _cycle in range(multiplier):
|
|
835
|
+
for _path_idx, tp in enumerate(topic_paths):
|
|
836
|
+
if len(cycled_paths) >= required_samples:
|
|
837
|
+
break
|
|
838
|
+
# Use global index for uniqueness: handles both cycling and graph duplicates
|
|
839
|
+
global_idx = len(cycled_paths)
|
|
840
|
+
unique_id = f"{tp.topic_id}_idx_{global_idx}"
|
|
841
|
+
cycled_paths.append(TopicPath(path=tp.path, topic_id=unique_id))
|
|
842
|
+
if len(cycled_paths) >= required_samples:
|
|
843
|
+
break
|
|
844
|
+
topic_paths = cycled_paths
|
|
845
|
+
logger.info(
|
|
846
|
+
"Topics cycled: %d original paths × %d cycles = %d total (trimmed to %d)",
|
|
847
|
+
total_paths,
|
|
848
|
+
multiplier,
|
|
849
|
+
total_paths * multiplier,
|
|
850
|
+
len(topic_paths),
|
|
851
|
+
)
|
|
775
852
|
elif required_samples < total_paths:
|
|
776
853
|
# Sample subset (percentage case or explicit count < total)
|
|
777
854
|
# Bandit: not a security function
|
|
778
855
|
topic_paths = random.sample(topic_paths, required_samples) # nosec
|
|
779
|
-
|
|
856
|
+
# Assign unique IDs to handle graphs with duplicate node IDs
|
|
857
|
+
topic_paths = [
|
|
858
|
+
TopicPath(path=tp.path, topic_id=f"{tp.topic_id}_idx_{idx}")
|
|
859
|
+
for idx, tp in enumerate(topic_paths)
|
|
860
|
+
]
|
|
861
|
+
else:
|
|
862
|
+
# required_samples == total_paths - use all paths but with unique IDs
|
|
863
|
+
# Graphs can have duplicate topic_ids (multiple paths to same node)
|
|
864
|
+
topic_paths = [
|
|
865
|
+
TopicPath(path=tp.path, topic_id=f"{tp.topic_id}_idx_{idx}")
|
|
866
|
+
for idx, tp in enumerate(topic_paths)
|
|
867
|
+
]
|
|
868
|
+
|
|
869
|
+
logger.info("Topic paths after preparation: %d paths", len(topic_paths))
|
|
780
870
|
|
|
781
871
|
return topic_paths, num_steps
|
|
782
872
|
|
|
873
|
+
def _prepare_unique_topics(
|
|
874
|
+
self,
|
|
875
|
+
total_samples: int,
|
|
876
|
+
topic_model: "TopicModel",
|
|
877
|
+
) -> tuple[list[Topic], int]:
|
|
878
|
+
"""Prepare unique topics and calculate cycles needed for generation.
|
|
879
|
+
|
|
880
|
+
This method supports the new cycle-based generation model where we iterate
|
|
881
|
+
over unique topics (by UUID) multiple times (cycles) to generate the
|
|
882
|
+
requested number of samples.
|
|
883
|
+
|
|
884
|
+
Args:
|
|
885
|
+
total_samples: Total number of samples to generate.
|
|
886
|
+
topic_model: The topic model (Tree or Graph) to extract topics from.
|
|
887
|
+
|
|
888
|
+
Returns:
|
|
889
|
+
Tuple of (unique_topics, cycles_needed):
|
|
890
|
+
- unique_topics: List of Topic namedtuples with (uuid, topic)
|
|
891
|
+
- cycles_needed: Number of times to iterate through all topics
|
|
892
|
+
"""
|
|
893
|
+
unique_topics = topic_model.get_unique_topics()
|
|
894
|
+
unique_count = len(unique_topics)
|
|
895
|
+
|
|
896
|
+
if unique_count == 0:
|
|
897
|
+
raise DataSetGeneratorError(
|
|
898
|
+
"Topic model has no unique topics. Ensure the topic model was built successfully."
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
# Calculate cycles needed to cover the requested samples
|
|
902
|
+
# e.g., 5000 samples from 1875 topics = ceil(5000/1875) = 3 cycles
|
|
903
|
+
cycles_needed = math.ceil(total_samples / unique_count)
|
|
904
|
+
|
|
905
|
+
# Calculate how many samples we'll generate in the final (partial) cycle
|
|
906
|
+
final_cycle_size = total_samples - (cycles_needed - 1) * unique_count
|
|
907
|
+
|
|
908
|
+
logger.info(
|
|
909
|
+
"Topic preparation: unique_topics=%d, requested_samples=%d, cycles_needed=%d, "
|
|
910
|
+
"final_cycle_size=%d",
|
|
911
|
+
unique_count,
|
|
912
|
+
total_samples,
|
|
913
|
+
cycles_needed,
|
|
914
|
+
final_cycle_size,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
return unique_topics, cycles_needed
|
|
918
|
+
|
|
783
919
|
def _generate_batch_prompts(
|
|
784
920
|
self,
|
|
785
921
|
batch_size: int,
|
|
@@ -1050,9 +1186,11 @@ class DataSetGenerator:
|
|
|
1050
1186
|
return "malformed_responses"
|
|
1051
1187
|
|
|
1052
1188
|
def summarize_failures(self) -> dict:
|
|
1053
|
-
"""Generate a summary of all failures."""
|
|
1189
|
+
"""Generate a summary of all failures including those flushed to checkpoint."""
|
|
1190
|
+
# Include both in-memory and flushed failures for accurate total
|
|
1191
|
+
total_failures = self._flushed_failures_count + len(self.failed_samples)
|
|
1054
1192
|
summary = {
|
|
1055
|
-
"total_failures":
|
|
1193
|
+
"total_failures": total_failures,
|
|
1056
1194
|
"failure_types": {k: len(v) for k, v in self.failure_analysis.items()},
|
|
1057
1195
|
"failure_examples": {},
|
|
1058
1196
|
}
|
|
@@ -1155,22 +1293,59 @@ class DataSetGenerator:
|
|
|
1155
1293
|
|
|
1156
1294
|
include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
|
|
1157
1295
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1296
|
+
# Calculate total samples requested
|
|
1160
1297
|
total_samples = num_steps * batch_size
|
|
1161
1298
|
data_creation_prompt = self._get_cot_prompt_template()
|
|
1162
1299
|
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
topic_paths=topic_paths or [],
|
|
1169
|
-
data_creation_prompt=data_creation_prompt,
|
|
1170
|
-
num_example_demonstrations=num_example_demonstrations,
|
|
1171
|
-
include_sys_msg=include_sys_msg,
|
|
1300
|
+
# Ensure checkpoint_interval is at least as large as concurrency/batch_size
|
|
1301
|
+
# so checkpoints align with batch boundaries
|
|
1302
|
+
if (
|
|
1303
|
+
self.config.checkpoint_interval is not None
|
|
1304
|
+
and self.config.checkpoint_interval < batch_size
|
|
1172
1305
|
):
|
|
1173
|
-
|
|
1306
|
+
logger.warning(
|
|
1307
|
+
"checkpoint_interval (%d) is less than concurrency/batch_size (%d), "
|
|
1308
|
+
"adjusting to %d to align with batch boundaries",
|
|
1309
|
+
self.config.checkpoint_interval,
|
|
1310
|
+
batch_size,
|
|
1311
|
+
batch_size,
|
|
1312
|
+
)
|
|
1313
|
+
self.config.checkpoint_interval = batch_size
|
|
1314
|
+
|
|
1315
|
+
final_result: HFDataset | dict | None = None
|
|
1316
|
+
|
|
1317
|
+
# Use cycle-based generation when a topic model is provided
|
|
1318
|
+
if topic_model is not None:
|
|
1319
|
+
unique_topics, cycles_needed = self._prepare_unique_topics(total_samples, topic_model)
|
|
1320
|
+
|
|
1321
|
+
# batch_size becomes concurrency in the new model
|
|
1322
|
+
concurrency = batch_size
|
|
1323
|
+
|
|
1324
|
+
async for event in self._run_cycle_based_generation_async(
|
|
1325
|
+
unique_topics=unique_topics,
|
|
1326
|
+
cycles_needed=cycles_needed,
|
|
1327
|
+
total_samples=total_samples,
|
|
1328
|
+
concurrency=concurrency,
|
|
1329
|
+
data_creation_prompt=data_creation_prompt,
|
|
1330
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
1331
|
+
include_sys_msg=include_sys_msg,
|
|
1332
|
+
topic_model=topic_model,
|
|
1333
|
+
):
|
|
1334
|
+
final_result = event
|
|
1335
|
+
else:
|
|
1336
|
+
# Fall back to step-based generation when no topic model
|
|
1337
|
+
topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
|
|
1338
|
+
|
|
1339
|
+
async for event in self._run_generation_loop_async(
|
|
1340
|
+
num_steps=num_steps,
|
|
1341
|
+
batch_size=batch_size,
|
|
1342
|
+
total_samples=total_samples,
|
|
1343
|
+
topic_paths=topic_paths or [],
|
|
1344
|
+
data_creation_prompt=data_creation_prompt,
|
|
1345
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
1346
|
+
include_sys_msg=include_sys_msg,
|
|
1347
|
+
):
|
|
1348
|
+
final_result = event
|
|
1174
1349
|
|
|
1175
1350
|
if isinstance(final_result, HFDataset):
|
|
1176
1351
|
trace(
|
|
@@ -1211,29 +1386,67 @@ class DataSetGenerator:
|
|
|
1211
1386
|
|
|
1212
1387
|
include_sys_msg = sys_msg if sys_msg is not None else self.config.sys_msg
|
|
1213
1388
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1389
|
+
# Calculate total samples requested
|
|
1216
1390
|
total_samples = num_steps * batch_size
|
|
1217
1391
|
data_creation_prompt = self._get_cot_prompt_template()
|
|
1218
1392
|
|
|
1393
|
+
# Ensure checkpoint_interval is at least as large as concurrency/batch_size
|
|
1394
|
+
# so checkpoints align with batch boundaries
|
|
1395
|
+
if (
|
|
1396
|
+
self.config.checkpoint_interval is not None
|
|
1397
|
+
and self.config.checkpoint_interval < batch_size
|
|
1398
|
+
):
|
|
1399
|
+
logger.warning(
|
|
1400
|
+
"checkpoint_interval (%d) is less than concurrency/batch_size (%d), "
|
|
1401
|
+
"adjusting to %d to align with batch boundaries",
|
|
1402
|
+
self.config.checkpoint_interval,
|
|
1403
|
+
batch_size,
|
|
1404
|
+
batch_size,
|
|
1405
|
+
)
|
|
1406
|
+
self.config.checkpoint_interval = batch_size
|
|
1407
|
+
|
|
1219
1408
|
root_topic_prompt = None
|
|
1220
1409
|
topic_model_type = None
|
|
1221
1410
|
if topic_model is not None:
|
|
1222
1411
|
root_topic_prompt = getattr(topic_model, "topic_prompt", None)
|
|
1223
1412
|
topic_model_type = type(topic_model).__name__.lower()
|
|
1224
1413
|
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1414
|
+
# Use cycle-based generation when a topic model is provided
|
|
1415
|
+
if topic_model is not None:
|
|
1416
|
+
unique_topics, cycles_needed = self._prepare_unique_topics(total_samples, topic_model)
|
|
1417
|
+
|
|
1418
|
+
# batch_size becomes concurrency in the new model
|
|
1419
|
+
concurrency = batch_size
|
|
1420
|
+
|
|
1421
|
+
async for event in self._run_cycle_based_generation_async(
|
|
1422
|
+
unique_topics=unique_topics,
|
|
1423
|
+
cycles_needed=cycles_needed,
|
|
1424
|
+
total_samples=total_samples,
|
|
1425
|
+
concurrency=concurrency,
|
|
1426
|
+
data_creation_prompt=data_creation_prompt,
|
|
1427
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
1428
|
+
include_sys_msg=include_sys_msg,
|
|
1429
|
+
root_topic_prompt=root_topic_prompt,
|
|
1430
|
+
topic_model_type=topic_model_type,
|
|
1431
|
+
topic_model=topic_model,
|
|
1432
|
+
):
|
|
1433
|
+
yield event
|
|
1434
|
+
else:
|
|
1435
|
+
# Fall back to step-based generation when no topic model
|
|
1436
|
+
topic_paths, num_steps = self._prepare_topic_paths(num_steps, batch_size, topic_model)
|
|
1437
|
+
|
|
1438
|
+
async for event in self._run_generation_loop_async(
|
|
1439
|
+
num_steps=num_steps,
|
|
1440
|
+
batch_size=batch_size,
|
|
1441
|
+
total_samples=total_samples,
|
|
1442
|
+
topic_paths=topic_paths or [],
|
|
1443
|
+
data_creation_prompt=data_creation_prompt,
|
|
1444
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
1445
|
+
include_sys_msg=include_sys_msg,
|
|
1446
|
+
root_topic_prompt=root_topic_prompt,
|
|
1447
|
+
topic_model_type=topic_model_type,
|
|
1448
|
+
):
|
|
1449
|
+
yield event
|
|
1237
1450
|
|
|
1238
1451
|
async def _run_generation_loop_async( # noqa: PLR0912, PLR0915
|
|
1239
1452
|
self,
|
|
@@ -1248,6 +1461,17 @@ class DataSetGenerator:
|
|
|
1248
1461
|
topic_model_type: str | None = None,
|
|
1249
1462
|
) -> AsyncGenerator[dict | HFDataset, None]:
|
|
1250
1463
|
"""Run the main generation loop yielding progress events."""
|
|
1464
|
+
# Verify topic paths cover all expected samples (only when a topic model is used)
|
|
1465
|
+
expected_prompts = num_steps * batch_size
|
|
1466
|
+
actual_paths = len(topic_paths) if topic_paths else 0
|
|
1467
|
+
if topic_paths and actual_paths < expected_prompts:
|
|
1468
|
+
logger.warning(
|
|
1469
|
+
"Topic paths (%d) < expected samples (%d). Steps beyond path %d will produce 0 samples.",
|
|
1470
|
+
actual_paths,
|
|
1471
|
+
expected_prompts,
|
|
1472
|
+
actual_paths // batch_size,
|
|
1473
|
+
)
|
|
1474
|
+
|
|
1251
1475
|
# Initialize checkpoint paths if checkpointing is enabled
|
|
1252
1476
|
if self.config.checkpoint_interval is not None:
|
|
1253
1477
|
self._initialize_checkpoint_paths()
|
|
@@ -1257,6 +1481,7 @@ class DataSetGenerator:
|
|
|
1257
1481
|
samples_in_current_batch: list[dict] = []
|
|
1258
1482
|
failures_in_current_batch: list[dict] = []
|
|
1259
1483
|
topic_paths_in_current_batch: list[TopicPath | None] = []
|
|
1484
|
+
topics_exhausted_count = 0 # Samples not generated due to topic exhaustion
|
|
1260
1485
|
|
|
1261
1486
|
try:
|
|
1262
1487
|
yield {
|
|
@@ -1267,11 +1492,12 @@ class DataSetGenerator:
|
|
|
1267
1492
|
"total_samples": total_samples,
|
|
1268
1493
|
"root_topic_prompt": root_topic_prompt,
|
|
1269
1494
|
"topic_model_type": topic_model_type,
|
|
1270
|
-
"resumed_from_checkpoint": len(self.
|
|
1271
|
-
"previously_processed": len(self.
|
|
1495
|
+
"resumed_from_checkpoint": len(self._completed) > 0,
|
|
1496
|
+
"previously_processed": len(self._completed),
|
|
1272
1497
|
"resumed_samples": self._flushed_samples_count,
|
|
1273
1498
|
"resumed_failures": self._flushed_failures_count,
|
|
1274
1499
|
"checkpoint_enabled": self.config.checkpoint_interval is not None,
|
|
1500
|
+
"checkpoint_interval": self.config.checkpoint_interval,
|
|
1275
1501
|
}
|
|
1276
1502
|
|
|
1277
1503
|
for step in range(num_steps):
|
|
@@ -1290,12 +1516,36 @@ class DataSetGenerator:
|
|
|
1290
1516
|
num_example_demonstrations,
|
|
1291
1517
|
)
|
|
1292
1518
|
|
|
1293
|
-
#
|
|
1294
|
-
if
|
|
1519
|
+
# Handle topic exhaustion - when we've run out of topics
|
|
1520
|
+
if not prompts and topic_paths:
|
|
1521
|
+
# Topics exhausted - remaining steps will produce nothing
|
|
1522
|
+
exhausted_in_step = batch_size # All samples in this batch had no topics
|
|
1523
|
+
topics_exhausted_count += exhausted_in_step
|
|
1524
|
+
logger.warning(
|
|
1525
|
+
"Step %d: Topics exhausted at index %d (only %d topics available for %d expected)",
|
|
1526
|
+
step + 1,
|
|
1527
|
+
start_idx,
|
|
1528
|
+
len(topic_paths),
|
|
1529
|
+
total_samples,
|
|
1530
|
+
)
|
|
1531
|
+
yield {
|
|
1532
|
+
"event": "step_complete",
|
|
1533
|
+
"step": step + 1,
|
|
1534
|
+
"samples_generated": 0,
|
|
1535
|
+
"success": True,
|
|
1536
|
+
"failed_in_step": 0,
|
|
1537
|
+
"failure_reasons": [],
|
|
1538
|
+
"topics_exhausted": exhausted_in_step,
|
|
1539
|
+
}
|
|
1540
|
+
continue
|
|
1541
|
+
|
|
1542
|
+
# Filter out already-completed topics when resuming
|
|
1543
|
+
if self._completed:
|
|
1295
1544
|
filtered_prompts = []
|
|
1296
1545
|
filtered_topic_paths: list[TopicPath | None] = []
|
|
1297
1546
|
for prompt, tp in zip(prompts, used_topic_paths, strict=False):
|
|
1298
|
-
if
|
|
1547
|
+
# Check if this topic_id has been completed in any cycle
|
|
1548
|
+
if tp is None or not self._is_uuid_completed_any_cycle(tp.topic_id):
|
|
1299
1549
|
filtered_prompts.append(prompt)
|
|
1300
1550
|
filtered_topic_paths.append(tp)
|
|
1301
1551
|
|
|
@@ -1335,10 +1585,14 @@ class DataSetGenerator:
|
|
|
1335
1585
|
self.config.checkpoint_interval is not None
|
|
1336
1586
|
and samples_since_checkpoint >= self.config.checkpoint_interval
|
|
1337
1587
|
):
|
|
1588
|
+
# Convert topic_paths to (uuid, cycle) tuples (cycle=0 for step-based)
|
|
1589
|
+
completed_items = [
|
|
1590
|
+
(tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
|
|
1591
|
+
]
|
|
1338
1592
|
self._save_checkpoint(
|
|
1339
1593
|
samples_in_current_batch,
|
|
1340
1594
|
failures_in_current_batch,
|
|
1341
|
-
|
|
1595
|
+
completed_items,
|
|
1342
1596
|
)
|
|
1343
1597
|
samples_in_current_batch = []
|
|
1344
1598
|
failures_in_current_batch = []
|
|
@@ -1360,11 +1614,11 @@ class DataSetGenerator:
|
|
|
1360
1614
|
}
|
|
1361
1615
|
return # Exit generator cleanly
|
|
1362
1616
|
|
|
1363
|
-
|
|
1617
|
+
# Use new_failures captured BEFORE checkpoint clear, not recalculated after
|
|
1618
|
+
failed_in_batch = len(new_failures)
|
|
1364
1619
|
failure_reasons: list[str] = []
|
|
1365
|
-
if failed_in_batch > 0 and
|
|
1366
|
-
|
|
1367
|
-
for f in recent_failures[:3]:
|
|
1620
|
+
if failed_in_batch > 0 and new_failures:
|
|
1621
|
+
for f in new_failures[:3]:
|
|
1368
1622
|
if isinstance(f, dict):
|
|
1369
1623
|
failure_reasons.append(f.get("error", str(f)))
|
|
1370
1624
|
else:
|
|
@@ -1390,10 +1644,13 @@ class DataSetGenerator:
|
|
|
1390
1644
|
if self.config.checkpoint_interval is not None and (
|
|
1391
1645
|
samples_in_current_batch or failures_in_current_batch
|
|
1392
1646
|
):
|
|
1647
|
+
completed_items = [
|
|
1648
|
+
(tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
|
|
1649
|
+
]
|
|
1393
1650
|
self._save_checkpoint(
|
|
1394
1651
|
samples_in_current_batch,
|
|
1395
1652
|
failures_in_current_batch,
|
|
1396
|
-
|
|
1653
|
+
completed_items,
|
|
1397
1654
|
)
|
|
1398
1655
|
yield {
|
|
1399
1656
|
"event": "checkpoint_saved",
|
|
@@ -1403,13 +1660,30 @@ class DataSetGenerator:
|
|
|
1403
1660
|
}
|
|
1404
1661
|
|
|
1405
1662
|
# Calculate total counts including flushed data
|
|
1406
|
-
|
|
1407
|
-
|
|
1663
|
+
actual_samples = self._flushed_samples_count + len(self._samples)
|
|
1664
|
+
actual_failures = self._flushed_failures_count + len(self.failed_samples)
|
|
1665
|
+
total_accounted = actual_samples + actual_failures + topics_exhausted_count
|
|
1666
|
+
true_unaccounted = total_samples - total_accounted
|
|
1667
|
+
|
|
1668
|
+
# Log accounting summary for debugging
|
|
1669
|
+
logger.info(
|
|
1670
|
+
"Generation complete: expected=%d, generated=%d, failed=%d, topics_exhausted=%d, "
|
|
1671
|
+
"accounted=%d, unaccounted=%d",
|
|
1672
|
+
total_samples, # This is the parameter passed in (expected count)
|
|
1673
|
+
actual_samples,
|
|
1674
|
+
actual_failures,
|
|
1675
|
+
topics_exhausted_count,
|
|
1676
|
+
total_accounted,
|
|
1677
|
+
true_unaccounted,
|
|
1678
|
+
)
|
|
1408
1679
|
|
|
1409
1680
|
yield {
|
|
1410
1681
|
"event": "generation_complete",
|
|
1411
|
-
"total_samples":
|
|
1412
|
-
"failed_samples":
|
|
1682
|
+
"total_samples": actual_samples,
|
|
1683
|
+
"failed_samples": actual_failures,
|
|
1684
|
+
"expected_samples": total_samples,
|
|
1685
|
+
"topics_exhausted": topics_exhausted_count,
|
|
1686
|
+
"unaccounted": true_unaccounted,
|
|
1413
1687
|
}
|
|
1414
1688
|
|
|
1415
1689
|
except KeyboardInterrupt:
|
|
@@ -1417,10 +1691,13 @@ class DataSetGenerator:
|
|
|
1417
1691
|
if self.config.checkpoint_interval is not None and (
|
|
1418
1692
|
samples_in_current_batch or failures_in_current_batch
|
|
1419
1693
|
):
|
|
1694
|
+
completed_items = [
|
|
1695
|
+
(tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
|
|
1696
|
+
]
|
|
1420
1697
|
self._save_checkpoint(
|
|
1421
1698
|
samples_in_current_batch,
|
|
1422
1699
|
failures_in_current_batch,
|
|
1423
|
-
|
|
1700
|
+
completed_items,
|
|
1424
1701
|
)
|
|
1425
1702
|
yield {
|
|
1426
1703
|
"event": "generation_interrupted",
|
|
@@ -1434,10 +1711,327 @@ class DataSetGenerator:
|
|
|
1434
1711
|
if self.config.checkpoint_interval is not None and (
|
|
1435
1712
|
samples_in_current_batch or failures_in_current_batch
|
|
1436
1713
|
):
|
|
1714
|
+
completed_items = [
|
|
1715
|
+
(tp.topic_id, 0) for tp in topic_paths_in_current_batch if tp is not None
|
|
1716
|
+
]
|
|
1437
1717
|
self._save_checkpoint(
|
|
1438
1718
|
samples_in_current_batch,
|
|
1439
1719
|
failures_in_current_batch,
|
|
1440
|
-
|
|
1720
|
+
completed_items,
|
|
1721
|
+
)
|
|
1722
|
+
yield {"event": "generation_error", "error": str(e)}
|
|
1723
|
+
self.print_failure_summary()
|
|
1724
|
+
self._save_samples_to_file(ERROR_DATASET_FILENAME)
|
|
1725
|
+
raise DataSetGeneratorError("failed") from e
|
|
1726
|
+
|
|
1727
|
+
# Build final dataset: if samples were flushed to disk, load them from checkpoint
|
|
1728
|
+
if self._flushed_samples_count > 0:
|
|
1729
|
+
all_samples = self._load_all_samples_from_checkpoint()
|
|
1730
|
+
yield HFDataset.from_list(all_samples) if all_samples else HFDataset.from_list([])
|
|
1731
|
+
else:
|
|
1732
|
+
yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
|
|
1733
|
+
|
|
1734
|
+
async def _run_cycle_based_generation_async( # noqa: PLR0912, PLR0915
|
|
1735
|
+
self,
|
|
1736
|
+
unique_topics: list[Topic],
|
|
1737
|
+
cycles_needed: int,
|
|
1738
|
+
total_samples: int,
|
|
1739
|
+
concurrency: int,
|
|
1740
|
+
data_creation_prompt: str,
|
|
1741
|
+
num_example_demonstrations: int,
|
|
1742
|
+
include_sys_msg: bool,
|
|
1743
|
+
root_topic_prompt: str | None = None,
|
|
1744
|
+
topic_model_type: str | None = None,
|
|
1745
|
+
topic_model: "TopicModel | None" = None,
|
|
1746
|
+
) -> AsyncGenerator[dict | HFDataset, None]:
|
|
1747
|
+
"""Run cycle-based generation loop yielding progress events.
|
|
1748
|
+
|
|
1749
|
+
This is the new generation model that iterates over unique topics (by UUID)
|
|
1750
|
+
for multiple cycles, rather than the old step-based batching approach.
|
|
1751
|
+
|
|
1752
|
+
Args:
|
|
1753
|
+
unique_topics: List of Topic namedtuples with (uuid, topic).
|
|
1754
|
+
cycles_needed: Number of cycles to iterate through topics.
|
|
1755
|
+
total_samples: Total number of samples to generate.
|
|
1756
|
+
concurrency: Maximum parallel LLM calls (semaphore limit).
|
|
1757
|
+
data_creation_prompt: The prompt template for data creation.
|
|
1758
|
+
num_example_demonstrations: Number of example demonstrations to include.
|
|
1759
|
+
include_sys_msg: Whether to include system message in output.
|
|
1760
|
+
root_topic_prompt: Original topic prompt for display.
|
|
1761
|
+
topic_model_type: Type of topic model (tree, graph) for display.
|
|
1762
|
+
topic_model: Topic model for path lookup (recovers full path context).
|
|
1763
|
+
|
|
1764
|
+
Yields:
|
|
1765
|
+
Progress event dicts and final HFDataset.
|
|
1766
|
+
"""
|
|
1767
|
+
unique_topic_count = len(unique_topics)
|
|
1768
|
+
final_cycle_size = total_samples - (cycles_needed - 1) * unique_topic_count
|
|
1769
|
+
|
|
1770
|
+
# Initialize checkpoint paths if checkpointing is enabled
|
|
1771
|
+
if self.config.checkpoint_interval is not None:
|
|
1772
|
+
self._initialize_checkpoint_paths()
|
|
1773
|
+
|
|
1774
|
+
# Track samples for checkpointing
|
|
1775
|
+
samples_since_checkpoint = 0
|
|
1776
|
+
pending_samples: list[dict] = []
|
|
1777
|
+
pending_failures: list[dict] = []
|
|
1778
|
+
pending_completed: list[tuple[str, int]] = []
|
|
1779
|
+
samples_generated = 0
|
|
1780
|
+
|
|
1781
|
+
# Create semaphore for concurrency control
|
|
1782
|
+
semaphore = asyncio.Semaphore(concurrency)
|
|
1783
|
+
|
|
1784
|
+
try:
|
|
1785
|
+
yield {
|
|
1786
|
+
"event": "generation_start",
|
|
1787
|
+
"model_name": self.model_name,
|
|
1788
|
+
"unique_topics": unique_topic_count,
|
|
1789
|
+
"cycles_needed": cycles_needed,
|
|
1790
|
+
"final_cycle_size": final_cycle_size,
|
|
1791
|
+
"concurrency": concurrency,
|
|
1792
|
+
"total_samples": total_samples,
|
|
1793
|
+
"root_topic_prompt": root_topic_prompt,
|
|
1794
|
+
"topic_model_type": topic_model_type,
|
|
1795
|
+
"resumed_from_checkpoint": len(self._completed) > 0,
|
|
1796
|
+
"previously_completed": len(self._completed),
|
|
1797
|
+
"resumed_samples": self._flushed_samples_count,
|
|
1798
|
+
"resumed_failures": self._flushed_failures_count,
|
|
1799
|
+
"checkpoint_enabled": self.config.checkpoint_interval is not None,
|
|
1800
|
+
"checkpoint_interval": self.config.checkpoint_interval,
|
|
1801
|
+
}
|
|
1802
|
+
|
|
1803
|
+
for cycle in range(cycles_needed):
|
|
1804
|
+
# Determine how many topics to process in this cycle
|
|
1805
|
+
if cycle == cycles_needed - 1:
|
|
1806
|
+
# Final cycle may be partial
|
|
1807
|
+
topics_in_cycle = min(final_cycle_size, unique_topic_count)
|
|
1808
|
+
else:
|
|
1809
|
+
topics_in_cycle = unique_topic_count
|
|
1810
|
+
|
|
1811
|
+
yield {
|
|
1812
|
+
"event": "cycle_start",
|
|
1813
|
+
"cycle": cycle + 1,
|
|
1814
|
+
"total_cycles": cycles_needed,
|
|
1815
|
+
"topics_in_cycle": topics_in_cycle,
|
|
1816
|
+
}
|
|
1817
|
+
|
|
1818
|
+
# Collect topics to process in this cycle
|
|
1819
|
+
topics_to_process: list[Topic] = []
|
|
1820
|
+
for topic_idx, topic in enumerate(unique_topics):
|
|
1821
|
+
if cycle == cycles_needed - 1 and topic_idx >= final_cycle_size:
|
|
1822
|
+
break # Partial final cycle - stop early
|
|
1823
|
+
|
|
1824
|
+
if not self._is_completed(topic.uuid, cycle):
|
|
1825
|
+
topics_to_process.append(topic)
|
|
1826
|
+
|
|
1827
|
+
if not topics_to_process:
|
|
1828
|
+
# All topics in this cycle already completed (resume scenario)
|
|
1829
|
+
yield {
|
|
1830
|
+
"event": "cycle_complete",
|
|
1831
|
+
"cycle": cycle + 1,
|
|
1832
|
+
"samples_in_cycle": 0,
|
|
1833
|
+
"skipped": topics_in_cycle,
|
|
1834
|
+
}
|
|
1835
|
+
continue
|
|
1836
|
+
|
|
1837
|
+
# Process topics with concurrency control
|
|
1838
|
+
cycle_samples = 0
|
|
1839
|
+
cycle_failures = 0
|
|
1840
|
+
|
|
1841
|
+
async def process_topic(
|
|
1842
|
+
topic: Topic, cycle_num: int, sample_idx: int
|
|
1843
|
+
) -> tuple[tuple[str, int], bool, int]:
|
|
1844
|
+
"""Process a single topic with semaphore-controlled concurrency.
|
|
1845
|
+
|
|
1846
|
+
Note: Samples/failures are NOT extracted here because concurrent
|
|
1847
|
+
tasks share self._samples. Extracting per-task would cause
|
|
1848
|
+
duplicates since each task's slice overlaps with others.
|
|
1849
|
+
Instead, samples are captured in bulk after asyncio.gather.
|
|
1850
|
+
|
|
1851
|
+
Returns:
|
|
1852
|
+
Tuple of (completed_item, success, count)
|
|
1853
|
+
"""
|
|
1854
|
+
async with semaphore:
|
|
1855
|
+
# Recover full path context (root -> ... -> leaf) for prompt quality
|
|
1856
|
+
full_path = topic_model.get_path_by_id(topic.uuid) if topic_model else None
|
|
1857
|
+
subtopics = full_path if full_path else [topic.topic]
|
|
1858
|
+
|
|
1859
|
+
# Build prompt for this topic
|
|
1860
|
+
sample_prompt = self.build_prompt(
|
|
1861
|
+
data_creation_prompt=data_creation_prompt,
|
|
1862
|
+
num_example_demonstrations=num_example_demonstrations,
|
|
1863
|
+
subtopics_list=subtopics,
|
|
1864
|
+
)
|
|
1865
|
+
|
|
1866
|
+
# Use existing batch processing for a single sample
|
|
1867
|
+
topic_path = TopicPath(path=subtopics, topic_id=topic.uuid)
|
|
1868
|
+
success, count = await self._process_batch_with_retries_async(
|
|
1869
|
+
prompts=[sample_prompt],
|
|
1870
|
+
include_sys_msg=include_sys_msg,
|
|
1871
|
+
start_sample_idx=sample_idx,
|
|
1872
|
+
topic_paths_for_batch=[topic_path],
|
|
1873
|
+
)
|
|
1874
|
+
|
|
1875
|
+
completed_item = (topic.uuid, cycle_num)
|
|
1876
|
+
return completed_item, success, count
|
|
1877
|
+
|
|
1878
|
+
# Process topics in batches for checkpoint saving
|
|
1879
|
+
for batch_start in range(0, len(topics_to_process), concurrency):
|
|
1880
|
+
batch_end = min(batch_start + concurrency, len(topics_to_process))
|
|
1881
|
+
batch_topics = topics_to_process[batch_start:batch_end]
|
|
1882
|
+
|
|
1883
|
+
# Snapshot list lengths before tasks start so we can
|
|
1884
|
+
# capture new items in one safe slice after all complete.
|
|
1885
|
+
samples_before_gather = len(self._samples)
|
|
1886
|
+
failures_before_gather = len(self.failed_samples)
|
|
1887
|
+
|
|
1888
|
+
# Create concurrent tasks
|
|
1889
|
+
# Pass current samples_generated as starting index for each task
|
|
1890
|
+
tasks = [
|
|
1891
|
+
asyncio.create_task(process_topic(topic, cycle, samples_generated + i))
|
|
1892
|
+
for i, topic in enumerate(batch_topics)
|
|
1893
|
+
]
|
|
1894
|
+
|
|
1895
|
+
# Process results as each task finishes (not waiting for
|
|
1896
|
+
# the whole batch) so the progress bar advances per-sample.
|
|
1897
|
+
batch_samples = 0
|
|
1898
|
+
batch_failures = 0
|
|
1899
|
+
for future in asyncio.as_completed(tasks):
|
|
1900
|
+
completed_item, success, count = await future
|
|
1901
|
+
pending_completed.append(completed_item)
|
|
1902
|
+
|
|
1903
|
+
if success:
|
|
1904
|
+
cycle_samples += count
|
|
1905
|
+
samples_generated += count
|
|
1906
|
+
batch_samples += count
|
|
1907
|
+
else:
|
|
1908
|
+
cycle_failures += 1
|
|
1909
|
+
batch_failures += 1
|
|
1910
|
+
|
|
1911
|
+
samples_since_checkpoint += 1
|
|
1912
|
+
|
|
1913
|
+
# Emit per-sample progress so progress bars advance immediately
|
|
1914
|
+
yield {
|
|
1915
|
+
"event": "batch_complete",
|
|
1916
|
+
"samples_generated": count if success else 0,
|
|
1917
|
+
"samples_failed": 1 if not success else 0,
|
|
1918
|
+
}
|
|
1919
|
+
|
|
1920
|
+
# After all tasks complete, capture samples/failures in
|
|
1921
|
+
# one safe slice (no concurrent tasks are running now).
|
|
1922
|
+
batch_new_samples = list(self._samples[samples_before_gather:])
|
|
1923
|
+
batch_new_failures = list(self.failed_samples[failures_before_gather:])
|
|
1924
|
+
# Tag each failure with the cycle number so retry logic can
|
|
1925
|
+
# target only the specific (uuid, cycle) that failed
|
|
1926
|
+
for failure in batch_new_failures:
|
|
1927
|
+
failure["cycle"] = cycle
|
|
1928
|
+
pending_samples.extend(batch_new_samples)
|
|
1929
|
+
pending_failures.extend(batch_new_failures)
|
|
1930
|
+
|
|
1931
|
+
# Save checkpoint if we've reached the interval
|
|
1932
|
+
if (
|
|
1933
|
+
self.config.checkpoint_interval is not None
|
|
1934
|
+
and samples_since_checkpoint >= self.config.checkpoint_interval
|
|
1935
|
+
):
|
|
1936
|
+
self._save_checkpoint(
|
|
1937
|
+
pending_samples,
|
|
1938
|
+
pending_failures,
|
|
1939
|
+
pending_completed,
|
|
1940
|
+
)
|
|
1941
|
+
pending_samples = []
|
|
1942
|
+
pending_failures = []
|
|
1943
|
+
pending_completed = []
|
|
1944
|
+
samples_since_checkpoint = 0
|
|
1945
|
+
yield {
|
|
1946
|
+
"event": "checkpoint_saved",
|
|
1947
|
+
"total_samples": self._flushed_samples_count,
|
|
1948
|
+
"total_failures": self._flushed_failures_count,
|
|
1949
|
+
}
|
|
1950
|
+
|
|
1951
|
+
# Check for graceful stop request
|
|
1952
|
+
if self.stop_requested:
|
|
1953
|
+
yield {
|
|
1954
|
+
"event": "generation_stopped",
|
|
1955
|
+
"message": "Stopped at checkpoint as requested",
|
|
1956
|
+
"total_samples": self._flushed_samples_count,
|
|
1957
|
+
"total_failures": self._flushed_failures_count,
|
|
1958
|
+
}
|
|
1959
|
+
return
|
|
1960
|
+
|
|
1961
|
+
yield {
|
|
1962
|
+
"event": "cycle_complete",
|
|
1963
|
+
"cycle": cycle + 1,
|
|
1964
|
+
"samples_in_cycle": cycle_samples,
|
|
1965
|
+
"failures_in_cycle": cycle_failures,
|
|
1966
|
+
}
|
|
1967
|
+
|
|
1968
|
+
# Save final checkpoint with any remaining samples
|
|
1969
|
+
if self.config.checkpoint_interval is not None and (
|
|
1970
|
+
pending_samples or pending_failures
|
|
1971
|
+
):
|
|
1972
|
+
self._save_checkpoint(
|
|
1973
|
+
pending_samples,
|
|
1974
|
+
pending_failures,
|
|
1975
|
+
pending_completed,
|
|
1976
|
+
)
|
|
1977
|
+
yield {
|
|
1978
|
+
"event": "checkpoint_saved",
|
|
1979
|
+
"total_samples": self._flushed_samples_count,
|
|
1980
|
+
"total_failures": self._flushed_failures_count,
|
|
1981
|
+
"final": True,
|
|
1982
|
+
}
|
|
1983
|
+
|
|
1984
|
+
# Calculate total counts including flushed data
|
|
1985
|
+
actual_samples = self._flushed_samples_count + len(self._samples)
|
|
1986
|
+
actual_failures = self._flushed_failures_count + len(self.failed_samples)
|
|
1987
|
+
unaccounted = total_samples - actual_samples - actual_failures
|
|
1988
|
+
|
|
1989
|
+
logger.info(
|
|
1990
|
+
"Generation complete: expected=%d, generated=%d, failed=%d, "
|
|
1991
|
+
"accounted=%d, unaccounted=%d",
|
|
1992
|
+
total_samples,
|
|
1993
|
+
actual_samples,
|
|
1994
|
+
actual_failures,
|
|
1995
|
+
actual_samples + actual_failures,
|
|
1996
|
+
unaccounted,
|
|
1997
|
+
)
|
|
1998
|
+
|
|
1999
|
+
yield {
|
|
2000
|
+
"event": "generation_complete",
|
|
2001
|
+
"total_samples": actual_samples,
|
|
2002
|
+
"failed_samples": actual_failures,
|
|
2003
|
+
"expected_samples": total_samples,
|
|
2004
|
+
"unaccounted": unaccounted,
|
|
2005
|
+
"cycles_completed": cycles_needed,
|
|
2006
|
+
"unique_topics": unique_topic_count,
|
|
2007
|
+
}
|
|
2008
|
+
|
|
2009
|
+
except KeyboardInterrupt:
|
|
2010
|
+
# Save checkpoint on interrupt
|
|
2011
|
+
if self.config.checkpoint_interval is not None and (
|
|
2012
|
+
pending_samples or pending_failures
|
|
2013
|
+
):
|
|
2014
|
+
self._save_checkpoint(
|
|
2015
|
+
pending_samples,
|
|
2016
|
+
pending_failures,
|
|
2017
|
+
pending_completed,
|
|
2018
|
+
)
|
|
2019
|
+
yield {
|
|
2020
|
+
"event": "generation_interrupted",
|
|
2021
|
+
"message": "Generation interrupted by user.",
|
|
2022
|
+
}
|
|
2023
|
+
self.print_failure_summary()
|
|
2024
|
+
self._save_samples_to_file(INTERRUPTED_DATASET_FILENAME)
|
|
2025
|
+
|
|
2026
|
+
except Exception as e: # noqa: BLE001
|
|
2027
|
+
# Save checkpoint on error
|
|
2028
|
+
if self.config.checkpoint_interval is not None and (
|
|
2029
|
+
pending_samples or pending_failures
|
|
2030
|
+
):
|
|
2031
|
+
self._save_checkpoint(
|
|
2032
|
+
pending_samples,
|
|
2033
|
+
pending_failures,
|
|
2034
|
+
pending_completed,
|
|
1441
2035
|
)
|
|
1442
2036
|
yield {"event": "generation_error", "error": str(e)}
|
|
1443
2037
|
self.print_failure_summary()
|
|
@@ -1465,6 +2059,17 @@ class DataSetGenerator:
|
|
|
1465
2059
|
prompts, include_sys_msg, start_sample_idx, topic_paths_for_batch
|
|
1466
2060
|
)
|
|
1467
2061
|
|
|
2062
|
+
# Verify all prompts are accounted for (no silent drops)
|
|
2063
|
+
accounted = len(samples) + len(failed_responses)
|
|
2064
|
+
if accounted != len(prompts):
|
|
2065
|
+
logger.warning(
|
|
2066
|
+
"Sample accounting mismatch: %d prompts, %d samples, %d failures (missing %d)",
|
|
2067
|
+
len(prompts),
|
|
2068
|
+
len(samples),
|
|
2069
|
+
len(failed_responses),
|
|
2070
|
+
len(prompts) - accounted,
|
|
2071
|
+
)
|
|
2072
|
+
|
|
1468
2073
|
# Update failed samples
|
|
1469
2074
|
self.failed_samples.extend(failed_responses)
|
|
1470
2075
|
|