DeepFabric 4.10.1__py3-none-any.whl → 4.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/cli.py +624 -33
- deepfabric/cloud_upload.py +1 -1
- deepfabric/config.py +14 -5
- deepfabric/config_manager.py +6 -1
- deepfabric/constants.py +1 -1
- deepfabric/dataset_manager.py +264 -62
- deepfabric/generator.py +687 -82
- deepfabric/graph.py +202 -2
- deepfabric/graph_pruner.py +122 -0
- deepfabric/llm/retry_handler.py +28 -9
- deepfabric/progress.py +42 -0
- deepfabric/topic_inspector.py +237 -0
- deepfabric/topic_manager.py +54 -2
- deepfabric/topic_model.py +26 -0
- deepfabric/tree.py +81 -41
- deepfabric/tui.py +448 -349
- deepfabric/utils.py +4 -1
- {deepfabric-4.10.1.dist-info → deepfabric-4.12.0.dist-info}/METADATA +3 -1
- {deepfabric-4.10.1.dist-info → deepfabric-4.12.0.dist-info}/RECORD +22 -20
- {deepfabric-4.10.1.dist-info → deepfabric-4.12.0.dist-info}/licenses/LICENSE +1 -1
- {deepfabric-4.10.1.dist-info → deepfabric-4.12.0.dist-info}/WHEEL +0 -0
- {deepfabric-4.10.1.dist-info → deepfabric-4.12.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Topic file inspection utilities for deepfabric CLI."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Literal
|
|
9
|
+
|
|
10
|
+
from .graph import Graph
|
|
11
|
+
from .utils import read_topic_tree_from_jsonl
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class TopicInspectionResult:
|
|
16
|
+
"""Result of inspecting a topic file."""
|
|
17
|
+
|
|
18
|
+
format: Literal["tree", "graph"]
|
|
19
|
+
total_paths: int
|
|
20
|
+
max_depth: int
|
|
21
|
+
paths_at_level: list[list[str]] | None
|
|
22
|
+
expanded_paths: list[list[str]] | None # Paths from level onwards (with --expand)
|
|
23
|
+
all_paths: list[list[str]] | None
|
|
24
|
+
metadata: dict[str, Any]
|
|
25
|
+
source_file: str
|
|
26
|
+
# Maps path tuple to UUID/topic_id (for --uuid flag)
|
|
27
|
+
path_to_uuid: dict[tuple[str, ...], str] = field(default_factory=dict)
|
|
28
|
+
# Maps topic name to UUID (for graph format, all nodes)
|
|
29
|
+
topic_to_uuid: dict[str, str] = field(default_factory=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def detect_format(file_path: str) -> Literal["tree", "graph"]:
|
|
33
|
+
"""Auto-detect topic file format based on content.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
file_path: Path to the topic file
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
"tree" for JSONL format, "graph" for JSON format
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If format cannot be detected
|
|
43
|
+
FileNotFoundError: If file doesn't exist
|
|
44
|
+
"""
|
|
45
|
+
path = Path(file_path)
|
|
46
|
+
if not path.exists():
|
|
47
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
48
|
+
|
|
49
|
+
with open(file_path, encoding="utf-8") as f:
|
|
50
|
+
content = f.read().strip()
|
|
51
|
+
if not content:
|
|
52
|
+
raise ValueError("Empty file")
|
|
53
|
+
|
|
54
|
+
# Try to parse as a complete JSON object (Graph format)
|
|
55
|
+
try:
|
|
56
|
+
data = json.loads(content)
|
|
57
|
+
if isinstance(data, dict) and "nodes" in data and "root_id" in data:
|
|
58
|
+
return "graph"
|
|
59
|
+
except json.JSONDecodeError:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
# Try to parse first line as JSONL (Tree format)
|
|
63
|
+
first_line = content.split("\n")[0].strip()
|
|
64
|
+
try:
|
|
65
|
+
first_obj = json.loads(first_line)
|
|
66
|
+
if isinstance(first_obj, dict) and "path" in first_obj:
|
|
67
|
+
return "tree"
|
|
68
|
+
except json.JSONDecodeError:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
raise ValueError(f"Unable to detect format for: {file_path}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _load_tree_paths(file_path: str) -> tuple[list[list[str]], dict[tuple[str, ...], str]]:
|
|
75
|
+
"""Load tree paths directly from JSONL without initializing LLM.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
file_path: Path to the JSONL file
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Tuple of (paths, path_to_uuid mapping)
|
|
82
|
+
"""
|
|
83
|
+
dict_list = read_topic_tree_from_jsonl(file_path)
|
|
84
|
+
paths = []
|
|
85
|
+
path_to_uuid: dict[tuple[str, ...], str] = {}
|
|
86
|
+
|
|
87
|
+
for d in dict_list:
|
|
88
|
+
if "path" not in d:
|
|
89
|
+
continue
|
|
90
|
+
path = d["path"]
|
|
91
|
+
paths.append(path)
|
|
92
|
+
# Generate hash-based ID from path (same as tree.py)
|
|
93
|
+
path_str = " > ".join(path)
|
|
94
|
+
topic_id = hashlib.sha256(path_str.encode()).hexdigest()[:16]
|
|
95
|
+
path_to_uuid[tuple(path)] = topic_id
|
|
96
|
+
|
|
97
|
+
return paths, path_to_uuid
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _load_graph_data(
|
|
101
|
+
file_path: str,
|
|
102
|
+
) -> tuple[list[list[str]], dict[str, Any], dict[tuple[str, ...], str], dict[str, str]]:
|
|
103
|
+
"""Load graph data and extract paths and metadata.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
file_path: Path to the JSON file
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Tuple of (paths, metadata, path_to_uuid mapping, topic_to_uuid mapping)
|
|
110
|
+
"""
|
|
111
|
+
graph = Graph.load(file_path)
|
|
112
|
+
|
|
113
|
+
# Get paths with UUIDs (for leaf nodes)
|
|
114
|
+
paths_with_ids = graph.get_all_paths_with_ids()
|
|
115
|
+
all_paths = [tp.path for tp in paths_with_ids]
|
|
116
|
+
path_to_uuid: dict[tuple[str, ...], str] = {
|
|
117
|
+
tuple(tp.path): tp.topic_id for tp in paths_with_ids
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
# Build topic name to UUID mapping for ALL nodes (not just leaves)
|
|
121
|
+
topic_to_uuid: dict[str, str] = {}
|
|
122
|
+
for node in graph.nodes.values():
|
|
123
|
+
node_uuid = node.metadata.get("uuid", "")
|
|
124
|
+
if node_uuid:
|
|
125
|
+
topic_to_uuid[node.topic] = node_uuid
|
|
126
|
+
|
|
127
|
+
metadata: dict[str, Any] = {
|
|
128
|
+
"total_nodes": len(graph.nodes),
|
|
129
|
+
"has_cycles": graph.has_cycle(),
|
|
130
|
+
"root_topic": graph.root.topic if graph.root else None,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
# Read graph-level metadata directly from the JSON file
|
|
134
|
+
# since Graph.from_json doesn't restore provider/model
|
|
135
|
+
with open(file_path, encoding="utf-8") as f:
|
|
136
|
+
raw_data = json.load(f)
|
|
137
|
+
|
|
138
|
+
if "metadata" in raw_data and raw_data["metadata"]:
|
|
139
|
+
file_metadata = raw_data["metadata"]
|
|
140
|
+
if file_metadata.get("created_at"):
|
|
141
|
+
metadata["created_at"] = file_metadata["created_at"]
|
|
142
|
+
if file_metadata.get("provider"):
|
|
143
|
+
metadata["provider"] = file_metadata["provider"]
|
|
144
|
+
if file_metadata.get("model"):
|
|
145
|
+
metadata["model"] = file_metadata["model"]
|
|
146
|
+
|
|
147
|
+
return all_paths, metadata, path_to_uuid, topic_to_uuid
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def inspect_topic_file(
|
|
151
|
+
file_path: str,
|
|
152
|
+
level: int | None = None,
|
|
153
|
+
expand_depth: int | None = None,
|
|
154
|
+
show_all: bool = False,
|
|
155
|
+
) -> TopicInspectionResult:
|
|
156
|
+
"""Inspect a topic file and return structured results.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
file_path: Path to the topic file
|
|
160
|
+
level: Specific level to show (0=root), or None
|
|
161
|
+
expand_depth: Number of sublevels to show (-1 for all), or None for no expansion
|
|
162
|
+
show_all: Whether to include all paths in result
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
TopicInspectionResult with inspection data
|
|
166
|
+
"""
|
|
167
|
+
format_type = detect_format(file_path)
|
|
168
|
+
|
|
169
|
+
# Load paths and metadata based on format
|
|
170
|
+
topic_to_uuid: dict[str, str] = {}
|
|
171
|
+
if format_type == "graph":
|
|
172
|
+
all_paths, metadata, path_to_uuid, topic_to_uuid = _load_graph_data(file_path)
|
|
173
|
+
else:
|
|
174
|
+
all_paths, path_to_uuid = _load_tree_paths(file_path)
|
|
175
|
+
# Extract root topic from paths
|
|
176
|
+
metadata = {}
|
|
177
|
+
if all_paths:
|
|
178
|
+
metadata["root_topic"] = all_paths[0][0]
|
|
179
|
+
|
|
180
|
+
max_depth = max(len(p) for p in all_paths) if all_paths else 0
|
|
181
|
+
|
|
182
|
+
# Get unique topics at specific level if requested
|
|
183
|
+
# Level 0 = root, Level 1 = children of root, etc.
|
|
184
|
+
paths_at_level = None
|
|
185
|
+
expanded_paths = None
|
|
186
|
+
|
|
187
|
+
if level is not None:
|
|
188
|
+
# Extract unique topic names at the given depth position
|
|
189
|
+
seen_topics: set[str] = set()
|
|
190
|
+
unique_topics: list[str] = []
|
|
191
|
+
for path in all_paths:
|
|
192
|
+
if len(path) > level:
|
|
193
|
+
topic_at_level = path[level]
|
|
194
|
+
if topic_at_level not in seen_topics:
|
|
195
|
+
seen_topics.add(topic_at_level)
|
|
196
|
+
unique_topics.append(topic_at_level)
|
|
197
|
+
# If this topic is a leaf (path ends at level+1), map single-topic to UUID
|
|
198
|
+
if len(path) == level + 1:
|
|
199
|
+
original_uuid = path_to_uuid.get(tuple(path), "")
|
|
200
|
+
if original_uuid:
|
|
201
|
+
path_to_uuid[(topic_at_level,)] = original_uuid
|
|
202
|
+
# Store as single-element paths for consistency
|
|
203
|
+
paths_at_level = [[t] for t in unique_topics]
|
|
204
|
+
|
|
205
|
+
# If expand_depth is set, get paths from level onwards
|
|
206
|
+
if expand_depth is not None:
|
|
207
|
+
seen_paths: set[tuple[str, ...]] = set()
|
|
208
|
+
expanded_paths = []
|
|
209
|
+
for path in all_paths:
|
|
210
|
+
if len(path) > level:
|
|
211
|
+
original_uuid = path_to_uuid.get(tuple(path), "")
|
|
212
|
+
# Trim path to start from the specified level
|
|
213
|
+
trimmed_path = path[level:]
|
|
214
|
+
# Limit depth if expand_depth is not -1
|
|
215
|
+
if expand_depth != -1 and len(trimmed_path) > expand_depth + 1:
|
|
216
|
+
trimmed_path = trimmed_path[: expand_depth + 1]
|
|
217
|
+
# Deduplicate paths (after trimming, many may be identical)
|
|
218
|
+
path_key = tuple(trimmed_path)
|
|
219
|
+
if path_key not in seen_paths:
|
|
220
|
+
seen_paths.add(path_key)
|
|
221
|
+
expanded_paths.append(trimmed_path)
|
|
222
|
+
# Map trimmed path to original UUID (for --uuid display)
|
|
223
|
+
if original_uuid and path_key not in path_to_uuid:
|
|
224
|
+
path_to_uuid[path_key] = original_uuid
|
|
225
|
+
|
|
226
|
+
return TopicInspectionResult(
|
|
227
|
+
format=format_type,
|
|
228
|
+
total_paths=len(all_paths),
|
|
229
|
+
max_depth=max_depth,
|
|
230
|
+
paths_at_level=paths_at_level,
|
|
231
|
+
expanded_paths=expanded_paths,
|
|
232
|
+
all_paths=all_paths if show_all else None,
|
|
233
|
+
metadata=metadata,
|
|
234
|
+
source_file=file_path,
|
|
235
|
+
path_to_uuid=path_to_uuid,
|
|
236
|
+
topic_to_uuid=topic_to_uuid,
|
|
237
|
+
)
|
deepfabric/topic_manager.py
CHANGED
|
@@ -45,6 +45,8 @@ async def _process_graph_events(graph: Graph, debug: bool = False) -> dict | Non
|
|
|
45
45
|
progress_reporter = ProgressReporter()
|
|
46
46
|
progress_reporter.attach(tui)
|
|
47
47
|
graph.progress_reporter = progress_reporter
|
|
48
|
+
if hasattr(graph, "llm_client"):
|
|
49
|
+
graph.llm_client.retry_handler.progress_reporter = progress_reporter
|
|
48
50
|
|
|
49
51
|
tui_started = False
|
|
50
52
|
|
|
@@ -88,6 +90,19 @@ async def _process_graph_events(graph: Graph, debug: bool = False) -> dict | Non
|
|
|
88
90
|
tui.finish_building(failed_generations)
|
|
89
91
|
final_event = event
|
|
90
92
|
|
|
93
|
+
if failed_generations > 0 and hasattr(graph, "failed_generations"):
|
|
94
|
+
truncated = sum(
|
|
95
|
+
1
|
|
96
|
+
for f in graph.failed_generations
|
|
97
|
+
if "EOF while parsing" in f.get("last_error", "")
|
|
98
|
+
)
|
|
99
|
+
if truncated:
|
|
100
|
+
get_tui().warning(
|
|
101
|
+
f"Hint: {truncated} of {failed_generations} failures appear to be "
|
|
102
|
+
f"truncated responses. Consider increasing max_tokens "
|
|
103
|
+
f"(currently {graph.max_tokens})."
|
|
104
|
+
)
|
|
105
|
+
|
|
91
106
|
if debug and failed_generations > 0 and hasattr(graph, "failed_generations"):
|
|
92
107
|
get_tui().error("\nDebug: Graph generation failures:")
|
|
93
108
|
for idx, failure in enumerate(graph.failed_generations, 1):
|
|
@@ -116,6 +131,8 @@ async def _process_tree_events(tree: Tree, debug: bool = False) -> dict | None:
|
|
|
116
131
|
progress_reporter = ProgressReporter()
|
|
117
132
|
progress_reporter.attach(tui)
|
|
118
133
|
tree.progress_reporter = progress_reporter
|
|
134
|
+
if hasattr(tree, "llm_client"):
|
|
135
|
+
tree.llm_client.retry_handler.progress_reporter = progress_reporter
|
|
119
136
|
|
|
120
137
|
final_event = None
|
|
121
138
|
try:
|
|
@@ -129,6 +146,8 @@ async def _process_tree_events(tree: Tree, debug: bool = False) -> dict | None:
|
|
|
129
146
|
tui.add_failure()
|
|
130
147
|
if debug and "error" in event:
|
|
131
148
|
get_tui().error(f"Debug: Tree generation failure - {event['error']}")
|
|
149
|
+
else:
|
|
150
|
+
tui.advance_simple_progress()
|
|
132
151
|
elif event["event"] == "build_complete":
|
|
133
152
|
total_paths = (
|
|
134
153
|
int(event["total_paths"]) if isinstance(event["total_paths"], str | int) else 0
|
|
@@ -141,6 +160,19 @@ async def _process_tree_events(tree: Tree, debug: bool = False) -> dict | None:
|
|
|
141
160
|
tui.finish_building(total_paths, failed_generations)
|
|
142
161
|
final_event = event
|
|
143
162
|
|
|
163
|
+
if failed_generations > 0 and hasattr(tree, "failed_generations"):
|
|
164
|
+
truncated = sum(
|
|
165
|
+
1
|
|
166
|
+
for f in tree.failed_generations
|
|
167
|
+
if "EOF while parsing" in f.get("error", "")
|
|
168
|
+
)
|
|
169
|
+
if truncated:
|
|
170
|
+
get_tui().warning(
|
|
171
|
+
f"Hint: {truncated} of {failed_generations} failures appear to be "
|
|
172
|
+
f"truncated responses. Consider increasing max_tokens "
|
|
173
|
+
f"(currently {tree.max_tokens})."
|
|
174
|
+
)
|
|
175
|
+
|
|
144
176
|
if debug and failed_generations > 0 and hasattr(tree, "failed_generations"):
|
|
145
177
|
get_tui().error("\nDebug: Tree generation failures:")
|
|
146
178
|
for idx, failure in enumerate(tree.failed_generations, 1):
|
|
@@ -233,8 +265,22 @@ def load_or_build_topic_model(
|
|
|
233
265
|
tui = get_tui()
|
|
234
266
|
|
|
235
267
|
if topics_load:
|
|
236
|
-
#
|
|
237
|
-
is_graph = config.topics.mode == "graph"
|
|
268
|
+
# Config mode takes precedence; file extension is only used to warn on mismatch
|
|
269
|
+
is_graph = config.topics.mode == "graph"
|
|
270
|
+
|
|
271
|
+
# Warn if file extension doesn't match the configured mode
|
|
272
|
+
if not is_graph and topics_load.endswith(".json"):
|
|
273
|
+
tui.warning(
|
|
274
|
+
f"File '{topics_load}' has .json extension (typically a graph) "
|
|
275
|
+
f"but mode is '{config.topics.mode}'. "
|
|
276
|
+
"If this is a graph set mode: graph in config."
|
|
277
|
+
)
|
|
278
|
+
elif is_graph and topics_load.endswith(".jsonl"):
|
|
279
|
+
tui.warning(
|
|
280
|
+
f"File '{topics_load}' has .jsonl extension (typically a tree) "
|
|
281
|
+
"but mode is 'graph'. "
|
|
282
|
+
"If this is a tree set mode: tree in config."
|
|
283
|
+
)
|
|
238
284
|
|
|
239
285
|
if is_graph:
|
|
240
286
|
tui.info(f"Reading topic graph from JSON file: {topics_load}")
|
|
@@ -293,6 +339,9 @@ def save_topic_model(
|
|
|
293
339
|
try:
|
|
294
340
|
tree_save_path = topics_save_as or config.topics.save_as or "topic_tree.jsonl"
|
|
295
341
|
topic_model.save(tree_save_path)
|
|
342
|
+
if getattr(topic_model, "failed_generations", None):
|
|
343
|
+
failed_path = tree_save_path.replace(".jsonl", "_failed.jsonl")
|
|
344
|
+
tui.warning(f"Failed generations saved to: {failed_path}")
|
|
296
345
|
tui.success(f"Topic tree saved to {tree_save_path}")
|
|
297
346
|
tui.info(f"Total paths: {len(topic_model.tree_paths)}")
|
|
298
347
|
except Exception as e:
|
|
@@ -302,6 +351,9 @@ def save_topic_model(
|
|
|
302
351
|
try:
|
|
303
352
|
graph_save_path = topics_save_as or config.topics.save_as or "topic_graph.json"
|
|
304
353
|
topic_model.save(graph_save_path)
|
|
354
|
+
if getattr(topic_model, "failed_generations", None):
|
|
355
|
+
failed_path = graph_save_path.replace(".json", "_failed.jsonl")
|
|
356
|
+
tui.warning(f"Failed generations saved to: {failed_path}")
|
|
305
357
|
tui.success(f"Topic graph saved to {graph_save_path}")
|
|
306
358
|
except Exception as e:
|
|
307
359
|
raise ConfigurationError(f"Error saving topic graph: {str(e)}") from e
|
deepfabric/topic_model.py
CHANGED
|
@@ -9,6 +9,18 @@ class TopicPath(NamedTuple):
|
|
|
9
9
|
topic_id: str
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
class Topic(NamedTuple):
|
|
13
|
+
"""A unique topic with its UUID and content.
|
|
14
|
+
|
|
15
|
+
Used for generation where we iterate over unique topics (by UUID)
|
|
16
|
+
rather than paths. This deduplicated view is essential for graphs
|
|
17
|
+
where multiple paths can lead to the same topic node.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
uuid: str
|
|
21
|
+
topic: str # The topic text/content
|
|
22
|
+
|
|
23
|
+
|
|
12
24
|
class TopicModel(ABC):
|
|
13
25
|
"""Abstract base class for topic models like Tree and Graph."""
|
|
14
26
|
|
|
@@ -37,6 +49,20 @@ class TopicModel(ABC):
|
|
|
37
49
|
"""
|
|
38
50
|
raise NotImplementedError
|
|
39
51
|
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def get_unique_topics(self) -> list[Topic]:
|
|
54
|
+
"""Returns deduplicated topics by UUID.
|
|
55
|
+
|
|
56
|
+
For generation, we iterate over unique topics rather than paths.
|
|
57
|
+
This is important for graphs where multiple paths can lead to the
|
|
58
|
+
same topic node - we only want to generate one sample per unique topic.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List of Topic namedtuples containing (uuid, topic).
|
|
62
|
+
Each UUID appears exactly once.
|
|
63
|
+
"""
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
40
66
|
def get_path_by_id(self, topic_id: str) -> list[str] | None:
|
|
41
67
|
"""Look up a path by its topic_id.
|
|
42
68
|
|
deepfabric/tree.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
+
import hashlib
|
|
2
3
|
import json
|
|
3
4
|
import time
|
|
4
5
|
import warnings
|
|
@@ -21,7 +22,7 @@ from .metrics import trace
|
|
|
21
22
|
from .prompts import TreePromptBuilder
|
|
22
23
|
from .schemas import TopicList
|
|
23
24
|
from .stream_simulator import simulate_stream
|
|
24
|
-
from .topic_model import TopicModel, TopicPath
|
|
25
|
+
from .topic_model import Topic, TopicModel, TopicPath
|
|
25
26
|
|
|
26
27
|
warnings.filterwarnings("ignore", message=".*Pydantic serializer warnings:.*")
|
|
27
28
|
|
|
@@ -82,6 +83,17 @@ class TreeConfig(BaseModel):
|
|
|
82
83
|
default=None,
|
|
83
84
|
description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
|
|
84
85
|
)
|
|
86
|
+
max_concurrent: int = Field(
|
|
87
|
+
default=4,
|
|
88
|
+
ge=1,
|
|
89
|
+
le=20,
|
|
90
|
+
description="Maximum concurrent LLM calls during tree expansion (helps avoid rate limits)",
|
|
91
|
+
)
|
|
92
|
+
max_tokens: int = Field(
|
|
93
|
+
default=DEFAULT_MAX_TOKENS,
|
|
94
|
+
ge=1,
|
|
95
|
+
description="Maximum tokens for topic generation LLM calls",
|
|
96
|
+
)
|
|
85
97
|
|
|
86
98
|
|
|
87
99
|
class TreeValidator:
|
|
@@ -147,6 +159,8 @@ class Tree(TopicModel):
|
|
|
147
159
|
self.temperature = self.config.temperature
|
|
148
160
|
self.provider = self.config.provider
|
|
149
161
|
self.model_name = self.config.model_name
|
|
162
|
+
self.max_concurrent = self.config.max_concurrent
|
|
163
|
+
self.max_tokens = self.config.max_tokens
|
|
150
164
|
|
|
151
165
|
# Initialize LLM client
|
|
152
166
|
llm_kwargs = {}
|
|
@@ -242,24 +256,41 @@ class Tree(TopicModel):
|
|
|
242
256
|
"""Returns all the paths in the topic model."""
|
|
243
257
|
return self.tree_paths
|
|
244
258
|
|
|
259
|
+
@staticmethod
|
|
260
|
+
def _path_to_id(path: list[str]) -> str:
|
|
261
|
+
"""Compute a deterministic topic ID from a tree path."""
|
|
262
|
+
return hashlib.sha256(json.dumps(path).encode()).hexdigest()[:16]
|
|
263
|
+
|
|
264
|
+
def _add_path(self, path: list[str]) -> None:
|
|
265
|
+
"""Add a path to the tree.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
path: The topic path to add.
|
|
269
|
+
"""
|
|
270
|
+
self.tree_paths.append(path)
|
|
271
|
+
|
|
245
272
|
def get_all_paths_with_ids(self) -> list[TopicPath]:
|
|
246
273
|
"""Returns all paths with their unique identifiers.
|
|
247
274
|
|
|
248
|
-
For Tree, we generate stable IDs by hashing the path content.
|
|
249
|
-
This ensures the same path always gets the same ID across runs.
|
|
250
|
-
|
|
251
275
|
Returns:
|
|
252
276
|
List of TopicPath namedtuples containing (path, topic_id).
|
|
277
|
+
The topic_id is computed deterministically from the path content.
|
|
253
278
|
"""
|
|
254
|
-
|
|
279
|
+
return [TopicPath(path=path, topic_id=self._path_to_id(path)) for path in self.tree_paths]
|
|
255
280
|
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
281
|
+
def get_unique_topics(self) -> list[Topic]:
|
|
282
|
+
"""Returns all leaf topics with computed IDs.
|
|
283
|
+
|
|
284
|
+
For Trees, each path is unique by definition, so this returns
|
|
285
|
+
all leaf topics with deterministic path-based IDs.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
List of Topic namedtuples containing (uuid, topic).
|
|
289
|
+
"""
|
|
290
|
+
return [
|
|
291
|
+
Topic(uuid=self._path_to_id(path), topic=path[-1] if path else "")
|
|
292
|
+
for path in self.tree_paths
|
|
293
|
+
]
|
|
263
294
|
|
|
264
295
|
async def get_subtopics(
|
|
265
296
|
self, system_prompt: str, node_path: list[str], num_subtopics: int
|
|
@@ -282,7 +313,7 @@ class Tree(TopicModel):
|
|
|
282
313
|
prompt=prompt,
|
|
283
314
|
schema=TopicList,
|
|
284
315
|
max_retries=MAX_RETRY_ATTEMPTS,
|
|
285
|
-
max_tokens=
|
|
316
|
+
max_tokens=self.max_tokens,
|
|
286
317
|
temperature=self.temperature,
|
|
287
318
|
)
|
|
288
319
|
|
|
@@ -293,19 +324,11 @@ class Tree(TopicModel):
|
|
|
293
324
|
source="tree_generation",
|
|
294
325
|
)
|
|
295
326
|
|
|
296
|
-
# Extract
|
|
327
|
+
# Extract subtopics — accept whatever the LLM returned
|
|
297
328
|
subtopics = topic_response.subtopics
|
|
298
|
-
if len(subtopics) >= num_subtopics:
|
|
299
|
-
return subtopics[:num_subtopics]
|
|
300
|
-
|
|
301
|
-
# If insufficient subtopics, pad with defaults
|
|
302
|
-
while len(subtopics) < num_subtopics:
|
|
303
|
-
subtopics.append(f"subtopic_{len(subtopics) + 1}_for_{node_path[-1]}")
|
|
304
|
-
|
|
305
329
|
return subtopics[:num_subtopics]
|
|
306
330
|
|
|
307
331
|
except Exception as e:
|
|
308
|
-
# Log the failure and return default subtopics
|
|
309
332
|
self.failed_generations.append(
|
|
310
333
|
{
|
|
311
334
|
"node_path": node_path,
|
|
@@ -313,9 +336,7 @@ class Tree(TopicModel):
|
|
|
313
336
|
"timestamp": time.time(),
|
|
314
337
|
}
|
|
315
338
|
)
|
|
316
|
-
|
|
317
|
-
# Generate default subtopics
|
|
318
|
-
return [f"subtopic_{i + 1}_for_{node_path[-1]}" for i in range(num_subtopics)]
|
|
339
|
+
return []
|
|
319
340
|
|
|
320
341
|
def _detect_domain(self, system_prompt: str, node_path: list[str]) -> str:
|
|
321
342
|
"""Detect the appropriate domain for prompt examples based on context."""
|
|
@@ -361,7 +382,7 @@ class Tree(TopicModel):
|
|
|
361
382
|
yield {"event": "subtree_start", "node_path": node_path, "depth": current_depth}
|
|
362
383
|
|
|
363
384
|
if current_depth > total_depth:
|
|
364
|
-
self.
|
|
385
|
+
self._add_path(node_path)
|
|
365
386
|
yield {"event": "leaf_reached", "path": node_path}
|
|
366
387
|
return
|
|
367
388
|
|
|
@@ -383,27 +404,43 @@ class Tree(TopicModel):
|
|
|
383
404
|
yield event
|
|
384
405
|
|
|
385
406
|
if not subtopics:
|
|
386
|
-
self.
|
|
407
|
+
self._add_path(node_path)
|
|
387
408
|
yield {"event": "leaf_reached", "path": node_path}
|
|
388
409
|
return
|
|
389
410
|
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
411
|
+
event_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue()
|
|
412
|
+
semaphore = asyncio.Semaphore(self.max_concurrent)
|
|
413
|
+
|
|
414
|
+
async def _expand_child(child_subtopic: str) -> None:
|
|
415
|
+
async with semaphore:
|
|
416
|
+
child_path = node_path + [child_subtopic]
|
|
417
|
+
async for child_event in self._build_subtree_generator(
|
|
418
|
+
child_path, system_prompt, total_depth, n_child, current_depth + 1
|
|
419
|
+
):
|
|
420
|
+
await event_queue.put(child_event)
|
|
398
421
|
|
|
399
|
-
tasks = [asyncio.create_task(
|
|
422
|
+
tasks = [asyncio.create_task(_expand_child(s)) for s in subtopics]
|
|
400
423
|
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
424
|
+
async def _signal_done() -> None:
|
|
425
|
+
await asyncio.gather(*tasks)
|
|
426
|
+
await event_queue.put(None)
|
|
427
|
+
|
|
428
|
+
done_task = asyncio.create_task(_signal_done())
|
|
429
|
+
|
|
430
|
+
while True:
|
|
431
|
+
event = await event_queue.get()
|
|
432
|
+
if event is None:
|
|
433
|
+
break
|
|
434
|
+
yield event
|
|
435
|
+
|
|
436
|
+
await done_task
|
|
404
437
|
|
|
405
438
|
def save(self, save_path: str) -> None:
|
|
406
|
-
"""Save the topic tree to a file.
|
|
439
|
+
"""Save the topic tree to a file.
|
|
440
|
+
|
|
441
|
+
Format: {"path": [...]}
|
|
442
|
+
IDs are computed on-the-fly from path content, not persisted.
|
|
443
|
+
"""
|
|
407
444
|
from pathlib import Path # noqa: PLC0415
|
|
408
445
|
|
|
409
446
|
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
@@ -446,8 +483,11 @@ class Tree(TopicModel):
|
|
|
446
483
|
def from_dict_list(self, dict_list: list[dict[str, Any]]) -> None:
|
|
447
484
|
"""Construct the topic tree from a list of dictionaries.
|
|
448
485
|
|
|
486
|
+
Accepts both the current format (``{"path": [...]}``) and the
|
|
487
|
+
legacy format that included a ``leaf_uuid`` field (silently ignored).
|
|
488
|
+
|
|
449
489
|
Args:
|
|
450
|
-
dict_list
|
|
490
|
+
dict_list: The list of dictionaries representing the topic tree.
|
|
451
491
|
"""
|
|
452
492
|
# Clear existing data
|
|
453
493
|
self.tree_paths = []
|