DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/graph.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import textwrap
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
8
|
+
|
|
9
|
+
from .constants import (
|
|
10
|
+
DEFAULT_MAX_TOKENS,
|
|
11
|
+
MAX_RETRY_ATTEMPTS,
|
|
12
|
+
TOPIC_GRAPH_DEFAULT_MODEL,
|
|
13
|
+
TOPIC_GRAPH_DEFAULT_TEMPERATURE,
|
|
14
|
+
TOPIC_GRAPH_SUMMARY,
|
|
15
|
+
)
|
|
16
|
+
from .llm import LLMClient
|
|
17
|
+
from .llm.rate_limit_detector import RateLimitDetector
|
|
18
|
+
from .metrics import trace
|
|
19
|
+
from .prompts import GRAPH_EXPANSION_PROMPT
|
|
20
|
+
from .schemas import GraphSubtopics
|
|
21
|
+
from .stream_simulator import simulate_stream
|
|
22
|
+
from .topic_model import TopicModel
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING: # only for type hints to avoid runtime cycles
|
|
25
|
+
from .progress import ProgressReporter
|
|
26
|
+
|
|
27
|
+
RETRY_BASE_DELAY = 0.5 # seconds
|
|
28
|
+
ERROR_MESSAGE_MAX_LENGTH = 40 # Max chars for error messages in TUI
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class GraphConfig(BaseModel):
|
|
32
|
+
"""Configuration for constructing a topic graph."""
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
35
|
+
|
|
36
|
+
topic_prompt: str = Field(
|
|
37
|
+
..., min_length=1, description="The initial prompt to start the topic graph"
|
|
38
|
+
)
|
|
39
|
+
topic_system_prompt: str = Field(
|
|
40
|
+
default="", description="System prompt for topic exploration and generation"
|
|
41
|
+
)
|
|
42
|
+
provider: str = Field(
|
|
43
|
+
default="ollama",
|
|
44
|
+
min_length=1,
|
|
45
|
+
description="LLM provider (openai, anthropic, gemini, ollama)",
|
|
46
|
+
)
|
|
47
|
+
model_name: str = Field(
|
|
48
|
+
default=TOPIC_GRAPH_DEFAULT_MODEL,
|
|
49
|
+
min_length=1,
|
|
50
|
+
description="The name of the model to be used",
|
|
51
|
+
)
|
|
52
|
+
temperature: float = Field(
|
|
53
|
+
default=TOPIC_GRAPH_DEFAULT_TEMPERATURE,
|
|
54
|
+
ge=0.0,
|
|
55
|
+
le=2.0,
|
|
56
|
+
description="Temperature for model generation",
|
|
57
|
+
)
|
|
58
|
+
degree: int = Field(default=3, ge=1, le=10, description="The branching factor of the graph")
|
|
59
|
+
depth: int = Field(default=2, ge=1, le=5, description="The depth of the graph")
|
|
60
|
+
max_concurrent: int = Field(
|
|
61
|
+
default=4,
|
|
62
|
+
ge=1,
|
|
63
|
+
le=20,
|
|
64
|
+
description="Maximum concurrent LLM calls during graph expansion (helps avoid rate limits)",
|
|
65
|
+
)
|
|
66
|
+
base_url: str | None = Field(
|
|
67
|
+
default=None,
|
|
68
|
+
description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class NodeModel(BaseModel):
|
|
73
|
+
"""Pydantic model for a node in the graph."""
|
|
74
|
+
|
|
75
|
+
id: int
|
|
76
|
+
topic: str
|
|
77
|
+
children: list[int] = Field(default_factory=list)
|
|
78
|
+
parents: list[int] = Field(default_factory=list)
|
|
79
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class GraphModel(BaseModel):
|
|
83
|
+
"""Pydantic model for the entire topic graph."""
|
|
84
|
+
|
|
85
|
+
nodes: dict[int, NodeModel]
|
|
86
|
+
root_id: int
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Node:
|
|
90
|
+
"""Represents a node in the Graph for runtime manipulation."""
|
|
91
|
+
|
|
92
|
+
def __init__(self, topic: str, node_id: int, metadata: dict[str, Any] | None = None):
|
|
93
|
+
self.topic: str = topic
|
|
94
|
+
self.id: int = node_id
|
|
95
|
+
self.children: list[Node] = []
|
|
96
|
+
self.parents: list[Node] = []
|
|
97
|
+
self.metadata: dict[str, Any] = metadata.copy() if metadata is not None else {}
|
|
98
|
+
|
|
99
|
+
def to_pydantic(self) -> NodeModel:
|
|
100
|
+
"""Converts the runtime Node to its Pydantic model representation."""
|
|
101
|
+
return NodeModel(
|
|
102
|
+
id=self.id,
|
|
103
|
+
topic=self.topic,
|
|
104
|
+
children=[child.id for child in self.children],
|
|
105
|
+
parents=[parent.id for parent in self.parents],
|
|
106
|
+
metadata=self.metadata,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class Graph(TopicModel):
|
|
111
|
+
"""Represents the topic graph and manages its structure."""
|
|
112
|
+
|
|
113
|
+
def __init__(self, **kwargs):
|
|
114
|
+
try:
|
|
115
|
+
self.config = GraphConfig.model_validate(kwargs)
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise ValueError(f"Invalid graph configuration: {str(e)}") from e # noqa: TRY003
|
|
118
|
+
|
|
119
|
+
# Initialize from config
|
|
120
|
+
self.topic_prompt = self.config.topic_prompt
|
|
121
|
+
self.model_system_prompt = self.config.topic_system_prompt
|
|
122
|
+
self.provider = self.config.provider
|
|
123
|
+
self.model_name = self.config.model_name
|
|
124
|
+
self.temperature = self.config.temperature
|
|
125
|
+
self.degree = self.config.degree
|
|
126
|
+
self.depth = self.config.depth
|
|
127
|
+
self.max_concurrent = self.config.max_concurrent
|
|
128
|
+
|
|
129
|
+
# Initialize LLM client
|
|
130
|
+
llm_kwargs = {}
|
|
131
|
+
if self.config.base_url:
|
|
132
|
+
llm_kwargs["base_url"] = self.config.base_url
|
|
133
|
+
|
|
134
|
+
self.llm_client = LLMClient(
|
|
135
|
+
provider=self.provider,
|
|
136
|
+
model_name=self.model_name,
|
|
137
|
+
**llm_kwargs,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Progress reporter for streaming feedback (set by topic_manager)
|
|
141
|
+
self.progress_reporter: ProgressReporter | None = None
|
|
142
|
+
|
|
143
|
+
trace(
|
|
144
|
+
"graph_created",
|
|
145
|
+
{
|
|
146
|
+
"provider": self.provider,
|
|
147
|
+
"model_name": self.model_name,
|
|
148
|
+
"degree": self.degree,
|
|
149
|
+
"depth": self.depth,
|
|
150
|
+
},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self.root: Node = Node(self.config.topic_prompt, 0)
|
|
154
|
+
self.nodes: dict[int, Node] = {0: self.root}
|
|
155
|
+
self._next_node_id: int = 1
|
|
156
|
+
self.failed_generations: list[dict[str, Any]] = []
|
|
157
|
+
|
|
158
|
+
def _wrap_text(self, text: str, width: int = 30) -> str:
|
|
159
|
+
"""Wrap text to a specified width."""
|
|
160
|
+
return "\n".join(textwrap.wrap(text, width=width))
|
|
161
|
+
|
|
162
|
+
def add_node(self, topic: str, metadata: dict[str, Any] | None = None) -> Node:
|
|
163
|
+
"""Adds a new node to the graph."""
|
|
164
|
+
node = Node(topic, self._next_node_id, metadata)
|
|
165
|
+
self.nodes[node.id] = node
|
|
166
|
+
self._next_node_id += 1
|
|
167
|
+
return node
|
|
168
|
+
|
|
169
|
+
def add_edge(self, parent_id: int, child_id: int) -> None:
|
|
170
|
+
"""Adds a directed edge from a parent to a child node, avoiding duplicates."""
|
|
171
|
+
parent_node = self.nodes.get(parent_id)
|
|
172
|
+
child_node = self.nodes.get(child_id)
|
|
173
|
+
if parent_node and child_node:
|
|
174
|
+
if child_node not in parent_node.children:
|
|
175
|
+
parent_node.children.append(child_node)
|
|
176
|
+
if parent_node not in child_node.parents:
|
|
177
|
+
child_node.parents.append(parent_node)
|
|
178
|
+
|
|
179
|
+
def to_pydantic(self) -> GraphModel:
|
|
180
|
+
"""Converts the runtime graph to its Pydantic model representation."""
|
|
181
|
+
return GraphModel(
|
|
182
|
+
nodes={node_id: node.to_pydantic() for node_id, node in self.nodes.items()},
|
|
183
|
+
root_id=self.root.id,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def to_json(self) -> str:
|
|
187
|
+
"""Returns a JSON representation of the graph."""
|
|
188
|
+
pydantic_model = self.to_pydantic()
|
|
189
|
+
return pydantic_model.model_dump_json(indent=2)
|
|
190
|
+
|
|
191
|
+
def save(self, save_path: str) -> None:
|
|
192
|
+
"""Save the topic graph to a file."""
|
|
193
|
+
with open(save_path, "w") as f:
|
|
194
|
+
f.write(self.to_json())
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def from_json(cls, json_path: str, params: dict) -> "Graph":
|
|
198
|
+
"""Load a topic graph from a JSON file."""
|
|
199
|
+
with open(json_path) as f:
|
|
200
|
+
data = json.load(f)
|
|
201
|
+
|
|
202
|
+
graph_model = GraphModel(**data)
|
|
203
|
+
graph = cls(**params)
|
|
204
|
+
graph.nodes = {}
|
|
205
|
+
|
|
206
|
+
# Create nodes
|
|
207
|
+
for node_model in graph_model.nodes.values():
|
|
208
|
+
node = Node(node_model.topic, node_model.id, node_model.metadata)
|
|
209
|
+
graph.nodes[node.id] = node
|
|
210
|
+
if node.id == graph_model.root_id:
|
|
211
|
+
graph.root = node
|
|
212
|
+
|
|
213
|
+
# Create edges
|
|
214
|
+
for node_model in graph_model.nodes.values():
|
|
215
|
+
for child_id in node_model.children:
|
|
216
|
+
graph.add_edge(node_model.id, child_id)
|
|
217
|
+
|
|
218
|
+
graph._next_node_id = max(graph.nodes.keys()) + 1
|
|
219
|
+
return graph
|
|
220
|
+
|
|
221
|
+
def visualize(self, save_path: str) -> None:
|
|
222
|
+
"""Visualize the graph and save it to a file."""
|
|
223
|
+
try:
|
|
224
|
+
from mermaid import Mermaid # noqa: PLC0415
|
|
225
|
+
except ImportError as err:
|
|
226
|
+
raise ImportError(
|
|
227
|
+
"Mermaid package is required for graph visualization. "
|
|
228
|
+
"Please install it via 'pip install mermaid'."
|
|
229
|
+
) from err
|
|
230
|
+
|
|
231
|
+
graph_definition = "graph TD\n"
|
|
232
|
+
for node in self.nodes.values():
|
|
233
|
+
graph_definition += f' {node.id}["{self._wrap_text(node.topic)}"]\n'
|
|
234
|
+
|
|
235
|
+
for node in self.nodes.values():
|
|
236
|
+
for child in node.children:
|
|
237
|
+
graph_definition += f" {node.id} --> {child.id}\n"
|
|
238
|
+
|
|
239
|
+
mermaid = Mermaid(graph_definition)
|
|
240
|
+
mermaid.to_svg(f"{save_path}.svg")
|
|
241
|
+
|
|
242
|
+
async def build_async(self):
|
|
243
|
+
"""Builds the graph by iteratively calling the LLM to get subtopics and connections.
|
|
244
|
+
|
|
245
|
+
Yields:
|
|
246
|
+
dict: Progress events with event type and associated data
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
def _raise_if_build_failed():
|
|
250
|
+
"""Check if build failed completely and raise appropriate error."""
|
|
251
|
+
if len(self.nodes) == 1 and self.failed_generations:
|
|
252
|
+
# Surface the actual first error instead of a generic message
|
|
253
|
+
first_error = self.failed_generations[0]["last_error"]
|
|
254
|
+
raise RuntimeError(first_error)
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
for depth in range(self.depth):
|
|
258
|
+
leaf_nodes = [node for node in self.nodes.values() if not node.children]
|
|
259
|
+
yield {"event": "depth_start", "depth": depth + 1, "leaf_count": len(leaf_nodes)}
|
|
260
|
+
|
|
261
|
+
if leaf_nodes:
|
|
262
|
+
# Use semaphore to limit concurrent LLM calls and avoid rate limits
|
|
263
|
+
semaphore = asyncio.Semaphore(self.max_concurrent)
|
|
264
|
+
|
|
265
|
+
async def bounded_expand(
|
|
266
|
+
node: Node, sem: asyncio.Semaphore = semaphore
|
|
267
|
+
) -> tuple[int, int]:
|
|
268
|
+
async with sem:
|
|
269
|
+
return await self.get_subtopics_and_connections(node, self.degree)
|
|
270
|
+
|
|
271
|
+
tasks = [bounded_expand(node) for node in leaf_nodes]
|
|
272
|
+
results = await asyncio.gather(*tasks)
|
|
273
|
+
|
|
274
|
+
for node, (subtopics_added, connections_added) in zip(
|
|
275
|
+
leaf_nodes, results, strict=True
|
|
276
|
+
):
|
|
277
|
+
yield {
|
|
278
|
+
"event": "node_expanded",
|
|
279
|
+
"node_topic": node.topic,
|
|
280
|
+
"subtopics_added": subtopics_added,
|
|
281
|
+
"connections_added": connections_added,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
yield {"event": "depth_complete", "depth": depth + 1}
|
|
285
|
+
|
|
286
|
+
# Check if build was completely unsuccessful (only root node exists)
|
|
287
|
+
_raise_if_build_failed()
|
|
288
|
+
|
|
289
|
+
trace(
|
|
290
|
+
"graph_built",
|
|
291
|
+
{
|
|
292
|
+
"provider": self.provider,
|
|
293
|
+
"model_name": self.model_name,
|
|
294
|
+
"nodes_count": len(self.nodes),
|
|
295
|
+
"failed_generations": len(self.failed_generations),
|
|
296
|
+
"success": len(self.nodes) > 1,
|
|
297
|
+
},
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
yield {
|
|
301
|
+
"event": "build_complete",
|
|
302
|
+
"nodes_count": len(self.nodes),
|
|
303
|
+
"failed_generations": len(self.failed_generations),
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
except Exception as e:
|
|
307
|
+
yield {"event": "error", "error": str(e)}
|
|
308
|
+
raise
|
|
309
|
+
|
|
310
|
+
def _process_subtopics_response(
|
|
311
|
+
self, response: GraphSubtopics, parent_node: Node
|
|
312
|
+
) -> tuple[int, int]:
|
|
313
|
+
"""Process a GraphSubtopics response, adding nodes and edges to the graph.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
response: The structured response containing subtopics and connections.
|
|
317
|
+
parent_node: The parent node to connect new subtopics to.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
A tuple of (subtopics_added, connections_added).
|
|
321
|
+
"""
|
|
322
|
+
subtopics_added = 0
|
|
323
|
+
connections_added = 0
|
|
324
|
+
|
|
325
|
+
for subtopic_data in response.subtopics:
|
|
326
|
+
new_node = self.add_node(subtopic_data.topic)
|
|
327
|
+
self.add_edge(parent_node.id, new_node.id)
|
|
328
|
+
subtopics_added += 1
|
|
329
|
+
for connection_id in subtopic_data.connections:
|
|
330
|
+
if connection_id in self.nodes:
|
|
331
|
+
self.add_edge(connection_id, new_node.id)
|
|
332
|
+
connections_added += 1
|
|
333
|
+
|
|
334
|
+
return subtopics_added, connections_added
|
|
335
|
+
|
|
336
|
+
def _get_friendly_error_message(self, exception: Exception) -> str:
|
|
337
|
+
"""Convert an exception into a user-friendly error message for TUI display.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
exception: The exception that occurred during generation.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
A concise, user-friendly error message suitable for display.
|
|
344
|
+
"""
|
|
345
|
+
# Check for rate limit errors using the detector
|
|
346
|
+
if RateLimitDetector.is_rate_limit_error(exception, self.provider):
|
|
347
|
+
return self._format_rate_limit_message(exception)
|
|
348
|
+
|
|
349
|
+
error_str = str(exception).lower()
|
|
350
|
+
|
|
351
|
+
# Check for validation/schema errors (Pydantic issues)
|
|
352
|
+
validation_indicators = ["validation failed", "validation error", "validationerror"]
|
|
353
|
+
if any(ind in error_str for ind in validation_indicators):
|
|
354
|
+
return "Response format issue - retrying"
|
|
355
|
+
|
|
356
|
+
# Check for network/connection errors
|
|
357
|
+
network_indicators = ["timeout", "connection", "network", "socket"]
|
|
358
|
+
if any(ind in error_str for ind in network_indicators):
|
|
359
|
+
return "Connection issue - retrying"
|
|
360
|
+
|
|
361
|
+
# Check for server errors
|
|
362
|
+
server_indicators = ["503", "502", "500", "504", "server error", "service unavailable"]
|
|
363
|
+
if any(ind in error_str for ind in server_indicators):
|
|
364
|
+
return "Server error - retrying"
|
|
365
|
+
|
|
366
|
+
# Fallback: truncate the original error for display
|
|
367
|
+
return self._truncate_error_message(str(exception))
|
|
368
|
+
|
|
369
|
+
def _format_rate_limit_message(self, exception: Exception) -> str:
|
|
370
|
+
"""Format a rate limit error into a user-friendly message."""
|
|
371
|
+
quota_info = RateLimitDetector.extract_quota_info(exception, self.provider)
|
|
372
|
+
if quota_info.daily_quota_exhausted:
|
|
373
|
+
return "Daily quota exhausted - waiting"
|
|
374
|
+
if quota_info.quota_type:
|
|
375
|
+
return f"Rate limit ({quota_info.quota_type}) - backing off"
|
|
376
|
+
return "Rate limit reached - backing off"
|
|
377
|
+
|
|
378
|
+
def _truncate_error_message(self, message: str) -> str:
|
|
379
|
+
"""Truncate an error message to fit within the TUI display limit."""
|
|
380
|
+
if len(message) > ERROR_MESSAGE_MAX_LENGTH:
|
|
381
|
+
return message[: ERROR_MESSAGE_MAX_LENGTH - 3] + "..."
|
|
382
|
+
return message
|
|
383
|
+
|
|
384
|
+
async def _generate_subtopics_with_retry(
|
|
385
|
+
self, prompt: str, parent_node: Node
|
|
386
|
+
) -> GraphSubtopics | None:
|
|
387
|
+
"""Generate subtopics with retry logic and exponential backoff.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
prompt: The prompt to send to the LLM.
|
|
391
|
+
parent_node: The parent node (used for error tracking and retry events).
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
The GraphSubtopics response, or None if all retries failed.
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
RuntimeError: If authentication fails (API key errors are not retried).
|
|
398
|
+
"""
|
|
399
|
+
last_error: str | None = None
|
|
400
|
+
|
|
401
|
+
for attempt in range(MAX_RETRY_ATTEMPTS):
|
|
402
|
+
try:
|
|
403
|
+
response = await self.llm_client.generate_async(
|
|
404
|
+
prompt=prompt,
|
|
405
|
+
schema=GraphSubtopics,
|
|
406
|
+
max_retries=1, # Don't retry inside - we handle it here
|
|
407
|
+
max_tokens=DEFAULT_MAX_TOKENS,
|
|
408
|
+
temperature=self.temperature,
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
412
|
+
simulate_stream(
|
|
413
|
+
self.progress_reporter,
|
|
414
|
+
response.model_dump_json(indent=2),
|
|
415
|
+
source="graph_generation",
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
except Exception as e:
|
|
419
|
+
last_error = str(e)
|
|
420
|
+
error_str = str(e).lower()
|
|
421
|
+
|
|
422
|
+
# Check if it's an API key related error - don't retry these
|
|
423
|
+
if any(
|
|
424
|
+
keyword in error_str
|
|
425
|
+
for keyword in ["api_key", "api key", "authentication", "unauthorized"]
|
|
426
|
+
):
|
|
427
|
+
error_msg = (
|
|
428
|
+
f"Authentication failed for provider '{self.provider}'. "
|
|
429
|
+
"Please set the required API key environment variable."
|
|
430
|
+
)
|
|
431
|
+
raise RuntimeError(error_msg) from e
|
|
432
|
+
|
|
433
|
+
# Log retry attempt if not the last one
|
|
434
|
+
if attempt < MAX_RETRY_ATTEMPTS - 1:
|
|
435
|
+
if self.progress_reporter:
|
|
436
|
+
# Use friendly error message for TUI display
|
|
437
|
+
friendly_error = self._get_friendly_error_message(e)
|
|
438
|
+
self.progress_reporter.emit_node_retry(
|
|
439
|
+
node_topic=parent_node.topic,
|
|
440
|
+
attempt=attempt + 1,
|
|
441
|
+
max_attempts=MAX_RETRY_ATTEMPTS,
|
|
442
|
+
error_summary=friendly_error,
|
|
443
|
+
)
|
|
444
|
+
# Brief delay before retry with exponential backoff
|
|
445
|
+
delay = (2**attempt) * RETRY_BASE_DELAY
|
|
446
|
+
await asyncio.sleep(delay)
|
|
447
|
+
|
|
448
|
+
else:
|
|
449
|
+
# Success - return the response
|
|
450
|
+
return response
|
|
451
|
+
|
|
452
|
+
# All retries exhausted - record failure
|
|
453
|
+
self.failed_generations.append(
|
|
454
|
+
{"node_id": parent_node.id, "attempts": MAX_RETRY_ATTEMPTS, "last_error": last_error}
|
|
455
|
+
)
|
|
456
|
+
return None
|
|
457
|
+
|
|
458
|
+
async def get_subtopics_and_connections(
|
|
459
|
+
self, parent_node: Node, num_subtopics: int
|
|
460
|
+
) -> tuple[int, int]:
|
|
461
|
+
"""Generate subtopics and connections for a given node.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
parent_node: The node to generate subtopics for.
|
|
465
|
+
num_subtopics: The number of subtopics to generate.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
A tuple of (subtopics_added, connections_added).
|
|
469
|
+
"""
|
|
470
|
+
graph_summary = (
|
|
471
|
+
self.to_json()
|
|
472
|
+
if len(self.nodes) <= TOPIC_GRAPH_SUMMARY
|
|
473
|
+
else "Graph too large to display"
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
graph_prompt = GRAPH_EXPANSION_PROMPT.replace("{{current_graph_summary}}", graph_summary)
|
|
477
|
+
graph_prompt = graph_prompt.replace("{{current_topic}}", parent_node.topic)
|
|
478
|
+
graph_prompt = graph_prompt.replace("{{num_subtopics}}", str(num_subtopics))
|
|
479
|
+
|
|
480
|
+
response = await self._generate_subtopics_with_retry(graph_prompt, parent_node)
|
|
481
|
+
if response is None:
|
|
482
|
+
return 0, 0
|
|
483
|
+
|
|
484
|
+
return self._process_subtopics_response(
|
|
485
|
+
response, parent_node
|
|
486
|
+
) # No subtopics or connections added on failure
|
|
487
|
+
|
|
488
|
+
def get_all_paths(self) -> list[list[str]]:
|
|
489
|
+
"""Returns all paths from the root to leaf nodes."""
|
|
490
|
+
paths = []
|
|
491
|
+
visited: set[int] = set()
|
|
492
|
+
self._dfs_paths(self.root, [self.root.topic], paths, visited)
|
|
493
|
+
return paths
|
|
494
|
+
|
|
495
|
+
def _dfs_paths(
|
|
496
|
+
self, node: Node, current_path: list[str], paths: list[list[str]], visited: set[int]
|
|
497
|
+
) -> None:
|
|
498
|
+
"""Helper function for DFS traversal to find all paths.
|
|
499
|
+
|
|
500
|
+
Args:
|
|
501
|
+
node: Current node being visited
|
|
502
|
+
current_path: Path from root to current node
|
|
503
|
+
paths: Accumulated list of complete paths
|
|
504
|
+
visited: Set of node IDs already visited in current path to prevent cycles
|
|
505
|
+
"""
|
|
506
|
+
# Prevent cycles by tracking visited nodes in the current path
|
|
507
|
+
if node.id in visited:
|
|
508
|
+
return
|
|
509
|
+
|
|
510
|
+
visited.add(node.id)
|
|
511
|
+
|
|
512
|
+
if not node.children:
|
|
513
|
+
paths.append(current_path)
|
|
514
|
+
|
|
515
|
+
for child in node.children:
|
|
516
|
+
self._dfs_paths(child, current_path + [child.topic], paths, visited)
|
|
517
|
+
|
|
518
|
+
# Remove node from visited when backtracking to allow it in other paths
|
|
519
|
+
visited.remove(node.id)
|
|
520
|
+
|
|
521
|
+
def _has_cycle_util(self, node: Node, visited: set[int], recursion_stack: set[int]) -> bool:
|
|
522
|
+
"""Utility function for cycle detection."""
|
|
523
|
+
visited.add(node.id)
|
|
524
|
+
recursion_stack.add(node.id)
|
|
525
|
+
|
|
526
|
+
for child in node.children:
|
|
527
|
+
if child.id not in visited:
|
|
528
|
+
if self._has_cycle_util(child, visited, recursion_stack):
|
|
529
|
+
return True
|
|
530
|
+
elif child.id in recursion_stack:
|
|
531
|
+
return True
|
|
532
|
+
|
|
533
|
+
recursion_stack.remove(node.id)
|
|
534
|
+
return False
|
|
535
|
+
|
|
536
|
+
def has_cycle(self) -> bool:
|
|
537
|
+
"""Checks if the graph contains a cycle."""
|
|
538
|
+
visited: set[int] = set()
|
|
539
|
+
recursion_stack: set[int] = set()
|
|
540
|
+
for node_id in self.nodes:
|
|
541
|
+
if node_id not in visited and self._has_cycle_util(
|
|
542
|
+
self.nodes[node_id], visited, recursion_stack
|
|
543
|
+
):
|
|
544
|
+
return True
|
|
545
|
+
return False
|