DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
deepfabric/cli.py
ADDED
|
@@ -0,0 +1,1288 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from typing import Literal, NoReturn, cast
|
|
6
|
+
|
|
7
|
+
import click
|
|
8
|
+
import yaml
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
11
|
+
from pydantic import ValidationError as PydanticValidationError
|
|
12
|
+
|
|
13
|
+
from .auth import auth as auth_group
|
|
14
|
+
from .config import DeepFabricConfig
|
|
15
|
+
from .config_manager import apply_cli_overrides, get_final_parameters, load_config
|
|
16
|
+
from .dataset_manager import create_dataset, save_dataset
|
|
17
|
+
from .exceptions import ConfigurationError
|
|
18
|
+
from .generator import DataSetGenerator
|
|
19
|
+
from .graph import Graph
|
|
20
|
+
from .llm import VerificationStatus, verify_provider_api_key
|
|
21
|
+
from .metrics import set_trace_debug, trace
|
|
22
|
+
from .topic_manager import load_or_build_topic_model, save_topic_model
|
|
23
|
+
from .topic_model import TopicModel
|
|
24
|
+
from .tui import configure_tui, get_tui
|
|
25
|
+
from .update_checker import check_for_updates
|
|
26
|
+
from .validation import show_validation_success, validate_path_requirements
|
|
27
|
+
|
|
28
|
+
OverrideValue = str | int | float | bool | None
|
|
29
|
+
OverrideMap = dict[str, OverrideValue]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def handle_error(ctx: click.Context, error: Exception) -> NoReturn:
|
|
33
|
+
"""Handle errors in CLI commands."""
|
|
34
|
+
_ = ctx # Unused but required for click context
|
|
35
|
+
tui = get_tui()
|
|
36
|
+
|
|
37
|
+
# Check if this is formatted error from our event handlers
|
|
38
|
+
error_msg = str(error)
|
|
39
|
+
if not error_msg.startswith("Error: "):
|
|
40
|
+
tui.error(f"Error: {error_msg}")
|
|
41
|
+
else:
|
|
42
|
+
tui.error(error_msg)
|
|
43
|
+
|
|
44
|
+
sys.exit(1)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@click.group()
|
|
48
|
+
@click.version_option()
|
|
49
|
+
@click.option(
|
|
50
|
+
"--debug",
|
|
51
|
+
is_flag=True,
|
|
52
|
+
envvar="DEEPFABRIC_DEBUG",
|
|
53
|
+
help="Enable debug mode for detailed output",
|
|
54
|
+
)
|
|
55
|
+
@click.pass_context
|
|
56
|
+
def cli(ctx: click.Context, debug: bool):
|
|
57
|
+
"""DeepFabric CLI - Generate synthetic training data for language models."""
|
|
58
|
+
# Store debug flag in context for subcommands to access
|
|
59
|
+
ctx.ensure_object(dict)
|
|
60
|
+
ctx.obj["debug"] = debug
|
|
61
|
+
|
|
62
|
+
# Check for updates on CLI startup (silently fail if any issues occur)
|
|
63
|
+
with contextlib.suppress(Exception):
|
|
64
|
+
check_for_updates()
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GenerateOptions(BaseModel):
|
|
68
|
+
"""
|
|
69
|
+
Validated command options for dataset generation.
|
|
70
|
+
|
|
71
|
+
These options can be provided via CLI arguments or a configuration file.
|
|
72
|
+
so they are marked as optional here.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
76
|
+
|
|
77
|
+
config_file: str | None = None
|
|
78
|
+
# New naming convention
|
|
79
|
+
output_system_prompt: str | None = None
|
|
80
|
+
topic_prompt: str | None = None
|
|
81
|
+
topics_system_prompt: str | None = None
|
|
82
|
+
generation_system_prompt: str | None = None
|
|
83
|
+
topics_save_as: str | None = None
|
|
84
|
+
topics_load: str | None = None
|
|
85
|
+
output_save_as: str | None = None
|
|
86
|
+
provider: str | None = None
|
|
87
|
+
model: str | None = None
|
|
88
|
+
temperature: float | None = None
|
|
89
|
+
degree: int | None = None
|
|
90
|
+
depth: int | None = None
|
|
91
|
+
num_samples: int | None = None
|
|
92
|
+
batch_size: int | None = None
|
|
93
|
+
base_url: str | None = None
|
|
94
|
+
include_system_message: bool | None = None
|
|
95
|
+
mode: Literal["tree", "graph"] = Field(default="tree")
|
|
96
|
+
debug: bool = False
|
|
97
|
+
topic_only: bool = False
|
|
98
|
+
tui: Literal["rich", "simple"] = Field(default="rich")
|
|
99
|
+
|
|
100
|
+
# Modular conversation configuration
|
|
101
|
+
conversation_type: Literal["basic", "chain_of_thought"] | None = None
|
|
102
|
+
reasoning_style: Literal["freetext", "agent", "structured", "hybrid"] | None = None
|
|
103
|
+
agent_mode: Literal["single_turn", "multi_turn"] | None = None
|
|
104
|
+
|
|
105
|
+
# Multi-turn configuration
|
|
106
|
+
min_turns: int | None = None
|
|
107
|
+
max_turns: int | None = None
|
|
108
|
+
min_tool_calls: int | None = None
|
|
109
|
+
|
|
110
|
+
@model_validator(mode="after")
|
|
111
|
+
def validate_mode_constraints(self) -> "GenerateOptions":
|
|
112
|
+
if self.topic_only and self.topics_load:
|
|
113
|
+
raise ValueError("--topic-only cannot be used with --topics-load")
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class GenerationPreparation(BaseModel):
|
|
118
|
+
"""Validated state required to run dataset generation."""
|
|
119
|
+
|
|
120
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
121
|
+
|
|
122
|
+
config: DeepFabricConfig
|
|
123
|
+
topics_overrides: OverrideMap = Field(default_factory=dict)
|
|
124
|
+
generation_overrides: OverrideMap = Field(default_factory=dict)
|
|
125
|
+
num_samples: int
|
|
126
|
+
batch_size: int
|
|
127
|
+
depth: int
|
|
128
|
+
degree: int
|
|
129
|
+
loading_existing: bool
|
|
130
|
+
|
|
131
|
+
@model_validator(mode="after")
|
|
132
|
+
def validate_positive_dimensions(self) -> "GenerationPreparation":
|
|
133
|
+
if self.num_samples <= 0:
|
|
134
|
+
raise ValueError("num_samples must be greater than zero")
|
|
135
|
+
if self.batch_size <= 0:
|
|
136
|
+
raise ValueError("batch_size must be greater than zero")
|
|
137
|
+
if self.depth <= 0:
|
|
138
|
+
raise ValueError("depth must be greater than zero")
|
|
139
|
+
if self.degree <= 0:
|
|
140
|
+
raise ValueError("degree must be greater than zero")
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _validate_api_keys(
|
|
145
|
+
config: DeepFabricConfig,
|
|
146
|
+
provider_override: str | None = None,
|
|
147
|
+
) -> None:
|
|
148
|
+
"""Validate that required API keys are present and working for configured providers.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
config: The loaded configuration
|
|
152
|
+
provider_override: Optional CLI provider override that takes precedence
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
ConfigurationError: If any required API key is missing or invalid
|
|
156
|
+
"""
|
|
157
|
+
tui = get_tui()
|
|
158
|
+
|
|
159
|
+
# Get providers from config
|
|
160
|
+
providers = config.get_configured_providers()
|
|
161
|
+
|
|
162
|
+
# If there's a provider override from CLI, that takes precedence for all components
|
|
163
|
+
if provider_override:
|
|
164
|
+
providers = {provider_override}
|
|
165
|
+
|
|
166
|
+
# Display what we're checking
|
|
167
|
+
provider_list = ", ".join(sorted(providers))
|
|
168
|
+
tui.info(f"Validating API keys for: {provider_list}")
|
|
169
|
+
|
|
170
|
+
# Verify each provider's API key
|
|
171
|
+
errors = []
|
|
172
|
+
validated_providers = []
|
|
173
|
+
for provider in providers:
|
|
174
|
+
result = verify_provider_api_key(provider)
|
|
175
|
+
|
|
176
|
+
# Determine the primary env var to show in error messages
|
|
177
|
+
env_var = result.api_key_env_var or f"{provider.upper()}_API_KEY"
|
|
178
|
+
primary_var = env_var.split(" or ")[0] if " or " in env_var else env_var
|
|
179
|
+
|
|
180
|
+
if result.status == VerificationStatus.MISSING:
|
|
181
|
+
errors.append(
|
|
182
|
+
f" {provider}: API key not found.\n"
|
|
183
|
+
f" Export it with: export {primary_var}=your-api-key"
|
|
184
|
+
)
|
|
185
|
+
elif result.status == VerificationStatus.INVALID:
|
|
186
|
+
errors.append(
|
|
187
|
+
f" {provider}: API key is invalid.\n"
|
|
188
|
+
f" Check your key and re-export: export {primary_var}=your-api-key"
|
|
189
|
+
)
|
|
190
|
+
elif result.status == VerificationStatus.CONNECTION_ERROR:
|
|
191
|
+
if provider == "ollama":
|
|
192
|
+
errors.append(
|
|
193
|
+
f" {provider}: Cannot connect to Ollama server.\n"
|
|
194
|
+
f" Make sure Ollama is running: ollama serve"
|
|
195
|
+
)
|
|
196
|
+
else:
|
|
197
|
+
errors.append(
|
|
198
|
+
f" {provider}: Connection failed.\n"
|
|
199
|
+
f" Check your internet connection and try again."
|
|
200
|
+
)
|
|
201
|
+
elif result.status == VerificationStatus.RATE_LIMITED:
|
|
202
|
+
errors.append(
|
|
203
|
+
f" {provider}: Rate limit exceeded.\n"
|
|
204
|
+
f" Wait a moment and try again, or check your API quota."
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
# VALID or NOT_APPLICABLE (e.g., ollama)
|
|
208
|
+
validated_providers.append(provider)
|
|
209
|
+
|
|
210
|
+
if errors:
|
|
211
|
+
error_list = "\n".join(errors)
|
|
212
|
+
raise ConfigurationError(f"API key verification failed:\n\n{error_list}")
|
|
213
|
+
|
|
214
|
+
# Show success message
|
|
215
|
+
tui.success(f"API keys validated for: {', '.join(sorted(validated_providers))}")
|
|
216
|
+
print() # Visual separator before next section
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _load_and_prepare_generation_context(
|
|
220
|
+
options: GenerateOptions,
|
|
221
|
+
) -> GenerationPreparation:
|
|
222
|
+
"""Load configuration, compute overrides, and validate derived parameters."""
|
|
223
|
+
tui = get_tui()
|
|
224
|
+
|
|
225
|
+
# Step 1: Load configuration
|
|
226
|
+
tui.info("Loading configuration...")
|
|
227
|
+
config = load_config(
|
|
228
|
+
config_file=options.config_file,
|
|
229
|
+
topic_prompt=options.topic_prompt,
|
|
230
|
+
topics_system_prompt=options.topics_system_prompt,
|
|
231
|
+
generation_system_prompt=options.generation_system_prompt,
|
|
232
|
+
output_system_prompt=options.output_system_prompt,
|
|
233
|
+
provider=options.provider,
|
|
234
|
+
model=options.model,
|
|
235
|
+
temperature=options.temperature,
|
|
236
|
+
degree=options.degree,
|
|
237
|
+
depth=options.depth,
|
|
238
|
+
num_samples=options.num_samples,
|
|
239
|
+
batch_size=options.batch_size,
|
|
240
|
+
topics_save_as=options.topics_save_as,
|
|
241
|
+
output_save_as=options.output_save_as,
|
|
242
|
+
include_system_message=options.include_system_message,
|
|
243
|
+
mode=options.mode,
|
|
244
|
+
conversation_type=options.conversation_type,
|
|
245
|
+
reasoning_style=options.reasoning_style,
|
|
246
|
+
agent_mode=options.agent_mode,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Step 2: Validate API keys EARLY - this is critical for user feedback
|
|
250
|
+
# This must happen before any LLM operations and should be clearly visible
|
|
251
|
+
_validate_api_keys(config, options.provider)
|
|
252
|
+
|
|
253
|
+
topics_overrides_raw, generation_overrides_raw = apply_cli_overrides(
|
|
254
|
+
output_system_prompt=options.output_system_prompt,
|
|
255
|
+
topic_prompt=options.topic_prompt,
|
|
256
|
+
topics_system_prompt=options.topics_system_prompt,
|
|
257
|
+
generation_system_prompt=options.generation_system_prompt,
|
|
258
|
+
provider=options.provider,
|
|
259
|
+
model=options.model,
|
|
260
|
+
temperature=options.temperature,
|
|
261
|
+
degree=options.degree,
|
|
262
|
+
depth=options.depth,
|
|
263
|
+
base_url=options.base_url,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
final_num_samples, final_batch_size, final_depth, final_degree = get_final_parameters(
|
|
267
|
+
config=config,
|
|
268
|
+
num_samples=options.num_samples,
|
|
269
|
+
batch_size=options.batch_size,
|
|
270
|
+
depth=options.depth,
|
|
271
|
+
degree=options.degree,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
loading_existing = bool(options.topics_load)
|
|
275
|
+
|
|
276
|
+
validate_path_requirements(
|
|
277
|
+
mode=options.mode,
|
|
278
|
+
depth=final_depth,
|
|
279
|
+
degree=final_degree,
|
|
280
|
+
num_steps=final_num_samples,
|
|
281
|
+
batch_size=final_batch_size,
|
|
282
|
+
loading_existing=loading_existing,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
show_validation_success(
|
|
286
|
+
mode=options.mode,
|
|
287
|
+
depth=final_depth,
|
|
288
|
+
degree=final_degree,
|
|
289
|
+
num_steps=final_num_samples,
|
|
290
|
+
batch_size=final_batch_size,
|
|
291
|
+
loading_existing=loading_existing,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
try:
|
|
295
|
+
return GenerationPreparation(
|
|
296
|
+
config=config,
|
|
297
|
+
topics_overrides=cast(OverrideMap, topics_overrides_raw),
|
|
298
|
+
generation_overrides=cast(OverrideMap, generation_overrides_raw),
|
|
299
|
+
num_samples=final_num_samples,
|
|
300
|
+
batch_size=final_batch_size,
|
|
301
|
+
depth=final_depth,
|
|
302
|
+
degree=final_degree,
|
|
303
|
+
loading_existing=loading_existing,
|
|
304
|
+
)
|
|
305
|
+
except (ValueError, PydanticValidationError) as error:
|
|
306
|
+
raise ConfigurationError(str(error)) from error
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _initialize_topic_model(
|
|
310
|
+
*,
|
|
311
|
+
preparation: GenerationPreparation,
|
|
312
|
+
options: GenerateOptions,
|
|
313
|
+
) -> TopicModel:
|
|
314
|
+
"""Load existing topic structures or build new ones and persist when needed."""
|
|
315
|
+
|
|
316
|
+
topic_model = load_or_build_topic_model(
|
|
317
|
+
config=preparation.config,
|
|
318
|
+
topics_load=options.topics_load,
|
|
319
|
+
topics_overrides=preparation.topics_overrides,
|
|
320
|
+
provider=options.provider,
|
|
321
|
+
model=options.model,
|
|
322
|
+
base_url=options.base_url,
|
|
323
|
+
debug=options.debug,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
if not options.topics_load:
|
|
327
|
+
save_topic_model(
|
|
328
|
+
topic_model=topic_model,
|
|
329
|
+
config=preparation.config,
|
|
330
|
+
topics_save_as=options.topics_save_as,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return topic_model
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _run_generation(
|
|
337
|
+
*,
|
|
338
|
+
preparation: GenerationPreparation,
|
|
339
|
+
topic_model: TopicModel,
|
|
340
|
+
options: GenerateOptions,
|
|
341
|
+
) -> None:
|
|
342
|
+
"""Create the dataset using the prepared configuration and topic model."""
|
|
343
|
+
|
|
344
|
+
generation_params = preparation.config.get_generation_params(**preparation.generation_overrides)
|
|
345
|
+
engine = DataSetGenerator(**generation_params)
|
|
346
|
+
|
|
347
|
+
dataset = create_dataset(
|
|
348
|
+
engine=engine,
|
|
349
|
+
topic_model=topic_model,
|
|
350
|
+
config=preparation.config,
|
|
351
|
+
num_samples=preparation.num_samples,
|
|
352
|
+
batch_size=preparation.batch_size,
|
|
353
|
+
include_system_message=options.include_system_message,
|
|
354
|
+
provider=options.provider,
|
|
355
|
+
model=options.model,
|
|
356
|
+
generation_overrides=preparation.generation_overrides,
|
|
357
|
+
debug=options.debug,
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
output_config = preparation.config.get_output_config()
|
|
361
|
+
output_save_path = options.output_save_as or output_config["save_as"]
|
|
362
|
+
save_dataset(dataset, output_save_path, preparation.config, engine=engine)
|
|
363
|
+
|
|
364
|
+
trace(
|
|
365
|
+
"dataset_generated",
|
|
366
|
+
{"samples": len(dataset)},
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@cli.command()
|
|
371
|
+
@click.argument("config_file", type=click.Path(exists=True), required=False)
|
|
372
|
+
@click.option(
|
|
373
|
+
"--output-system-prompt",
|
|
374
|
+
help="System prompt for final dataset output (if include_system_message is true)",
|
|
375
|
+
)
|
|
376
|
+
@click.option("--topic-prompt", help="Starting topic/seed for topic generation")
|
|
377
|
+
@click.option("--topics-system-prompt", help="System prompt for topic generation")
|
|
378
|
+
@click.option("--generation-system-prompt", help="System prompt for dataset content generation")
|
|
379
|
+
@click.option("--topics-save-as", help="Save path for the generated topics")
|
|
380
|
+
@click.option(
|
|
381
|
+
"--topics-load",
|
|
382
|
+
type=click.Path(exists=True),
|
|
383
|
+
help="Path to existing topics file (JSONL for tree, JSON for graph)",
|
|
384
|
+
)
|
|
385
|
+
@click.option("--output-save-as", help="Save path for the dataset")
|
|
386
|
+
@click.option("--provider", help="LLM provider (e.g., ollama, openai)")
|
|
387
|
+
@click.option("--model", help="Model name (e.g., qwen3:8b, gpt-4o)")
|
|
388
|
+
@click.option("--temperature", type=float, help="Temperature setting")
|
|
389
|
+
@click.option("--degree", type=int, help="Degree (branching factor)")
|
|
390
|
+
@click.option("--depth", type=int, help="Depth setting")
|
|
391
|
+
@click.option("--num-samples", type=int, help="Number of samples to generate")
|
|
392
|
+
@click.option("--batch-size", type=int, help="Batch size")
|
|
393
|
+
@click.option("--base-url", help="Base URL for LLM provider API endpoint")
|
|
394
|
+
@click.option(
|
|
395
|
+
"--include-system-message/--no-system-message",
|
|
396
|
+
default=None,
|
|
397
|
+
help="Include system message in dataset output (default: true)",
|
|
398
|
+
)
|
|
399
|
+
@click.option(
|
|
400
|
+
"--mode",
|
|
401
|
+
type=click.Choice(["tree", "graph"]),
|
|
402
|
+
default="tree",
|
|
403
|
+
help="Topic generation mode (default: tree)",
|
|
404
|
+
)
|
|
405
|
+
@click.option(
|
|
406
|
+
"--debug",
|
|
407
|
+
is_flag=True,
|
|
408
|
+
help="Enable debug mode for detailed error output",
|
|
409
|
+
)
|
|
410
|
+
@click.option(
|
|
411
|
+
"--tui",
|
|
412
|
+
type=click.Choice(["rich", "simple"]),
|
|
413
|
+
default="rich",
|
|
414
|
+
show_default=True,
|
|
415
|
+
help="TUI mode: rich (two-pane with preview) or simple (headless-friendly)",
|
|
416
|
+
)
|
|
417
|
+
@click.option(
|
|
418
|
+
"--topic-only",
|
|
419
|
+
is_flag=True,
|
|
420
|
+
help="Only create topic assets, no dataset",
|
|
421
|
+
)
|
|
422
|
+
@click.option(
|
|
423
|
+
"--conversation-type",
|
|
424
|
+
type=click.Choice(["basic", "chain_of_thought"]),
|
|
425
|
+
help="Base conversation type: basic (simple chat), chain_of_thought (with reasoning)",
|
|
426
|
+
)
|
|
427
|
+
@click.option(
|
|
428
|
+
"--reasoning-style",
|
|
429
|
+
type=click.Choice(["freetext", "agent"]),
|
|
430
|
+
help="Reasoning style for chain_of_thought: freetext (natural language) or agent (structured for tool-calling)",
|
|
431
|
+
)
|
|
432
|
+
@click.option(
|
|
433
|
+
"--agent-mode",
|
|
434
|
+
type=click.Choice(["single_turn", "multi_turn"]),
|
|
435
|
+
help="Agent mode: single_turn (one-shot tool use), multi_turn (extended conversations). Requires tools.",
|
|
436
|
+
)
|
|
437
|
+
@click.option(
|
|
438
|
+
"--min-turns",
|
|
439
|
+
type=int,
|
|
440
|
+
help="Minimum conversation turns for multi_turn agent mode",
|
|
441
|
+
)
|
|
442
|
+
@click.option(
|
|
443
|
+
"--max-turns",
|
|
444
|
+
type=int,
|
|
445
|
+
help="Maximum conversation turns for multi_turn agent mode",
|
|
446
|
+
)
|
|
447
|
+
@click.option(
|
|
448
|
+
"--min-tool-calls",
|
|
449
|
+
type=int,
|
|
450
|
+
help="Minimum tool calls before allowing conversation conclusion",
|
|
451
|
+
)
|
|
452
|
+
def generate( # noqa: PLR0913
|
|
453
|
+
config_file: str | None,
|
|
454
|
+
output_system_prompt: str | None = None,
|
|
455
|
+
topic_prompt: str | None = None,
|
|
456
|
+
topics_system_prompt: str | None = None,
|
|
457
|
+
generation_system_prompt: str | None = None,
|
|
458
|
+
topics_save_as: str | None = None,
|
|
459
|
+
topics_load: str | None = None,
|
|
460
|
+
output_save_as: str | None = None,
|
|
461
|
+
provider: str | None = None,
|
|
462
|
+
model: str | None = None,
|
|
463
|
+
temperature: float | None = None,
|
|
464
|
+
degree: int | None = None,
|
|
465
|
+
depth: int | None = None,
|
|
466
|
+
num_samples: int | None = None,
|
|
467
|
+
batch_size: int | None = None,
|
|
468
|
+
base_url: str | None = None,
|
|
469
|
+
include_system_message: bool | None = None,
|
|
470
|
+
mode: Literal["tree", "graph"] = "tree",
|
|
471
|
+
debug: bool = False,
|
|
472
|
+
topic_only: bool = False,
|
|
473
|
+
conversation_type: Literal["basic", "chain_of_thought"] | None = None,
|
|
474
|
+
reasoning_style: Literal["freetext", "agent"] | None = None,
|
|
475
|
+
agent_mode: Literal["single_turn", "multi_turn"] | None = None,
|
|
476
|
+
min_turns: int | None = None,
|
|
477
|
+
max_turns: int | None = None,
|
|
478
|
+
min_tool_calls: int | None = None,
|
|
479
|
+
tui: Literal["rich", "simple"] = "rich",
|
|
480
|
+
) -> None:
|
|
481
|
+
"""Generate training data from a YAML configuration file or CLI parameters."""
|
|
482
|
+
set_trace_debug(debug)
|
|
483
|
+
trace(
|
|
484
|
+
"cli_generate",
|
|
485
|
+
{
|
|
486
|
+
"mode": mode,
|
|
487
|
+
"has_config": config_file is not None,
|
|
488
|
+
"provider": provider,
|
|
489
|
+
"model": model,
|
|
490
|
+
},
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
try:
|
|
494
|
+
options = GenerateOptions(
|
|
495
|
+
config_file=config_file,
|
|
496
|
+
output_system_prompt=output_system_prompt,
|
|
497
|
+
topic_prompt=topic_prompt,
|
|
498
|
+
topics_system_prompt=topics_system_prompt,
|
|
499
|
+
generation_system_prompt=generation_system_prompt,
|
|
500
|
+
topics_save_as=topics_save_as,
|
|
501
|
+
topics_load=topics_load,
|
|
502
|
+
output_save_as=output_save_as,
|
|
503
|
+
provider=provider,
|
|
504
|
+
model=model,
|
|
505
|
+
temperature=temperature,
|
|
506
|
+
degree=degree,
|
|
507
|
+
depth=depth,
|
|
508
|
+
num_samples=num_samples,
|
|
509
|
+
batch_size=batch_size,
|
|
510
|
+
base_url=base_url,
|
|
511
|
+
include_system_message=include_system_message,
|
|
512
|
+
mode=mode,
|
|
513
|
+
debug=debug,
|
|
514
|
+
topic_only=topic_only,
|
|
515
|
+
conversation_type=conversation_type,
|
|
516
|
+
reasoning_style=reasoning_style,
|
|
517
|
+
agent_mode=agent_mode,
|
|
518
|
+
min_turns=min_turns,
|
|
519
|
+
max_turns=max_turns,
|
|
520
|
+
min_tool_calls=min_tool_calls,
|
|
521
|
+
tui=tui,
|
|
522
|
+
)
|
|
523
|
+
except PydanticValidationError as error:
|
|
524
|
+
handle_error(click.get_current_context(), ConfigurationError(str(error)))
|
|
525
|
+
return
|
|
526
|
+
|
|
527
|
+
try:
|
|
528
|
+
# Configure TUI before any output
|
|
529
|
+
configure_tui(options.tui)
|
|
530
|
+
tui = get_tui() # type: ignore
|
|
531
|
+
|
|
532
|
+
# Show initialization header
|
|
533
|
+
tui.info("Initializing DeepFabric...") # type: ignore
|
|
534
|
+
print()
|
|
535
|
+
|
|
536
|
+
preparation = _load_and_prepare_generation_context(options)
|
|
537
|
+
|
|
538
|
+
topic_model = _initialize_topic_model(
|
|
539
|
+
preparation=preparation,
|
|
540
|
+
options=options,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
if topic_only:
|
|
544
|
+
return
|
|
545
|
+
|
|
546
|
+
_run_generation(
|
|
547
|
+
preparation=preparation,
|
|
548
|
+
topic_model=topic_model,
|
|
549
|
+
options=options,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
except ConfigurationError as e:
|
|
553
|
+
handle_error(click.get_current_context(), e)
|
|
554
|
+
except Exception as e:
|
|
555
|
+
tui = get_tui() # type: ignore
|
|
556
|
+
tui.error(f"Unexpected error: {str(e)}") # type: ignore
|
|
557
|
+
sys.exit(1)
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
@cli.command()
|
|
561
|
+
@click.argument("dataset_file", type=click.Path(exists=True))
|
|
562
|
+
@click.option(
|
|
563
|
+
"--repo",
|
|
564
|
+
required=True,
|
|
565
|
+
help="Hugging Face repository (e.g., username/dataset-name)",
|
|
566
|
+
)
|
|
567
|
+
@click.option(
|
|
568
|
+
"--token",
|
|
569
|
+
help="Hugging Face API token (can also be set via HF_TOKEN env var)",
|
|
570
|
+
)
|
|
571
|
+
@click.option(
|
|
572
|
+
"--tags",
|
|
573
|
+
multiple=True,
|
|
574
|
+
help="Tags for the dataset (can be specified multiple times)",
|
|
575
|
+
)
|
|
576
|
+
def upload(
|
|
577
|
+
dataset_file: str,
|
|
578
|
+
repo: str,
|
|
579
|
+
token: str | None = None,
|
|
580
|
+
tags: list[str] | None = None,
|
|
581
|
+
) -> None:
|
|
582
|
+
"""Upload a dataset to Hugging Face Hub."""
|
|
583
|
+
trace("cli_upload", {"has_tags": len(tags) > 0 if tags else False})
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
# Get token from CLI arg or env var
|
|
587
|
+
token = token or os.getenv("HF_TOKEN")
|
|
588
|
+
if not token:
|
|
589
|
+
handle_error(
|
|
590
|
+
click.get_current_context(),
|
|
591
|
+
ValueError("Hugging Face token not provided. Set via --token or HF_TOKEN env var."),
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# Lazy import to avoid slow startup when not using HF features
|
|
595
|
+
from .hf_hub import HFUploader # noqa: PLC0415
|
|
596
|
+
|
|
597
|
+
uploader = HFUploader(token)
|
|
598
|
+
result = uploader.push_to_hub(str(repo), dataset_file, tags=list(tags) if tags else [])
|
|
599
|
+
|
|
600
|
+
tui = get_tui()
|
|
601
|
+
if result["status"] == "success":
|
|
602
|
+
tui.success(result["message"])
|
|
603
|
+
else:
|
|
604
|
+
tui.error(result["message"])
|
|
605
|
+
sys.exit(1)
|
|
606
|
+
|
|
607
|
+
except Exception as e:
|
|
608
|
+
tui = get_tui()
|
|
609
|
+
tui.error(f"Error uploading to Hugging Face Hub: {str(e)}")
|
|
610
|
+
sys.exit(1)
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
@cli.command("upload-kaggle")
|
|
614
|
+
@click.argument("dataset_file", type=click.Path(exists=True))
|
|
615
|
+
@click.option(
|
|
616
|
+
"--handle",
|
|
617
|
+
required=True,
|
|
618
|
+
help="Kaggle dataset handle (e.g., username/dataset-name)",
|
|
619
|
+
)
|
|
620
|
+
@click.option(
|
|
621
|
+
"--username",
|
|
622
|
+
help="Kaggle username (can also be set via KAGGLE_USERNAME env var)",
|
|
623
|
+
)
|
|
624
|
+
@click.option(
|
|
625
|
+
"--key",
|
|
626
|
+
help="Kaggle API key (can also be set via KAGGLE_KEY env var)",
|
|
627
|
+
)
|
|
628
|
+
@click.option(
|
|
629
|
+
"--tags",
|
|
630
|
+
multiple=True,
|
|
631
|
+
help="Tags for the dataset (can be specified multiple times)",
|
|
632
|
+
)
|
|
633
|
+
@click.option(
|
|
634
|
+
"--version-notes",
|
|
635
|
+
help="Version notes for the dataset update",
|
|
636
|
+
)
|
|
637
|
+
@click.option(
|
|
638
|
+
"--description",
|
|
639
|
+
help="Description for the dataset",
|
|
640
|
+
)
|
|
641
|
+
def upload_kaggle(
|
|
642
|
+
dataset_file: str,
|
|
643
|
+
handle: str,
|
|
644
|
+
username: str | None = None,
|
|
645
|
+
key: str | None = None,
|
|
646
|
+
tags: list[str] | None = None,
|
|
647
|
+
version_notes: str | None = None,
|
|
648
|
+
description: str | None = None,
|
|
649
|
+
) -> None:
|
|
650
|
+
"""Upload a dataset to Kaggle."""
|
|
651
|
+
trace("cli_upload_kaggle", {"has_tags": len(tags) > 0 if tags else False})
|
|
652
|
+
|
|
653
|
+
try:
|
|
654
|
+
# Get credentials from CLI args or env vars
|
|
655
|
+
username = username or os.getenv("KAGGLE_USERNAME")
|
|
656
|
+
key = key or os.getenv("KAGGLE_KEY")
|
|
657
|
+
|
|
658
|
+
if not username or not key:
|
|
659
|
+
handle_error(
|
|
660
|
+
click.get_current_context(),
|
|
661
|
+
ValueError(
|
|
662
|
+
"Kaggle credentials not provided. "
|
|
663
|
+
"Set via --username/--key or KAGGLE_USERNAME/KAGGLE_KEY env vars."
|
|
664
|
+
),
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
# Lazy import to avoid slow startup when not using Kaggle features
|
|
668
|
+
from .kaggle_hub import KaggleUploader # noqa: PLC0415
|
|
669
|
+
|
|
670
|
+
uploader = KaggleUploader(username, key)
|
|
671
|
+
result = uploader.push_to_hub(
|
|
672
|
+
str(handle),
|
|
673
|
+
dataset_file,
|
|
674
|
+
tags=list(tags) if tags else [],
|
|
675
|
+
version_notes=version_notes,
|
|
676
|
+
description=description,
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
tui = get_tui()
|
|
680
|
+
if result["status"] == "success":
|
|
681
|
+
tui.success(result["message"])
|
|
682
|
+
else:
|
|
683
|
+
tui.error(result["message"])
|
|
684
|
+
sys.exit(1)
|
|
685
|
+
|
|
686
|
+
except Exception as e:
|
|
687
|
+
tui = get_tui()
|
|
688
|
+
tui.error(f"Error uploading to Kaggle: {str(e)}")
|
|
689
|
+
sys.exit(1)
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
@cli.command()
|
|
693
|
+
@click.argument("graph_file", type=click.Path(exists=True))
|
|
694
|
+
@click.option(
|
|
695
|
+
"--output",
|
|
696
|
+
"-o",
|
|
697
|
+
required=True,
|
|
698
|
+
help="Output SVG file path",
|
|
699
|
+
)
|
|
700
|
+
def visualize(graph_file: str, output: str) -> None:
|
|
701
|
+
"""Visualize a topic graph as an SVG file."""
|
|
702
|
+
try:
|
|
703
|
+
# Load the graph
|
|
704
|
+
with open(graph_file) as f:
|
|
705
|
+
import json # noqa: PLC0415
|
|
706
|
+
|
|
707
|
+
graph_data = json.load(f)
|
|
708
|
+
|
|
709
|
+
# Create a minimal Graph object for visualization
|
|
710
|
+
# We need to get the args from somewhere - for now, use defaults
|
|
711
|
+
from .constants import ( # noqa: PLC0415
|
|
712
|
+
TOPIC_GRAPH_DEFAULT_DEGREE,
|
|
713
|
+
TOPIC_GRAPH_DEFAULT_DEPTH,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
# Create parameters for Graph instantiation
|
|
717
|
+
graph_params = {
|
|
718
|
+
"topic_prompt": "placeholder", # Not needed for visualization
|
|
719
|
+
"model_name": "placeholder/model", # Not needed for visualization
|
|
720
|
+
"degree": graph_data.get("degree", TOPIC_GRAPH_DEFAULT_DEGREE),
|
|
721
|
+
"depth": graph_data.get("depth", TOPIC_GRAPH_DEFAULT_DEPTH),
|
|
722
|
+
"temperature": 0.7, # Default, not used for visualization
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
# Use the Graph.from_json method to properly load the graph structure
|
|
726
|
+
import tempfile # noqa: PLC0415
|
|
727
|
+
|
|
728
|
+
# Create a temporary file with the graph data and use from_json
|
|
729
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as tmp_file:
|
|
730
|
+
json.dump(graph_data, tmp_file)
|
|
731
|
+
temp_path = tmp_file.name
|
|
732
|
+
|
|
733
|
+
try:
|
|
734
|
+
graph = Graph.from_json(temp_path, graph_params)
|
|
735
|
+
finally:
|
|
736
|
+
import os # noqa: PLC0415
|
|
737
|
+
|
|
738
|
+
os.unlink(temp_path)
|
|
739
|
+
|
|
740
|
+
# Visualize the graph
|
|
741
|
+
graph.visualize(output)
|
|
742
|
+
tui = get_tui()
|
|
743
|
+
tui.success(f"Graph visualization saved to: {output}.svg")
|
|
744
|
+
|
|
745
|
+
except Exception as e:
|
|
746
|
+
tui = get_tui()
|
|
747
|
+
tui.error(f"Error visualizing graph: {str(e)}")
|
|
748
|
+
sys.exit(1)
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
@cli.command()
|
|
752
|
+
@click.argument("config_file", type=click.Path(exists=True))
|
|
753
|
+
def validate(config_file: str) -> None: # noqa: PLR0912
|
|
754
|
+
"""Validate a DeepFabric configuration file."""
|
|
755
|
+
try:
|
|
756
|
+
# Try to load the configuration
|
|
757
|
+
config = DeepFabricConfig.from_yaml(config_file)
|
|
758
|
+
|
|
759
|
+
# Check required sections
|
|
760
|
+
errors = []
|
|
761
|
+
warnings = []
|
|
762
|
+
|
|
763
|
+
# Check for system prompt
|
|
764
|
+
if not config.generation.system_prompt:
|
|
765
|
+
warnings.append("No generation.system_prompt defined")
|
|
766
|
+
|
|
767
|
+
# Check topics configuration
|
|
768
|
+
if not config.topics.prompt:
|
|
769
|
+
errors.append("topics.prompt is required")
|
|
770
|
+
|
|
771
|
+
# Check output configuration
|
|
772
|
+
if not config.output.save_as:
|
|
773
|
+
warnings.append("No output.save_as path defined for dataset")
|
|
774
|
+
|
|
775
|
+
# Report results
|
|
776
|
+
tui = get_tui()
|
|
777
|
+
if errors:
|
|
778
|
+
tui.error("Configuration validation failed:")
|
|
779
|
+
for error in errors:
|
|
780
|
+
tui.console.print(f" - {error}", style="red")
|
|
781
|
+
sys.exit(1)
|
|
782
|
+
else:
|
|
783
|
+
tui.success("Configuration is valid")
|
|
784
|
+
|
|
785
|
+
if warnings:
|
|
786
|
+
tui.console.print("\nWarnings:", style="yellow bold")
|
|
787
|
+
for warning in warnings:
|
|
788
|
+
tui.warning(warning)
|
|
789
|
+
|
|
790
|
+
# Print configuration summary
|
|
791
|
+
tui.console.print("\nConfiguration Summary:", style="cyan bold")
|
|
792
|
+
tui.info(
|
|
793
|
+
f"Topics: mode={config.topics.mode}, depth={config.topics.depth}, degree={config.topics.degree}"
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
tui.info(
|
|
797
|
+
f"Output: num_samples={config.output.num_samples}, batch_size={config.output.batch_size}"
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
if config.huggingface:
|
|
801
|
+
hf_config = config.get_huggingface_config()
|
|
802
|
+
tui.info(f"Hugging Face: repo={hf_config.get('repository', 'not set')}")
|
|
803
|
+
|
|
804
|
+
if config.kaggle:
|
|
805
|
+
kaggle_config = config.get_kaggle_config()
|
|
806
|
+
tui.info(f"Kaggle: handle={kaggle_config.get('handle', 'not set')}")
|
|
807
|
+
|
|
808
|
+
except FileNotFoundError:
|
|
809
|
+
handle_error(
|
|
810
|
+
click.get_current_context(),
|
|
811
|
+
ValueError(f"Config file not found: {config_file}"),
|
|
812
|
+
)
|
|
813
|
+
except yaml.YAMLError as e:
|
|
814
|
+
handle_error(
|
|
815
|
+
click.get_current_context(),
|
|
816
|
+
ValueError(f"Invalid YAML in config file: {str(e)}"),
|
|
817
|
+
)
|
|
818
|
+
except Exception as e:
|
|
819
|
+
handle_error(
|
|
820
|
+
click.get_current_context(),
|
|
821
|
+
ValueError(f"Error validating config file: {str(e)}"),
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
@cli.command()
|
|
826
|
+
def info() -> None:
|
|
827
|
+
"""Show DeepFabric version and configuration information."""
|
|
828
|
+
try:
|
|
829
|
+
import importlib.metadata # noqa: PLC0415
|
|
830
|
+
|
|
831
|
+
# Get version
|
|
832
|
+
try:
|
|
833
|
+
version = importlib.metadata.version("deepfabric")
|
|
834
|
+
except importlib.metadata.PackageNotFoundError:
|
|
835
|
+
version = "development"
|
|
836
|
+
|
|
837
|
+
tui = get_tui()
|
|
838
|
+
header = tui.create_header(
|
|
839
|
+
f"DeepFabric v{version}",
|
|
840
|
+
"Large Scale Topic based Synthetic Data Generation",
|
|
841
|
+
)
|
|
842
|
+
tui.console.print(header)
|
|
843
|
+
|
|
844
|
+
tui.console.print("\nAvailable Commands:", style="cyan bold")
|
|
845
|
+
commands = [
|
|
846
|
+
("generate", "Generate training data from configuration"),
|
|
847
|
+
("validate", "Validate a configuration file"),
|
|
848
|
+
("visualize", "Create SVG visualization of a topic graph"),
|
|
849
|
+
("upload", "Upload dataset to Hugging Face Hub"),
|
|
850
|
+
("upload-kaggle", "Upload dataset to Kaggle"),
|
|
851
|
+
("info", "Show this information"),
|
|
852
|
+
]
|
|
853
|
+
for cmd, desc in commands:
|
|
854
|
+
tui.console.print(f" [cyan]{cmd}[/cyan] - {desc}")
|
|
855
|
+
|
|
856
|
+
tui.console.print("\nEnvironment Variables:", style="cyan bold")
|
|
857
|
+
env_vars = [
|
|
858
|
+
("OPENAI_API_KEY", "OpenAI API key"),
|
|
859
|
+
("ANTHROPIC_API_KEY", "Anthropic API key"),
|
|
860
|
+
("HF_TOKEN", "Hugging Face API token"),
|
|
861
|
+
]
|
|
862
|
+
for var, desc in env_vars:
|
|
863
|
+
tui.console.print(f" [yellow]{var}[/yellow] - {desc}")
|
|
864
|
+
|
|
865
|
+
tui.console.print(
|
|
866
|
+
"\nFor more information, visit: [link]https://github.com/RedDotRocket/deepfabric[/link]"
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
except Exception as e:
|
|
870
|
+
tui = get_tui()
|
|
871
|
+
tui.error(f"Error getting info: {str(e)}")
|
|
872
|
+
sys.exit(1)
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
@cli.command()
|
|
876
|
+
@click.argument("model_path", type=click.Path())
|
|
877
|
+
@click.argument("dataset_path", type=click.Path(exists=True))
|
|
878
|
+
@click.option(
|
|
879
|
+
"--output",
|
|
880
|
+
"-o",
|
|
881
|
+
type=click.Path(),
|
|
882
|
+
help="Path to save evaluation results (JSON)",
|
|
883
|
+
)
|
|
884
|
+
@click.option(
|
|
885
|
+
"--adapter-path",
|
|
886
|
+
type=click.Path(),
|
|
887
|
+
help="Path to PEFT/LoRA adapter (for adapter-based fine-tuning)",
|
|
888
|
+
)
|
|
889
|
+
@click.option(
|
|
890
|
+
"--batch-size",
|
|
891
|
+
type=int,
|
|
892
|
+
default=1,
|
|
893
|
+
help="Batch size for evaluation",
|
|
894
|
+
)
|
|
895
|
+
@click.option(
|
|
896
|
+
"--max-samples",
|
|
897
|
+
type=int,
|
|
898
|
+
help="Maximum number of samples to evaluate (default: all)",
|
|
899
|
+
)
|
|
900
|
+
@click.option(
|
|
901
|
+
"--temperature",
|
|
902
|
+
type=float,
|
|
903
|
+
default=0.7,
|
|
904
|
+
help="Sampling temperature",
|
|
905
|
+
)
|
|
906
|
+
@click.option(
|
|
907
|
+
"--max-tokens",
|
|
908
|
+
type=int,
|
|
909
|
+
default=2048,
|
|
910
|
+
help="Maximum tokens to generate",
|
|
911
|
+
)
|
|
912
|
+
@click.option(
|
|
913
|
+
"--top-p",
|
|
914
|
+
type=float,
|
|
915
|
+
default=0.9,
|
|
916
|
+
help="Nucleus sampling top-p",
|
|
917
|
+
)
|
|
918
|
+
@click.option(
|
|
919
|
+
"--backend",
|
|
920
|
+
type=click.Choice(["transformers", "ollama"]),
|
|
921
|
+
default="transformers",
|
|
922
|
+
help="Inference backend to use",
|
|
923
|
+
)
|
|
924
|
+
@click.option(
|
|
925
|
+
"--device",
|
|
926
|
+
type=str,
|
|
927
|
+
help="Device to use (cuda, cpu, mps, etc.) - only for transformers backend",
|
|
928
|
+
)
|
|
929
|
+
@click.option(
|
|
930
|
+
"--no-save-predictions",
|
|
931
|
+
is_flag=True,
|
|
932
|
+
help="Don't save individual predictions to output file",
|
|
933
|
+
)
|
|
934
|
+
def evaluate(
|
|
935
|
+
model_path: str,
|
|
936
|
+
dataset_path: str,
|
|
937
|
+
output: str | None,
|
|
938
|
+
adapter_path: str | None,
|
|
939
|
+
batch_size: int,
|
|
940
|
+
max_samples: int | None,
|
|
941
|
+
temperature: float,
|
|
942
|
+
max_tokens: int,
|
|
943
|
+
top_p: float,
|
|
944
|
+
backend: str,
|
|
945
|
+
device: str | None,
|
|
946
|
+
no_save_predictions: bool,
|
|
947
|
+
):
|
|
948
|
+
"""Evaluate a fine-tuned model on tool-calling tasks.
|
|
949
|
+
|
|
950
|
+
MODEL_PATH: Path to base model or fine-tuned model (local directory or HuggingFace Hub ID)
|
|
951
|
+
|
|
952
|
+
DATASET_PATH: Path to evaluation dataset (JSONL format)
|
|
953
|
+
|
|
954
|
+
Typical workflow:
|
|
955
|
+
|
|
956
|
+
# Full fine-tuning: evaluate checkpoint
|
|
957
|
+
deepfabric evaluate ./checkpoints/final ./eval.jsonl --output results.json
|
|
958
|
+
|
|
959
|
+
# LoRA/PEFT: evaluate adapter on base model
|
|
960
|
+
deepfabric evaluate unsloth/Qwen3-4B-Instruct ./eval.jsonl \\
|
|
961
|
+
--adapter-path ./lora_model \\
|
|
962
|
+
--output results.json
|
|
963
|
+
|
|
964
|
+
# Quick evaluation during development
|
|
965
|
+
deepfabric evaluate ./my-model ./eval.jsonl --max-samples 50
|
|
966
|
+
|
|
967
|
+
# Evaluate HuggingFace model
|
|
968
|
+
deepfabric evaluate username/model-name ./eval.jsonl \\
|
|
969
|
+
--temperature 0.5 \\
|
|
970
|
+
--device cuda
|
|
971
|
+
"""
|
|
972
|
+
tui = get_tui()
|
|
973
|
+
|
|
974
|
+
try:
|
|
975
|
+
from typing import Literal, cast # noqa: PLC0415
|
|
976
|
+
|
|
977
|
+
from .evaluation import EvaluatorConfig, InferenceConfig # noqa: PLC0415
|
|
978
|
+
from .evaluation.evaluator import Evaluator # noqa: PLC0415
|
|
979
|
+
|
|
980
|
+
# Create inference configuration
|
|
981
|
+
inference_config = InferenceConfig(
|
|
982
|
+
model_path=model_path,
|
|
983
|
+
adapter_path=adapter_path,
|
|
984
|
+
backend=cast(Literal["transformers", "ollama"], backend),
|
|
985
|
+
temperature=temperature,
|
|
986
|
+
max_tokens=max_tokens,
|
|
987
|
+
top_p=top_p,
|
|
988
|
+
device=device,
|
|
989
|
+
batch_size=batch_size,
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
# Create evaluator configuration
|
|
993
|
+
evaluator_config = EvaluatorConfig(
|
|
994
|
+
dataset_path=dataset_path,
|
|
995
|
+
output_path=output,
|
|
996
|
+
inference_config=inference_config,
|
|
997
|
+
batch_size=batch_size,
|
|
998
|
+
max_samples=max_samples,
|
|
999
|
+
save_predictions=not no_save_predictions,
|
|
1000
|
+
metric_weights={
|
|
1001
|
+
"accuracy": 1.0,
|
|
1002
|
+
"exact_match": 1.0,
|
|
1003
|
+
"f1_score": 0.5,
|
|
1004
|
+
},
|
|
1005
|
+
evaluators=["tool_calling"],
|
|
1006
|
+
reporters=["console", "json"] if output else ["console"],
|
|
1007
|
+
cloud_api_key=os.getenv("OPENAI_API_KEY"),
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
# Display configuration
|
|
1011
|
+
tui.console.print("\n[bold]Evaluation Configuration:[/bold]")
|
|
1012
|
+
tui.console.print(f" Model: {model_path}")
|
|
1013
|
+
tui.console.print(f" Backend: {backend}")
|
|
1014
|
+
if adapter_path:
|
|
1015
|
+
tui.console.print(f" Adapter: {adapter_path}")
|
|
1016
|
+
tui.console.print(f" Dataset: {dataset_path}")
|
|
1017
|
+
if output:
|
|
1018
|
+
tui.console.print(f" Output: {output}")
|
|
1019
|
+
tui.console.print(f" Batch size: {batch_size}")
|
|
1020
|
+
if max_samples:
|
|
1021
|
+
tui.console.print(f" Max samples: {max_samples}")
|
|
1022
|
+
tui.console.print(f" Temperature: {temperature}")
|
|
1023
|
+
tui.console.print(f" Max tokens: {max_tokens}")
|
|
1024
|
+
if device and backend == "transformers":
|
|
1025
|
+
tui.console.print(f" Device: {device}")
|
|
1026
|
+
tui.console.print()
|
|
1027
|
+
|
|
1028
|
+
# Create evaluator and run evaluation
|
|
1029
|
+
tui.console.print("[bold blue]Loading model...[/bold blue]")
|
|
1030
|
+
tui.console.print(
|
|
1031
|
+
"[dim]This may take several minutes for large models (downloading + loading into memory)[/dim]"
|
|
1032
|
+
)
|
|
1033
|
+
evaluator = Evaluator(evaluator_config)
|
|
1034
|
+
tui.console.print("[green]Model loaded successfully![/green]\n")
|
|
1035
|
+
|
|
1036
|
+
# Track evaluation start
|
|
1037
|
+
trace(
|
|
1038
|
+
"evaluation_started",
|
|
1039
|
+
{
|
|
1040
|
+
"model_path": model_path,
|
|
1041
|
+
"backend": backend,
|
|
1042
|
+
"has_adapter": adapter_path is not None,
|
|
1043
|
+
"dataset_path": dataset_path,
|
|
1044
|
+
"batch_size": batch_size,
|
|
1045
|
+
"max_samples": max_samples,
|
|
1046
|
+
"temperature": temperature,
|
|
1047
|
+
},
|
|
1048
|
+
)
|
|
1049
|
+
|
|
1050
|
+
try:
|
|
1051
|
+
result = evaluator.evaluate()
|
|
1052
|
+
|
|
1053
|
+
# Print summary
|
|
1054
|
+
evaluator.print_summary(result.metrics)
|
|
1055
|
+
|
|
1056
|
+
if output:
|
|
1057
|
+
tui.console.print(f"\n[green]Full results saved to {output}[/green]")
|
|
1058
|
+
|
|
1059
|
+
finally:
|
|
1060
|
+
evaluator.cleanup()
|
|
1061
|
+
|
|
1062
|
+
except FileNotFoundError as e:
|
|
1063
|
+
trace(
|
|
1064
|
+
"evaluation_failed",
|
|
1065
|
+
{
|
|
1066
|
+
"model_path": model_path,
|
|
1067
|
+
"backend": backend,
|
|
1068
|
+
"dataset_path": dataset_path,
|
|
1069
|
+
"error_type": "FileNotFoundError",
|
|
1070
|
+
},
|
|
1071
|
+
)
|
|
1072
|
+
handle_error(click.get_current_context(), e)
|
|
1073
|
+
except ValueError as e:
|
|
1074
|
+
trace(
|
|
1075
|
+
"evaluation_failed",
|
|
1076
|
+
{
|
|
1077
|
+
"model_path": model_path,
|
|
1078
|
+
"backend": backend,
|
|
1079
|
+
"dataset_path": dataset_path,
|
|
1080
|
+
"error_type": "ValueError",
|
|
1081
|
+
},
|
|
1082
|
+
)
|
|
1083
|
+
handle_error(click.get_current_context(), e)
|
|
1084
|
+
except Exception as e:
|
|
1085
|
+
trace(
|
|
1086
|
+
"evaluation_failed",
|
|
1087
|
+
{
|
|
1088
|
+
"model_path": model_path,
|
|
1089
|
+
"backend": backend,
|
|
1090
|
+
"dataset_path": dataset_path,
|
|
1091
|
+
"error_type": type(e).__name__,
|
|
1092
|
+
},
|
|
1093
|
+
)
|
|
1094
|
+
handle_error(click.get_current_context(), e)
|
|
1095
|
+
|
|
1096
|
+
|
|
1097
|
+
# Register the auth command group
|
|
1098
|
+
cli.add_command(auth_group)
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
@cli.command("import-tools")
|
|
1102
|
+
@click.option(
|
|
1103
|
+
"--transport",
|
|
1104
|
+
type=click.Choice(["stdio", "http"]),
|
|
1105
|
+
required=True,
|
|
1106
|
+
help="MCP transport type: stdio (subprocess) or http (Streamable HTTP)",
|
|
1107
|
+
)
|
|
1108
|
+
@click.option(
|
|
1109
|
+
"--command",
|
|
1110
|
+
"-c",
|
|
1111
|
+
help="Shell command to launch MCP server (required for stdio transport)",
|
|
1112
|
+
)
|
|
1113
|
+
@click.option(
|
|
1114
|
+
"--endpoint",
|
|
1115
|
+
"-e",
|
|
1116
|
+
help="HTTP endpoint URL for MCP server (required for http transport)",
|
|
1117
|
+
)
|
|
1118
|
+
@click.option(
|
|
1119
|
+
"--output",
|
|
1120
|
+
"-o",
|
|
1121
|
+
type=click.Path(),
|
|
1122
|
+
help="Output file path (.json or .yaml). Optional if --spin is used.",
|
|
1123
|
+
)
|
|
1124
|
+
@click.option(
|
|
1125
|
+
"--spin",
|
|
1126
|
+
"-s",
|
|
1127
|
+
"spin_endpoint",
|
|
1128
|
+
help="Spin server URL to push tools to (e.g., http://localhost:3000)",
|
|
1129
|
+
)
|
|
1130
|
+
@click.option(
|
|
1131
|
+
"--format",
|
|
1132
|
+
"output_format",
|
|
1133
|
+
type=click.Choice(["deepfabric", "openai"]),
|
|
1134
|
+
default="deepfabric",
|
|
1135
|
+
help="Output format: deepfabric (native) or openai (TRL compatible)",
|
|
1136
|
+
)
|
|
1137
|
+
@click.option(
|
|
1138
|
+
"--env",
|
|
1139
|
+
multiple=True,
|
|
1140
|
+
help="Environment variables for stdio (format: KEY=VALUE, can be repeated)",
|
|
1141
|
+
)
|
|
1142
|
+
@click.option(
|
|
1143
|
+
"--header",
|
|
1144
|
+
multiple=True,
|
|
1145
|
+
help="HTTP headers for authentication (format: KEY=VALUE, can be repeated)",
|
|
1146
|
+
)
|
|
1147
|
+
@click.option(
|
|
1148
|
+
"--timeout",
|
|
1149
|
+
type=float,
|
|
1150
|
+
default=30.0,
|
|
1151
|
+
help="Request timeout in seconds (default: 30)",
|
|
1152
|
+
)
|
|
1153
|
+
def import_tools(
|
|
1154
|
+
transport: str,
|
|
1155
|
+
command: str | None,
|
|
1156
|
+
endpoint: str | None,
|
|
1157
|
+
output: str | None,
|
|
1158
|
+
spin_endpoint: str | None,
|
|
1159
|
+
output_format: str,
|
|
1160
|
+
env: tuple[str, ...],
|
|
1161
|
+
header: tuple[str, ...],
|
|
1162
|
+
timeout: float,
|
|
1163
|
+
) -> None:
|
|
1164
|
+
"""Import tool definitions from an MCP (Model Context Protocol) server.
|
|
1165
|
+
|
|
1166
|
+
This command connects to an MCP server, fetches available tools via the
|
|
1167
|
+
tools/list method, and either saves them to a file or pushes them to a
|
|
1168
|
+
Spin server (or both).
|
|
1169
|
+
|
|
1170
|
+
Supports both MCP transport types:
|
|
1171
|
+
|
|
1172
|
+
\b
|
|
1173
|
+
STDIO: Launches the MCP server as a subprocess
|
|
1174
|
+
deepfabric import-tools --transport stdio \\
|
|
1175
|
+
--command "npx -y @modelcontextprotocol/server-filesystem /tmp" \\
|
|
1176
|
+
--output tools.json
|
|
1177
|
+
|
|
1178
|
+
\b
|
|
1179
|
+
HTTP: Connects to a running MCP server via HTTP
|
|
1180
|
+
deepfabric import-tools --transport http \\
|
|
1181
|
+
--endpoint "http://localhost:3000/mcp" \\
|
|
1182
|
+
--output tools.json
|
|
1183
|
+
|
|
1184
|
+
\b
|
|
1185
|
+
Push directly to Spin server:
|
|
1186
|
+
deepfabric import-tools --transport stdio \\
|
|
1187
|
+
--command "npx -y figma-developer-mcp --stdio" \\
|
|
1188
|
+
--env "FIGMA_API_KEY=your-key" \\
|
|
1189
|
+
--spin http://localhost:3000
|
|
1190
|
+
|
|
1191
|
+
\b
|
|
1192
|
+
Both save to file and push to Spin:
|
|
1193
|
+
deepfabric import-tools --transport stdio \\
|
|
1194
|
+
--command "your-mcp-server" \\
|
|
1195
|
+
--output tools.json \\
|
|
1196
|
+
--spin http://localhost:3000
|
|
1197
|
+
"""
|
|
1198
|
+
tui = get_tui()
|
|
1199
|
+
|
|
1200
|
+
# Validate that at least one output is specified
|
|
1201
|
+
if not output and not spin_endpoint:
|
|
1202
|
+
tui.error("At least one of --output or --spin is required")
|
|
1203
|
+
sys.exit(1)
|
|
1204
|
+
|
|
1205
|
+
def parse_key_value_pairs(pairs: tuple[str, ...], pair_type: str) -> dict[str, str]:
|
|
1206
|
+
"""Parse a tuple of 'KEY=VALUE' strings into a dictionary."""
|
|
1207
|
+
result: dict[str, str] = {}
|
|
1208
|
+
for p in pairs:
|
|
1209
|
+
if "=" not in p:
|
|
1210
|
+
tui.error(f"Invalid {pair_type} format: {p} (expected KEY=VALUE)")
|
|
1211
|
+
sys.exit(1)
|
|
1212
|
+
key, value = p.split("=", 1)
|
|
1213
|
+
result[key] = value
|
|
1214
|
+
return result
|
|
1215
|
+
|
|
1216
|
+
env_dict = parse_key_value_pairs(env, "env")
|
|
1217
|
+
header_dict = parse_key_value_pairs(header, "header")
|
|
1218
|
+
|
|
1219
|
+
# Validate transport-specific options
|
|
1220
|
+
if transport == "stdio" and not command:
|
|
1221
|
+
tui.error("--command is required for stdio transport")
|
|
1222
|
+
sys.exit(1)
|
|
1223
|
+
if transport == "http" and not endpoint:
|
|
1224
|
+
tui.error("--endpoint is required for http transport")
|
|
1225
|
+
sys.exit(1)
|
|
1226
|
+
|
|
1227
|
+
try:
|
|
1228
|
+
# Lazy import to avoid slow startup
|
|
1229
|
+
from typing import Literal, cast # noqa: PLC0415
|
|
1230
|
+
|
|
1231
|
+
from .tools.mcp_client import ( # noqa: PLC0415
|
|
1232
|
+
fetch_and_push_to_spin,
|
|
1233
|
+
fetch_tools_from_mcp,
|
|
1234
|
+
save_tools_to_file,
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
tui.info(f"Connecting to MCP server via {transport}...")
|
|
1238
|
+
|
|
1239
|
+
# If pushing to Spin, use the combined function
|
|
1240
|
+
if spin_endpoint:
|
|
1241
|
+
registry, spin_result = fetch_and_push_to_spin(
|
|
1242
|
+
transport=cast(Literal["stdio", "http"], transport),
|
|
1243
|
+
spin_endpoint=spin_endpoint,
|
|
1244
|
+
command=command,
|
|
1245
|
+
endpoint=endpoint,
|
|
1246
|
+
env=env_dict if env_dict else None,
|
|
1247
|
+
headers=header_dict if header_dict else None,
|
|
1248
|
+
timeout=timeout,
|
|
1249
|
+
)
|
|
1250
|
+
tui.success(f"Pushed {spin_result.loaded} tools to Spin server at {spin_endpoint}")
|
|
1251
|
+
else:
|
|
1252
|
+
# Just fetch without pushing to Spin
|
|
1253
|
+
registry = fetch_tools_from_mcp(
|
|
1254
|
+
transport=cast(Literal["stdio", "http"], transport),
|
|
1255
|
+
command=command,
|
|
1256
|
+
endpoint=endpoint,
|
|
1257
|
+
env=env_dict if env_dict else None,
|
|
1258
|
+
headers=header_dict if header_dict else None,
|
|
1259
|
+
timeout=timeout,
|
|
1260
|
+
)
|
|
1261
|
+
|
|
1262
|
+
if not registry.tools:
|
|
1263
|
+
tui.warning("No tools found from MCP server")
|
|
1264
|
+
sys.exit(0)
|
|
1265
|
+
|
|
1266
|
+
# Save to file if output path specified
|
|
1267
|
+
if output:
|
|
1268
|
+
save_tools_to_file(
|
|
1269
|
+
registry,
|
|
1270
|
+
output,
|
|
1271
|
+
output_format=cast(Literal["deepfabric", "openai"], output_format),
|
|
1272
|
+
)
|
|
1273
|
+
tui.success(f"Saved {len(registry.tools)} tools to {output}")
|
|
1274
|
+
|
|
1275
|
+
# List the tools
|
|
1276
|
+
tui.console.print("\nImported tools:", style="cyan bold")
|
|
1277
|
+
for tool in registry.tools:
|
|
1278
|
+
param_count = len(tool.parameters)
|
|
1279
|
+
desc = tool.description[:60] if tool.description else "(no description)"
|
|
1280
|
+
tui.console.print(f" - {tool.name} ({param_count} params): {desc}...")
|
|
1281
|
+
|
|
1282
|
+
except Exception as e:
|
|
1283
|
+
tui.error(f"Failed to import tools: {str(e)}")
|
|
1284
|
+
sys.exit(1)
|
|
1285
|
+
|
|
1286
|
+
|
|
1287
|
+
if __name__ == "__main__":
|
|
1288
|
+
cli()
|