DeepFabric 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. deepfabric/__init__.py +70 -0
  2. deepfabric/__main__.py +6 -0
  3. deepfabric/auth.py +382 -0
  4. deepfabric/builders.py +303 -0
  5. deepfabric/builders_agent.py +1304 -0
  6. deepfabric/cli.py +1288 -0
  7. deepfabric/config.py +899 -0
  8. deepfabric/config_manager.py +251 -0
  9. deepfabric/constants.py +94 -0
  10. deepfabric/dataset_manager.py +534 -0
  11. deepfabric/error_codes.py +581 -0
  12. deepfabric/evaluation/__init__.py +47 -0
  13. deepfabric/evaluation/backends/__init__.py +32 -0
  14. deepfabric/evaluation/backends/ollama_backend.py +137 -0
  15. deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
  16. deepfabric/evaluation/backends/transformers_backend.py +326 -0
  17. deepfabric/evaluation/evaluator.py +845 -0
  18. deepfabric/evaluation/evaluators/__init__.py +13 -0
  19. deepfabric/evaluation/evaluators/base.py +104 -0
  20. deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
  21. deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
  22. deepfabric/evaluation/evaluators/registry.py +66 -0
  23. deepfabric/evaluation/inference.py +155 -0
  24. deepfabric/evaluation/metrics.py +397 -0
  25. deepfabric/evaluation/parser.py +304 -0
  26. deepfabric/evaluation/reporters/__init__.py +13 -0
  27. deepfabric/evaluation/reporters/base.py +56 -0
  28. deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
  29. deepfabric/evaluation/reporters/file_reporter.py +61 -0
  30. deepfabric/evaluation/reporters/multi_reporter.py +56 -0
  31. deepfabric/exceptions.py +67 -0
  32. deepfabric/factory.py +26 -0
  33. deepfabric/generator.py +1084 -0
  34. deepfabric/graph.py +545 -0
  35. deepfabric/hf_hub.py +214 -0
  36. deepfabric/kaggle_hub.py +219 -0
  37. deepfabric/llm/__init__.py +41 -0
  38. deepfabric/llm/api_key_verifier.py +534 -0
  39. deepfabric/llm/client.py +1206 -0
  40. deepfabric/llm/errors.py +105 -0
  41. deepfabric/llm/rate_limit_config.py +262 -0
  42. deepfabric/llm/rate_limit_detector.py +278 -0
  43. deepfabric/llm/retry_handler.py +270 -0
  44. deepfabric/metrics.py +212 -0
  45. deepfabric/progress.py +262 -0
  46. deepfabric/prompts.py +290 -0
  47. deepfabric/schemas.py +1000 -0
  48. deepfabric/spin/__init__.py +6 -0
  49. deepfabric/spin/client.py +263 -0
  50. deepfabric/spin/models.py +26 -0
  51. deepfabric/stream_simulator.py +90 -0
  52. deepfabric/tools/__init__.py +5 -0
  53. deepfabric/tools/defaults.py +85 -0
  54. deepfabric/tools/loader.py +87 -0
  55. deepfabric/tools/mcp_client.py +677 -0
  56. deepfabric/topic_manager.py +303 -0
  57. deepfabric/topic_model.py +20 -0
  58. deepfabric/training/__init__.py +35 -0
  59. deepfabric/training/api_key_prompt.py +302 -0
  60. deepfabric/training/callback.py +363 -0
  61. deepfabric/training/metrics_sender.py +301 -0
  62. deepfabric/tree.py +438 -0
  63. deepfabric/tui.py +1267 -0
  64. deepfabric/update_checker.py +166 -0
  65. deepfabric/utils.py +150 -0
  66. deepfabric/validation.py +143 -0
  67. deepfabric-4.4.0.dist-info/METADATA +702 -0
  68. deepfabric-4.4.0.dist-info/RECORD +71 -0
  69. deepfabric-4.4.0.dist-info/WHEEL +4 -0
  70. deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
  71. deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,326 @@
1
+ import json
2
+ import logging
3
+
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
9
+
10
+ from ...schemas import ToolDefinition
11
+ from ..inference import InferenceBackend, InferenceConfig, ModelResponse
12
+ from .tool_call_parsers import ToolCallParser, get_parser
13
+
14
+ # Mistral-family architectures that require fix_mistral_regex=True
15
+ MISTRAL_ARCHITECTURES = frozenset(
16
+ {
17
+ "MistralForCausalLM",
18
+ "Mistral3ForConditionalGeneration",
19
+ "MixtralForCausalLM",
20
+ "MinistralForCausalLM",
21
+ "PixtralForConditionalGeneration",
22
+ }
23
+ )
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class TransformersBackend(InferenceBackend):
29
+ """Inference backend using HuggingFace Transformers."""
30
+
31
+ def __init__(self, config: InferenceConfig):
32
+ """Initialize Transformers backend.
33
+
34
+ Args:
35
+ config: Inference configuration
36
+ """
37
+ super().__init__(config)
38
+
39
+ # Determine device
40
+ if config.device:
41
+ self.device = config.device
42
+ # Auto-detect best available device
43
+ elif torch.cuda.is_available():
44
+ self.device = "cuda"
45
+ elif torch.backends.mps.is_available():
46
+ self.device = "mps"
47
+ else:
48
+ self.device = "cpu"
49
+
50
+ # Determine dtype based on device
51
+ if self.device == "cuda":
52
+ dtype = torch.float16
53
+ device_map = "auto"
54
+ elif self.device == "mps":
55
+ dtype = torch.float32 # MPS works best with float32
56
+ device_map = None
57
+ else:
58
+ dtype = torch.float32
59
+ device_map = None
60
+
61
+ # Detect model architecture for parser selection and tokenizer config
62
+ self._architectures: list[str] = []
63
+ tokenizer_kwargs: dict[str, Any] = {}
64
+ try:
65
+ model_config = AutoConfig.from_pretrained(config.model_path) # nosec
66
+ self._architectures = getattr(model_config, "architectures", []) or []
67
+ if any(arch in MISTRAL_ARCHITECTURES for arch in self._architectures):
68
+ tokenizer_kwargs["fix_mistral_regex"] = True
69
+ logger.debug("Detected Mistral architecture, enabling fix_mistral_regex")
70
+ except Exception as e:
71
+ logger.warning("Could not detect model architecture: %s", e)
72
+
73
+ # Initialize tool call parser based on detected architecture
74
+ self._tool_call_parser: ToolCallParser = get_parser(self._architectures)
75
+ logger.info(
76
+ "Using %s for model architectures: %s",
77
+ type(self._tool_call_parser).__name__,
78
+ self._architectures or ["unknown"],
79
+ )
80
+
81
+ self.loaded_with_unsloth = False
82
+ # Load with Unsloth if requested
83
+ if config.use_unsloth:
84
+ try:
85
+ from unsloth import FastLanguageModel # type: ignore # noqa: PLC0415
86
+
87
+ # Load from adapter path if provided, otherwise from model_path
88
+ load_path = config.adapter_path if config.adapter_path else config.model_path
89
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
90
+ model_name=load_path,
91
+ max_seq_length=config.max_seq_length,
92
+ dtype=dtype,
93
+ load_in_4bit=config.load_in_4bit,
94
+ )
95
+ FastLanguageModel.for_inference(self.model)
96
+ self.loaded_with_unsloth = True
97
+ except ImportError:
98
+ logger.warning("Unsloth not installed, falling back to standard transformers")
99
+ except Exception as e:
100
+ logger.warning(
101
+ "Unsloth loading failed (%s), falling back to standard transformers", e
102
+ )
103
+
104
+ # Standard transformers/PEFT loading
105
+ if not self.loaded_with_unsloth:
106
+ self.tokenizer = AutoTokenizer.from_pretrained( # nosec
107
+ config.model_path, **tokenizer_kwargs
108
+ )
109
+
110
+ self.model = AutoModelForCausalLM.from_pretrained( # nosec
111
+ config.model_path,
112
+ device_map=device_map,
113
+ dtype=dtype,
114
+ )
115
+
116
+ # Load PEFT adapter if provided
117
+ if config.adapter_path:
118
+ from peft import PeftModel # noqa: PLC0415
119
+
120
+ self.model = PeftModel.from_pretrained(self.model, config.adapter_path)
121
+
122
+ # Move to device if not using device_map
123
+ if self.device in ("cpu", "mps"):
124
+ self.model.to(self.device) # type: ignore[arg-type]
125
+
126
+ # Note: torch.compile disabled - causes very slow first inference
127
+ # due to CUDA graph compilation overhead. For evaluation workloads
128
+ # with many short inferences, the compilation cost isn't amortized.
129
+ # Uncomment for long-running inference servers:
130
+ # with suppress(Exception):
131
+ # self.model = torch.compile(self.model, mode="reduce-overhead")
132
+
133
+ # Set padding token if not set
134
+ if self.tokenizer.pad_token is None:
135
+ self.tokenizer.pad_token = self.tokenizer.eos_token
136
+
137
+ def generate(
138
+ self,
139
+ messages: list[dict[str, str]],
140
+ tools: list[ToolDefinition] | None = None,
141
+ ) -> ModelResponse:
142
+ """Generate response from model.
143
+
144
+ Args:
145
+ messages: List of message dicts with 'role' and 'content'
146
+ tools: Optional list of available tools for function calling
147
+
148
+ Returns:
149
+ ModelResponse with generated content and parsed tool calls
150
+ """
151
+ # Format messages using chat template
152
+ prompt = self._format_prompt(messages, tools)
153
+
154
+ # Tokenize
155
+ inputs = self.tokenizer(
156
+ prompt,
157
+ return_tensors="pt",
158
+ padding=True,
159
+ truncation=True,
160
+ ).to(self.model.device)
161
+
162
+ # Generate with optimizations
163
+ with torch.no_grad():
164
+ outputs = self.model.generate(
165
+ **inputs,
166
+ max_new_tokens=self.config.max_tokens,
167
+ temperature=self.config.temperature,
168
+ top_p=self.config.top_p,
169
+ do_sample=self.config.temperature > 0,
170
+ pad_token_id=self.tokenizer.pad_token_id,
171
+ eos_token_id=self.tokenizer.eos_token_id,
172
+ # Performance optimizations
173
+ use_cache=True, # Enable KV cache for faster generation
174
+ num_beams=1, # Greedy decoding (faster than beam search)
175
+ )
176
+
177
+ # Decode output
178
+ generated_ids = outputs[0][inputs.input_ids.shape[1] :]
179
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
180
+
181
+ # Parse tool calls if present
182
+ tool_calls = self._tool_call_parser.parse(generated_text) if tools else []
183
+ tool_call = tool_calls[0] if tool_calls else None
184
+
185
+ return ModelResponse(
186
+ content=generated_text,
187
+ tool_call=tool_call,
188
+ tool_calls=tool_calls if tool_calls else None,
189
+ raw_output=generated_text,
190
+ finish_reason="stop",
191
+ )
192
+
193
+ def generate_batch(
194
+ self,
195
+ batch_messages: list[list[dict[str, str]]],
196
+ tools: list[ToolDefinition] | None = None,
197
+ ) -> list[ModelResponse]:
198
+ """Generate responses for a batch of message sequences.
199
+
200
+ Args:
201
+ batch_messages: List of message sequences
202
+ tools: Optional list of available tools for function calling
203
+
204
+ Returns:
205
+ List of ModelResponse objects
206
+ """
207
+ # Format all prompts
208
+ prompts = [self._format_prompt(msgs, tools) for msgs in batch_messages]
209
+
210
+ # Tokenize batch
211
+ inputs = self.tokenizer(
212
+ prompts,
213
+ return_tensors="pt",
214
+ padding=True,
215
+ truncation=True,
216
+ ).to(self.model.device)
217
+
218
+ # Generate batch with optimizations
219
+ with torch.no_grad():
220
+ outputs = self.model.generate(
221
+ **inputs,
222
+ max_new_tokens=self.config.max_tokens,
223
+ temperature=self.config.temperature,
224
+ top_p=self.config.top_p,
225
+ do_sample=self.config.temperature > 0,
226
+ pad_token_id=self.tokenizer.pad_token_id,
227
+ eos_token_id=self.tokenizer.eos_token_id,
228
+ # Performance optimizations
229
+ use_cache=True, # Enable KV cache for faster generation
230
+ num_beams=1, # Greedy decoding (faster than beam search)
231
+ )
232
+
233
+ # Decode outputs
234
+ responses = []
235
+ for i, output_ids in enumerate(outputs):
236
+ # Extract generated portion (skip input tokens)
237
+ generated_ids = output_ids[inputs.input_ids[i].shape[0] :]
238
+ generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
239
+
240
+ # Parse tool calls if present
241
+ tool_calls = self._tool_call_parser.parse(generated_text) if tools else []
242
+ tool_call = tool_calls[0] if tool_calls else None
243
+
244
+ responses.append(
245
+ ModelResponse(
246
+ content=generated_text,
247
+ tool_call=tool_call,
248
+ tool_calls=tool_calls if tool_calls else None,
249
+ raw_output=generated_text,
250
+ finish_reason="stop",
251
+ )
252
+ )
253
+
254
+ return responses
255
+
256
+ def cleanup(self) -> None:
257
+ """Clean up GPU memory."""
258
+ if hasattr(self, "model"):
259
+ del self.model
260
+ if hasattr(self, "tokenizer"):
261
+ del self.tokenizer
262
+ if torch.cuda.is_available():
263
+ torch.cuda.empty_cache()
264
+
265
+ def _format_prompt(
266
+ self,
267
+ messages: list[dict[str, str]],
268
+ tools: list[ToolDefinition] | None = None,
269
+ ) -> str:
270
+ """Format messages into a prompt string.
271
+
272
+ Args:
273
+ messages: List of message dicts
274
+ tools: Optional list of tools
275
+
276
+ Returns:
277
+ Formatted prompt string
278
+ """
279
+ # Try to use chat template with tools support (modern approach)
280
+ if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template:
281
+ try:
282
+ # Convert tools to OpenAI format for chat template compatibility
283
+ tools_param = None
284
+ if tools:
285
+ tools_param = [tool.to_openai() for tool in tools]
286
+
287
+ # Try with tools parameter (for models with native tool support)
288
+ return self.tokenizer.apply_chat_template(
289
+ messages,
290
+ tools=tools_param,
291
+ tokenize=False,
292
+ add_generation_prompt=True,
293
+ )
294
+ except (TypeError, KeyError):
295
+ # Model's chat template doesn't support tools parameter
296
+ # Try without tools parameter
297
+ try:
298
+ return self.tokenizer.apply_chat_template(
299
+ messages,
300
+ tokenize=False,
301
+ add_generation_prompt=True,
302
+ )
303
+ except Exception: # noqa: S110
304
+ # Fallback to manual formatting
305
+ pass # nosec
306
+
307
+ # Manual formatting fallback (for models without chat templates)
308
+ prompt_parts = []
309
+
310
+ # Add tools if present
311
+ if tools:
312
+ tools_str = "Available tools:\n"
313
+ for tool in tools:
314
+ tools_str += f"- {tool.name}: {tool.description}\n"
315
+ params_list = [p.model_dump() for p in tool.parameters]
316
+ tools_str += f" Parameters: {json.dumps(params_list)}\n"
317
+ prompt_parts.append(tools_str)
318
+
319
+ # Add messages
320
+ for msg in messages:
321
+ role = msg["role"]
322
+ content = msg["content"]
323
+ prompt_parts.append(f"{role.upper()}: {content}")
324
+
325
+ prompt_parts.append("ASSISTANT:")
326
+ return "\n\n".join(prompt_parts)