DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Base classes for evaluation result reporting."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from ..evaluator import EvaluationResult
|
|
10
|
+
from ..metrics import SampleEvaluation
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseReporter(ABC):
|
|
14
|
+
"""Base class for evaluation result reporters.
|
|
15
|
+
|
|
16
|
+
Reporters handle the output of evaluation results. They can write to
|
|
17
|
+
local files, send to cloud services, or perform other actions with
|
|
18
|
+
the evaluation data.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, config: dict | None = None):
|
|
22
|
+
"""Initialize reporter with optional configuration.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
config: Optional reporter-specific configuration
|
|
26
|
+
"""
|
|
27
|
+
self.config = config or {}
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def report(self, result: EvaluationResult) -> None:
|
|
31
|
+
"""Report complete evaluation results.
|
|
32
|
+
|
|
33
|
+
Called once at the end of evaluation with all results.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
result: Complete evaluation result
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def report_sample(self, sample_eval: SampleEvaluation) -> None: # noqa: B027
|
|
40
|
+
"""Report individual sample evaluation (optional, for streaming).
|
|
41
|
+
|
|
42
|
+
Called after each sample is evaluated, allowing real-time reporting.
|
|
43
|
+
Default implementation does nothing.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
sample_eval: Individual sample evaluation result
|
|
47
|
+
"""
|
|
48
|
+
pass # Optional hook for subclasses
|
|
49
|
+
|
|
50
|
+
def get_name(self) -> str:
|
|
51
|
+
"""Return reporter identifier.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Reporter name (e.g., "file", "cloud")
|
|
55
|
+
"""
|
|
56
|
+
return self.__class__.__name__.replace("Reporter", "").lower()
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""Cloud-based reporter for sending results to DeepFabric Cloud."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
from datetime import UTC, datetime
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
|
|
16
|
+
from .base import BaseReporter
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ..evaluator import EvaluationResult
|
|
20
|
+
from ..metrics import SampleEvaluation
|
|
21
|
+
|
|
22
|
+
console = Console()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_auth_token() -> str | None:
|
|
26
|
+
"""Get authentication token from CLI config."""
|
|
27
|
+
config_file = Path.home() / ".deepfabric" / "config.json"
|
|
28
|
+
if not config_file.exists():
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
with open(config_file) as f:
|
|
33
|
+
config = json.load(f)
|
|
34
|
+
# Return API key if present, otherwise access token
|
|
35
|
+
return config.get("api_key") or config.get("access_token")
|
|
36
|
+
except (json.JSONDecodeError, OSError):
|
|
37
|
+
return None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CloudReporter(BaseReporter):
|
|
41
|
+
"""Posts evaluation results to DeepFabric cloud service."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, config: dict | None = None):
|
|
44
|
+
"""Initialize cloud reporter.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
config: Optional configuration with:
|
|
48
|
+
- api_url: DeepFabric API URL (default: https://api.deepfabric.dev")
|
|
49
|
+
- project_id: Project ID to associate results with
|
|
50
|
+
- auth_token: Authentication token (if not provided, will read from config file)
|
|
51
|
+
- enabled: Whether to enable cloud reporting (default: True if authenticated)
|
|
52
|
+
"""
|
|
53
|
+
super().__init__(config)
|
|
54
|
+
|
|
55
|
+
# Get API URL from config or environment
|
|
56
|
+
self.api_url = os.getenv("DEEPFABRIC_API_URL", "https://api.deepfabric.dev")
|
|
57
|
+
if config and "api_url" in config:
|
|
58
|
+
self.api_url = config["api_url"]
|
|
59
|
+
|
|
60
|
+
# Get auth token from config or CLI config file
|
|
61
|
+
self.auth_token = None
|
|
62
|
+
if config and "auth_token" in config:
|
|
63
|
+
self.auth_token = config["auth_token"]
|
|
64
|
+
else:
|
|
65
|
+
self.auth_token = get_auth_token()
|
|
66
|
+
|
|
67
|
+
# Get project ID from config
|
|
68
|
+
self.project_id = config.get("project_id") if config else None
|
|
69
|
+
|
|
70
|
+
# Enable cloud reporting if authenticated
|
|
71
|
+
self.enabled = (
|
|
72
|
+
config.get("enabled", bool(self.auth_token)) if config else bool(self.auth_token)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Generate unique run ID for this evaluation
|
|
76
|
+
self.run_id = None # Will be set when creating run
|
|
77
|
+
self.evaluation_run_id = None # Backend run ID
|
|
78
|
+
|
|
79
|
+
def report(self, result: EvaluationResult) -> None:
|
|
80
|
+
"""Upload complete evaluation results to cloud service.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
result: Complete evaluation result
|
|
84
|
+
"""
|
|
85
|
+
if not self.enabled:
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
if not self.auth_token:
|
|
89
|
+
console.print(
|
|
90
|
+
"[yellow]Cloud reporting skipped: Not authenticated. "
|
|
91
|
+
"Run 'deepfabric auth login' to enable cloud sync.[/yellow]"
|
|
92
|
+
)
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
if not self.project_id:
|
|
96
|
+
console.print("[yellow]Cloud reporting skipped: No project_id configured.[/yellow]")
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
console.print("[cyan]Uploading evaluation results to cloud...[/cyan]")
|
|
101
|
+
|
|
102
|
+
# Create evaluation run
|
|
103
|
+
run_data = {
|
|
104
|
+
"project_id": self.project_id,
|
|
105
|
+
"name": f"Evaluation - {datetime.now(UTC).strftime('%Y-%m-%d %H:%M')}",
|
|
106
|
+
"model_name": result.config.inference_config.model_path,
|
|
107
|
+
"model_provider": result.config.inference_config.backend,
|
|
108
|
+
"config": {
|
|
109
|
+
"evaluators": getattr(result.config, "evaluators", ["tool_calling"]),
|
|
110
|
+
"inference": result.config.inference_config.model_dump(),
|
|
111
|
+
},
|
|
112
|
+
"status": "completed",
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
with httpx.Client(timeout=30.0) as client:
|
|
116
|
+
# Create run
|
|
117
|
+
response = client.post(
|
|
118
|
+
f"{self.api_url}/api/v1/evaluations/runs",
|
|
119
|
+
json=run_data,
|
|
120
|
+
headers={
|
|
121
|
+
"Authorization": f"Bearer {self.auth_token}",
|
|
122
|
+
"Content-Type": "application/json",
|
|
123
|
+
},
|
|
124
|
+
)
|
|
125
|
+
response.raise_for_status()
|
|
126
|
+
run_response = response.json()
|
|
127
|
+
self.evaluation_run_id = run_response["id"]
|
|
128
|
+
|
|
129
|
+
console.print(f"[green]v[/green] Created evaluation run: {self.evaluation_run_id}")
|
|
130
|
+
|
|
131
|
+
# Upload metrics
|
|
132
|
+
metrics_data = {
|
|
133
|
+
"overall_score": result.metrics.overall_score,
|
|
134
|
+
"tool_selection_accuracy": result.metrics.tool_selection_accuracy,
|
|
135
|
+
"parameter_accuracy": result.metrics.parameter_accuracy,
|
|
136
|
+
"execution_success_rate": result.metrics.execution_success_rate,
|
|
137
|
+
"response_quality": result.metrics.response_quality,
|
|
138
|
+
"samples_evaluated": result.metrics.samples_evaluated,
|
|
139
|
+
"samples_processed": result.metrics.samples_processed,
|
|
140
|
+
"processing_errors": result.metrics.processing_errors,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
response = client.post(
|
|
144
|
+
f"{self.api_url}/api/v1/evaluations/runs/{self.evaluation_run_id}/metrics",
|
|
145
|
+
json=metrics_data,
|
|
146
|
+
headers={
|
|
147
|
+
"Authorization": f"Bearer {self.auth_token}",
|
|
148
|
+
"Content-Type": "application/json",
|
|
149
|
+
},
|
|
150
|
+
)
|
|
151
|
+
response.raise_for_status()
|
|
152
|
+
|
|
153
|
+
console.print("[green]v[/green] Uploaded metrics")
|
|
154
|
+
|
|
155
|
+
# Upload samples in batches
|
|
156
|
+
batch_size = 100
|
|
157
|
+
samples = []
|
|
158
|
+
for s in result.predictions:
|
|
159
|
+
sample_dict = s.model_dump()
|
|
160
|
+
# Convert sample_id to string (backend expects str, CLI uses int)
|
|
161
|
+
sample_dict["sample_id"] = str(sample_dict["sample_id"])
|
|
162
|
+
samples.append(sample_dict)
|
|
163
|
+
|
|
164
|
+
for i in range(0, len(samples), batch_size):
|
|
165
|
+
batch = samples[i : i + batch_size]
|
|
166
|
+
response = client.post(
|
|
167
|
+
f"{self.api_url}/api/v1/evaluations/runs/{self.evaluation_run_id}/samples",
|
|
168
|
+
json={"samples": batch},
|
|
169
|
+
headers={
|
|
170
|
+
"Authorization": f"Bearer {self.auth_token}",
|
|
171
|
+
"Content-Type": "application/json",
|
|
172
|
+
},
|
|
173
|
+
)
|
|
174
|
+
response.raise_for_status()
|
|
175
|
+
|
|
176
|
+
console.print(f"[green]v[/green] Uploaded {len(samples)} samples")
|
|
177
|
+
console.print("[green]Results uploaded successfully![/green]")
|
|
178
|
+
console.print(
|
|
179
|
+
f"View at: {self.api_url.replace(':8080', ':3000')}/studio/evaluations/{self.evaluation_run_id}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
except httpx.HTTPError as e:
|
|
183
|
+
console.print(f"[red]Cloud upload failed: {e}[/red]")
|
|
184
|
+
except Exception as e:
|
|
185
|
+
console.print(f"[red]Cloud upload error: {e}[/red]")
|
|
186
|
+
|
|
187
|
+
def report_sample(self, sample_eval: SampleEvaluation) -> None: # noqa: ARG002
|
|
188
|
+
"""Stream individual sample to cloud for real-time progress tracking.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
sample_eval: Individual sample evaluation result
|
|
192
|
+
"""
|
|
193
|
+
# Real-time streaming not implemented yet
|
|
194
|
+
# Samples are uploaded in batch in report()
|
|
195
|
+
pass
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""File-based reporter for writing results to local JSON files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from .base import BaseReporter
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..evaluator import EvaluationResult
|
|
14
|
+
from ..metrics import SampleEvaluation
|
|
15
|
+
|
|
16
|
+
console = Console()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FileReporter(BaseReporter):
|
|
20
|
+
"""Writes evaluation results to local JSON file.
|
|
21
|
+
|
|
22
|
+
This is the default reporter that maintains backwards compatibility
|
|
23
|
+
with the original file-based output.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, config: dict | None = None):
|
|
27
|
+
"""Initialize file reporter.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
config: Optional configuration with 'path' key
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(config)
|
|
33
|
+
self.output_path = config.get("path") if config else None
|
|
34
|
+
|
|
35
|
+
def report(self, result: EvaluationResult) -> None:
|
|
36
|
+
"""Write evaluation results to JSON file.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
result: Complete evaluation result
|
|
40
|
+
"""
|
|
41
|
+
# Use path from config, or fall back to result's config
|
|
42
|
+
output_path = self.output_path or result.config.output_path
|
|
43
|
+
|
|
44
|
+
if output_path is None:
|
|
45
|
+
console.print("[yellow]No output path specified, skipping file write[/yellow]")
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
path = Path(output_path)
|
|
49
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
|
|
51
|
+
with path.open("w") as f:
|
|
52
|
+
f.write(result.model_dump_json(indent=2))
|
|
53
|
+
|
|
54
|
+
console.print(f"[green]Results saved to {path}[/green]")
|
|
55
|
+
|
|
56
|
+
def report_sample(self, sample_eval: SampleEvaluation) -> None:
|
|
57
|
+
"""File reporter doesn't support streaming (waits for final results).
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
sample_eval: Individual sample evaluation (ignored)
|
|
61
|
+
"""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Multi-reporter for running multiple reporters simultaneously."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from contextlib import suppress
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from .base import BaseReporter
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..evaluator import EvaluationResult
|
|
14
|
+
from ..metrics import SampleEvaluation
|
|
15
|
+
|
|
16
|
+
console = Console()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class MultiReporter(BaseReporter):
|
|
20
|
+
"""Runs multiple reporters (e.g., file + cloud).
|
|
21
|
+
|
|
22
|
+
This reporter allows sending results to multiple destinations
|
|
23
|
+
simultaneously. Errors in one reporter don't affect others.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, reporters: list[BaseReporter]):
|
|
27
|
+
"""Initialize multi-reporter.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
reporters: List of reporter instances to run
|
|
31
|
+
"""
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.reporters = reporters
|
|
34
|
+
|
|
35
|
+
def report(self, result: EvaluationResult) -> None:
|
|
36
|
+
"""Report to all reporters.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
result: Complete evaluation result
|
|
40
|
+
"""
|
|
41
|
+
for reporter in self.reporters:
|
|
42
|
+
try:
|
|
43
|
+
reporter.report(result)
|
|
44
|
+
except Exception as e: # noqa: BLE001
|
|
45
|
+
console.print(f"[red]Reporter {reporter.get_name()} failed: {e}[/red]")
|
|
46
|
+
|
|
47
|
+
def report_sample(self, sample_eval: SampleEvaluation) -> None:
|
|
48
|
+
"""Report sample to all reporters.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
sample_eval: Individual sample evaluation result
|
|
52
|
+
"""
|
|
53
|
+
for reporter in self.reporters:
|
|
54
|
+
# Silently fail on sample streaming errors
|
|
55
|
+
with suppress(Exception):
|
|
56
|
+
reporter.report_sample(sample_eval)
|
deepfabric/exceptions.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
class DeepFabricError(Exception):
|
|
2
|
+
"""Base exception class for DeepFabric."""
|
|
3
|
+
|
|
4
|
+
def __init__(self, message: str, context: dict | None = None):
|
|
5
|
+
super().__init__(message)
|
|
6
|
+
self.message = message
|
|
7
|
+
self.context = context or {}
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ConfigurationError(DeepFabricError):
|
|
11
|
+
"""Raised when there is an error in configuration."""
|
|
12
|
+
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ValidationError(DeepFabricError):
|
|
17
|
+
"""Raised when data validation fails."""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ModelError(DeepFabricError):
|
|
23
|
+
"""Raised when there is an error with LLM model operations."""
|
|
24
|
+
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TreeError(DeepFabricError):
|
|
29
|
+
"""Raised when there is an error in topic tree operations."""
|
|
30
|
+
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DataSetGeneratorError(DeepFabricError):
|
|
35
|
+
"""Raised when there is an error in data engine operations."""
|
|
36
|
+
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DatasetError(DeepFabricError):
|
|
41
|
+
"""Raised when there is an error in dataset operations."""
|
|
42
|
+
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class HubUploadError(DeepFabricError):
|
|
47
|
+
"""Raised when there is an error uploading to Hugging Face Hub."""
|
|
48
|
+
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class JSONParsingError(ValidationError):
|
|
53
|
+
"""Raised when JSON parsing fails."""
|
|
54
|
+
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class APIError(ModelError):
|
|
59
|
+
"""Raised when API calls fail."""
|
|
60
|
+
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class RetryExhaustedError(ModelError):
|
|
65
|
+
"""Raised when maximum retries are exceeded."""
|
|
66
|
+
|
|
67
|
+
pass
|
deepfabric/factory.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .config import DeepFabricConfig
|
|
2
|
+
from .graph import Graph
|
|
3
|
+
from .topic_model import TopicModel
|
|
4
|
+
from .tree import Tree
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def create_topic_generator(
|
|
8
|
+
config: DeepFabricConfig,
|
|
9
|
+
topics_overrides: dict | None = None,
|
|
10
|
+
) -> TopicModel:
|
|
11
|
+
"""Factory function to create a topic generator based on the configuration.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
config: DeepFabricConfig object with topics configuration
|
|
15
|
+
topics_overrides: Override parameters for topic generation
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
TopicModel (Tree or Graph) based on topics.mode
|
|
19
|
+
"""
|
|
20
|
+
topics_params = config.get_topics_params(**(topics_overrides or {}))
|
|
21
|
+
|
|
22
|
+
if config.topics.mode == "graph":
|
|
23
|
+
return Graph(**topics_params)
|
|
24
|
+
|
|
25
|
+
# Default to tree mode
|
|
26
|
+
return Tree(**topics_params)
|