DeepFabric 4.10.1__py3-none-any.whl → 4.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepfabric/graph.py CHANGED
@@ -26,7 +26,7 @@ from .prompts import (
26
26
  )
27
27
  from .schemas import GraphSubtopics
28
28
  from .stream_simulator import simulate_stream
29
- from .topic_model import TopicModel, TopicPath
29
+ from .topic_model import Topic, TopicModel, TopicPath
30
30
 
31
31
  if TYPE_CHECKING: # only for type hints to avoid runtime cycles
32
32
  from .progress import ProgressReporter
@@ -70,6 +70,11 @@ class GraphConfig(BaseModel):
70
70
  le=20,
71
71
  description="Maximum concurrent LLM calls during graph expansion (helps avoid rate limits)",
72
72
  )
73
+ max_tokens: int = Field(
74
+ default=DEFAULT_MAX_TOKENS,
75
+ ge=1,
76
+ description="Maximum tokens for topic generation LLM calls",
77
+ )
73
78
  base_url: str | None = Field(
74
79
  default=None,
75
80
  description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
@@ -156,6 +161,7 @@ class Graph(TopicModel):
156
161
  self.degree = self.config.degree
157
162
  self.depth = self.config.depth
158
163
  self.max_concurrent = self.config.max_concurrent
164
+ self.max_tokens = self.config.max_tokens
159
165
  self.prompt_style = self.config.prompt_style
160
166
 
161
167
  # Initialize LLM client
@@ -211,6 +217,139 @@ class Graph(TopicModel):
211
217
  if parent_node not in child_node.parents:
212
218
  child_node.parents.append(parent_node)
213
219
 
220
+ def find_node_by_uuid(self, uuid: str) -> Node | None:
221
+ """Find a node by its UUID.
222
+
223
+ Args:
224
+ uuid: The UUID string to search for.
225
+
226
+ Returns:
227
+ The Node if found, None otherwise.
228
+ """
229
+ for node in self.nodes.values():
230
+ if node.metadata.get("uuid") == uuid:
231
+ return node
232
+ return None
233
+
234
+ def remove_node(self, node_id: int) -> None:
235
+ """Remove a single node from the graph, cleaning up bidirectional references.
236
+
237
+ Does not remove children — use remove_subtree() for cascading removal.
238
+
239
+ Args:
240
+ node_id: The ID of the node to remove.
241
+
242
+ Raises:
243
+ ValueError: If node_id is the root node or does not exist.
244
+ """
245
+ if node_id == self.root.id:
246
+ raise ValueError("Cannot remove the root node") # noqa: TRY003
247
+ node = self.nodes.get(node_id)
248
+ if node is None:
249
+ raise ValueError(f"Node {node_id} not found in graph") # noqa: TRY003
250
+
251
+ for parent in node.parents:
252
+ if node in parent.children:
253
+ parent.children.remove(node)
254
+
255
+ for child in node.children:
256
+ if node in child.parents:
257
+ child.parents.remove(node)
258
+
259
+ del self.nodes[node_id]
260
+
261
+ def remove_subtree(self, node_id: int) -> list[int]:
262
+ """Remove a node and all its descendants from the graph.
263
+
264
+ Args:
265
+ node_id: The ID of the node to remove (along with all descendants).
266
+
267
+ Returns:
268
+ List of removed node IDs.
269
+
270
+ Raises:
271
+ ValueError: If node_id is the root node or does not exist.
272
+ """
273
+ if node_id == self.root.id:
274
+ raise ValueError("Cannot remove the root node") # noqa: TRY003
275
+ node = self.nodes.get(node_id)
276
+ if node is None:
277
+ raise ValueError(f"Node {node_id} not found in graph") # noqa: TRY003
278
+
279
+ # BFS to collect all descendant node IDs
280
+ to_remove: list[int] = []
281
+ queue = [node]
282
+ visited: set[int] = set()
283
+ while queue:
284
+ current = queue.pop(0)
285
+ if current.id in visited:
286
+ continue
287
+ visited.add(current.id)
288
+ to_remove.append(current.id)
289
+ for child in current.children:
290
+ if child.id not in visited:
291
+ queue.append(child)
292
+
293
+ # Remove in reverse order (leaves first)
294
+ for nid in reversed(to_remove):
295
+ self.remove_node(nid)
296
+
297
+ return to_remove
298
+
299
+ def prune_at_level(self, max_depth: int) -> list[int]:
300
+ """Remove all nodes below the given depth level.
301
+
302
+ Nodes at exactly max_depth become leaf nodes. Root is depth 0.
303
+
304
+ Args:
305
+ max_depth: Maximum depth to keep (inclusive).
306
+ 0 = keep only root, 1 = root and its children, etc.
307
+
308
+ Returns:
309
+ List of removed node IDs.
310
+
311
+ Raises:
312
+ ValueError: If max_depth is negative.
313
+ """
314
+ if max_depth < 0:
315
+ raise ValueError("max_depth must be non-negative") # noqa: TRY003
316
+
317
+ # BFS from root to compute node depths
318
+ node_depths: dict[int, int] = {}
319
+ queue: list[tuple[Node, int]] = [(self.root, 0)]
320
+ visited: set[int] = set()
321
+ while queue:
322
+ current, depth = queue.pop(0)
323
+ if current.id in visited:
324
+ continue
325
+ visited.add(current.id)
326
+ node_depths[current.id] = depth
327
+ for child in current.children:
328
+ if child.id not in visited:
329
+ queue.append((child, depth + 1))
330
+
331
+ to_remove_set = {nid for nid, d in node_depths.items() if d > max_depth}
332
+
333
+ # Sever children links from boundary nodes
334
+ for nid, d in node_depths.items():
335
+ if d == max_depth:
336
+ self.nodes[nid].children = [
337
+ c for c in self.nodes[nid].children if c.id not in to_remove_set
338
+ ]
339
+
340
+ # Remove deeper nodes
341
+ for nid in to_remove_set:
342
+ node = self.nodes[nid]
343
+ for parent in node.parents:
344
+ if node in parent.children:
345
+ parent.children.remove(node)
346
+ for child in node.children:
347
+ if node in child.parents:
348
+ child.parents.remove(node)
349
+ del self.nodes[nid]
350
+
351
+ return list(to_remove_set)
352
+
214
353
  def to_pydantic(self) -> GraphModel:
215
354
  """Converts the runtime graph to its Pydantic model representation."""
216
355
  return GraphModel(
@@ -237,6 +376,13 @@ class Graph(TopicModel):
237
376
  with open(save_path, "w") as f:
238
377
  f.write(self.to_json())
239
378
 
379
+ # Save failed generations if any
380
+ if self.failed_generations:
381
+ failed_path = save_path.replace(".json", "_failed.jsonl")
382
+ with open(failed_path, "w") as f:
383
+ for failed in self.failed_generations:
384
+ f.write(json.dumps({"failed_generation": failed}) + "\n")
385
+
240
386
  @classmethod
241
387
  def from_json(cls, json_path: str, params: dict) -> "Graph":
242
388
  """Load a topic graph from a JSON file."""
@@ -268,6 +414,36 @@ class Graph(TopicModel):
268
414
  graph._next_node_id = max(graph.nodes.keys()) + 1
269
415
  return graph
270
416
 
417
+ @classmethod
418
+ def load(cls, json_path: str) -> "Graph":
419
+ """Load a graph from JSON without initializing LLM client.
420
+
421
+ Intended for inspection and manipulation operations that don't
422
+ require LLM generation capabilities. Restores provider, model,
423
+ and temperature from the file metadata so saves preserve them.
424
+ """
425
+ params = {
426
+ "topic_prompt": "loaded",
427
+ "model_name": "placeholder/model",
428
+ "degree": 3,
429
+ "depth": 2,
430
+ "temperature": 0.7,
431
+ }
432
+ graph = cls.from_json(json_path, params)
433
+
434
+ # Restore original metadata so save() preserves provenance
435
+ with open(json_path) as f:
436
+ raw = json.load(f)
437
+ file_meta = raw.get("metadata") or {}
438
+ if file_meta.get("provider"):
439
+ graph.provider = file_meta["provider"]
440
+ if file_meta.get("model"):
441
+ graph.model_name = file_meta["model"]
442
+ if file_meta.get("temperature") is not None:
443
+ graph.temperature = file_meta["temperature"]
444
+
445
+ return graph
446
+
271
447
  def visualize(self, save_path: str) -> None:
272
448
  """Visualize the graph and save it to a file."""
273
449
  try:
@@ -454,7 +630,7 @@ class Graph(TopicModel):
454
630
  prompt=prompt,
455
631
  schema=GraphSubtopics,
456
632
  max_retries=1, # Don't retry inside - we handle it here
457
- max_tokens=DEFAULT_MAX_TOKENS,
633
+ max_tokens=self.max_tokens,
458
634
  temperature=self.temperature,
459
635
  )
460
636
 
@@ -615,6 +791,30 @@ class Graph(TopicModel):
615
791
 
616
792
  visited.remove(node.id)
617
793
 
794
+ def get_unique_topics(self) -> list[Topic]:
795
+ """Returns deduplicated topics by UUID.
796
+
797
+ Iterates through all nodes in the graph and returns unique topics.
798
+ Each node has a UUID in its metadata, ensuring uniqueness.
799
+
800
+ Returns:
801
+ List of Topic namedtuples containing (uuid, topic).
802
+ Each UUID appears exactly once.
803
+ """
804
+ seen_uuids: set[str] = set()
805
+ result: list[Topic] = []
806
+
807
+ for node in self.nodes.values():
808
+ # Skip root node — it holds the generation seed prompt, not a topic
809
+ if node.id == self.root.id:
810
+ continue
811
+ node_uuid = node.metadata.get("uuid")
812
+ if node_uuid and node_uuid not in seen_uuids:
813
+ seen_uuids.add(node_uuid)
814
+ result.append(Topic(uuid=node_uuid, topic=node.topic))
815
+
816
+ return result
817
+
618
818
  def _dfs_paths(
619
819
  self, node: Node, current_path: list[str], paths: list[list[str]], visited: set[int]
620
820
  ) -> None:
@@ -0,0 +1,122 @@
1
+ """Graph pruning operations for deepfabric CLI."""
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ from .graph import Graph
8
+
9
+
10
+ @dataclass
11
+ class PruneResult:
12
+ """Result of a pruning operation."""
13
+
14
+ operation: Literal["level", "uuid"]
15
+ removed_count: int
16
+ removed_node_ids: list[int]
17
+ remaining_nodes: int
18
+ remaining_paths: int
19
+ output_path: str
20
+
21
+
22
+ def load_graph_for_pruning(file_path: str) -> Graph:
23
+ """Load a graph from JSON for pruning operations.
24
+
25
+ Args:
26
+ file_path: Path to the graph JSON file.
27
+
28
+ Returns:
29
+ Loaded Graph instance.
30
+
31
+ Raises:
32
+ FileNotFoundError: If the file does not exist.
33
+ ValueError: If the file is not a JSON graph file.
34
+ """
35
+ path = Path(file_path)
36
+ if not path.exists():
37
+ raise FileNotFoundError(f"Graph file not found: {file_path}")
38
+ if path.suffix != ".json":
39
+ raise ValueError(
40
+ f"Expected a JSON graph file, got: {path.suffix}. "
41
+ "Pruning is only supported for graph format files."
42
+ )
43
+ return Graph.load(file_path)
44
+
45
+
46
+ def prune_graph_at_level(
47
+ file_path: str,
48
+ max_depth: int,
49
+ output_path: str | None = None,
50
+ ) -> PruneResult:
51
+ """Prune a graph file by removing all nodes below a depth level.
52
+
53
+ Args:
54
+ file_path: Path to the input graph JSON file.
55
+ max_depth: Maximum depth to keep (0=root only, 1=root+children, etc.).
56
+ output_path: Output file path. If None, derives from input filename.
57
+
58
+ Returns:
59
+ PruneResult with operation details.
60
+ """
61
+ graph = load_graph_for_pruning(file_path)
62
+ removed_ids = graph.prune_at_level(max_depth)
63
+
64
+ final_output = output_path or _derive_output_path(file_path, f"pruned_level{max_depth}")
65
+ graph.save(final_output)
66
+
67
+ return PruneResult(
68
+ operation="level",
69
+ removed_count=len(removed_ids),
70
+ removed_node_ids=removed_ids,
71
+ remaining_nodes=len(graph.nodes),
72
+ remaining_paths=len(graph.get_all_paths()),
73
+ output_path=final_output,
74
+ )
75
+
76
+
77
+ def prune_graph_by_uuid(
78
+ file_path: str,
79
+ uuid: str,
80
+ output_path: str | None = None,
81
+ ) -> PruneResult:
82
+ """Remove a node (by UUID) and its entire subtree from a graph file.
83
+
84
+ Args:
85
+ file_path: Path to the input graph JSON file.
86
+ uuid: UUID of the node to remove.
87
+ output_path: Output file path. If None, derives from input filename.
88
+
89
+ Returns:
90
+ PruneResult with operation details.
91
+
92
+ Raises:
93
+ ValueError: If UUID not found or targets the root node.
94
+ """
95
+ graph = load_graph_for_pruning(file_path)
96
+ node = graph.find_node_by_uuid(uuid)
97
+
98
+ if node is None:
99
+ raise ValueError(f"No node found with UUID: {uuid}")
100
+
101
+ removed_ids = graph.remove_subtree(node.id)
102
+
103
+ final_output = output_path or _derive_output_path(file_path, "pruned")
104
+ graph.save(final_output)
105
+
106
+ return PruneResult(
107
+ operation="uuid",
108
+ removed_count=len(removed_ids),
109
+ removed_node_ids=removed_ids,
110
+ remaining_nodes=len(graph.nodes),
111
+ remaining_paths=len(graph.get_all_paths()),
112
+ output_path=final_output,
113
+ )
114
+
115
+
116
+ def _derive_output_path(input_path: str, suffix: str) -> str:
117
+ """Derive a non-destructive output path from the input path.
118
+
119
+ Example: topic_graph.json -> topic_graph_pruned_level2.json
120
+ """
121
+ p = Path(input_path)
122
+ return str(p.with_stem(f"{p.stem}_{suffix}"))
@@ -7,15 +7,21 @@ import time
7
7
 
8
8
  from collections.abc import Callable, Coroutine
9
9
  from functools import wraps
10
- from typing import Any, TypeVar
10
+ from typing import TYPE_CHECKING, Any, TypeVar
11
11
 
12
12
  from .rate_limit_config import BackoffStrategy, RateLimitConfig
13
13
  from .rate_limit_detector import RateLimitDetector
14
14
 
15
+ if TYPE_CHECKING:
16
+ from deepfabric.progress import ProgressReporter
17
+
15
18
  logger = logging.getLogger(__name__)
16
19
 
17
20
  T = TypeVar("T")
18
21
 
22
+ # Max chars for error summaries emitted through progress reporter
23
+ _ERROR_SUMMARY_MAX_LENGTH = 200
24
+
19
25
 
20
26
  class RetryHandler:
21
27
  """Intelligent retry handler for LLM API calls with provider-aware backoff."""
@@ -30,6 +36,7 @@ class RetryHandler:
30
36
  self.config = config
31
37
  self.provider = provider
32
38
  self.detector = RateLimitDetector()
39
+ self.progress_reporter: ProgressReporter | None = None
33
40
 
34
41
  def should_retry(self, exception: Exception) -> bool:
35
42
  """Determine if an exception should trigger a retry.
@@ -126,14 +133,26 @@ class RetryHandler:
126
133
  if quota_info.quota_type:
127
134
  quota_info_str = f" (quota_type: {quota_info.quota_type})"
128
135
 
129
- logger.warning(
130
- "Rate limit/transient error for %s on attempt %d, backing off %.2fs%s: %s",
131
- self.provider,
132
- tries,
133
- wait,
134
- quota_info_str,
135
- exception,
136
- )
136
+ if self.progress_reporter:
137
+ error_summary = str(exception)
138
+ if len(error_summary) > _ERROR_SUMMARY_MAX_LENGTH:
139
+ error_summary = error_summary[:_ERROR_SUMMARY_MAX_LENGTH] + "..."
140
+ self.progress_reporter.emit_llm_retry(
141
+ provider=self.provider,
142
+ attempt=tries,
143
+ wait=wait,
144
+ error_summary=error_summary,
145
+ quota_type=quota_info_str.strip(" ()") if quota_info_str else "",
146
+ )
147
+ else:
148
+ logger.warning(
149
+ "Rate limit/transient error for %s on attempt %d, backing off %.2fs%s: %s",
150
+ self.provider,
151
+ tries,
152
+ wait,
153
+ quota_info_str,
154
+ exception,
155
+ )
137
156
 
138
157
  def on_giveup_handler(self, details: dict[str, Any]) -> None:
139
158
  """Callback when giving up after max retries.
deepfabric/progress.py CHANGED
@@ -81,6 +81,25 @@ class StreamObserver(Protocol):
81
81
  """
82
82
  ...
83
83
 
84
+ def on_llm_retry(
85
+ self,
86
+ provider: str,
87
+ attempt: int,
88
+ wait: float,
89
+ error_summary: str,
90
+ metadata: dict[str, Any],
91
+ ) -> None:
92
+ """Called when an LLM API call is retried due to rate limiting or transient error.
93
+
94
+ Args:
95
+ provider: LLM provider name (e.g., "gemini", "openai")
96
+ attempt: Current attempt number (1-based)
97
+ wait: Backoff delay in seconds
98
+ error_summary: Brief description of the error
99
+ metadata: Additional context (e.g., quota_type)
100
+ """
101
+ ...
102
+
84
103
 
85
104
  class ProgressReporter:
86
105
  """Central progress reporter that notifies observers of generation events.
@@ -184,6 +203,29 @@ class ProgressReporter:
184
203
  if hasattr(observer, "on_retry"):
185
204
  observer.on_retry(sample_idx, attempt, max_attempts, error_summary, metadata)
186
205
 
206
+ def emit_llm_retry(
207
+ self,
208
+ provider: str,
209
+ attempt: int,
210
+ wait: float,
211
+ error_summary: str,
212
+ **metadata,
213
+ ) -> None:
214
+ """Emit an LLM retry event to all observers.
215
+
216
+ Used to track LLM API rate limits and transient errors.
217
+
218
+ Args:
219
+ provider: LLM provider name
220
+ attempt: Current attempt number (1-based)
221
+ wait: Backoff delay in seconds
222
+ error_summary: Brief description of the error
223
+ **metadata: Additional context as keyword arguments
224
+ """
225
+ for observer in self._observers:
226
+ if hasattr(observer, "on_llm_retry"):
227
+ observer.on_llm_retry(provider, attempt, wait, error_summary, metadata)
228
+
187
229
  def emit_tool_execution(
188
230
  self,
189
231
  tool_name: str,