DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/tree.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import time
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from typing import TYPE_CHECKING, Any, TypedDict
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
9
|
+
|
|
10
|
+
from .constants import (
|
|
11
|
+
DEFAULT_MAX_TOKENS,
|
|
12
|
+
MAX_RETRY_ATTEMPTS,
|
|
13
|
+
TOPIC_TREE_DEFAULT_DEGREE,
|
|
14
|
+
TOPIC_TREE_DEFAULT_DEPTH,
|
|
15
|
+
TOPIC_TREE_DEFAULT_MODEL,
|
|
16
|
+
TOPIC_TREE_DEFAULT_TEMPERATURE,
|
|
17
|
+
)
|
|
18
|
+
from .exceptions import TreeError
|
|
19
|
+
from .llm import LLMClient
|
|
20
|
+
from .metrics import trace
|
|
21
|
+
from .prompts import TreePromptBuilder
|
|
22
|
+
from .schemas import TopicList
|
|
23
|
+
from .stream_simulator import simulate_stream
|
|
24
|
+
from .topic_model import TopicModel
|
|
25
|
+
|
|
26
|
+
warnings.filterwarnings("ignore", message=".*Pydantic serializer warnings:.*")
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING: # only for type hints to avoid runtime cycles
|
|
29
|
+
from .progress import ProgressReporter
|
|
30
|
+
|
|
31
|
+
UPPER_DEGREE = 50
|
|
32
|
+
UPPER_DEPTH = 10
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ValidationResult(TypedDict, total=False):
|
|
36
|
+
valid: bool
|
|
37
|
+
total_tree_paths: int
|
|
38
|
+
total_requested_paths: int
|
|
39
|
+
recommended_batch_size: int
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class TreeConfig(BaseModel):
|
|
43
|
+
"""Configuration for constructing a topic tree."""
|
|
44
|
+
|
|
45
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
46
|
+
|
|
47
|
+
topic_prompt: str = Field(
|
|
48
|
+
..., min_length=1, description="The initial prompt to start the topic tree"
|
|
49
|
+
)
|
|
50
|
+
topic_system_prompt: str = Field(
|
|
51
|
+
default="", description="System prompt for topic exploration and generation"
|
|
52
|
+
)
|
|
53
|
+
degree: int = Field(
|
|
54
|
+
default=TOPIC_TREE_DEFAULT_DEGREE,
|
|
55
|
+
ge=1,
|
|
56
|
+
le=UPPER_DEGREE,
|
|
57
|
+
description="The branching factor of the tree",
|
|
58
|
+
)
|
|
59
|
+
depth: int = Field(
|
|
60
|
+
default=TOPIC_TREE_DEFAULT_DEPTH,
|
|
61
|
+
ge=1,
|
|
62
|
+
le=UPPER_DEPTH,
|
|
63
|
+
description="The depth of the tree",
|
|
64
|
+
)
|
|
65
|
+
provider: str = Field(
|
|
66
|
+
default="ollama",
|
|
67
|
+
min_length=1,
|
|
68
|
+
description="LLM provider (openai, anthropic, gemini, ollama)",
|
|
69
|
+
)
|
|
70
|
+
model_name: str = Field(
|
|
71
|
+
default=TOPIC_TREE_DEFAULT_MODEL,
|
|
72
|
+
min_length=1,
|
|
73
|
+
description="The name of the model to be used",
|
|
74
|
+
)
|
|
75
|
+
temperature: float = Field(
|
|
76
|
+
default=TOPIC_TREE_DEFAULT_TEMPERATURE,
|
|
77
|
+
ge=0.0,
|
|
78
|
+
le=2.0,
|
|
79
|
+
description="Temperature for model generation",
|
|
80
|
+
)
|
|
81
|
+
base_url: str | None = Field(
|
|
82
|
+
default=None,
|
|
83
|
+
description="Base URL for API endpoint (e.g., custom OpenAI-compatible servers)",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class TreeValidator:
|
|
88
|
+
"""TreeValidator validates and calculates unique paths in a tree structure."""
|
|
89
|
+
|
|
90
|
+
def __init__(self, degree: int, depth: int):
|
|
91
|
+
self.degree = degree
|
|
92
|
+
self.depth = depth
|
|
93
|
+
|
|
94
|
+
def calculate_paths(self) -> int:
|
|
95
|
+
"""Calculate total number of paths in the tree."""
|
|
96
|
+
return self.degree**self.depth
|
|
97
|
+
|
|
98
|
+
def validate_configuration(self, num_steps: int, batch_size: int) -> ValidationResult:
|
|
99
|
+
"""
|
|
100
|
+
Validates the tree configuration and provides recommendations if invalid.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
num_steps: Number of steps requested.
|
|
104
|
+
batch_size: Batch size per step.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
A ValidationResult dict containing validity, totals, and recommendations.
|
|
108
|
+
"""
|
|
109
|
+
total_requested_paths = num_steps * batch_size
|
|
110
|
+
total_tree_paths = self.calculate_paths()
|
|
111
|
+
|
|
112
|
+
print(f"Total tree paths available: {total_tree_paths}")
|
|
113
|
+
print(f"Total requested paths: {total_requested_paths}")
|
|
114
|
+
|
|
115
|
+
result: ValidationResult = {
|
|
116
|
+
"valid": total_requested_paths <= total_tree_paths,
|
|
117
|
+
"total_tree_paths": total_tree_paths,
|
|
118
|
+
"total_requested_paths": total_requested_paths,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
if not result["valid"]:
|
|
122
|
+
print(
|
|
123
|
+
"Requested paths (%d) exceed available tree paths (%d).",
|
|
124
|
+
total_requested_paths,
|
|
125
|
+
total_tree_paths,
|
|
126
|
+
)
|
|
127
|
+
result["recommended_batch_size"] = min(batch_size, total_tree_paths)
|
|
128
|
+
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Tree(TopicModel):
|
|
133
|
+
"""A class to represent and build a hierarchical topic tree."""
|
|
134
|
+
|
|
135
|
+
def __init__(self, **kwargs):
|
|
136
|
+
"""Initialize the Tree with the given parameters."""
|
|
137
|
+
try:
|
|
138
|
+
self.config = TreeConfig.model_validate(kwargs)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
raise TreeError(f"Invalid tree configuration: {str(e)}") from e # noqa: TRY003
|
|
141
|
+
|
|
142
|
+
# Initialize from config
|
|
143
|
+
self.topic_prompt = self.config.topic_prompt
|
|
144
|
+
self.model_system_prompt = self.config.topic_system_prompt
|
|
145
|
+
self.degree = self.config.degree
|
|
146
|
+
self.depth = self.config.depth
|
|
147
|
+
self.temperature = self.config.temperature
|
|
148
|
+
self.provider = self.config.provider
|
|
149
|
+
self.model_name = self.config.model_name
|
|
150
|
+
|
|
151
|
+
# Initialize LLM client
|
|
152
|
+
llm_kwargs = {}
|
|
153
|
+
if self.config.base_url:
|
|
154
|
+
llm_kwargs["base_url"] = self.config.base_url
|
|
155
|
+
|
|
156
|
+
self.llm_client = LLMClient(
|
|
157
|
+
provider=self.provider,
|
|
158
|
+
model_name=self.model_name,
|
|
159
|
+
**llm_kwargs,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Progress reporter for streaming feedback (set by topic_manager)
|
|
163
|
+
self.progress_reporter: ProgressReporter | None = None
|
|
164
|
+
|
|
165
|
+
trace(
|
|
166
|
+
"tree_created",
|
|
167
|
+
{
|
|
168
|
+
"provider": self.provider,
|
|
169
|
+
"model_name": self.model_name,
|
|
170
|
+
"degree": self.degree,
|
|
171
|
+
"depth": self.depth,
|
|
172
|
+
},
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Derived attributes
|
|
176
|
+
self.system_prompt = self.config.topic_system_prompt
|
|
177
|
+
self.tree_paths: list[list[str]] = []
|
|
178
|
+
self.failed_generations: list[dict[str, Any]] = []
|
|
179
|
+
|
|
180
|
+
async def build_async(self, model_name: str | None = None):
|
|
181
|
+
"""Build the complete topic tree.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
model_name: Optional model name to override the configured model
|
|
185
|
+
|
|
186
|
+
Yields:
|
|
187
|
+
dict: Progress events with event type and associated data
|
|
188
|
+
"""
|
|
189
|
+
if model_name:
|
|
190
|
+
self.model_name = model_name
|
|
191
|
+
|
|
192
|
+
yield {
|
|
193
|
+
"event": "build_start",
|
|
194
|
+
"model_name": self.model_name,
|
|
195
|
+
"depth": self.config.depth,
|
|
196
|
+
"degree": self.config.degree,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
def _raise_if_build_failed():
|
|
200
|
+
"""Check if build failed completely and raise appropriate error."""
|
|
201
|
+
if len(self.tree_paths) == 0 and self.failed_generations:
|
|
202
|
+
error_msg = f"Tree build failed completely: all {len(self.failed_generations)} generation attempts failed. No topic paths created."
|
|
203
|
+
raise RuntimeError(error_msg)
|
|
204
|
+
|
|
205
|
+
try:
|
|
206
|
+
async for event in self._build_subtree_generator(
|
|
207
|
+
[self.config.topic_prompt],
|
|
208
|
+
self.config.topic_system_prompt,
|
|
209
|
+
self.config.depth,
|
|
210
|
+
self.config.degree,
|
|
211
|
+
1,
|
|
212
|
+
):
|
|
213
|
+
yield event
|
|
214
|
+
|
|
215
|
+
# Check if build was completely unsuccessful (no paths generated)
|
|
216
|
+
_raise_if_build_failed()
|
|
217
|
+
|
|
218
|
+
trace(
|
|
219
|
+
"tree_built",
|
|
220
|
+
{
|
|
221
|
+
"provider": self.provider,
|
|
222
|
+
"model_name": self.model_name,
|
|
223
|
+
"total_paths": len(self.tree_paths),
|
|
224
|
+
"failed_generations": len(self.failed_generations),
|
|
225
|
+
"success": len(self.tree_paths) > 0,
|
|
226
|
+
},
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
yield {
|
|
230
|
+
"event": "build_complete",
|
|
231
|
+
"total_paths": len(self.tree_paths),
|
|
232
|
+
"failed_generations": len(self.failed_generations),
|
|
233
|
+
}
|
|
234
|
+
except Exception as e:
|
|
235
|
+
yield {"event": "error", "error": str(e)}
|
|
236
|
+
# Save partial results before re-raising
|
|
237
|
+
if self.tree_paths:
|
|
238
|
+
self.save("partial_tree.jsonl")
|
|
239
|
+
raise
|
|
240
|
+
|
|
241
|
+
def get_all_paths(self) -> list[list[str]]:
|
|
242
|
+
"""Returns all the paths in the topic model."""
|
|
243
|
+
return self.tree_paths
|
|
244
|
+
|
|
245
|
+
async def get_subtopics(
|
|
246
|
+
self, system_prompt: str, node_path: list[str], num_subtopics: int
|
|
247
|
+
) -> list[str]:
|
|
248
|
+
"""Generate subtopics using structured generation."""
|
|
249
|
+
|
|
250
|
+
# Determine domain based on system prompt or path content
|
|
251
|
+
domain = self._detect_domain(system_prompt, node_path)
|
|
252
|
+
|
|
253
|
+
prompt = TreePromptBuilder.build_expansion_prompt(
|
|
254
|
+
topic_path=node_path,
|
|
255
|
+
num_subtopics=num_subtopics,
|
|
256
|
+
system_prompt=system_prompt if system_prompt else "",
|
|
257
|
+
domain=domain,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
# Always use non-streaming for reliable structured output
|
|
262
|
+
topic_response = await self.llm_client.generate_async(
|
|
263
|
+
prompt=prompt,
|
|
264
|
+
schema=TopicList,
|
|
265
|
+
max_retries=MAX_RETRY_ATTEMPTS,
|
|
266
|
+
max_tokens=DEFAULT_MAX_TOKENS,
|
|
267
|
+
temperature=self.temperature,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Fire-and-forget: simulate streaming for TUI preview (non-blocking)
|
|
271
|
+
simulate_stream(
|
|
272
|
+
self.progress_reporter,
|
|
273
|
+
topic_response.model_dump_json(indent=2),
|
|
274
|
+
source="tree_generation",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Extract and validate subtopics
|
|
278
|
+
subtopics = topic_response.subtopics
|
|
279
|
+
if len(subtopics) >= num_subtopics:
|
|
280
|
+
return subtopics[:num_subtopics]
|
|
281
|
+
|
|
282
|
+
# If insufficient subtopics, pad with defaults
|
|
283
|
+
while len(subtopics) < num_subtopics:
|
|
284
|
+
subtopics.append(f"subtopic_{len(subtopics) + 1}_for_{node_path[-1]}")
|
|
285
|
+
|
|
286
|
+
return subtopics[:num_subtopics]
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
# Log the failure and return default subtopics
|
|
290
|
+
self.failed_generations.append(
|
|
291
|
+
{
|
|
292
|
+
"node_path": node_path,
|
|
293
|
+
"error": str(e),
|
|
294
|
+
"timestamp": time.time(),
|
|
295
|
+
}
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Generate default subtopics
|
|
299
|
+
return [f"subtopic_{i + 1}_for_{node_path[-1]}" for i in range(num_subtopics)]
|
|
300
|
+
|
|
301
|
+
def _detect_domain(self, system_prompt: str, node_path: list[str]) -> str:
|
|
302
|
+
"""Detect the appropriate domain for prompt examples based on context."""
|
|
303
|
+
combined_text = f"{system_prompt} {' '.join(node_path)}".lower()
|
|
304
|
+
|
|
305
|
+
if any(
|
|
306
|
+
word in combined_text
|
|
307
|
+
for word in ["math", "calculus", "algebra", "geometry", "equation"]
|
|
308
|
+
):
|
|
309
|
+
return "educational"
|
|
310
|
+
if any(
|
|
311
|
+
word in combined_text
|
|
312
|
+
for word in ["programming", "code", "software", "python", "algorithm"]
|
|
313
|
+
):
|
|
314
|
+
return "technical"
|
|
315
|
+
if any(
|
|
316
|
+
word in combined_text
|
|
317
|
+
for word in ["chat", "conversation", "talk", "friendly", "assistant"]
|
|
318
|
+
):
|
|
319
|
+
return "conversational"
|
|
320
|
+
return "general"
|
|
321
|
+
|
|
322
|
+
async def _build_subtree_generator(
|
|
323
|
+
self,
|
|
324
|
+
node_path: list[str],
|
|
325
|
+
system_prompt: str,
|
|
326
|
+
total_depth: int,
|
|
327
|
+
n_child: int,
|
|
328
|
+
current_depth: int,
|
|
329
|
+
):
|
|
330
|
+
"""Recursively build a subtree of topics, yielding progress events.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
node_path: Current path in the tree
|
|
334
|
+
system_prompt: System prompt for topic generation
|
|
335
|
+
total_depth: Maximum depth of the tree
|
|
336
|
+
n_child: Number of child nodes per parent
|
|
337
|
+
current_depth: Current depth in the tree
|
|
338
|
+
|
|
339
|
+
Yields:
|
|
340
|
+
dict: Progress events
|
|
341
|
+
"""
|
|
342
|
+
yield {"event": "subtree_start", "node_path": node_path, "depth": current_depth}
|
|
343
|
+
|
|
344
|
+
if current_depth > total_depth:
|
|
345
|
+
self.tree_paths.append(node_path)
|
|
346
|
+
yield {"event": "leaf_reached", "path": node_path}
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
subtopics = await self.get_subtopics(system_prompt, node_path, n_child)
|
|
350
|
+
|
|
351
|
+
event = {
|
|
352
|
+
"event": "subtopics_generated",
|
|
353
|
+
"parent_path": node_path,
|
|
354
|
+
"count": len(subtopics),
|
|
355
|
+
"success": bool(subtopics),
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
# Include error information if generation failed
|
|
359
|
+
if not event["success"] and self.failed_generations:
|
|
360
|
+
# Get the most recent failure
|
|
361
|
+
recent_failure = self.failed_generations[-1]
|
|
362
|
+
event["error"] = recent_failure.get("error", "Unknown error")
|
|
363
|
+
|
|
364
|
+
yield event
|
|
365
|
+
|
|
366
|
+
if not subtopics:
|
|
367
|
+
self.tree_paths.append(node_path)
|
|
368
|
+
yield {"event": "leaf_reached", "path": node_path}
|
|
369
|
+
return
|
|
370
|
+
|
|
371
|
+
async def _collect_child_events(child_subtopic: str) -> list[dict[str, Any]]:
|
|
372
|
+
child_path = node_path + [child_subtopic]
|
|
373
|
+
events: list[dict[str, Any]] = []
|
|
374
|
+
async for child_event in self._build_subtree_generator(
|
|
375
|
+
child_path, system_prompt, total_depth, n_child, current_depth + 1
|
|
376
|
+
):
|
|
377
|
+
events.append(child_event)
|
|
378
|
+
return events
|
|
379
|
+
|
|
380
|
+
tasks = [asyncio.create_task(_collect_child_events(subtopic)) for subtopic in subtopics]
|
|
381
|
+
|
|
382
|
+
for child_events in await asyncio.gather(*tasks):
|
|
383
|
+
for child_event in child_events:
|
|
384
|
+
yield child_event
|
|
385
|
+
|
|
386
|
+
def save(self, save_path: str) -> None:
|
|
387
|
+
"""Save the topic tree to a file."""
|
|
388
|
+
with open(save_path, "w") as f:
|
|
389
|
+
for path in self.tree_paths:
|
|
390
|
+
f.write(json.dumps({"path": path}) + "\n")
|
|
391
|
+
|
|
392
|
+
# Save failed generations if any
|
|
393
|
+
if self.failed_generations:
|
|
394
|
+
failed_path = save_path.replace(".jsonl", "_failed.jsonl")
|
|
395
|
+
with open(failed_path, "w") as f:
|
|
396
|
+
for failed in self.failed_generations:
|
|
397
|
+
f.write(json.dumps({"failed_generation": failed}) + "\n")
|
|
398
|
+
|
|
399
|
+
def print_tree(self) -> None:
|
|
400
|
+
"""Print the topic tree in a readable format."""
|
|
401
|
+
print("Topic Tree Structure:")
|
|
402
|
+
for path in self.tree_paths:
|
|
403
|
+
print(" -> ".join(path))
|
|
404
|
+
|
|
405
|
+
def to_dict(self) -> dict[str, Any]:
|
|
406
|
+
"""Convert the topic tree to a dictionary representation.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
dict: Dictionary containing the tree structure and metadata
|
|
410
|
+
"""
|
|
411
|
+
return {
|
|
412
|
+
"tree_paths": self.tree_paths,
|
|
413
|
+
"failed_generations": self.failed_generations,
|
|
414
|
+
"config": {
|
|
415
|
+
"topic_prompt": self.topic_prompt,
|
|
416
|
+
"degree": self.degree,
|
|
417
|
+
"depth": self.depth,
|
|
418
|
+
"temperature": self.temperature,
|
|
419
|
+
"provider": self.provider,
|
|
420
|
+
"model_name": self.model_name,
|
|
421
|
+
},
|
|
422
|
+
}
|
|
423
|
+
|
|
424
|
+
def from_dict_list(self, dict_list: list[dict[str, Any]]) -> None:
|
|
425
|
+
"""Construct the topic tree from a list of dictionaries.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
dict_list (list[dict]): The list of dictionaries representing the topic tree.
|
|
429
|
+
"""
|
|
430
|
+
# Clear existing data
|
|
431
|
+
self.tree_paths = []
|
|
432
|
+
self.failed_generations = []
|
|
433
|
+
|
|
434
|
+
for d in dict_list:
|
|
435
|
+
if "path" in d:
|
|
436
|
+
self.tree_paths.append(d["path"])
|
|
437
|
+
elif "failed_generation" in d:
|
|
438
|
+
self.failed_generations.append(d["failed_generation"])
|