DeepFabric 4.8.3__py3-none-any.whl → 4.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepfabric/cli.py CHANGED
@@ -1,7 +1,11 @@
1
1
  import contextlib
2
+ import json
3
+ import math
2
4
  import os
5
+ import signal
3
6
  import sys
4
7
 
8
+ from pathlib import Path
5
9
  from typing import Literal, NoReturn, cast
6
10
 
7
11
  import click
@@ -13,6 +17,11 @@ from pydantic import ValidationError as PydanticValidationError
13
17
  from .auth import auth as auth_group
14
18
  from .config import DeepFabricConfig
15
19
  from .config_manager import apply_cli_overrides, get_final_parameters, load_config
20
+ from .constants import (
21
+ CHECKPOINT_FAILURES_SUFFIX,
22
+ CHECKPOINT_METADATA_SUFFIX,
23
+ CHECKPOINT_SAMPLES_SUFFIX,
24
+ )
16
25
  from .dataset_manager import create_dataset, save_dataset
17
26
  from .exceptions import ConfigurationError
18
27
  from .generator import DataSetGenerator
@@ -21,9 +30,15 @@ from .llm import VerificationStatus, verify_provider_api_key
21
30
  from .metrics import set_trace_debug, trace
22
31
  from .topic_manager import load_or_build_topic_model, save_topic_model
23
32
  from .topic_model import TopicModel
24
- from .tui import configure_tui, get_tui
33
+ from .tui import configure_tui, get_dataset_tui, get_tui
25
34
  from .update_checker import check_for_updates
26
- from .utils import get_bool_env
35
+ from .utils import (
36
+ check_dir_writable,
37
+ check_path_writable,
38
+ get_bool_env,
39
+ get_checkpoint_dir,
40
+ parse_num_samples,
41
+ )
27
42
  from .validation import show_validation_success, validate_path_requirements
28
43
 
29
44
  OverrideValue = str | int | float | bool | None
@@ -45,6 +60,35 @@ def handle_error(ctx: click.Context, error: Exception) -> NoReturn:
45
60
  sys.exit(1)
46
61
 
47
62
 
63
+ def _get_checkpoint_topics_path(
64
+ checkpoint_dir: str,
65
+ output_save_as: str,
66
+ ) -> str | None:
67
+ """
68
+ Read checkpoint metadata to get the topics path used in the original run.
69
+
70
+ Args:
71
+ checkpoint_dir: Directory containing checkpoint files
72
+ output_save_as: Output file path (used to derive checkpoint file names)
73
+
74
+ Returns:
75
+ Topics file path from checkpoint metadata, or None if not available
76
+ """
77
+ checkpoint_path = Path(checkpoint_dir)
78
+ output_stem = Path(output_save_as).stem
79
+ metadata_path = checkpoint_path / f"{output_stem}{CHECKPOINT_METADATA_SUFFIX}"
80
+
81
+ if not metadata_path.exists():
82
+ return None
83
+
84
+ try:
85
+ with open(metadata_path, encoding="utf-8") as f:
86
+ metadata = json.load(f)
87
+ return metadata.get("topics_file") or metadata.get("topics_save_as")
88
+ except (OSError, json.JSONDecodeError):
89
+ return None
90
+
91
+
48
92
  @click.group()
49
93
  @click.version_option()
50
94
  @click.option(
@@ -89,7 +133,7 @@ class GenerateOptions(BaseModel):
89
133
  temperature: float | None = None
90
134
  degree: int | None = None
91
135
  depth: int | None = None
92
- num_samples: int | None = None
136
+ num_samples: int | str | None = None
93
137
  batch_size: int | None = None
94
138
  base_url: str | None = None
95
139
  include_system_message: bool | None = None
@@ -101,16 +145,19 @@ class GenerateOptions(BaseModel):
101
145
  # Modular conversation configuration
102
146
  conversation_type: Literal["basic", "cot"] | None = None
103
147
  reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = None
104
- agent_mode: Literal["single_turn", "multi_turn"] | None = None
105
-
106
- # Multi-turn configuration
107
- min_turns: int | None = None
108
- max_turns: int | None = None
109
- min_tool_calls: int | None = None
148
+ agent_mode: Literal["single_turn", "multi_turn"] | None = (
149
+ None # Deprecated, kept for backward compat
150
+ )
110
151
 
111
152
  # Cloud upload (experimental)
112
153
  cloud_upload: Literal["all", "dataset", "graph", "none"] | None = None
113
154
 
155
+ # Checkpointing options
156
+ checkpoint_interval: int | None = None
157
+ checkpoint_path: str | None = None
158
+ resume: bool = False
159
+ retry_failed: bool = False
160
+
114
161
  @model_validator(mode="after")
115
162
  def validate_mode_constraints(self) -> "GenerateOptions":
116
163
  if self.topic_only and self.topics_load:
@@ -126,7 +173,7 @@ class GenerationPreparation(BaseModel):
126
173
  config: DeepFabricConfig
127
174
  topics_overrides: OverrideMap = Field(default_factory=dict)
128
175
  generation_overrides: OverrideMap = Field(default_factory=dict)
129
- num_samples: int
176
+ num_samples: int | str # Can be int, "auto", or percentage like "50%"
130
177
  batch_size: int
131
178
  depth: int
132
179
  degree: int
@@ -134,7 +181,8 @@ class GenerationPreparation(BaseModel):
134
181
 
135
182
  @model_validator(mode="after")
136
183
  def validate_positive_dimensions(self) -> "GenerationPreparation":
137
- if self.num_samples <= 0:
184
+ # Skip num_samples validation for dynamic values (auto or percentage)
185
+ if isinstance(self.num_samples, int) and self.num_samples <= 0:
138
186
  raise ValueError("num_samples must be greater than zero")
139
187
  if self.batch_size <= 0:
140
188
  raise ValueError("batch_size must be greater than zero")
@@ -222,6 +270,8 @@ def _validate_api_keys(
222
270
 
223
271
  def _load_and_prepare_generation_context(
224
272
  options: GenerateOptions,
273
+ *,
274
+ skip_path_validation: bool = False,
225
275
  ) -> GenerationPreparation:
226
276
  """Load configuration, compute overrides, and validate derived parameters."""
227
277
  tui = get_tui()
@@ -277,23 +327,25 @@ def _load_and_prepare_generation_context(
277
327
 
278
328
  loading_existing = bool(options.topics_load)
279
329
 
280
- validate_path_requirements(
281
- mode=options.mode,
282
- depth=final_depth,
283
- degree=final_degree,
284
- num_steps=final_num_samples,
285
- batch_size=final_batch_size,
286
- loading_existing=loading_existing,
287
- )
330
+ # Skip path validation for topic-only mode since we're not generating dataset samples
331
+ if not skip_path_validation:
332
+ validate_path_requirements(
333
+ mode=options.mode,
334
+ depth=final_depth,
335
+ degree=final_degree,
336
+ num_samples=final_num_samples,
337
+ batch_size=final_batch_size,
338
+ loading_existing=loading_existing,
339
+ )
288
340
 
289
- show_validation_success(
290
- mode=options.mode,
291
- depth=final_depth,
292
- degree=final_degree,
293
- num_steps=final_num_samples,
294
- batch_size=final_batch_size,
295
- loading_existing=loading_existing,
296
- )
341
+ show_validation_success(
342
+ mode=options.mode,
343
+ depth=final_depth,
344
+ degree=final_degree,
345
+ num_samples=final_num_samples,
346
+ batch_size=final_batch_size,
347
+ loading_existing=loading_existing,
348
+ )
297
349
 
298
350
  try:
299
351
  return GenerationPreparation(
@@ -372,29 +424,145 @@ def _run_generation(
372
424
  preparation: GenerationPreparation,
373
425
  topic_model: TopicModel,
374
426
  options: GenerateOptions,
427
+ checkpoint_dir: str,
375
428
  ) -> None:
376
429
  """Create the dataset using the prepared configuration and topic model."""
430
+ tui = get_tui()
377
431
 
378
- generation_params = preparation.config.get_generation_params(**preparation.generation_overrides)
379
- engine = DataSetGenerator(**generation_params)
432
+ # Apply CLI checkpoint overrides
433
+ checkpoint_overrides = {}
434
+ if options.checkpoint_interval is not None:
435
+ checkpoint_overrides["checkpoint_interval"] = options.checkpoint_interval
436
+ if options.checkpoint_path is not None:
437
+ checkpoint_overrides["checkpoint_path"] = options.checkpoint_path
438
+ if options.retry_failed:
439
+ checkpoint_overrides["checkpoint_retry_failed"] = options.retry_failed
440
+
441
+ generation_params = preparation.config.get_generation_params(
442
+ **preparation.generation_overrides, **checkpoint_overrides
443
+ )
380
444
 
381
- dataset = create_dataset(
382
- engine=engine,
383
- topic_model=topic_model,
384
- config=preparation.config,
385
- num_samples=preparation.num_samples,
386
- batch_size=preparation.batch_size,
387
- include_system_message=options.include_system_message,
388
- provider=options.provider,
389
- model=options.model,
390
- generation_overrides=preparation.generation_overrides,
391
- debug=options.debug,
445
+ # Use provided checkpoint_dir if not explicitly set via CLI
446
+ if generation_params.get("checkpoint_path") is None:
447
+ generation_params["checkpoint_path"] = checkpoint_dir
448
+
449
+ # Resolve and pass topics_file for checkpoint metadata
450
+ # Prioritize: loaded file > save path > config > default
451
+ # Store absolute path for reliable resume from any working directory
452
+ topics_mode = preparation.config.topics.mode
453
+ default_topics_path = "topic_graph.json" if topics_mode == "graph" else "topic_tree.jsonl"
454
+ resolved_topics_path = (
455
+ options.topics_load
456
+ or options.topics_save_as
457
+ or preparation.config.topics.save_as
458
+ or default_topics_path
392
459
  )
460
+ generation_params["topics_file"] = str(Path(resolved_topics_path).resolve())
461
+
462
+ engine = DataSetGenerator(**generation_params)
463
+
464
+ # Check for existing checkpoint when not resuming
465
+ if not options.resume and engine.has_checkpoint():
466
+ tui.warning("Existing checkpoint found for this configuration")
467
+ tui.console.print()
468
+ tui.console.print(" [cyan]1)[/cyan] Resume from checkpoint")
469
+ tui.console.print(" [cyan]2)[/cyan] Clear checkpoint and start fresh")
470
+ tui.console.print(" [cyan]3)[/cyan] Abort")
471
+ tui.console.print()
472
+
473
+ choice = click.prompt(
474
+ "Choose an option",
475
+ type=click.Choice(["1", "2", "3"]),
476
+ default="1",
477
+ )
478
+
479
+ if choice == "1":
480
+ # User wants to resume
481
+ options.resume = True
482
+ elif choice == "2":
483
+ # Clear and start fresh
484
+ engine.clear_checkpoint()
485
+ tui.info("Checkpoint cleared, starting fresh generation")
486
+ else:
487
+ # Abort
488
+ tui.info("Aborted")
489
+ sys.exit(0)
490
+
491
+ # Handle resume from checkpoint
492
+ if options.resume:
493
+ if engine.load_checkpoint(retry_failed=options.retry_failed):
494
+ samples_done = engine._flushed_samples_count
495
+ failures_done = engine._flushed_failures_count
496
+ ids_processed = len(engine._processed_ids)
497
+ retry_msg = " (retrying failed samples)" if options.retry_failed else ""
498
+
499
+ # Update TUI status panel with checkpoint progress
500
+ get_dataset_tui().set_checkpoint_resume_status(samples_done, failures_done)
501
+
502
+ # Log resume info including failures
503
+ if failures_done > 0:
504
+ tui.info(
505
+ f"Resuming from checkpoint: {samples_done} samples, "
506
+ f"{failures_done} failed, {ids_processed} IDs processed{retry_msg}"
507
+ )
508
+ else:
509
+ tui.info(
510
+ f"Resuming from checkpoint: {samples_done} samples, "
511
+ f"{ids_processed} IDs processed{retry_msg}"
512
+ )
513
+ else:
514
+ tui.info("No checkpoint found, starting fresh generation")
515
+
516
+ # Set up graceful Ctrl+C handling for checkpoint-based stop
517
+ interrupt_count = 0
518
+
519
+ def handle_sigint(_signum, _frame):
520
+ nonlocal interrupt_count
521
+ interrupt_count += 1
522
+
523
+ if interrupt_count == 1:
524
+ engine.stop_requested = True
525
+ tui.warning("Stopping after current checkpoint... (Ctrl+C again to force quit)")
526
+ dataset_tui = get_dataset_tui()
527
+ dataset_tui.log_event("⚠ Graceful stop requested")
528
+ dataset_tui.status_stop_requested()
529
+ else:
530
+ tui.error("Force quit!")
531
+ sys.exit(1)
532
+
533
+ original_handler = signal.signal(signal.SIGINT, handle_sigint)
534
+ try:
535
+ dataset = create_dataset(
536
+ engine=engine,
537
+ topic_model=topic_model,
538
+ config=preparation.config,
539
+ num_samples=preparation.num_samples,
540
+ batch_size=preparation.batch_size,
541
+ include_system_message=options.include_system_message,
542
+ provider=options.provider,
543
+ model=options.model,
544
+ generation_overrides=preparation.generation_overrides,
545
+ debug=options.debug,
546
+ )
547
+ finally:
548
+ signal.signal(signal.SIGINT, original_handler)
549
+
550
+ # If gracefully stopped, don't save partial dataset or clean up checkpoints
551
+ if engine.stop_requested:
552
+ return
393
553
 
394
554
  output_config = preparation.config.get_output_config()
395
555
  output_save_path = options.output_save_as or output_config["save_as"]
396
556
  save_dataset(dataset, output_save_path, preparation.config, engine=engine)
397
557
 
558
+ # Clean up checkpoint files after successful completion
559
+ if generation_params.get("checkpoint_interval") is not None:
560
+ try:
561
+ engine.clear_checkpoint()
562
+ tui.info("Checkpoint files cleaned up after successful generation")
563
+ except OSError as e:
564
+ tui.warning(f"Failed to clean up checkpoint files: {e}")
565
+
398
566
  trace(
399
567
  "dataset_generated",
400
568
  {"samples": len(dataset)},
@@ -429,7 +597,11 @@ def _run_generation(
429
597
  @click.option("--temperature", type=float, help="Temperature setting")
430
598
  @click.option("--degree", type=int, help="Degree (branching factor)")
431
599
  @click.option("--depth", type=int, help="Depth setting")
432
- @click.option("--num-samples", type=int, help="Number of samples to generate")
600
+ @click.option(
601
+ "--num-samples",
602
+ type=str,
603
+ help="Number of samples: integer, 'auto' (all topics), or percentage like '50%'",
604
+ )
433
605
  @click.option("--batch-size", type=int, help="Batch size")
434
606
  @click.option("--base-url", help="Base URL for LLM provider API endpoint")
435
607
  @click.option(
@@ -473,29 +645,34 @@ def _run_generation(
473
645
  @click.option(
474
646
  "--agent-mode",
475
647
  type=click.Choice(["single_turn", "multi_turn"]),
476
- help="Agent mode: single_turn (one-shot tool use), multi_turn (extended conversations). Requires tools.",
648
+ help="[Deprecated] Agent mode is now implicit when tools are configured. 'multi_turn' is no longer supported.",
477
649
  )
478
650
  @click.option(
479
- "--min-turns",
480
- type=int,
481
- help="Minimum conversation turns for multi_turn agent mode",
651
+ "--cloud-upload",
652
+ type=click.Choice(["all", "dataset", "graph", "none"], case_sensitive=False),
653
+ default=None,
654
+ help="Upload to DeepFabric Cloud (experimental): all, dataset, graph, or none. "
655
+ "Enables headless mode for CI. Requires DEEPFABRIC_API_KEY or prior auth.",
482
656
  )
483
657
  @click.option(
484
- "--max-turns",
658
+ "--checkpoint-interval",
485
659
  type=int,
486
- help="Maximum conversation turns for multi_turn agent mode",
660
+ help="Save checkpoint every N samples. Enables resumable generation.",
487
661
  )
488
662
  @click.option(
489
- "--min-tool-calls",
490
- type=int,
491
- help="Minimum tool calls before allowing conversation conclusion",
663
+ "--checkpoint-path",
664
+ type=click.Path(),
665
+ help="Override checkpoint directory (default: XDG data dir)",
492
666
  )
493
667
  @click.option(
494
- "--cloud-upload",
495
- type=click.Choice(["all", "dataset", "graph", "none"], case_sensitive=False),
496
- default=None,
497
- help="Upload to DeepFabric Cloud (experimental): all, dataset, graph, or none. "
498
- "Enables headless mode for CI. Requires DEEPFABRIC_API_KEY or prior auth.",
668
+ "--resume",
669
+ is_flag=True,
670
+ help="Resume from existing checkpoint if available",
671
+ )
672
+ @click.option(
673
+ "--retry-failed",
674
+ is_flag=True,
675
+ help="When resuming, retry previously failed samples",
499
676
  )
500
677
  def generate( # noqa: PLR0913
501
678
  config_file: str | None,
@@ -511,7 +688,7 @@ def generate( # noqa: PLR0913
511
688
  temperature: float | None = None,
512
689
  degree: int | None = None,
513
690
  depth: int | None = None,
514
- num_samples: int | None = None,
691
+ num_samples: str | None = None,
515
692
  batch_size: int | None = None,
516
693
  base_url: str | None = None,
517
694
  include_system_message: bool | None = None,
@@ -521,13 +698,28 @@ def generate( # noqa: PLR0913
521
698
  conversation_type: Literal["basic", "cot"] | None = None,
522
699
  reasoning_style: Literal["freetext", "agent"] | None = None,
523
700
  agent_mode: Literal["single_turn", "multi_turn"] | None = None,
524
- min_turns: int | None = None,
525
- max_turns: int | None = None,
526
- min_tool_calls: int | None = None,
527
701
  cloud_upload: Literal["all", "dataset", "graph", "none"] | None = None,
528
702
  tui: Literal["rich", "simple"] = "rich",
703
+ checkpoint_interval: int | None = None,
704
+ checkpoint_path: str | None = None,
705
+ resume: bool = False,
706
+ retry_failed: bool = False,
529
707
  ) -> None:
530
708
  """Generate training data from a YAML configuration file or CLI parameters."""
709
+ # Handle deprecated --agent-mode flag
710
+ if agent_mode == "multi_turn":
711
+ click.echo(
712
+ "Error: --agent-mode multi_turn is deprecated and no longer supported. "
713
+ "Omit --agent-mode and the default supported agent mode will be used.",
714
+ err=True,
715
+ )
716
+ sys.exit(1)
717
+ elif agent_mode == "single_turn":
718
+ click.echo(
719
+ "Note: --agent-mode single_turn is deprecated. "
720
+ "Single-turn agent mode is now implicit when tools are configured."
721
+ )
722
+
531
723
  set_trace_debug(debug)
532
724
  trace(
533
725
  "cli_generate",
@@ -540,6 +732,9 @@ def generate( # noqa: PLR0913
540
732
  )
541
733
 
542
734
  try:
735
+ # Parse num_samples from CLI string (could be int, "auto", or "50%")
736
+ parsed_num_samples = parse_num_samples(num_samples)
737
+
543
738
  options = GenerateOptions(
544
739
  config_file=config_file,
545
740
  output_system_prompt=output_system_prompt,
@@ -554,7 +749,7 @@ def generate( # noqa: PLR0913
554
749
  temperature=temperature,
555
750
  degree=degree,
556
751
  depth=depth,
557
- num_samples=num_samples,
752
+ num_samples=parsed_num_samples,
558
753
  batch_size=batch_size,
559
754
  base_url=base_url,
560
755
  include_system_message=include_system_message,
@@ -564,13 +759,14 @@ def generate( # noqa: PLR0913
564
759
  conversation_type=conversation_type,
565
760
  reasoning_style=reasoning_style,
566
761
  agent_mode=agent_mode,
567
- min_turns=min_turns,
568
- max_turns=max_turns,
569
- min_tool_calls=min_tool_calls,
570
762
  cloud_upload=cloud_upload,
571
763
  tui=tui,
764
+ checkpoint_interval=checkpoint_interval,
765
+ checkpoint_path=checkpoint_path,
766
+ resume=resume,
767
+ retry_failed=retry_failed,
572
768
  )
573
- except PydanticValidationError as error:
769
+ except (PydanticValidationError, ValueError) as error:
574
770
  handle_error(click.get_current_context(), ConfigurationError(str(error)))
575
771
  return
576
772
 
@@ -583,7 +779,27 @@ def generate( # noqa: PLR0913
583
779
  tui.info("Initializing DeepFabric...") # type: ignore
584
780
  print()
585
781
 
586
- preparation = _load_and_prepare_generation_context(options)
782
+ preparation = _load_and_prepare_generation_context(options, skip_path_validation=topic_only)
783
+
784
+ # Compute checkpoint directory once for consistent use throughout generation
785
+ # Use config file for hash, fallback to output path for config-less runs
786
+ path_source = options.config_file or options.output_save_as or preparation.config.output.save_as
787
+ checkpoint_dir = options.checkpoint_path or get_checkpoint_dir(path_source)
788
+
789
+ # Auto-infer topics-load when resuming from checkpoint
790
+ if options.resume and not options.topics_load:
791
+ output_path = options.output_save_as or preparation.config.output.save_as
792
+
793
+ inferred_topics_path = _get_checkpoint_topics_path(checkpoint_dir, output_path)
794
+ if inferred_topics_path:
795
+ if Path(inferred_topics_path).exists():
796
+ tui.info(f"Resume: auto-loading topics from {inferred_topics_path}")
797
+ options.topics_load = inferred_topics_path
798
+ else:
799
+ tui.warning(
800
+ f"Checkpoint references topics at {inferred_topics_path} but file not found. "
801
+ "Topic graph will be regenerated."
802
+ )
587
803
 
588
804
  topic_model = _initialize_topic_model(
589
805
  preparation=preparation,
@@ -603,6 +819,7 @@ def generate( # noqa: PLR0913
603
819
  preparation=preparation,
604
820
  topic_model=topic_model,
605
821
  options=options,
822
+ checkpoint_dir=checkpoint_dir,
606
823
  )
607
824
 
608
825
  except ConfigurationError as e:
@@ -1024,7 +1241,12 @@ def visualize(graph_file: str, output: str) -> None:
1024
1241
 
1025
1242
  @cli.command()
1026
1243
  @click.argument("config_file", type=click.Path(exists=True))
1027
- def validate(config_file: str) -> None: # noqa: PLR0912
1244
+ @click.option(
1245
+ "--check-api/--no-check-api",
1246
+ default=True,
1247
+ help="Validate API keys by making test calls (default: enabled)",
1248
+ )
1249
+ def validate(config_file: str, check_api: bool) -> None: # noqa: PLR0912
1028
1250
  """Validate a DeepFabric configuration file."""
1029
1251
  try:
1030
1252
  # Try to load the configuration
@@ -1053,24 +1275,43 @@ def validate(config_file: str) -> None: # noqa: PLR0912
1053
1275
  for error in errors:
1054
1276
  tui.console.print(f" - {error}", style="red")
1055
1277
  sys.exit(1)
1056
- else:
1057
- tui.success("Configuration is valid")
1058
1278
 
1059
1279
  if warnings:
1060
- tui.console.print("\nWarnings:", style="yellow bold")
1280
+ tui.console.print("Warnings:", style="yellow bold")
1061
1281
  for warning in warnings:
1062
1282
  tui.warning(warning)
1283
+ tui.console.print()
1063
1284
 
1064
1285
  # Print configuration summary
1065
- tui.console.print("\nConfiguration Summary:", style="cyan bold")
1066
- tui.info(
1067
- f"Topics: mode={config.topics.mode}, depth={config.topics.depth}, degree={config.topics.degree}"
1068
- )
1286
+ tui.console.print("Configuration Summary:", style="cyan bold")
1069
1287
 
1288
+ # Topics summary with estimated paths
1289
+ depth = config.topics.depth
1290
+ degree = config.topics.degree
1291
+ # Estimated paths = degree^depth (each level branches by degree)
1292
+ estimated_paths = degree**depth
1070
1293
  tui.info(
1071
- f"Output: num_samples={config.output.num_samples}, batch_size={config.output.batch_size}"
1294
+ f"Topics: mode={config.topics.mode}, depth={depth}, degree={degree}, "
1295
+ f"estimated_paths={estimated_paths} ({degree}^{depth})"
1072
1296
  )
1073
1297
 
1298
+ # Output summary with step size and checkpoint info
1299
+ num_samples = config.output.num_samples
1300
+ batch_size = config.output.batch_size
1301
+ # Calculate num_steps - handle 'auto' and percentage strings
1302
+ if isinstance(num_samples, int):
1303
+ num_steps = math.ceil(num_samples / batch_size)
1304
+ output_info = f"Output: num_samples={num_samples}, batch_size={batch_size}, num_steps={num_steps}"
1305
+ else:
1306
+ # For 'auto' or percentage, we can't compute steps without topic count
1307
+ output_info = f"Output: num_samples={num_samples}, batch_size={batch_size}"
1308
+
1309
+ # Add checkpoint info if enabled
1310
+ if config.output.checkpoint:
1311
+ checkpoint = config.output.checkpoint
1312
+ output_info += f", checkpoint_interval={checkpoint.interval}"
1313
+ tui.info(output_info)
1314
+
1074
1315
  if config.huggingface:
1075
1316
  hf_config = config.get_huggingface_config()
1076
1317
  tui.info(f"Hugging Face: repo={hf_config.get('repository', 'not set')}")
@@ -1079,6 +1320,58 @@ def validate(config_file: str) -> None: # noqa: PLR0912
1079
1320
  kaggle_config = config.get_kaggle_config()
1080
1321
  tui.info(f"Kaggle: handle={kaggle_config.get('handle', 'not set')}")
1081
1322
 
1323
+ # Check path writability
1324
+ tui.console.print("\nPath Writability:", style="cyan bold")
1325
+ path_errors = []
1326
+
1327
+ # Check topics.save_as if configured
1328
+ if config.topics.save_as:
1329
+ is_writable, error_msg = check_path_writable(config.topics.save_as, "topics.save_as")
1330
+ if is_writable:
1331
+ tui.success(f"topics.save_as: {config.topics.save_as}")
1332
+ else:
1333
+ path_errors.append(error_msg)
1334
+ tui.error(f"topics.save_as: {error_msg}")
1335
+
1336
+ # Check output.save_as
1337
+ if config.output.save_as:
1338
+ is_writable, error_msg = check_path_writable(config.output.save_as, "output.save_as")
1339
+ if is_writable:
1340
+ tui.success(f"output.save_as: {config.output.save_as}")
1341
+ else:
1342
+ path_errors.append(error_msg)
1343
+ tui.error(f"output.save_as: {error_msg}")
1344
+
1345
+ # Check checkpoint directory if enabled
1346
+ if config.output.checkpoint:
1347
+ checkpoint_path = config.output.checkpoint.path or get_checkpoint_dir(config_file)
1348
+ is_writable, error_msg = check_dir_writable(checkpoint_path, "checkpoint directory")
1349
+ if is_writable:
1350
+ tui.success(f"checkpoint directory: {checkpoint_path}")
1351
+ else:
1352
+ path_errors.append(error_msg)
1353
+ tui.error(f"checkpoint directory: {error_msg}")
1354
+
1355
+ if path_errors:
1356
+ tui.console.print()
1357
+ tui.error("Some paths are not writable. Fix permissions or choose different paths.")
1358
+ sys.exit(1)
1359
+
1360
+ # Validate API keys if requested
1361
+ if check_api:
1362
+ tui.console.print("\nAPI Keys:", style="cyan bold")
1363
+ try:
1364
+ _validate_api_keys(config)
1365
+ except ConfigurationError as e:
1366
+ tui.error(str(e))
1367
+ sys.exit(1)
1368
+ else:
1369
+ tui.console.print("\nSkipping API key validation (use --check-api to enable)")
1370
+
1371
+ # Final success message
1372
+ tui.console.print()
1373
+ tui.success("Configuration is valid")
1374
+
1082
1375
  except FileNotFoundError:
1083
1376
  handle_error(
1084
1377
  click.get_current_context(),
@@ -1563,5 +1856,143 @@ def import_tools(
1563
1856
  sys.exit(1)
1564
1857
 
1565
1858
 
1859
+ @cli.command("checkpoint-status")
1860
+ @click.argument("config_file", type=click.Path(exists=True))
1861
+ def checkpoint_status(config_file: str) -> None:
1862
+ """Show checkpoint status for a generation config.
1863
+
1864
+ Displays the current state of any checkpoint files associated with
1865
+ the given configuration file, including progress, failures, and
1866
+ resume instructions.
1867
+ """
1868
+ tui = get_tui()
1869
+
1870
+ try:
1871
+ config = DeepFabricConfig.from_yaml(config_file)
1872
+ except Exception as e:
1873
+ tui.error(f"Failed to load config: {e}")
1874
+ sys.exit(1)
1875
+
1876
+ # Get checkpoint configuration
1877
+ checkpoint_config = config.get_checkpoint_config()
1878
+ output_config = config.get_output_config()
1879
+ checkpoint_dir = checkpoint_config.get("path") or get_checkpoint_dir(config_file)
1880
+ save_as = output_config.get("save_as")
1881
+
1882
+ if not save_as:
1883
+ tui.error("Config does not specify output.save_as - cannot determine checkpoint paths")
1884
+ sys.exit(1)
1885
+
1886
+ # Derive checkpoint paths
1887
+ output_stem = Path(save_as).stem
1888
+ checkpoint_dir_path = Path(checkpoint_dir)
1889
+ metadata_path = checkpoint_dir_path / f"{output_stem}{CHECKPOINT_METADATA_SUFFIX}"
1890
+ samples_path = checkpoint_dir_path / f"{output_stem}{CHECKPOINT_SAMPLES_SUFFIX}"
1891
+ failures_path = checkpoint_dir_path / f"{output_stem}{CHECKPOINT_FAILURES_SUFFIX}"
1892
+
1893
+ # Check if checkpoint exists
1894
+ if not metadata_path.exists():
1895
+ tui.info(f"No checkpoint found at: {metadata_path}")
1896
+ tui.info("\nTo enable checkpointing, run:")
1897
+ tui.info(f" deepfabric generate {config_file} --checkpoint-interval 10")
1898
+ return
1899
+
1900
+ # Load and display checkpoint metadata
1901
+ try:
1902
+ with open(metadata_path) as f:
1903
+ metadata = json.load(f)
1904
+ except Exception as e:
1905
+ tui.error(f"Failed to read checkpoint metadata: {e}")
1906
+ sys.exit(1)
1907
+
1908
+ # Count samples in checkpoint file
1909
+ checkpoint_sample_count = 0
1910
+ if samples_path.exists():
1911
+ with open(samples_path) as f:
1912
+ checkpoint_sample_count = sum(1 for line in f if line.strip())
1913
+
1914
+ # Count failures
1915
+ checkpoint_failures = 0
1916
+ failure_details = []
1917
+ if failures_path.exists():
1918
+ with open(failures_path) as f:
1919
+ for raw_line in f:
1920
+ stripped = raw_line.strip()
1921
+ if stripped:
1922
+ checkpoint_failures += 1
1923
+ try:
1924
+ failure = json.loads(stripped)
1925
+ failure_details.append(failure)
1926
+ except json.JSONDecodeError:
1927
+ pass
1928
+
1929
+ # Get target samples from config
1930
+ # num_samples is the total target, not per-batch. It can be int, "auto", or percentage like "50%"
1931
+ target_samples = output_config.get("num_samples", 0)
1932
+ # "auto" or percentage strings can't be resolved without topic model
1933
+ total_target = 0 if isinstance(target_samples, str) else (target_samples or 0)
1934
+
1935
+ # Display status
1936
+ tui.console.print()
1937
+ tui.console.print(f"[bold]Checkpoint Status:[/bold] {metadata_path}")
1938
+ tui.console.print()
1939
+
1940
+ # Progress
1941
+ progress_pct = (checkpoint_sample_count / total_target * 100) if total_target > 0 else 0
1942
+ tui.console.print(
1943
+ f" [cyan]Progress:[/cyan] {checkpoint_sample_count}/{total_target} samples ({progress_pct:.1f}%)"
1944
+ )
1945
+ tui.console.print(f" [cyan]Failed:[/cyan] {checkpoint_failures} samples")
1946
+
1947
+ # Paths processed
1948
+ processed_ids = metadata.get("processed_ids", [])
1949
+ tui.console.print(f" [cyan]Paths done:[/cyan] {len(processed_ids)}")
1950
+
1951
+ # Config info
1952
+ tui.console.print()
1953
+ tui.console.print(f" [dim]Provider:[/dim] {metadata.get('provider', 'unknown')}")
1954
+ tui.console.print(f" [dim]Model:[/dim] {metadata.get('model_name', 'unknown')}")
1955
+ tui.console.print(f" [dim]Conv type:[/dim] {metadata.get('conversation_type', 'unknown')}")
1956
+ if metadata.get("reasoning_style"):
1957
+ tui.console.print(f" [dim]Reasoning:[/dim] {metadata.get('reasoning_style')}")
1958
+ tui.console.print(f" [dim]Last saved:[/dim] {metadata.get('created_at', 'unknown')}")
1959
+
1960
+ # Show topics file path if available
1961
+ topics_path = metadata.get("topics_file") or metadata.get("topics_save_as")
1962
+ if topics_path:
1963
+ topics_exists = Path(topics_path).exists()
1964
+ status = "[green]exists[/green]" if topics_exists else "[red]missing[/red]"
1965
+ tui.console.print(f" [dim]Topics file:[/dim] {topics_path} ({status})")
1966
+
1967
+ # Show failed topics if any
1968
+ max_failures_to_show = 5
1969
+ max_error_length = 60
1970
+ if failure_details:
1971
+ tui.console.print()
1972
+ tui.console.print("[yellow]Failed Topics:[/yellow]")
1973
+ for failure in failure_details[:max_failures_to_show]:
1974
+ error_msg = failure.get("error", "Unknown error")
1975
+ # Truncate long error messages
1976
+ if len(error_msg) > max_error_length:
1977
+ error_msg = error_msg[: max_error_length - 3] + "..."
1978
+ tui.console.print(f" - {error_msg}")
1979
+ if len(failure_details) > max_failures_to_show:
1980
+ remaining = len(failure_details) - max_failures_to_show
1981
+ tui.console.print(f" ... and {remaining} more failures")
1982
+
1983
+ # Resume instructions
1984
+ tui.console.print()
1985
+ checkpoint_interval_arg = metadata.get("checkpoint_interval", 10)
1986
+ tui.console.print("[green]Resume with:[/green]")
1987
+ tui.console.print(
1988
+ f" deepfabric generate {config_file} --checkpoint-interval {checkpoint_interval_arg} --resume"
1989
+ )
1990
+ if metadata.get("total_failures", 0) > 0:
1991
+ tui.console.print("[green]Retry failed:[/green]")
1992
+ tui.console.print(
1993
+ f" deepfabric generate {config_file} --checkpoint-interval {checkpoint_interval_arg} --resume --retry-failed"
1994
+ )
1995
+
1996
+
1566
1997
  if __name__ == "__main__":
1567
1998
  cli()