DeepFabric 4.9.0__py3-none-any.whl → 4.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/builders.py +7 -21
- deepfabric/builders_agent.py +0 -542
- deepfabric/cli.py +505 -74
- deepfabric/config.py +57 -73
- deepfabric/config_manager.py +8 -6
- deepfabric/constants.py +6 -0
- deepfabric/dataset_manager.py +107 -11
- deepfabric/evaluation/parser.py +7 -7
- deepfabric/generator.py +656 -103
- deepfabric/graph.py +46 -1
- deepfabric/prompts.py +0 -39
- deepfabric/schemas.py +4 -3
- deepfabric/topic_model.py +32 -0
- deepfabric/tree.py +23 -1
- deepfabric/tui.py +66 -21
- deepfabric/utils.py +184 -0
- deepfabric/validation.py +47 -77
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/METADATA +5 -6
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/RECORD +22 -22
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/WHEEL +0 -0
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/entry_points.txt +0 -0
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.0.dist-info}/licenses/LICENSE +0 -0
deepfabric/generator.py
CHANGED
|
@@ -2,9 +2,12 @@ import asyncio
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
import math
|
|
5
|
+
import os
|
|
5
6
|
import random
|
|
6
7
|
|
|
7
8
|
from collections.abc import AsyncGenerator
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from pathlib import Path
|
|
8
11
|
from typing import TYPE_CHECKING, Any, Literal
|
|
9
12
|
|
|
10
13
|
from datasets import Dataset as HFDataset
|
|
@@ -14,6 +17,10 @@ from .builders import ConversationBuilderFactory
|
|
|
14
17
|
from .config import _normalize_reasoning_style
|
|
15
18
|
from .constants import (
|
|
16
19
|
API_ERROR_INDICATORS,
|
|
20
|
+
CHECKPOINT_FAILURES_SUFFIX,
|
|
21
|
+
CHECKPOINT_METADATA_SUFFIX,
|
|
22
|
+
CHECKPOINT_SAMPLES_SUFFIX,
|
|
23
|
+
CHECKPOINT_VERSION,
|
|
17
24
|
DEFAULT_MAX_RETRIES,
|
|
18
25
|
DEFAULT_REQUEST_TIMEOUT,
|
|
19
26
|
DEFAULT_SAMPLE_RETRIES,
|
|
@@ -30,7 +37,6 @@ from .llm import LLMClient
|
|
|
30
37
|
from .metrics import trace
|
|
31
38
|
from .progress import ProgressReporter
|
|
32
39
|
from .prompts import (
|
|
33
|
-
AGENT_COT_MULTI_TURN_PROMPT,
|
|
34
40
|
AGENT_COT_TOOLS_PROMPT,
|
|
35
41
|
CONVERSATION_GENERATION_PROMPT,
|
|
36
42
|
FREETEXT_COT_PROMPT,
|
|
@@ -40,8 +46,8 @@ from .prompts import (
|
|
|
40
46
|
from .schemas import Conversation, ToolRegistry, get_conversation_schema
|
|
41
47
|
from .tools import BUILTIN_TOOL_REGISTRY
|
|
42
48
|
from .tools.loader import load_tools_from_dict, load_tools_from_endpoint
|
|
43
|
-
from .topic_model import TopicModel
|
|
44
|
-
from .utils import ensure_not_running_loop, is_validation_error
|
|
49
|
+
from .topic_model import TopicModel, TopicPath
|
|
50
|
+
from .utils import ensure_not_running_loop, get_checkpoint_dir, is_validation_error
|
|
45
51
|
|
|
46
52
|
# Handle circular import for type hints
|
|
47
53
|
if TYPE_CHECKING:
|
|
@@ -143,12 +149,7 @@ class DataSetGeneratorConfig(BaseModel):
|
|
|
143
149
|
"""Normalize deprecated reasoning_style values."""
|
|
144
150
|
return _normalize_reasoning_style(v)
|
|
145
151
|
|
|
146
|
-
|
|
147
|
-
default=None,
|
|
148
|
-
description="Agent mode: single_turn (one-shot tool use), multi_turn (extended agent conversations). Requires tools to be configured.",
|
|
149
|
-
)
|
|
150
|
-
|
|
151
|
-
# Tool configuration (used when agent_mode is enabled or for tool_calling)
|
|
152
|
+
# Tool configuration (used when tools are configured for agent mode)
|
|
152
153
|
tool_components: dict[str, list[str]] = Field(
|
|
153
154
|
default_factory=dict,
|
|
154
155
|
description=(
|
|
@@ -194,28 +195,32 @@ class DataSetGeneratorConfig(BaseModel):
|
|
|
194
195
|
description="Path for tool execution when using tools_endpoint (e.g., '/mock/execute'). Combined with spin_endpoint.",
|
|
195
196
|
)
|
|
196
197
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
ge=1,
|
|
201
|
-
le=10,
|
|
202
|
-
description="Minimum number of conversation turns for multi-turn agent mode",
|
|
198
|
+
tool_inclusion_strategy: Literal["all", "used_only"] = Field(
|
|
199
|
+
default="used_only",
|
|
200
|
+
description="Which tools to include in each sample: 'all' includes full catalog, 'used_only' includes only tools actually called (recommended for training)",
|
|
203
201
|
)
|
|
204
|
-
|
|
205
|
-
|
|
202
|
+
|
|
203
|
+
# Checkpoint configuration
|
|
204
|
+
checkpoint_interval: int | None = Field(
|
|
205
|
+
default=None,
|
|
206
206
|
ge=1,
|
|
207
|
-
|
|
208
|
-
description="Maximum number of conversation turns for multi-turn agent mode",
|
|
207
|
+
description="Save checkpoint every N samples. None disables checkpointing.",
|
|
209
208
|
)
|
|
210
|
-
|
|
211
|
-
default=
|
|
212
|
-
|
|
213
|
-
le=20,
|
|
214
|
-
description="Minimum number of tool calls required before allowing early conversation conclusion",
|
|
209
|
+
checkpoint_path: str | None = Field(
|
|
210
|
+
default=None,
|
|
211
|
+
description="Directory to store checkpoint files. None uses fallback '.checkpoints'",
|
|
215
212
|
)
|
|
216
|
-
|
|
217
|
-
default=
|
|
218
|
-
description="
|
|
213
|
+
checkpoint_retry_failed: bool = Field(
|
|
214
|
+
default=False,
|
|
215
|
+
description="When resuming, retry previously failed samples",
|
|
216
|
+
)
|
|
217
|
+
output_save_as: str | None = Field(
|
|
218
|
+
default=None,
|
|
219
|
+
description="Output file path (used to derive checkpoint file names)",
|
|
220
|
+
)
|
|
221
|
+
topics_file: str | None = Field(
|
|
222
|
+
default=None,
|
|
223
|
+
description="Topics file path (stored in checkpoint metadata for auto-resume)",
|
|
219
224
|
)
|
|
220
225
|
|
|
221
226
|
|
|
@@ -260,18 +265,27 @@ class DataSetGenerator:
|
|
|
260
265
|
# Store generation prompt for content generation
|
|
261
266
|
self.generation_prompt = self.config.generation_system_prompt
|
|
262
267
|
|
|
263
|
-
# Initialize tool registry when
|
|
268
|
+
# Initialize tool registry when tools are configured (enables agent mode)
|
|
264
269
|
self.tool_registry = None
|
|
265
|
-
if
|
|
266
|
-
self.config.agent_mode is not None
|
|
267
|
-
or self.config.tool_components
|
|
268
|
-
or self.config.custom_tools
|
|
269
|
-
):
|
|
270
|
+
if self.config.tool_components or self.config.custom_tools:
|
|
270
271
|
self._initialize_tool_registry()
|
|
271
272
|
|
|
272
273
|
# Progress reporter for streaming feedback (set by external callers)
|
|
273
274
|
self.progress_reporter: ProgressReporter | None = None
|
|
274
275
|
|
|
276
|
+
# Checkpoint state
|
|
277
|
+
self._checkpoint_samples_since_save = 0
|
|
278
|
+
self._processed_ids: set[str] = set() # Track processed topic IDs (UUIDs)
|
|
279
|
+
self._checkpoint_metadata_path: Path | None = None
|
|
280
|
+
self._checkpoint_samples_path: Path | None = None
|
|
281
|
+
self._checkpoint_failures_path: Path | None = None
|
|
282
|
+
# Memory optimization: track flushed counts for checkpoint mode
|
|
283
|
+
self._flushed_samples_count = 0
|
|
284
|
+
self._flushed_failures_count = 0
|
|
285
|
+
|
|
286
|
+
# Graceful stop flag - set by signal handler to stop at next checkpoint
|
|
287
|
+
self.stop_requested = False
|
|
288
|
+
|
|
275
289
|
def _initialize_tool_registry(self):
|
|
276
290
|
"""Initialize tool registry from component configuration.
|
|
277
291
|
|
|
@@ -328,6 +342,400 @@ class DataSetGenerator:
|
|
|
328
342
|
except Exception as e: # noqa: BLE001
|
|
329
343
|
raise DataSetGeneratorError(f"Failed to initialize tool registry: {str(e)}") from e
|
|
330
344
|
|
|
345
|
+
def _get_checkpoint_paths(self) -> tuple[Path, Path, Path]:
|
|
346
|
+
"""Get checkpoint file paths based on output_save_as.
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Tuple of (metadata_path, samples_path, failures_path)
|
|
350
|
+
"""
|
|
351
|
+
if not self.config.output_save_as:
|
|
352
|
+
raise DataSetGeneratorError(
|
|
353
|
+
"Cannot create checkpoint paths: output_save_as not configured"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
# Create checkpoint directory if needed
|
|
357
|
+
# Use XDG-compliant fallback if checkpoint_path not resolved by CLI
|
|
358
|
+
checkpoint_dir = Path(self.config.checkpoint_path or get_checkpoint_dir(config_path=None))
|
|
359
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
360
|
+
|
|
361
|
+
# Derive checkpoint filenames from output filename
|
|
362
|
+
output_stem = Path(self.config.output_save_as).stem
|
|
363
|
+
metadata_path = checkpoint_dir / f"{output_stem}{CHECKPOINT_METADATA_SUFFIX}"
|
|
364
|
+
samples_path = checkpoint_dir / f"{output_stem}{CHECKPOINT_SAMPLES_SUFFIX}"
|
|
365
|
+
failures_path = checkpoint_dir / f"{output_stem}{CHECKPOINT_FAILURES_SUFFIX}"
|
|
366
|
+
|
|
367
|
+
return metadata_path, samples_path, failures_path
|
|
368
|
+
|
|
369
|
+
def _initialize_checkpoint_paths(self) -> None:
|
|
370
|
+
"""Initialize checkpoint file paths if checkpointing is enabled."""
|
|
371
|
+
if self.config.checkpoint_interval is not None:
|
|
372
|
+
paths = self._get_checkpoint_paths()
|
|
373
|
+
self._checkpoint_metadata_path = paths[0]
|
|
374
|
+
self._checkpoint_samples_path = paths[1]
|
|
375
|
+
self._checkpoint_failures_path = paths[2]
|
|
376
|
+
logger.info(
|
|
377
|
+
"Checkpointing enabled: saving every %d samples to %s",
|
|
378
|
+
self.config.checkpoint_interval,
|
|
379
|
+
self._checkpoint_samples_path,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
def _save_checkpoint(
|
|
383
|
+
self,
|
|
384
|
+
new_samples: list[dict],
|
|
385
|
+
new_failures: list[dict],
|
|
386
|
+
processed_topic_paths: list[TopicPath | None],
|
|
387
|
+
flush_memory: bool = True,
|
|
388
|
+
) -> None:
|
|
389
|
+
"""Save checkpoint data incrementally.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
new_samples: New successful samples to append
|
|
393
|
+
new_failures: New failed samples to append
|
|
394
|
+
processed_topic_paths: TopicPath objects that were processed in this batch
|
|
395
|
+
flush_memory: If True, clear flushed samples from memory (memory optimization)
|
|
396
|
+
"""
|
|
397
|
+
if self._checkpoint_samples_path is None:
|
|
398
|
+
return
|
|
399
|
+
|
|
400
|
+
# Append new samples to checkpoint file
|
|
401
|
+
if new_samples:
|
|
402
|
+
with open(self._checkpoint_samples_path, "a", encoding="utf-8") as f:
|
|
403
|
+
for sample in new_samples:
|
|
404
|
+
f.write(json.dumps(sample, separators=(",", ":")) + "\n")
|
|
405
|
+
|
|
406
|
+
# Append new failures to failures file
|
|
407
|
+
if new_failures and self._checkpoint_failures_path:
|
|
408
|
+
with open(self._checkpoint_failures_path, "a", encoding="utf-8") as f:
|
|
409
|
+
for failure in new_failures:
|
|
410
|
+
f.write(json.dumps(failure, separators=(",", ":")) + "\n")
|
|
411
|
+
|
|
412
|
+
# Track processed topic IDs
|
|
413
|
+
for topic_path in processed_topic_paths:
|
|
414
|
+
if topic_path is not None:
|
|
415
|
+
self._processed_ids.add(topic_path.topic_id)
|
|
416
|
+
|
|
417
|
+
# Memory optimization: track flushed counts and clear in-memory lists
|
|
418
|
+
# Must happen BEFORE saving metadata so counts are accurate
|
|
419
|
+
if flush_memory:
|
|
420
|
+
self._flushed_samples_count += len(new_samples)
|
|
421
|
+
self._flushed_failures_count += len(new_failures)
|
|
422
|
+
# Clear the in-memory lists since data is now on disk
|
|
423
|
+
self._samples.clear()
|
|
424
|
+
self.failed_samples.clear()
|
|
425
|
+
|
|
426
|
+
# Update metadata (after flush counts are updated)
|
|
427
|
+
self._save_checkpoint_metadata()
|
|
428
|
+
|
|
429
|
+
logger.debug(
|
|
430
|
+
"Checkpoint saved: %d samples, %d failures, %d total IDs processed (flushed=%s)",
|
|
431
|
+
len(new_samples),
|
|
432
|
+
len(new_failures),
|
|
433
|
+
len(self._processed_ids),
|
|
434
|
+
flush_memory,
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
def _save_checkpoint_metadata(self) -> None:
|
|
438
|
+
"""Save checkpoint metadata file."""
|
|
439
|
+
if self._checkpoint_metadata_path is None:
|
|
440
|
+
return
|
|
441
|
+
|
|
442
|
+
# Total counts include both flushed (on disk) and in-memory samples
|
|
443
|
+
total_samples = self._flushed_samples_count + len(self._samples)
|
|
444
|
+
total_failures = self._flushed_failures_count + len(self.failed_samples)
|
|
445
|
+
|
|
446
|
+
metadata = {
|
|
447
|
+
"version": CHECKPOINT_VERSION,
|
|
448
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
449
|
+
"provider": self.provider,
|
|
450
|
+
"model_name": self.model_name,
|
|
451
|
+
"conversation_type": self.config.conversation_type,
|
|
452
|
+
"reasoning_style": self.config.reasoning_style,
|
|
453
|
+
"total_samples": total_samples,
|
|
454
|
+
"total_failures": total_failures,
|
|
455
|
+
"processed_ids": list(self._processed_ids),
|
|
456
|
+
"checkpoint_interval": self.config.checkpoint_interval,
|
|
457
|
+
"topics_file": self.config.topics_file,
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
with open(self._checkpoint_metadata_path, "w", encoding="utf-8") as f:
|
|
461
|
+
json.dump(metadata, f, indent=2)
|
|
462
|
+
|
|
463
|
+
def _validate_checkpoint_compatibility(self, metadata: dict) -> None:
|
|
464
|
+
"""Validate that current config is compatible with checkpoint.
|
|
465
|
+
|
|
466
|
+
Logs warnings for config mismatches but allows resumption.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
metadata: Checkpoint metadata dictionary
|
|
470
|
+
"""
|
|
471
|
+
mismatches: list[str] = []
|
|
472
|
+
|
|
473
|
+
# Check provider
|
|
474
|
+
checkpoint_provider = metadata.get("provider")
|
|
475
|
+
if checkpoint_provider and checkpoint_provider != self.provider:
|
|
476
|
+
mismatches.append(
|
|
477
|
+
f"provider: checkpoint={checkpoint_provider}, current={self.provider}"
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Check model
|
|
481
|
+
checkpoint_model = metadata.get("model_name")
|
|
482
|
+
if checkpoint_model and checkpoint_model != self.model_name:
|
|
483
|
+
mismatches.append(
|
|
484
|
+
f"model_name: checkpoint={checkpoint_model}, current={self.model_name}"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Check conversation type
|
|
488
|
+
checkpoint_conv_type = metadata.get("conversation_type")
|
|
489
|
+
if checkpoint_conv_type and checkpoint_conv_type != self.config.conversation_type:
|
|
490
|
+
mismatches.append(
|
|
491
|
+
f"conversation_type: checkpoint={checkpoint_conv_type}, "
|
|
492
|
+
f"current={self.config.conversation_type}"
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
# Check reasoning style
|
|
496
|
+
checkpoint_reasoning = metadata.get("reasoning_style")
|
|
497
|
+
if checkpoint_reasoning and checkpoint_reasoning != self.config.reasoning_style:
|
|
498
|
+
mismatches.append(
|
|
499
|
+
f"reasoning_style: checkpoint={checkpoint_reasoning}, "
|
|
500
|
+
f"current={self.config.reasoning_style}"
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
if mismatches:
|
|
504
|
+
logger.warning(
|
|
505
|
+
"Config mismatch with checkpoint. Resuming may produce inconsistent results. "
|
|
506
|
+
"Differences: %s",
|
|
507
|
+
"; ".join(mismatches),
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
def _validate_checkpoint_integrity(self, metadata: dict) -> tuple[bool, str | None]:
|
|
511
|
+
"""Validate checkpoint file integrity.
|
|
512
|
+
|
|
513
|
+
Checks that:
|
|
514
|
+
1. Metadata version is supported
|
|
515
|
+
2. Required metadata fields are present
|
|
516
|
+
3. Sample count in metadata matches actual file line count
|
|
517
|
+
4. Sample file contains valid JSON on each line
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
metadata: Checkpoint metadata dictionary
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
Tuple of (is_valid, error_message). error_message is None if valid.
|
|
524
|
+
"""
|
|
525
|
+
error_msg: str | None = None
|
|
526
|
+
|
|
527
|
+
# Check version
|
|
528
|
+
version = metadata.get("version")
|
|
529
|
+
if version is None:
|
|
530
|
+
error_msg = "Missing 'version' field in checkpoint metadata"
|
|
531
|
+
elif version != CHECKPOINT_VERSION:
|
|
532
|
+
error_msg = f"Unsupported checkpoint version: {version} (expected {CHECKPOINT_VERSION})"
|
|
533
|
+
|
|
534
|
+
# Check required fields
|
|
535
|
+
if error_msg is None:
|
|
536
|
+
required_fields = ["created_at", "total_samples", "processed_ids"]
|
|
537
|
+
for field in required_fields:
|
|
538
|
+
if field not in metadata:
|
|
539
|
+
error_msg = f"Missing required field in checkpoint metadata: {field}"
|
|
540
|
+
break
|
|
541
|
+
|
|
542
|
+
# Validate sample count matches file
|
|
543
|
+
if error_msg is None:
|
|
544
|
+
expected_samples = metadata.get("total_samples", 0)
|
|
545
|
+
if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
|
|
546
|
+
actual_count = 0
|
|
547
|
+
try:
|
|
548
|
+
with open(self._checkpoint_samples_path, encoding="utf-8") as f:
|
|
549
|
+
for line_num, raw_line in enumerate(f, 1):
|
|
550
|
+
stripped = raw_line.strip()
|
|
551
|
+
if stripped:
|
|
552
|
+
try:
|
|
553
|
+
json.loads(stripped)
|
|
554
|
+
actual_count += 1
|
|
555
|
+
except json.JSONDecodeError as e:
|
|
556
|
+
error_msg = f"Invalid JSON on line {line_num} of checkpoint samples: {e}"
|
|
557
|
+
break
|
|
558
|
+
except OSError as e:
|
|
559
|
+
error_msg = f"Failed to read checkpoint samples file: {e}"
|
|
560
|
+
|
|
561
|
+
if error_msg is None and actual_count != expected_samples:
|
|
562
|
+
error_msg = (
|
|
563
|
+
f"Sample count mismatch: metadata says {expected_samples}, "
|
|
564
|
+
f"file has {actual_count} samples"
|
|
565
|
+
)
|
|
566
|
+
elif expected_samples > 0:
|
|
567
|
+
error_msg = f"Checkpoint metadata expects {expected_samples} samples but samples file missing"
|
|
568
|
+
|
|
569
|
+
return (error_msg is None, error_msg)
|
|
570
|
+
|
|
571
|
+
def has_checkpoint(self) -> bool:
|
|
572
|
+
"""Check if a checkpoint exists without loading it.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
True if checkpoint metadata file exists, False otherwise
|
|
576
|
+
"""
|
|
577
|
+
if self.config.checkpoint_interval is None:
|
|
578
|
+
return False
|
|
579
|
+
|
|
580
|
+
self._initialize_checkpoint_paths()
|
|
581
|
+
return (
|
|
582
|
+
self._checkpoint_metadata_path is not None
|
|
583
|
+
and self._checkpoint_metadata_path.exists()
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
def load_checkpoint(self, retry_failed: bool = False) -> bool:
|
|
587
|
+
"""Load checkpoint data if it exists.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
retry_failed: If True, remove failed IDs from processed set to retry them
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
True if checkpoint was loaded, False if no checkpoint exists
|
|
594
|
+
"""
|
|
595
|
+
if self.config.checkpoint_interval is None:
|
|
596
|
+
return False
|
|
597
|
+
|
|
598
|
+
self._initialize_checkpoint_paths()
|
|
599
|
+
|
|
600
|
+
if self._checkpoint_metadata_path is None or not self._checkpoint_metadata_path.exists():
|
|
601
|
+
return False
|
|
602
|
+
|
|
603
|
+
try:
|
|
604
|
+
# Load metadata
|
|
605
|
+
with open(self._checkpoint_metadata_path, encoding="utf-8") as f:
|
|
606
|
+
metadata = json.load(f)
|
|
607
|
+
|
|
608
|
+
# Validate checkpoint integrity
|
|
609
|
+
is_valid, error_msg = self._validate_checkpoint_integrity(metadata)
|
|
610
|
+
if not is_valid:
|
|
611
|
+
logger.error("Checkpoint integrity check failed: %s", error_msg)
|
|
612
|
+
return False
|
|
613
|
+
|
|
614
|
+
# Validate config compatibility
|
|
615
|
+
self._validate_checkpoint_compatibility(metadata)
|
|
616
|
+
|
|
617
|
+
# Restore processed IDs
|
|
618
|
+
self._processed_ids = set(metadata.get("processed_ids", []))
|
|
619
|
+
|
|
620
|
+
# Count existing samples (don't load into memory - they're already on disk)
|
|
621
|
+
# Memory optimization: track as flushed counts instead of loading into RAM
|
|
622
|
+
if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
|
|
623
|
+
sample_count = 0
|
|
624
|
+
with open(self._checkpoint_samples_path, encoding="utf-8") as f:
|
|
625
|
+
for raw_line in f:
|
|
626
|
+
if raw_line.strip():
|
|
627
|
+
sample_count += 1
|
|
628
|
+
self._flushed_samples_count = sample_count
|
|
629
|
+
|
|
630
|
+
# Load failure IDs for retry logic (these are small)
|
|
631
|
+
failed_ids: set[str] = set()
|
|
632
|
+
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
633
|
+
failure_count = 0
|
|
634
|
+
with open(self._checkpoint_failures_path, encoding="utf-8") as f:
|
|
635
|
+
for raw_line in f:
|
|
636
|
+
stripped = raw_line.strip()
|
|
637
|
+
if stripped:
|
|
638
|
+
failure = json.loads(stripped)
|
|
639
|
+
failure_count += 1
|
|
640
|
+
# Track the topic_id that failed for potential retry
|
|
641
|
+
if "topic_id" in failure:
|
|
642
|
+
failed_ids.add(failure["topic_id"])
|
|
643
|
+
self._flushed_failures_count = failure_count
|
|
644
|
+
|
|
645
|
+
# If retry_failed is True, remove failed IDs from processed set
|
|
646
|
+
# so they will be retried during generation
|
|
647
|
+
if retry_failed and failed_ids:
|
|
648
|
+
ids_to_retry = self._processed_ids & failed_ids
|
|
649
|
+
self._processed_ids -= ids_to_retry
|
|
650
|
+
# Clear failures file since we're retrying
|
|
651
|
+
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
652
|
+
os.remove(self._checkpoint_failures_path)
|
|
653
|
+
self._flushed_failures_count = 0
|
|
654
|
+
logger.info(
|
|
655
|
+
"Retry mode: %d failed IDs will be retried",
|
|
656
|
+
len(ids_to_retry),
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
logger.info(
|
|
660
|
+
"Loaded checkpoint: %d samples, %d failures, %d IDs processed",
|
|
661
|
+
self._flushed_samples_count,
|
|
662
|
+
self._flushed_failures_count,
|
|
663
|
+
len(self._processed_ids),
|
|
664
|
+
)
|
|
665
|
+
except Exception as e: # noqa: BLE001
|
|
666
|
+
logger.warning("Failed to load checkpoint: %s", e)
|
|
667
|
+
return False
|
|
668
|
+
else:
|
|
669
|
+
return True
|
|
670
|
+
|
|
671
|
+
def clear_checkpoint(self) -> None:
|
|
672
|
+
"""Remove checkpoint files."""
|
|
673
|
+
if self._checkpoint_metadata_path and self._checkpoint_metadata_path.exists():
|
|
674
|
+
os.remove(self._checkpoint_metadata_path)
|
|
675
|
+
if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
|
|
676
|
+
os.remove(self._checkpoint_samples_path)
|
|
677
|
+
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
678
|
+
os.remove(self._checkpoint_failures_path)
|
|
679
|
+
self._processed_ids.clear()
|
|
680
|
+
self._flushed_samples_count = 0
|
|
681
|
+
self._flushed_failures_count = 0
|
|
682
|
+
logger.info("Checkpoint files cleared")
|
|
683
|
+
|
|
684
|
+
def _load_all_samples_from_checkpoint(self) -> list[dict]:
|
|
685
|
+
"""Load all samples from checkpoint file.
|
|
686
|
+
|
|
687
|
+
Used at end of generation to build final dataset when memory
|
|
688
|
+
optimization has flushed samples to disk.
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
List of all sample dictionaries from checkpoint file
|
|
692
|
+
"""
|
|
693
|
+
all_samples: list[dict] = []
|
|
694
|
+
if self._checkpoint_samples_path and self._checkpoint_samples_path.exists():
|
|
695
|
+
with open(self._checkpoint_samples_path, encoding="utf-8") as f:
|
|
696
|
+
for raw_line in f:
|
|
697
|
+
stripped = raw_line.strip()
|
|
698
|
+
if stripped:
|
|
699
|
+
all_samples.append(json.loads(stripped))
|
|
700
|
+
return all_samples
|
|
701
|
+
|
|
702
|
+
def get_all_failures(self) -> list[dict]:
|
|
703
|
+
"""Get all failures including those flushed to checkpoint.
|
|
704
|
+
|
|
705
|
+
This combines in-memory failures with any that were flushed to the
|
|
706
|
+
checkpoint failures file during memory optimization.
|
|
707
|
+
|
|
708
|
+
Returns:
|
|
709
|
+
List of all failure dictionaries
|
|
710
|
+
"""
|
|
711
|
+
all_failures: list[dict] = []
|
|
712
|
+
|
|
713
|
+
# First load from checkpoint file if it exists
|
|
714
|
+
if self._checkpoint_failures_path and self._checkpoint_failures_path.exists():
|
|
715
|
+
with open(self._checkpoint_failures_path, encoding="utf-8") as f:
|
|
716
|
+
for raw_line in f:
|
|
717
|
+
stripped = raw_line.strip()
|
|
718
|
+
if stripped:
|
|
719
|
+
all_failures.append(json.loads(stripped))
|
|
720
|
+
|
|
721
|
+
# Then add any in-memory failures (if not yet flushed)
|
|
722
|
+
all_failures.extend(self.failed_samples)
|
|
723
|
+
|
|
724
|
+
return all_failures
|
|
725
|
+
|
|
726
|
+
def _is_topic_processed(self, topic_path: TopicPath | None) -> bool:
|
|
727
|
+
"""Check if a topic has already been processed.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
topic_path: TopicPath to check
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
True if topic was already processed in a previous run
|
|
734
|
+
"""
|
|
735
|
+
if topic_path is None:
|
|
736
|
+
return False
|
|
737
|
+
return topic_path.topic_id in self._processed_ids
|
|
738
|
+
|
|
331
739
|
def _validate_create_data_params(
|
|
332
740
|
self,
|
|
333
741
|
num_steps: int,
|
|
@@ -351,34 +759,24 @@ class DataSetGenerator:
|
|
|
351
759
|
num_steps: int,
|
|
352
760
|
batch_size: int,
|
|
353
761
|
topic_model: "TopicModel | None" = None,
|
|
354
|
-
) -> tuple[list | None, int]:
|
|
762
|
+
) -> tuple[list[TopicPath] | None, int]:
|
|
355
763
|
"""Prepare and validate topic paths for data generation."""
|
|
356
|
-
topic_paths = None
|
|
764
|
+
topic_paths: list[TopicPath] | None = None
|
|
357
765
|
if topic_model is not None:
|
|
358
|
-
topic_paths = topic_model.
|
|
766
|
+
topic_paths = topic_model.get_all_paths_with_ids()
|
|
359
767
|
total_paths = len(topic_paths)
|
|
360
768
|
required_samples = num_steps * batch_size
|
|
361
769
|
|
|
362
770
|
if required_samples > total_paths:
|
|
363
|
-
#
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
f"Recommendations:\n"
|
|
373
|
-
f" • Reduce --num-steps to {max_steps_for_batch} (with current batch size {batch_size})\n"
|
|
374
|
-
f" • Reduce --batch-size to {max_batch_for_steps} (with current {num_steps} steps)\n"
|
|
375
|
-
f" • Increase topic tree/graph depth or degree to generate more paths"
|
|
376
|
-
)
|
|
377
|
-
raise DataSetGeneratorError(error_msg)
|
|
378
|
-
|
|
379
|
-
# Bandit: not a security function
|
|
380
|
-
topic_paths = random.sample(topic_paths, required_samples) # nosec
|
|
381
|
-
num_steps = math.ceil(len(topic_paths) / batch_size)
|
|
771
|
+
# Cycle through topics to generate more samples than paths
|
|
772
|
+
# Each topic will be used multiple times for even coverage
|
|
773
|
+
multiplier = math.ceil(required_samples / total_paths)
|
|
774
|
+
topic_paths = (topic_paths * multiplier)[:required_samples]
|
|
775
|
+
elif required_samples < total_paths:
|
|
776
|
+
# Sample subset (percentage case or explicit count < total)
|
|
777
|
+
# Bandit: not a security function
|
|
778
|
+
topic_paths = random.sample(topic_paths, required_samples) # nosec
|
|
779
|
+
# else: required_samples == total_paths - use all paths as-is
|
|
382
780
|
|
|
383
781
|
return topic_paths, num_steps
|
|
384
782
|
|
|
@@ -386,23 +784,25 @@ class DataSetGenerator:
|
|
|
386
784
|
self,
|
|
387
785
|
batch_size: int,
|
|
388
786
|
start_idx: int,
|
|
389
|
-
topic_paths: list,
|
|
787
|
+
topic_paths: list[TopicPath],
|
|
390
788
|
data_creation_prompt: str,
|
|
391
789
|
num_example_demonstrations: int,
|
|
392
|
-
) -> tuple[list[str], list[
|
|
393
|
-
"""Generate prompts for a batch and return the associated
|
|
790
|
+
) -> tuple[list[str], list[TopicPath | None]]:
|
|
791
|
+
"""Generate prompts for a batch and return the associated TopicPaths used.
|
|
394
792
|
|
|
395
793
|
Returns:
|
|
396
|
-
(prompts,
|
|
794
|
+
(prompts, used_topic_paths) where used_topic_paths aligns with prompts order.
|
|
397
795
|
"""
|
|
398
796
|
prompts: list[str] = []
|
|
399
|
-
|
|
797
|
+
used_topic_paths: list[TopicPath | None] = []
|
|
400
798
|
for i in range(batch_size):
|
|
401
|
-
|
|
799
|
+
topic_path: TopicPath | None = None
|
|
800
|
+
path: list[str] | None = None
|
|
402
801
|
if topic_paths:
|
|
403
802
|
current_idx = start_idx + i
|
|
404
803
|
if current_idx < len(topic_paths):
|
|
405
|
-
|
|
804
|
+
topic_path = topic_paths[current_idx]
|
|
805
|
+
path = topic_path.path
|
|
406
806
|
else:
|
|
407
807
|
break
|
|
408
808
|
|
|
@@ -412,8 +812,8 @@ class DataSetGenerator:
|
|
|
412
812
|
subtopics_list=path,
|
|
413
813
|
)
|
|
414
814
|
prompts.append(sample_prompt)
|
|
415
|
-
|
|
416
|
-
return prompts,
|
|
815
|
+
used_topic_paths.append(topic_path)
|
|
816
|
+
return prompts, used_topic_paths
|
|
417
817
|
|
|
418
818
|
def _get_minimal_schema(self) -> type:
|
|
419
819
|
"""Get the conversation schema for the current config."""
|
|
@@ -447,7 +847,7 @@ class DataSetGenerator:
|
|
|
447
847
|
prompts: list[str],
|
|
448
848
|
include_sys_msg: bool,
|
|
449
849
|
start_sample_idx: int = 0,
|
|
450
|
-
|
|
850
|
+
topic_paths_for_batch: list[TopicPath | None] | None = None,
|
|
451
851
|
) -> tuple[list, list]:
|
|
452
852
|
"""Generate structured samples using builder pattern.
|
|
453
853
|
|
|
@@ -455,6 +855,7 @@ class DataSetGenerator:
|
|
|
455
855
|
prompts: List of topic prompts to generate samples for
|
|
456
856
|
include_sys_msg: Whether to include system message in output
|
|
457
857
|
start_sample_idx: Starting sample index for progress reporting
|
|
858
|
+
topic_paths_for_batch: TopicPath objects for each sample (includes topic_id)
|
|
458
859
|
|
|
459
860
|
Returns:
|
|
460
861
|
Tuple of (successful samples, failed responses)
|
|
@@ -470,7 +871,7 @@ class DataSetGenerator:
|
|
|
470
871
|
config = self.config.model_copy(update={"sys_msg": include_sys_msg})
|
|
471
872
|
|
|
472
873
|
async def _generate_with_retry(
|
|
473
|
-
prompt: str, sample_idx: int,
|
|
874
|
+
prompt: str, sample_idx: int, topic_path_info: TopicPath | None
|
|
474
875
|
) -> tuple[bool, Exception | Conversation]:
|
|
475
876
|
"""Generate a single sample with per-sample retry for validation errors.
|
|
476
877
|
|
|
@@ -496,6 +897,9 @@ class DataSetGenerator:
|
|
|
496
897
|
self.config.sample_retries,
|
|
497
898
|
)
|
|
498
899
|
|
|
900
|
+
# Extract path for progress reporting
|
|
901
|
+
path_info = topic_path_info.path if topic_path_info else None
|
|
902
|
+
|
|
499
903
|
for attempt in range(max_attempts):
|
|
500
904
|
# Notify progress reporter about which sample we're working on
|
|
501
905
|
if self.progress_reporter:
|
|
@@ -533,8 +937,8 @@ class DataSetGenerator:
|
|
|
533
937
|
return False, last_error or Exception("Sample generation failed")
|
|
534
938
|
|
|
535
939
|
else:
|
|
536
|
-
# Validate tool execution count for agent
|
|
537
|
-
if self.
|
|
940
|
+
# Validate tool execution count for agent mode (when tools configured)
|
|
941
|
+
if self.tool_registry is not None:
|
|
538
942
|
if (
|
|
539
943
|
not conversation.tool_context
|
|
540
944
|
or not conversation.tool_context.executions
|
|
@@ -566,6 +970,12 @@ class DataSetGenerator:
|
|
|
566
970
|
]
|
|
567
971
|
)
|
|
568
972
|
|
|
973
|
+
# Add topic_id to conversation metadata for traceability
|
|
974
|
+
if topic_path_info and hasattr(conversation, "metadata"):
|
|
975
|
+
if conversation.metadata is None:
|
|
976
|
+
conversation.metadata = {}
|
|
977
|
+
conversation.metadata["topic_id"] = topic_path_info.topic_id
|
|
978
|
+
|
|
569
979
|
return True, conversation
|
|
570
980
|
|
|
571
981
|
return False, last_error or Exception("Sample generation failed")
|
|
@@ -573,11 +983,13 @@ class DataSetGenerator:
|
|
|
573
983
|
# Generate all samples concurrently with sample indices
|
|
574
984
|
tasks = []
|
|
575
985
|
for idx, prompt in enumerate(prompts):
|
|
576
|
-
|
|
577
|
-
if
|
|
578
|
-
|
|
986
|
+
topic_path_info = None
|
|
987
|
+
if topic_paths_for_batch and idx < len(topic_paths_for_batch):
|
|
988
|
+
topic_path_info = topic_paths_for_batch[idx]
|
|
579
989
|
tasks.append(
|
|
580
|
-
asyncio.create_task(
|
|
990
|
+
asyncio.create_task(
|
|
991
|
+
_generate_with_retry(prompt, start_sample_idx + idx, topic_path_info)
|
|
992
|
+
)
|
|
581
993
|
)
|
|
582
994
|
results = await asyncio.gather(*tasks)
|
|
583
995
|
|
|
@@ -587,12 +999,18 @@ class DataSetGenerator:
|
|
|
587
999
|
else:
|
|
588
1000
|
error = payload
|
|
589
1001
|
error_msg = f"Generation failed: {error}"
|
|
590
|
-
# Build failure record with raw content if available
|
|
591
|
-
failure_record = {"error": error_msg}
|
|
1002
|
+
# Build failure record with raw content and topic_id if available
|
|
1003
|
+
failure_record: dict[str, str | None] = {"error": error_msg}
|
|
592
1004
|
if isinstance(error, Exception):
|
|
593
1005
|
context = getattr(error, "context", None)
|
|
594
1006
|
if isinstance(context, dict) and "raw_content" in context:
|
|
595
1007
|
failure_record["raw_content"] = context["raw_content"]
|
|
1008
|
+
# Include topic_id and path for checkpoint retry functionality
|
|
1009
|
+
if topic_paths_for_batch and idx < len(topic_paths_for_batch):
|
|
1010
|
+
tp = topic_paths_for_batch[idx]
|
|
1011
|
+
if tp:
|
|
1012
|
+
failure_record["topic_id"] = tp.topic_id
|
|
1013
|
+
failure_record["path"] = " -> ".join(tp.path)
|
|
596
1014
|
failed_responses.append(failure_record)
|
|
597
1015
|
failure_type = self.analyze_failure(
|
|
598
1016
|
str(error), error=error if isinstance(error, Exception) else None
|
|
@@ -817,12 +1235,12 @@ class DataSetGenerator:
|
|
|
817
1235
|
):
|
|
818
1236
|
yield event
|
|
819
1237
|
|
|
820
|
-
async def _run_generation_loop_async( # noqa: PLR0912
|
|
1238
|
+
async def _run_generation_loop_async( # noqa: PLR0912, PLR0915
|
|
821
1239
|
self,
|
|
822
1240
|
num_steps: int,
|
|
823
1241
|
batch_size: int,
|
|
824
1242
|
total_samples: int,
|
|
825
|
-
topic_paths: list,
|
|
1243
|
+
topic_paths: list[TopicPath],
|
|
826
1244
|
data_creation_prompt: str,
|
|
827
1245
|
num_example_demonstrations: int,
|
|
828
1246
|
include_sys_msg: bool,
|
|
@@ -830,6 +1248,16 @@ class DataSetGenerator:
|
|
|
830
1248
|
topic_model_type: str | None = None,
|
|
831
1249
|
) -> AsyncGenerator[dict | HFDataset, None]:
|
|
832
1250
|
"""Run the main generation loop yielding progress events."""
|
|
1251
|
+
# Initialize checkpoint paths if checkpointing is enabled
|
|
1252
|
+
if self.config.checkpoint_interval is not None:
|
|
1253
|
+
self._initialize_checkpoint_paths()
|
|
1254
|
+
|
|
1255
|
+
# Track samples added in this run for checkpointing
|
|
1256
|
+
samples_since_checkpoint = 0
|
|
1257
|
+
samples_in_current_batch: list[dict] = []
|
|
1258
|
+
failures_in_current_batch: list[dict] = []
|
|
1259
|
+
topic_paths_in_current_batch: list[TopicPath | None] = []
|
|
1260
|
+
|
|
833
1261
|
try:
|
|
834
1262
|
yield {
|
|
835
1263
|
"event": "generation_start",
|
|
@@ -839,6 +1267,11 @@ class DataSetGenerator:
|
|
|
839
1267
|
"total_samples": total_samples,
|
|
840
1268
|
"root_topic_prompt": root_topic_prompt,
|
|
841
1269
|
"topic_model_type": topic_model_type,
|
|
1270
|
+
"resumed_from_checkpoint": len(self._processed_ids) > 0,
|
|
1271
|
+
"previously_processed": len(self._processed_ids),
|
|
1272
|
+
"resumed_samples": self._flushed_samples_count,
|
|
1273
|
+
"resumed_failures": self._flushed_failures_count,
|
|
1274
|
+
"checkpoint_enabled": self.config.checkpoint_interval is not None,
|
|
842
1275
|
}
|
|
843
1276
|
|
|
844
1277
|
for step in range(num_steps):
|
|
@@ -849,7 +1282,7 @@ class DataSetGenerator:
|
|
|
849
1282
|
}
|
|
850
1283
|
|
|
851
1284
|
start_idx = step * batch_size
|
|
852
|
-
prompts,
|
|
1285
|
+
prompts, used_topic_paths = self._generate_batch_prompts(
|
|
853
1286
|
batch_size,
|
|
854
1287
|
start_idx,
|
|
855
1288
|
topic_paths,
|
|
@@ -857,17 +1290,85 @@ class DataSetGenerator:
|
|
|
857
1290
|
num_example_demonstrations,
|
|
858
1291
|
)
|
|
859
1292
|
|
|
1293
|
+
# Filter out already-processed topics when resuming
|
|
1294
|
+
if self._processed_ids:
|
|
1295
|
+
filtered_prompts = []
|
|
1296
|
+
filtered_topic_paths: list[TopicPath | None] = []
|
|
1297
|
+
for prompt, tp in zip(prompts, used_topic_paths, strict=False):
|
|
1298
|
+
if not self._is_topic_processed(tp):
|
|
1299
|
+
filtered_prompts.append(prompt)
|
|
1300
|
+
filtered_topic_paths.append(tp)
|
|
1301
|
+
|
|
1302
|
+
if not filtered_prompts:
|
|
1303
|
+
# All topics in this batch were already processed
|
|
1304
|
+
yield {
|
|
1305
|
+
"event": "step_complete",
|
|
1306
|
+
"step": step + 1,
|
|
1307
|
+
"samples_generated": 0,
|
|
1308
|
+
"success": True,
|
|
1309
|
+
"failed_in_step": 0,
|
|
1310
|
+
"failure_reasons": [],
|
|
1311
|
+
"skipped": len(prompts),
|
|
1312
|
+
}
|
|
1313
|
+
continue
|
|
1314
|
+
|
|
1315
|
+
prompts = filtered_prompts
|
|
1316
|
+
used_topic_paths = filtered_topic_paths
|
|
1317
|
+
|
|
860
1318
|
failed_before = len(self.failed_samples)
|
|
1319
|
+
samples_before = len(self._samples)
|
|
861
1320
|
|
|
862
1321
|
success, samples_generated = await self._process_batch_with_retries_async(
|
|
863
|
-
prompts, include_sys_msg, start_idx,
|
|
1322
|
+
prompts, include_sys_msg, start_idx, used_topic_paths
|
|
864
1323
|
)
|
|
865
1324
|
|
|
1325
|
+
# Track new samples and failures for checkpointing
|
|
1326
|
+
new_samples = self._samples[samples_before:]
|
|
1327
|
+
new_failures = self.failed_samples[failed_before:]
|
|
1328
|
+
samples_in_current_batch.extend(new_samples)
|
|
1329
|
+
failures_in_current_batch.extend(new_failures)
|
|
1330
|
+
topic_paths_in_current_batch.extend(used_topic_paths)
|
|
1331
|
+
samples_since_checkpoint += samples_generated
|
|
1332
|
+
|
|
1333
|
+
# Save checkpoint if we've reached the interval
|
|
1334
|
+
if (
|
|
1335
|
+
self.config.checkpoint_interval is not None
|
|
1336
|
+
and samples_since_checkpoint >= self.config.checkpoint_interval
|
|
1337
|
+
):
|
|
1338
|
+
self._save_checkpoint(
|
|
1339
|
+
samples_in_current_batch,
|
|
1340
|
+
failures_in_current_batch,
|
|
1341
|
+
topic_paths_in_current_batch,
|
|
1342
|
+
)
|
|
1343
|
+
samples_in_current_batch = []
|
|
1344
|
+
failures_in_current_batch = []
|
|
1345
|
+
topic_paths_in_current_batch = []
|
|
1346
|
+
samples_since_checkpoint = 0
|
|
1347
|
+
yield {
|
|
1348
|
+
"event": "checkpoint_saved",
|
|
1349
|
+
"total_samples": self._flushed_samples_count,
|
|
1350
|
+
"total_failures": self._flushed_failures_count,
|
|
1351
|
+
}
|
|
1352
|
+
|
|
1353
|
+
# Check for graceful stop request after checkpoint save
|
|
1354
|
+
if self.stop_requested:
|
|
1355
|
+
yield {
|
|
1356
|
+
"event": "generation_stopped",
|
|
1357
|
+
"message": "Stopped at checkpoint as requested",
|
|
1358
|
+
"total_samples": self._flushed_samples_count,
|
|
1359
|
+
"total_failures": self._flushed_failures_count,
|
|
1360
|
+
}
|
|
1361
|
+
return # Exit generator cleanly
|
|
1362
|
+
|
|
866
1363
|
failed_in_batch = len(self.failed_samples) - failed_before
|
|
867
|
-
failure_reasons = []
|
|
1364
|
+
failure_reasons: list[str] = []
|
|
868
1365
|
if failed_in_batch > 0 and self.failed_samples:
|
|
869
1366
|
recent_failures = self.failed_samples[-failed_in_batch:]
|
|
870
|
-
|
|
1367
|
+
for f in recent_failures[:3]:
|
|
1368
|
+
if isinstance(f, dict):
|
|
1369
|
+
failure_reasons.append(f.get("error", str(f)))
|
|
1370
|
+
else:
|
|
1371
|
+
failure_reasons.append(str(f))
|
|
871
1372
|
|
|
872
1373
|
yield {
|
|
873
1374
|
"event": "step_complete",
|
|
@@ -885,13 +1386,42 @@ class DataSetGenerator:
|
|
|
885
1386
|
"message": f"Failed to process batch {step + 1} after all retries",
|
|
886
1387
|
}
|
|
887
1388
|
|
|
1389
|
+
# Save final checkpoint with any remaining samples
|
|
1390
|
+
if self.config.checkpoint_interval is not None and (
|
|
1391
|
+
samples_in_current_batch or failures_in_current_batch
|
|
1392
|
+
):
|
|
1393
|
+
self._save_checkpoint(
|
|
1394
|
+
samples_in_current_batch,
|
|
1395
|
+
failures_in_current_batch,
|
|
1396
|
+
topic_paths_in_current_batch,
|
|
1397
|
+
)
|
|
1398
|
+
yield {
|
|
1399
|
+
"event": "checkpoint_saved",
|
|
1400
|
+
"total_samples": self._flushed_samples_count,
|
|
1401
|
+
"total_failures": self._flushed_failures_count,
|
|
1402
|
+
"final": True,
|
|
1403
|
+
}
|
|
1404
|
+
|
|
1405
|
+
# Calculate total counts including flushed data
|
|
1406
|
+
total_samples = self._flushed_samples_count + len(self._samples)
|
|
1407
|
+
total_failures = self._flushed_failures_count + len(self.failed_samples)
|
|
1408
|
+
|
|
888
1409
|
yield {
|
|
889
1410
|
"event": "generation_complete",
|
|
890
|
-
"total_samples":
|
|
891
|
-
"failed_samples":
|
|
1411
|
+
"total_samples": total_samples,
|
|
1412
|
+
"failed_samples": total_failures,
|
|
892
1413
|
}
|
|
893
1414
|
|
|
894
1415
|
except KeyboardInterrupt:
|
|
1416
|
+
# Save checkpoint on interrupt
|
|
1417
|
+
if self.config.checkpoint_interval is not None and (
|
|
1418
|
+
samples_in_current_batch or failures_in_current_batch
|
|
1419
|
+
):
|
|
1420
|
+
self._save_checkpoint(
|
|
1421
|
+
samples_in_current_batch,
|
|
1422
|
+
failures_in_current_batch,
|
|
1423
|
+
topic_paths_in_current_batch,
|
|
1424
|
+
)
|
|
895
1425
|
yield {
|
|
896
1426
|
"event": "generation_interrupted",
|
|
897
1427
|
"message": "Generation interrupted by user.",
|
|
@@ -900,25 +1430,39 @@ class DataSetGenerator:
|
|
|
900
1430
|
self._save_samples_to_file(INTERRUPTED_DATASET_FILENAME)
|
|
901
1431
|
|
|
902
1432
|
except Exception as e: # noqa: BLE001
|
|
1433
|
+
# Save checkpoint on error
|
|
1434
|
+
if self.config.checkpoint_interval is not None and (
|
|
1435
|
+
samples_in_current_batch or failures_in_current_batch
|
|
1436
|
+
):
|
|
1437
|
+
self._save_checkpoint(
|
|
1438
|
+
samples_in_current_batch,
|
|
1439
|
+
failures_in_current_batch,
|
|
1440
|
+
topic_paths_in_current_batch,
|
|
1441
|
+
)
|
|
903
1442
|
yield {"event": "generation_error", "error": str(e)}
|
|
904
1443
|
self.print_failure_summary()
|
|
905
1444
|
self._save_samples_to_file(ERROR_DATASET_FILENAME)
|
|
906
1445
|
raise DataSetGeneratorError("failed") from e
|
|
907
1446
|
|
|
908
|
-
|
|
1447
|
+
# Build final dataset: if samples were flushed to disk, load them from checkpoint
|
|
1448
|
+
if self._flushed_samples_count > 0:
|
|
1449
|
+
all_samples = self._load_all_samples_from_checkpoint()
|
|
1450
|
+
yield HFDataset.from_list(all_samples) if all_samples else HFDataset.from_list([])
|
|
1451
|
+
else:
|
|
1452
|
+
yield (HFDataset.from_list(self._samples) if self._samples else HFDataset.from_list([]))
|
|
909
1453
|
|
|
910
1454
|
async def _process_batch_with_retries_async(
|
|
911
1455
|
self,
|
|
912
1456
|
prompts: list[str],
|
|
913
1457
|
include_sys_msg: bool,
|
|
914
1458
|
start_sample_idx: int = 0,
|
|
915
|
-
|
|
1459
|
+
topic_paths_for_batch: list[TopicPath | None] | None = None,
|
|
916
1460
|
) -> tuple[bool, int]:
|
|
917
1461
|
"""Process a batch with retry logic."""
|
|
918
1462
|
for attempt in range(self.config.max_retries):
|
|
919
1463
|
try:
|
|
920
1464
|
samples, failed_responses = await self._generate_structured_samples_async(
|
|
921
|
-
prompts, include_sys_msg, start_sample_idx,
|
|
1465
|
+
prompts, include_sys_msg, start_sample_idx, topic_paths_for_batch
|
|
922
1466
|
)
|
|
923
1467
|
|
|
924
1468
|
# Update failed samples
|
|
@@ -948,14 +1492,33 @@ class DataSetGenerator:
|
|
|
948
1492
|
error_msg = f"API error for provider '{self.provider}': {str(e)[:100]}..."
|
|
949
1493
|
self.failure_analysis["api_errors"].append(error_msg)
|
|
950
1494
|
|
|
951
|
-
|
|
1495
|
+
# Build failure records for each topic path in the batch
|
|
1496
|
+
if topic_paths_for_batch:
|
|
1497
|
+
for tp in topic_paths_for_batch:
|
|
1498
|
+
failure_record: dict[str, str | None] = {"error": error_msg}
|
|
1499
|
+
if tp:
|
|
1500
|
+
failure_record["topic_id"] = tp.topic_id
|
|
1501
|
+
failure_record["path"] = " -> ".join(tp.path)
|
|
1502
|
+
self.failed_samples.append(failure_record)
|
|
1503
|
+
else:
|
|
1504
|
+
self.failed_samples.append({"error": error_msg})
|
|
952
1505
|
logger.exception("API error: %s", error_msg)
|
|
953
1506
|
return False, 0 # Don't retry authentication/API errors
|
|
954
1507
|
except Exception as e:
|
|
955
1508
|
if attempt == self.config.max_retries - 1:
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
1509
|
+
error_msg = str(e)
|
|
1510
|
+
# Build failure records for each topic path in the batch
|
|
1511
|
+
if topic_paths_for_batch:
|
|
1512
|
+
for tp in topic_paths_for_batch:
|
|
1513
|
+
failure_record_exc: dict[str, str | None] = {"error": error_msg}
|
|
1514
|
+
if tp:
|
|
1515
|
+
failure_record_exc["topic_id"] = tp.topic_id
|
|
1516
|
+
failure_record_exc["path"] = " -> ".join(tp.path)
|
|
1517
|
+
self.failed_samples.append(failure_record_exc)
|
|
1518
|
+
else:
|
|
1519
|
+
self.failed_samples.append({"error": error_msg})
|
|
1520
|
+
failure_type = self.analyze_failure(error_msg, error=e)
|
|
1521
|
+
self.failure_analysis[failure_type].append(error_msg)
|
|
959
1522
|
return False, 0
|
|
960
1523
|
else:
|
|
961
1524
|
# If no exception and no samples, return False, 0
|
|
@@ -1015,7 +1578,7 @@ class DataSetGenerator:
|
|
|
1015
1578
|
return f"\nHere are output examples:\n<examples>\n{examples_text}\n</examples>\n"
|
|
1016
1579
|
|
|
1017
1580
|
def build_tools_text(self) -> str:
|
|
1018
|
-
"""Build formatted tools text for XLAM
|
|
1581
|
+
"""Build formatted tools text for XLAM prompts."""
|
|
1019
1582
|
if not self.tool_registry:
|
|
1020
1583
|
return "No tools available"
|
|
1021
1584
|
|
|
@@ -1046,8 +1609,8 @@ class DataSetGenerator:
|
|
|
1046
1609
|
|
|
1047
1610
|
# Handle chain of thought conversations
|
|
1048
1611
|
if self.config.conversation_type == "cot":
|
|
1049
|
-
# Agent mode with tools - use agent prompts
|
|
1050
|
-
if self.
|
|
1612
|
+
# Agent mode with tools - use agent prompts (implicit when tools configured)
|
|
1613
|
+
if self.tool_registry:
|
|
1051
1614
|
# Use agent prompt for single-turn tool calling
|
|
1052
1615
|
return (
|
|
1053
1616
|
AgentPromptBuilder.build_tool_context_prompt(
|
|
@@ -1057,16 +1620,6 @@ class DataSetGenerator:
|
|
|
1057
1620
|
or AGENT_COT_TOOLS_PROMPT
|
|
1058
1621
|
)
|
|
1059
1622
|
|
|
1060
|
-
if self.config.agent_mode == "multi_turn" and self.tool_registry:
|
|
1061
|
-
# Standard multi-turn agent
|
|
1062
|
-
return (
|
|
1063
|
-
AgentPromptBuilder.build_multi_turn_context_prompt(
|
|
1064
|
-
self.tool_registry,
|
|
1065
|
-
max_tools_per_query=self.config.max_tools_per_query,
|
|
1066
|
-
)
|
|
1067
|
-
or AGENT_COT_MULTI_TURN_PROMPT
|
|
1068
|
-
)
|
|
1069
|
-
|
|
1070
1623
|
# Non-agent CoT - select based on reasoning style
|
|
1071
1624
|
if self.config.reasoning_style == "freetext":
|
|
1072
1625
|
return FREETEXT_COT_PROMPT
|
|
@@ -1079,7 +1632,7 @@ class DataSetGenerator:
|
|
|
1079
1632
|
def _save_samples_to_file(self, save_path: str):
|
|
1080
1633
|
"""Save the current samples to a JSONL file."""
|
|
1081
1634
|
|
|
1082
|
-
with open(save_path, "w") as f:
|
|
1635
|
+
with open(save_path, "w", encoding="utf-8") as f:
|
|
1083
1636
|
for sample in self._samples:
|
|
1084
1637
|
f.write(json.dumps(sample, separators=(",", ":")) + "\n")
|
|
1085
1638
|
|