DeepFabric 4.9.0__py3-none-any.whl → 4.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/builders.py +7 -21
- deepfabric/builders_agent.py +0 -542
- deepfabric/cli.py +505 -74
- deepfabric/config.py +57 -73
- deepfabric/config_manager.py +8 -6
- deepfabric/constants.py +6 -0
- deepfabric/dataset_manager.py +107 -11
- deepfabric/evaluation/parser.py +7 -7
- deepfabric/generator.py +656 -103
- deepfabric/graph.py +46 -1
- deepfabric/prompts.py +0 -39
- deepfabric/schemas.py +4 -3
- deepfabric/topic_model.py +32 -0
- deepfabric/tree.py +23 -1
- deepfabric/tui.py +66 -21
- deepfabric/utils.py +184 -0
- deepfabric/validation.py +47 -77
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.1.dist-info}/METADATA +5 -6
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.1.dist-info}/RECORD +22 -22
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.1.dist-info}/WHEEL +0 -0
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.1.dist-info}/entry_points.txt +0 -0
- {deepfabric-4.9.0.dist-info → deepfabric-4.10.1.dist-info}/licenses/LICENSE +0 -0
deepfabric/graph.py
CHANGED
|
@@ -26,7 +26,7 @@ from .prompts import (
|
|
|
26
26
|
)
|
|
27
27
|
from .schemas import GraphSubtopics
|
|
28
28
|
from .stream_simulator import simulate_stream
|
|
29
|
-
from .topic_model import TopicModel
|
|
29
|
+
from .topic_model import TopicModel, TopicPath
|
|
30
30
|
|
|
31
31
|
if TYPE_CHECKING: # only for type hints to avoid runtime cycles
|
|
32
32
|
from .progress import ProgressReporter
|
|
@@ -231,6 +231,9 @@ class Graph(TopicModel):
|
|
|
231
231
|
|
|
232
232
|
def save(self, save_path: str) -> None:
|
|
233
233
|
"""Save the topic graph to a file."""
|
|
234
|
+
from pathlib import Path # noqa: PLC0415
|
|
235
|
+
|
|
236
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
234
237
|
with open(save_path, "w") as f:
|
|
235
238
|
f.write(self.to_json())
|
|
236
239
|
|
|
@@ -570,6 +573,48 @@ class Graph(TopicModel):
|
|
|
570
573
|
self._dfs_paths(self.root, [self.root.topic], paths, visited)
|
|
571
574
|
return paths
|
|
572
575
|
|
|
576
|
+
def get_all_paths_with_ids(self) -> list[TopicPath]:
|
|
577
|
+
"""Returns all paths from root to leaf nodes with their leaf node UUIDs.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
List of TopicPath namedtuples containing (path, topic_id).
|
|
581
|
+
The topic_id is the UUID of the leaf node for each path.
|
|
582
|
+
"""
|
|
583
|
+
result: list[TopicPath] = []
|
|
584
|
+
visited: set[int] = set()
|
|
585
|
+
self._dfs_paths_with_ids(self.root, [self.root.topic], result, visited)
|
|
586
|
+
return result
|
|
587
|
+
|
|
588
|
+
def _dfs_paths_with_ids(
|
|
589
|
+
self,
|
|
590
|
+
node: Node,
|
|
591
|
+
current_path: list[str],
|
|
592
|
+
result: list[TopicPath],
|
|
593
|
+
visited: set[int],
|
|
594
|
+
) -> None:
|
|
595
|
+
"""Helper function for DFS traversal to find all paths with leaf node UUIDs.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
node: Current node being visited
|
|
599
|
+
current_path: Path from root to current node
|
|
600
|
+
result: Accumulated list of TopicPath namedtuples
|
|
601
|
+
visited: Set of node IDs already visited in current path to prevent cycles
|
|
602
|
+
"""
|
|
603
|
+
if node.id in visited:
|
|
604
|
+
return
|
|
605
|
+
|
|
606
|
+
visited.add(node.id)
|
|
607
|
+
|
|
608
|
+
if not node.children:
|
|
609
|
+
# Leaf node - add path with this node's UUID
|
|
610
|
+
topic_id = node.metadata.get("uuid", str(node.id))
|
|
611
|
+
result.append(TopicPath(path=current_path, topic_id=topic_id))
|
|
612
|
+
|
|
613
|
+
for child in node.children:
|
|
614
|
+
self._dfs_paths_with_ids(child, current_path + [child.topic], result, visited)
|
|
615
|
+
|
|
616
|
+
visited.remove(node.id)
|
|
617
|
+
|
|
573
618
|
def _dfs_paths(
|
|
574
619
|
self, node: Node, current_path: list[str], paths: list[list[str]], visited: set[int]
|
|
575
620
|
) -> None:
|
deepfabric/prompts.py
CHANGED
|
@@ -165,35 +165,6 @@ ARGUMENT REQUIREMENTS:
|
|
|
165
165
|
|
|
166
166
|
Generate a complete agent reasoning example using structured output with tool_executions list."""
|
|
167
167
|
|
|
168
|
-
@staticmethod
|
|
169
|
-
def build_multi_turn_context_prompt(tool_registry, max_tools_per_query: int = 3) -> str:
|
|
170
|
-
"""Build context for multi-turn conversations.
|
|
171
|
-
|
|
172
|
-
Returns a template with {{{{instructions}}}} and {{{{subtopics}}}} placeholders
|
|
173
|
-
that will be filled in by build_prompt() with actual topic paths from the tree.
|
|
174
|
-
"""
|
|
175
|
-
tool_signatures = []
|
|
176
|
-
for tool in tool_registry.tools:
|
|
177
|
-
tool_signatures.append(f"- {tool.to_signature()}")
|
|
178
|
-
|
|
179
|
-
return f"""Generate a multi-turn agent conversation with evolving tool usage.
|
|
180
|
-
|
|
181
|
-
Available tools:
|
|
182
|
-
{chr(10).join(tool_signatures)}
|
|
183
|
-
|
|
184
|
-
You may use 1 to {max_tools_per_query} tools per query. Show tool dependencies and reasoning across conversation turns.
|
|
185
|
-
|
|
186
|
-
ARGUMENT REQUIREMENTS:
|
|
187
|
-
- All argument values must be concrete and realistic (e.g., owner="acme-corp", repo="web-app", issue_number=42)
|
|
188
|
-
- Never use template placeholders like {{{{owner}}}} or {{{{repo}}}}
|
|
189
|
-
- Never use null values - omit optional parameters entirely if not needed
|
|
190
|
-
- String fields must contain actual content, not empty strings
|
|
191
|
-
|
|
192
|
-
{{{{{{{{instructions}}}}}}}}
|
|
193
|
-
{{{{{{{{subtopics}}}}}}}}
|
|
194
|
-
|
|
195
|
-
Generate a complete multi-turn conversation using structured output with tool_executions list."""
|
|
196
|
-
|
|
197
168
|
|
|
198
169
|
# Simplified prompts that delegate to structured generation
|
|
199
170
|
AGENT_COT_TOOLS_PROMPT = """Generate an agent tool-calling training example using the available tool definitions.
|
|
@@ -224,16 +195,6 @@ Focus on teaching both the reasoning process AND multi-tool usage patterns.
|
|
|
224
195
|
{{{{examples}}}}
|
|
225
196
|
{{{{subtopics}}}}"""
|
|
226
197
|
|
|
227
|
-
AGENT_COT_MULTI_TURN_PROMPT = """Generate a multi-turn agent conversation with tool usage across turns.
|
|
228
|
-
|
|
229
|
-
Show how reasoning evolves: tool dependencies, progressive refinement, and result synthesis.
|
|
230
|
-
|
|
231
|
-
Create realistic tool chaining patterns and decision-making processes.
|
|
232
|
-
|
|
233
|
-
{{{{instructions}}}}
|
|
234
|
-
{{{{examples}}}}
|
|
235
|
-
{{{{subtopics}}}}"""
|
|
236
|
-
|
|
237
198
|
CONVERSATION_GENERATION_PROMPT = """Generate a training conversation for a language model with this system prompt:
|
|
238
199
|
|
|
239
200
|
<system_prompt>
|
deepfabric/schemas.py
CHANGED
|
@@ -842,10 +842,11 @@ class ToolContext(BaseModel):
|
|
|
842
842
|
|
|
843
843
|
|
|
844
844
|
class AgentContext(BaseModel):
|
|
845
|
-
"""Agent capability - present when
|
|
845
|
+
"""Agent capability - present when tools are configured for agent mode."""
|
|
846
846
|
|
|
847
|
-
mode: Literal["single_turn"
|
|
848
|
-
|
|
847
|
+
mode: Literal["single_turn"] = Field(
|
|
848
|
+
default="single_turn",
|
|
849
|
+
description="Agent interaction mode (single_turn is the only supported mode)",
|
|
849
850
|
)
|
|
850
851
|
planning_trace: str | None = Field(
|
|
851
852
|
default=None, description="Agent's planning and reasoning about tool usage strategy"
|
deepfabric/topic_model.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TopicPath(NamedTuple):
|
|
6
|
+
"""A topic path with its associated unique identifier."""
|
|
7
|
+
|
|
8
|
+
path: list[str]
|
|
9
|
+
topic_id: str
|
|
2
10
|
|
|
3
11
|
|
|
4
12
|
class TopicModel(ABC):
|
|
@@ -18,3 +26,27 @@ class TopicModel(ABC):
|
|
|
18
26
|
def get_all_paths(self) -> list[list[str]]:
|
|
19
27
|
"""Returns all the paths in the topic model."""
|
|
20
28
|
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def get_all_paths_with_ids(self) -> list[TopicPath]:
|
|
32
|
+
"""Returns all paths with their unique identifiers.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
List of TopicPath namedtuples containing (path, topic_id).
|
|
36
|
+
The topic_id is a stable identifier for the leaf node of each path.
|
|
37
|
+
"""
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def get_path_by_id(self, topic_id: str) -> list[str] | None:
|
|
41
|
+
"""Look up a path by its topic_id.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
topic_id: The unique identifier for a topic path.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The path list if found, None otherwise.
|
|
48
|
+
"""
|
|
49
|
+
for topic_path in self.get_all_paths_with_ids():
|
|
50
|
+
if topic_path.topic_id == topic_id:
|
|
51
|
+
return topic_path.path
|
|
52
|
+
return None
|
deepfabric/tree.py
CHANGED
|
@@ -21,7 +21,7 @@ from .metrics import trace
|
|
|
21
21
|
from .prompts import TreePromptBuilder
|
|
22
22
|
from .schemas import TopicList
|
|
23
23
|
from .stream_simulator import simulate_stream
|
|
24
|
-
from .topic_model import TopicModel
|
|
24
|
+
from .topic_model import TopicModel, TopicPath
|
|
25
25
|
|
|
26
26
|
warnings.filterwarnings("ignore", message=".*Pydantic serializer warnings:.*")
|
|
27
27
|
|
|
@@ -242,6 +242,25 @@ class Tree(TopicModel):
|
|
|
242
242
|
"""Returns all the paths in the topic model."""
|
|
243
243
|
return self.tree_paths
|
|
244
244
|
|
|
245
|
+
def get_all_paths_with_ids(self) -> list[TopicPath]:
|
|
246
|
+
"""Returns all paths with their unique identifiers.
|
|
247
|
+
|
|
248
|
+
For Tree, we generate stable IDs by hashing the path content.
|
|
249
|
+
This ensures the same path always gets the same ID across runs.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
List of TopicPath namedtuples containing (path, topic_id).
|
|
253
|
+
"""
|
|
254
|
+
import hashlib # noqa: PLC0415
|
|
255
|
+
|
|
256
|
+
result: list[TopicPath] = []
|
|
257
|
+
for path in self.tree_paths:
|
|
258
|
+
# Generate stable ID from path content
|
|
259
|
+
path_str = "::".join(path)
|
|
260
|
+
topic_id = hashlib.sha256(path_str.encode()).hexdigest()[:16]
|
|
261
|
+
result.append(TopicPath(path=path, topic_id=topic_id))
|
|
262
|
+
return result
|
|
263
|
+
|
|
245
264
|
async def get_subtopics(
|
|
246
265
|
self, system_prompt: str, node_path: list[str], num_subtopics: int
|
|
247
266
|
) -> list[str]:
|
|
@@ -385,6 +404,9 @@ class Tree(TopicModel):
|
|
|
385
404
|
|
|
386
405
|
def save(self, save_path: str) -> None:
|
|
387
406
|
"""Save the topic tree to a file."""
|
|
407
|
+
from pathlib import Path # noqa: PLC0415
|
|
408
|
+
|
|
409
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
388
410
|
with open(save_path, "w") as f:
|
|
389
411
|
for path in self.tree_paths:
|
|
390
412
|
f.write(json.dumps({"path": path}) + "\n")
|
deepfabric/tui.py
CHANGED
|
@@ -275,19 +275,19 @@ class DeepFabricTUI:
|
|
|
275
275
|
|
|
276
276
|
def success(self, message: str) -> None:
|
|
277
277
|
"""Display a success message."""
|
|
278
|
-
self.console.print(f" {message}", style="green")
|
|
278
|
+
self.console.print(f"✓ {message}", style="green")
|
|
279
279
|
|
|
280
280
|
def warning(self, message: str) -> None:
|
|
281
281
|
"""Display a warning message."""
|
|
282
|
-
self.console.print(f"
|
|
282
|
+
self.console.print(f"⚠ {message}", style="yellow")
|
|
283
283
|
|
|
284
284
|
def error(self, message: str) -> None:
|
|
285
285
|
"""Display an error message."""
|
|
286
|
-
self.console.print(f"
|
|
286
|
+
self.console.print(f"✗ {message}", style="red")
|
|
287
287
|
|
|
288
288
|
def info(self, message: str) -> None:
|
|
289
289
|
"""Display an info message."""
|
|
290
|
-
self.console.print(f" {message}", style="blue")
|
|
290
|
+
self.console.print(f"• {message}", style="blue")
|
|
291
291
|
|
|
292
292
|
|
|
293
293
|
class TreeBuildingTUI(TopicBuildingMixin, StreamObserver):
|
|
@@ -846,6 +846,13 @@ class DatasetGenerationTUI(StreamObserver):
|
|
|
846
846
|
self.status_samples_done = 0
|
|
847
847
|
self.status_failed_total = 0
|
|
848
848
|
self.status_step_started_at = 0.0
|
|
849
|
+
self.status_last_step_duration = 0.0
|
|
850
|
+
# Checkpoint tracking for status panel
|
|
851
|
+
self.checkpoint_enabled = False # Set to True when checkpointing is configured
|
|
852
|
+
self.checkpoint_count = 0
|
|
853
|
+
self.last_checkpoint_samples = 0
|
|
854
|
+
self._resumed_from_checkpoint = False # Set by set_checkpoint_resume_status()
|
|
855
|
+
self._stop_requested = False # Set when graceful stop requested via Ctrl+C
|
|
849
856
|
# Retry tracking for simple mode
|
|
850
857
|
self.step_retries: list[dict] = [] # Retries in current step
|
|
851
858
|
|
|
@@ -919,18 +926,8 @@ class DatasetGenerationTUI(StreamObserver):
|
|
|
919
926
|
type_map = {
|
|
920
927
|
"basic": "Basic Q&A",
|
|
921
928
|
"cot": "Chain of Thought",
|
|
922
|
-
"single_turn_agent": "Single-Turn Agent (Tool Calling)",
|
|
923
|
-
"multi_turn_agent": "Multi-Turn Agent (Tool Calling)",
|
|
924
929
|
}
|
|
925
930
|
self.current_sample_type = type_map.get(conv_type, conv_type)
|
|
926
|
-
elif "agent_mode" in metadata:
|
|
927
|
-
agent_mode = metadata["agent_mode"]
|
|
928
|
-
if agent_mode == "single_turn":
|
|
929
|
-
self.current_sample_type = "Single-Turn Agent (Tool Calling)"
|
|
930
|
-
elif agent_mode == "multi_turn":
|
|
931
|
-
self.current_sample_type = "Multi-Turn Agent (Tool Calling)"
|
|
932
|
-
else:
|
|
933
|
-
self.current_sample_type = f"Agent ({agent_mode})"
|
|
934
931
|
|
|
935
932
|
# Update current topic path if provided
|
|
936
933
|
topic_path = metadata.get("topic_path") if isinstance(metadata, dict) else None
|
|
@@ -1041,13 +1038,21 @@ class DatasetGenerationTUI(StreamObserver):
|
|
|
1041
1038
|
return
|
|
1042
1039
|
|
|
1043
1040
|
# --- Status Panel helpers ---
|
|
1044
|
-
def init_status(
|
|
1041
|
+
def init_status(
|
|
1042
|
+
self, total_steps: int, total_samples: int, checkpoint_enabled: bool = False
|
|
1043
|
+
) -> None:
|
|
1045
1044
|
self.status_total_steps = total_steps
|
|
1046
1045
|
self.status_total_samples = total_samples
|
|
1047
1046
|
self.status_current_step = 0
|
|
1048
|
-
|
|
1049
|
-
self
|
|
1047
|
+
# Preserve samples_done and failed_total if resuming from checkpoint
|
|
1048
|
+
if not getattr(self, "_resumed_from_checkpoint", False):
|
|
1049
|
+
self.status_samples_done = 0
|
|
1050
|
+
self.status_failed_total = 0
|
|
1051
|
+
self.checkpoint_count = 0
|
|
1052
|
+
self.last_checkpoint_samples = 0
|
|
1050
1053
|
self.status_step_started_at = 0.0
|
|
1054
|
+
self.status_last_step_duration = 0.0
|
|
1055
|
+
self.checkpoint_enabled = checkpoint_enabled
|
|
1051
1056
|
|
|
1052
1057
|
def status_step_start(self, step: int, total_steps: int | None = None) -> None:
|
|
1053
1058
|
self.status_current_step = step
|
|
@@ -1057,22 +1062,62 @@ class DatasetGenerationTUI(StreamObserver):
|
|
|
1057
1062
|
self.update_status_panel()
|
|
1058
1063
|
|
|
1059
1064
|
def status_step_complete(self, samples_generated: int, failed_in_step: int = 0) -> None:
|
|
1065
|
+
# Calculate step duration before updating counters
|
|
1066
|
+
if self.status_step_started_at:
|
|
1067
|
+
self.status_last_step_duration = max(0.0, monotonic() - self.status_step_started_at)
|
|
1068
|
+
self.status_step_started_at = 0.0 # Reset for next step
|
|
1060
1069
|
self.status_samples_done += max(0, int(samples_generated))
|
|
1061
1070
|
self.status_failed_total += max(0, int(failed_in_step))
|
|
1062
1071
|
self.update_status_panel()
|
|
1063
1072
|
|
|
1073
|
+
def set_checkpoint_resume_status(
|
|
1074
|
+
self, samples_done: int, failed_total: int, checkpoint_count: int = 0
|
|
1075
|
+
) -> None:
|
|
1076
|
+
"""Initialize status counters from checkpoint data when resuming.
|
|
1077
|
+
|
|
1078
|
+
Args:
|
|
1079
|
+
samples_done: Number of samples already generated in checkpoint
|
|
1080
|
+
failed_total: Number of failures already recorded in checkpoint
|
|
1081
|
+
checkpoint_count: Number of checkpoints already saved (optional)
|
|
1082
|
+
"""
|
|
1083
|
+
self._resumed_from_checkpoint = True
|
|
1084
|
+
self.status_samples_done = max(0, int(samples_done))
|
|
1085
|
+
self.status_failed_total = max(0, int(failed_total))
|
|
1086
|
+
if checkpoint_count > 0:
|
|
1087
|
+
self.checkpoint_count = checkpoint_count
|
|
1088
|
+
self.last_checkpoint_samples = samples_done
|
|
1089
|
+
self.update_status_panel()
|
|
1090
|
+
|
|
1091
|
+
def status_checkpoint_saved(self, total_samples: int) -> None:
|
|
1092
|
+
"""Update checkpoint tracking when a checkpoint is saved."""
|
|
1093
|
+
self.checkpoint_count += 1
|
|
1094
|
+
self.last_checkpoint_samples = total_samples
|
|
1095
|
+
self.update_status_panel()
|
|
1096
|
+
|
|
1097
|
+
def status_stop_requested(self) -> None:
|
|
1098
|
+
"""Mark that a graceful stop has been requested."""
|
|
1099
|
+
self._stop_requested = True
|
|
1100
|
+
self.update_status_panel()
|
|
1101
|
+
|
|
1064
1102
|
def _status_panel(self) -> Panel:
|
|
1065
|
-
elapsed = 0.0
|
|
1066
|
-
if self.status_step_started_at:
|
|
1067
|
-
elapsed = max(0.0, monotonic() - self.status_step_started_at)
|
|
1068
1103
|
table = Table(show_header=False, box=None, padding=(0, 1))
|
|
1069
1104
|
table.add_column(style="cyan", no_wrap=True)
|
|
1070
1105
|
table.add_column(style="white")
|
|
1071
1106
|
table.add_row("Step:", f"{self.status_current_step}/{self.status_total_steps}")
|
|
1072
|
-
|
|
1107
|
+
if self.status_last_step_duration > 0:
|
|
1108
|
+
table.add_row("Last Step:", f"{self.status_last_step_duration:0.1f}s")
|
|
1073
1109
|
table.add_row("Generated:", f"{self.status_samples_done}/{self.status_total_samples}")
|
|
1074
1110
|
if self.status_failed_total:
|
|
1075
1111
|
table.add_row("Failed:", str(self.status_failed_total))
|
|
1112
|
+
if self.checkpoint_enabled:
|
|
1113
|
+
if self.checkpoint_count > 0:
|
|
1114
|
+
table.add_row(
|
|
1115
|
+
"Checkpoints:", f"{self.checkpoint_count} ({self.last_checkpoint_samples} samples)"
|
|
1116
|
+
)
|
|
1117
|
+
else:
|
|
1118
|
+
table.add_row("Checkpoints:", "0 (enabled)")
|
|
1119
|
+
if self._stop_requested:
|
|
1120
|
+
table.add_row("[yellow]Stopping:[/yellow]", "[yellow]at next checkpoint[/yellow]")
|
|
1076
1121
|
return Panel(table, title="Status", border_style="dim", padding=(0, 1))
|
|
1077
1122
|
|
|
1078
1123
|
def update_status_panel(self) -> None:
|
deepfabric/utils.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import asyncio
|
|
3
|
+
import hashlib
|
|
3
4
|
import importlib
|
|
4
5
|
import json
|
|
5
6
|
import os
|
|
6
7
|
import re
|
|
8
|
+
import sys
|
|
7
9
|
|
|
10
|
+
from pathlib import Path
|
|
8
11
|
from typing import Any
|
|
9
12
|
|
|
10
13
|
VALIDATION_ERROR_INDICATORS = [
|
|
@@ -155,6 +158,51 @@ def read_topic_tree_from_jsonl(file_path: str) -> list[dict]:
|
|
|
155
158
|
return topic_tree
|
|
156
159
|
|
|
157
160
|
|
|
161
|
+
def parse_num_samples(value: int | str | None) -> int | str | None:
|
|
162
|
+
"""Parse and validate num_samples: integer, 'auto', or percentage like '50%'.
|
|
163
|
+
|
|
164
|
+
This is a shared utility used by both CLI argument parsing and config validation.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
value: Raw value - can be int, string, or None
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Parsed value: int, "auto", percentage string like "50%", or None
|
|
171
|
+
|
|
172
|
+
Raises:
|
|
173
|
+
ValueError: If the value is invalid
|
|
174
|
+
"""
|
|
175
|
+
if value is None:
|
|
176
|
+
return None
|
|
177
|
+
if isinstance(value, int):
|
|
178
|
+
if value < 1:
|
|
179
|
+
raise ValueError("num_samples must be at least 1")
|
|
180
|
+
return value
|
|
181
|
+
if isinstance(value, str):
|
|
182
|
+
normalized = value.strip().lower()
|
|
183
|
+
if normalized == "auto":
|
|
184
|
+
return "auto"
|
|
185
|
+
if normalized.endswith("%"):
|
|
186
|
+
try:
|
|
187
|
+
pct = float(normalized[:-1])
|
|
188
|
+
except ValueError as e:
|
|
189
|
+
raise ValueError(f"Invalid percentage format: {value}") from e
|
|
190
|
+
if pct <= 0:
|
|
191
|
+
raise ValueError("Percentage must be greater than 0")
|
|
192
|
+
return normalized
|
|
193
|
+
# Try to parse as integer string
|
|
194
|
+
try:
|
|
195
|
+
parsed = int(normalized)
|
|
196
|
+
except ValueError as e:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Invalid num_samples value: {value}. Use integer, 'auto', or percentage like '50%'"
|
|
199
|
+
) from e
|
|
200
|
+
if parsed < 1:
|
|
201
|
+
raise ValueError("num_samples must be at least 1")
|
|
202
|
+
return parsed
|
|
203
|
+
raise ValueError(f"num_samples must be int or string, got {type(value).__name__}")
|
|
204
|
+
|
|
205
|
+
|
|
158
206
|
def get_bool_env(key: str, default: bool = False) -> bool:
|
|
159
207
|
"""Get a boolean environment variable.
|
|
160
208
|
|
|
@@ -195,3 +243,139 @@ def import_optional_dependency(
|
|
|
195
243
|
else:
|
|
196
244
|
msg = f"The '{module_name}' library is required but is not installed."
|
|
197
245
|
raise ModuleNotFoundError(msg) from None
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def check_path_writable(path: str, path_description: str) -> tuple[bool, str | None]:
|
|
249
|
+
"""Check if a path is writable.
|
|
250
|
+
|
|
251
|
+
Checks whether the specified file path can be written to by verifying:
|
|
252
|
+
1. If the file exists, whether it's writable
|
|
253
|
+
2. If the file doesn't exist, whether the parent directory exists and is writable
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
path: The file path to check
|
|
257
|
+
path_description: Human-readable description for error messages
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of (is_writable, error_message). error_message is None if writable.
|
|
261
|
+
"""
|
|
262
|
+
file_path = Path(path)
|
|
263
|
+
parent_dir = file_path.parent
|
|
264
|
+
error_msg: str | None = None
|
|
265
|
+
|
|
266
|
+
# If the file exists, check if it's writable
|
|
267
|
+
if file_path.exists():
|
|
268
|
+
if not os.access(file_path, os.W_OK):
|
|
269
|
+
error_msg = f"{path_description} exists but is not writable: {path}"
|
|
270
|
+
elif not parent_dir.exists():
|
|
271
|
+
# File doesn't exist and parent doesn't exist
|
|
272
|
+
# Walk up to find the first existing ancestor
|
|
273
|
+
ancestor = parent_dir
|
|
274
|
+
while not ancestor.exists() and ancestor != ancestor.parent:
|
|
275
|
+
ancestor = ancestor.parent
|
|
276
|
+
|
|
277
|
+
if not ancestor.exists():
|
|
278
|
+
error_msg = (
|
|
279
|
+
f"{path_description} parent directory does not exist "
|
|
280
|
+
f"and cannot be created: {parent_dir}"
|
|
281
|
+
)
|
|
282
|
+
elif not os.access(ancestor, os.W_OK):
|
|
283
|
+
error_msg = (
|
|
284
|
+
f"{path_description} cannot create parent directory "
|
|
285
|
+
f"(no write access to {ancestor}): {parent_dir}"
|
|
286
|
+
)
|
|
287
|
+
elif not os.access(parent_dir, os.W_OK):
|
|
288
|
+
# Parent exists but is not writable
|
|
289
|
+
error_msg = f"{path_description} parent directory is not writable: {parent_dir}"
|
|
290
|
+
|
|
291
|
+
return (error_msg is None, error_msg)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def check_dir_writable(path: str, path_description: str) -> tuple[bool, str | None]:
|
|
295
|
+
"""Check if a directory path is writable.
|
|
296
|
+
|
|
297
|
+
Checks whether files can be created in the specified directory by verifying:
|
|
298
|
+
1. If the directory exists, whether it's writable
|
|
299
|
+
2. If the directory doesn't exist, whether we can create it
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
path: The directory path to check
|
|
303
|
+
path_description: Human-readable description for error messages
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Tuple of (is_writable, error_message). error_message is None if writable.
|
|
307
|
+
"""
|
|
308
|
+
dir_path = Path(path)
|
|
309
|
+
|
|
310
|
+
# If the directory exists, check if it's writable
|
|
311
|
+
if dir_path.exists():
|
|
312
|
+
if not dir_path.is_dir():
|
|
313
|
+
return False, f"{path_description} exists but is not a directory: {path}"
|
|
314
|
+
if not os.access(dir_path, os.W_OK):
|
|
315
|
+
return False, f"{path_description} directory is not writable: {path}"
|
|
316
|
+
return True, None
|
|
317
|
+
|
|
318
|
+
# Directory doesn't exist - check if we can create it
|
|
319
|
+
ancestor = dir_path
|
|
320
|
+
while not ancestor.exists() and ancestor != ancestor.parent:
|
|
321
|
+
ancestor = ancestor.parent
|
|
322
|
+
|
|
323
|
+
if not ancestor.exists():
|
|
324
|
+
return False, f"{path_description} cannot be created (root does not exist): {path}"
|
|
325
|
+
|
|
326
|
+
if not os.access(ancestor, os.W_OK):
|
|
327
|
+
return False, f"{path_description} cannot be created (no write access to {ancestor}): {path}"
|
|
328
|
+
|
|
329
|
+
return True, None
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
# Checkpoint directory resolution
|
|
333
|
+
APP_NAME = "deepfabric"
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _get_deepfabric_data_dir() -> Path:
|
|
337
|
+
"""Get the DeepFabric data directory using platformdirs or fallback."""
|
|
338
|
+
try:
|
|
339
|
+
from platformdirs import user_data_dir # noqa: PLC0415
|
|
340
|
+
|
|
341
|
+
return Path(user_data_dir(APP_NAME))
|
|
342
|
+
except ImportError:
|
|
343
|
+
# Fallback if platformdirs not available
|
|
344
|
+
if os.name == "nt":
|
|
345
|
+
# Windows: APPDATA
|
|
346
|
+
base = os.environ.get("APPDATA") or os.path.expanduser(r"~\AppData\Roaming")
|
|
347
|
+
elif sys.platform == "darwin":
|
|
348
|
+
# macOS: ~/Library/Application Support
|
|
349
|
+
base = os.path.expanduser("~/Library/Application Support")
|
|
350
|
+
else:
|
|
351
|
+
# Linux and other Unix: XDG_DATA_HOME
|
|
352
|
+
base = os.environ.get("XDG_DATA_HOME") or os.path.expanduser("~/.local/share")
|
|
353
|
+
return Path(base) / APP_NAME
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_checkpoint_dir(config_path: str | None = None) -> str:
|
|
357
|
+
"""
|
|
358
|
+
Get the checkpoint directory for a given config file.
|
|
359
|
+
|
|
360
|
+
Uses ~/.deepfabric/checkpoints/{hash}/ where hash is derived from
|
|
361
|
+
the absolute path of the config file. This ensures:
|
|
362
|
+
- Consistent location regardless of current working directory
|
|
363
|
+
- No conflicts between different projects with same output filename
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
config_path: Path to the config file. If None, uses a default subdirectory.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
Path to the checkpoint directory (not created, just resolved)
|
|
370
|
+
"""
|
|
371
|
+
base_dir = _get_deepfabric_data_dir() / "checkpoints"
|
|
372
|
+
|
|
373
|
+
if config_path is None:
|
|
374
|
+
# No config file - use a "default" subdirectory
|
|
375
|
+
return str(base_dir / "default")
|
|
376
|
+
|
|
377
|
+
# Create a short hash from the absolute path of the config file
|
|
378
|
+
abs_path = str(Path(config_path).resolve())
|
|
379
|
+
path_hash = hashlib.sha256(abs_path.encode()).hexdigest()[:12]
|
|
380
|
+
|
|
381
|
+
return str(base_dir / path_hash)
|