DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
"""Ollama-based inference backend."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import ollama
|
|
8
|
+
|
|
9
|
+
from ...schemas import ToolDefinition
|
|
10
|
+
from ..inference import InferenceBackend, ModelResponse
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ..inference import InferenceConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OllamaBackend(InferenceBackend):
|
|
17
|
+
"""Inference backend using Ollama for local model serving.
|
|
18
|
+
|
|
19
|
+
Ollama provides optimized local inference for open models with native
|
|
20
|
+
Apple Silicon (M1/M2/M3) support and automatic memory management.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, config: InferenceConfig):
|
|
24
|
+
"""Initialize Ollama backend.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
config: Inference configuration
|
|
28
|
+
|
|
29
|
+
Note:
|
|
30
|
+
- model_path should be the Ollama model name (e.g., "mistral", "llama2")
|
|
31
|
+
- Ollama server must be running (ollama serve)
|
|
32
|
+
- Device setting is ignored (Ollama handles device automatically)
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(config)
|
|
35
|
+
|
|
36
|
+
# Use model_path directly as Ollama model name
|
|
37
|
+
# Supports: "qwen3:8b", "hf.co/user/model:latest", etc.
|
|
38
|
+
self.model_name = config.model_path
|
|
39
|
+
|
|
40
|
+
# Verify model is available
|
|
41
|
+
try:
|
|
42
|
+
ollama.show(self.model_name)
|
|
43
|
+
except ollama.ResponseError as e:
|
|
44
|
+
msg = (
|
|
45
|
+
f"Model '{self.model_name}' not found in Ollama. "
|
|
46
|
+
f"Pull it first with: ollama pull {self.model_name}"
|
|
47
|
+
)
|
|
48
|
+
raise ValueError(msg) from e
|
|
49
|
+
|
|
50
|
+
def generate(
|
|
51
|
+
self,
|
|
52
|
+
messages: list[dict[str, str]],
|
|
53
|
+
tools: list[ToolDefinition] | None = None,
|
|
54
|
+
) -> ModelResponse:
|
|
55
|
+
"""Generate response from Ollama model.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
messages: List of message dicts with 'role' and 'content'
|
|
59
|
+
tools: Optional list of available tools for function calling
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
ModelResponse with generated content and parsed tool calls
|
|
63
|
+
"""
|
|
64
|
+
# Convert tools to Ollama format if provided
|
|
65
|
+
ollama_tools = None
|
|
66
|
+
if tools:
|
|
67
|
+
ollama_tools = [self._convert_tool_to_ollama(tool) for tool in tools]
|
|
68
|
+
|
|
69
|
+
# Call Ollama API
|
|
70
|
+
response = ollama.chat(
|
|
71
|
+
model=self.model_name,
|
|
72
|
+
messages=messages,
|
|
73
|
+
tools=ollama_tools,
|
|
74
|
+
options={
|
|
75
|
+
"temperature": self.config.temperature,
|
|
76
|
+
"num_predict": self.config.max_tokens,
|
|
77
|
+
"top_p": self.config.top_p,
|
|
78
|
+
},
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Extract response content (response is a Pydantic object, not dict)
|
|
82
|
+
message = response.message
|
|
83
|
+
content = message.content or ""
|
|
84
|
+
raw_output = content
|
|
85
|
+
|
|
86
|
+
# Parse tool calls if present
|
|
87
|
+
tool_call = None
|
|
88
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
89
|
+
# Ollama returns tool calls in a structured format
|
|
90
|
+
first_tool_call = message.tool_calls[0]
|
|
91
|
+
tool_call = {
|
|
92
|
+
"name": first_tool_call.function.name,
|
|
93
|
+
"arguments": first_tool_call.function.arguments,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
return ModelResponse(
|
|
97
|
+
content=content,
|
|
98
|
+
tool_call=tool_call,
|
|
99
|
+
raw_output=raw_output,
|
|
100
|
+
finish_reason=response.done_reason if hasattr(response, "done_reason") else None,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def generate_batch(
|
|
104
|
+
self,
|
|
105
|
+
batch_messages: list[list[dict[str, str]]],
|
|
106
|
+
tools: list[ToolDefinition] | None = None,
|
|
107
|
+
) -> list[ModelResponse]:
|
|
108
|
+
"""Generate responses for a batch of message lists.
|
|
109
|
+
|
|
110
|
+
Note: Ollama doesn't support true batching, so this processes sequentially.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
batch_messages: List of message lists
|
|
114
|
+
tools: Optional list of available tools
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of ModelResponse objects
|
|
118
|
+
"""
|
|
119
|
+
return [self.generate(messages, tools) for messages in batch_messages]
|
|
120
|
+
|
|
121
|
+
def _convert_tool_to_ollama(self, tool: ToolDefinition) -> dict:
|
|
122
|
+
"""Convert ToolDefinition to Ollama tool format.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
tool: DeepFabric ToolDefinition
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Ollama-formatted tool dict
|
|
129
|
+
"""
|
|
130
|
+
# Use the built-in OpenAI schema converter (Ollama uses same format)
|
|
131
|
+
return tool.to_openai()
|
|
132
|
+
|
|
133
|
+
def cleanup(self) -> None:
|
|
134
|
+
"""Clean up resources.
|
|
135
|
+
|
|
136
|
+
Note: Ollama manages resources automatically, no cleanup needed.
|
|
137
|
+
"""
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
"""Tool call parsers for different model architectures.
|
|
2
|
+
|
|
3
|
+
Each model family outputs tool calls in a different format. This module provides
|
|
4
|
+
a registry of parsers that can extract tool calls from generated text based on
|
|
5
|
+
the model architecture.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import logging
|
|
10
|
+
import re
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from pydantic import BaseModel, ValidationError, field_validator
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ToolCall(BaseModel):
|
|
21
|
+
"""Parsed tool call in normalized format."""
|
|
22
|
+
|
|
23
|
+
name: str
|
|
24
|
+
arguments: dict[str, Any]
|
|
25
|
+
|
|
26
|
+
@field_validator("arguments", mode="before")
|
|
27
|
+
@classmethod
|
|
28
|
+
def parse_arguments_string(cls, v: Any) -> dict[str, Any]:
|
|
29
|
+
"""Parse arguments if they're a JSON string."""
|
|
30
|
+
if isinstance(v, str):
|
|
31
|
+
try:
|
|
32
|
+
return json.loads(v)
|
|
33
|
+
except json.JSONDecodeError as e:
|
|
34
|
+
raise ValueError(f"Invalid JSON in arguments field: {e}") from e
|
|
35
|
+
return v
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ToolCallParser(ABC):
|
|
39
|
+
"""Abstract base class for tool call parsers."""
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def parse(self, text: str) -> list[dict[str, Any]]:
|
|
43
|
+
"""Parse tool calls from generated text.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
text: Generated text from model
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
List of tool call dicts with 'name' and 'arguments' keys
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _validate_tool_call(data: dict) -> dict | None:
|
|
54
|
+
"""Validate and normalize a tool call dict.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
data: Raw parsed data
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Normalized tool call dict or None if invalid
|
|
61
|
+
"""
|
|
62
|
+
try:
|
|
63
|
+
tool_call = ToolCall.model_validate(data)
|
|
64
|
+
return tool_call.model_dump()
|
|
65
|
+
except ValidationError:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class QwenToolCallParser(ToolCallParser):
|
|
70
|
+
"""Parser for Qwen2.5 and Qwen3 models.
|
|
71
|
+
|
|
72
|
+
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
# Pattern matches <tool_call>...</tool_call> with content inside
|
|
76
|
+
TOOL_CALL_PATTERN = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
|
|
77
|
+
|
|
78
|
+
def parse(self, text: str) -> list[dict[str, Any]]:
|
|
79
|
+
"""Parse Qwen-style tool calls."""
|
|
80
|
+
tool_calls = []
|
|
81
|
+
|
|
82
|
+
for match in self.TOOL_CALL_PATTERN.finditer(text):
|
|
83
|
+
content = match.group(1).strip()
|
|
84
|
+
try:
|
|
85
|
+
data = json.loads(content)
|
|
86
|
+
if validated := self._validate_tool_call(data):
|
|
87
|
+
tool_calls.append(validated)
|
|
88
|
+
except json.JSONDecodeError:
|
|
89
|
+
logger.debug("Failed to parse Qwen tool call JSON: %s", content[:100])
|
|
90
|
+
|
|
91
|
+
return tool_calls
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class LlamaToolCallParser(ToolCallParser):
|
|
95
|
+
"""Parser for Llama 3.1/3.2/3.3 models.
|
|
96
|
+
|
|
97
|
+
Llama uses multiple formats:
|
|
98
|
+
- JSON with "name" and "parameters" keys
|
|
99
|
+
- <|python_tag|> for code execution
|
|
100
|
+
- {"type": "function", "function": {"name": ..., "arguments": ...}}
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
# Pattern for function call JSON objects
|
|
104
|
+
FUNCTION_PATTERN = re.compile(
|
|
105
|
+
r'\{\s*"name"\s*:\s*"([^"]+)"\s*,\s*"parameters"\s*:\s*(\{[^}]*\})\s*\}',
|
|
106
|
+
re.DOTALL,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Pattern for OpenAI-style tool calls
|
|
110
|
+
OPENAI_PATTERN = re.compile(
|
|
111
|
+
r'\{\s*"type"\s*:\s*"function"\s*,\s*"function"\s*:\s*(\{[^}]+\})\s*\}',
|
|
112
|
+
re.DOTALL,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def parse(self, text: str) -> list[dict[str, Any]]:
|
|
116
|
+
"""Parse Llama-style tool calls."""
|
|
117
|
+
tool_calls = []
|
|
118
|
+
|
|
119
|
+
# Try direct name/parameters format
|
|
120
|
+
for match in self.FUNCTION_PATTERN.finditer(text):
|
|
121
|
+
name = match.group(1)
|
|
122
|
+
try:
|
|
123
|
+
params = json.loads(match.group(2))
|
|
124
|
+
data = {"name": name, "arguments": params}
|
|
125
|
+
if validated := self._validate_tool_call(data):
|
|
126
|
+
tool_calls.append(validated)
|
|
127
|
+
except json.JSONDecodeError:
|
|
128
|
+
logger.debug("Failed to parse Llama parameters: %s", match.group(2)[:100])
|
|
129
|
+
|
|
130
|
+
# Try OpenAI-style format if no matches
|
|
131
|
+
if not tool_calls:
|
|
132
|
+
for match in self.OPENAI_PATTERN.finditer(text):
|
|
133
|
+
try:
|
|
134
|
+
func_data = json.loads(match.group(1))
|
|
135
|
+
data = {
|
|
136
|
+
"name": func_data.get("name", ""),
|
|
137
|
+
"arguments": func_data.get("arguments", {}),
|
|
138
|
+
}
|
|
139
|
+
if validated := self._validate_tool_call(data):
|
|
140
|
+
tool_calls.append(validated)
|
|
141
|
+
except json.JSONDecodeError:
|
|
142
|
+
logger.debug("Failed to parse Llama OpenAI-style: %s", match.group(1)[:100])
|
|
143
|
+
|
|
144
|
+
# Fallback: try to find any JSON with name/arguments or name/parameters
|
|
145
|
+
if not tool_calls:
|
|
146
|
+
tool_calls = self._fallback_json_parse(text)
|
|
147
|
+
|
|
148
|
+
return tool_calls
|
|
149
|
+
|
|
150
|
+
def _fallback_json_parse(self, text: str) -> list[dict[str, Any]]:
|
|
151
|
+
"""Fallback parser that looks for JSON objects with tool call structure."""
|
|
152
|
+
tool_calls = []
|
|
153
|
+
# Find all JSON-like objects
|
|
154
|
+
for match in re.finditer(r"\{[^{}]*\}", text):
|
|
155
|
+
try:
|
|
156
|
+
data = json.loads(match.group())
|
|
157
|
+
# Check for name + arguments OR name + parameters
|
|
158
|
+
if "name" in data:
|
|
159
|
+
if "arguments" in data:
|
|
160
|
+
if validated := self._validate_tool_call(data):
|
|
161
|
+
tool_calls.append(validated)
|
|
162
|
+
elif "parameters" in data:
|
|
163
|
+
normalized = {"name": data["name"], "arguments": data["parameters"]}
|
|
164
|
+
if validated := self._validate_tool_call(normalized):
|
|
165
|
+
tool_calls.append(validated)
|
|
166
|
+
except json.JSONDecodeError:
|
|
167
|
+
continue
|
|
168
|
+
return tool_calls
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class MistralToolCallParser(ToolCallParser):
|
|
172
|
+
"""Parser for Mistral and Mixtral models.
|
|
173
|
+
|
|
174
|
+
Format: [TOOL_CALLS] [{"name": "func", "arguments": {...}}]
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
# Pattern for [TOOL_CALLS] marker followed by JSON array
|
|
178
|
+
TOOL_CALLS_PATTERN = re.compile(
|
|
179
|
+
r"\[TOOL_CALLS\]\s*(\[.*?\])",
|
|
180
|
+
re.DOTALL,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Fallback pattern for JSON array of tool calls
|
|
184
|
+
JSON_ARRAY_PATTERN = re.compile(r"\[\s*\{.*?\}\s*\]", re.DOTALL)
|
|
185
|
+
|
|
186
|
+
def parse(self, text: str) -> list[dict[str, Any]]:
|
|
187
|
+
"""Parse Mistral-style tool calls."""
|
|
188
|
+
tool_calls = []
|
|
189
|
+
|
|
190
|
+
# Try [TOOL_CALLS] format first
|
|
191
|
+
match = self.TOOL_CALLS_PATTERN.search(text)
|
|
192
|
+
if match:
|
|
193
|
+
try:
|
|
194
|
+
calls_array = json.loads(match.group(1))
|
|
195
|
+
if isinstance(calls_array, list):
|
|
196
|
+
for call in calls_array:
|
|
197
|
+
if validated := self._validate_tool_call(call):
|
|
198
|
+
tool_calls.append(validated)
|
|
199
|
+
except json.JSONDecodeError:
|
|
200
|
+
logger.debug("Failed to parse Mistral TOOL_CALLS array")
|
|
201
|
+
|
|
202
|
+
# Fallback to looking for JSON arrays
|
|
203
|
+
if not tool_calls:
|
|
204
|
+
for match in self.JSON_ARRAY_PATTERN.finditer(text):
|
|
205
|
+
try:
|
|
206
|
+
calls_array = json.loads(match.group())
|
|
207
|
+
if isinstance(calls_array, list):
|
|
208
|
+
for call in calls_array:
|
|
209
|
+
if isinstance(call, dict):
|
|
210
|
+
validated = self._validate_tool_call(call)
|
|
211
|
+
if validated:
|
|
212
|
+
tool_calls.append(validated)
|
|
213
|
+
except json.JSONDecodeError:
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
return tool_calls
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class HermesToolCallParser(ToolCallParser):
|
|
220
|
+
"""Parser for Hermes/Nous-style models.
|
|
221
|
+
|
|
222
|
+
Similar to Qwen but may have variations in formatting.
|
|
223
|
+
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
TOOL_CALL_PATTERN = re.compile(r"<tool_call>\s*(.*?)\s*</tool_call>", re.DOTALL)
|
|
227
|
+
|
|
228
|
+
def parse(self, text: str) -> list[dict[str, Any]]:
|
|
229
|
+
"""Parse Hermes-style tool calls."""
|
|
230
|
+
tool_calls = []
|
|
231
|
+
|
|
232
|
+
for match in self.TOOL_CALL_PATTERN.finditer(text):
|
|
233
|
+
content = match.group(1).strip()
|
|
234
|
+
try:
|
|
235
|
+
data = json.loads(content)
|
|
236
|
+
if validated := self._validate_tool_call(data):
|
|
237
|
+
tool_calls.append(validated)
|
|
238
|
+
except json.JSONDecodeError:
|
|
239
|
+
logger.debug("Failed to parse Hermes tool call: %s", content[:100])
|
|
240
|
+
|
|
241
|
+
return tool_calls
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class GenericToolCallParser(ToolCallParser):
|
|
245
|
+
"""Generic fallback parser that tries multiple strategies.
|
|
246
|
+
|
|
247
|
+
Used when model architecture is unknown or not specifically supported.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
def __init__(self) -> None:
|
|
251
|
+
"""Initialize with sub-parsers to try in order."""
|
|
252
|
+
self._parsers = [
|
|
253
|
+
QwenToolCallParser(),
|
|
254
|
+
MistralToolCallParser(),
|
|
255
|
+
LlamaToolCallParser(),
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
def parse(self, text: str) -> list[dict[str, Any]]:
|
|
259
|
+
"""Try each parser until one succeeds."""
|
|
260
|
+
for parser in self._parsers:
|
|
261
|
+
tool_calls = parser.parse(text)
|
|
262
|
+
if tool_calls:
|
|
263
|
+
return tool_calls
|
|
264
|
+
|
|
265
|
+
# Final fallback: extract any JSON object with name + arguments
|
|
266
|
+
return self._extract_json_objects(text)
|
|
267
|
+
|
|
268
|
+
def _extract_json_objects(self, text: str) -> list[dict[str, Any]]:
|
|
269
|
+
"""Extract JSON objects that look like tool calls."""
|
|
270
|
+
tool_calls = []
|
|
271
|
+
depth = 0
|
|
272
|
+
start = -1
|
|
273
|
+
|
|
274
|
+
for i, char in enumerate(text):
|
|
275
|
+
if char == "{":
|
|
276
|
+
if depth == 0:
|
|
277
|
+
start = i
|
|
278
|
+
depth += 1
|
|
279
|
+
elif char == "}":
|
|
280
|
+
depth -= 1
|
|
281
|
+
if depth == 0 and start >= 0:
|
|
282
|
+
json_str = text[start : i + 1]
|
|
283
|
+
try:
|
|
284
|
+
data = json.loads(json_str)
|
|
285
|
+
if "name" in data and ("arguments" in data or "parameters" in data):
|
|
286
|
+
if "parameters" in data and "arguments" not in data:
|
|
287
|
+
data["arguments"] = data.pop("parameters")
|
|
288
|
+
if validated := self._validate_tool_call(data):
|
|
289
|
+
tool_calls.append(validated)
|
|
290
|
+
except json.JSONDecodeError:
|
|
291
|
+
pass
|
|
292
|
+
start = -1
|
|
293
|
+
|
|
294
|
+
return tool_calls
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# Architecture to parser mapping
|
|
298
|
+
ARCHITECTURE_PARSERS: dict[str, type[ToolCallParser]] = {
|
|
299
|
+
# Qwen family
|
|
300
|
+
"Qwen2ForCausalLM": QwenToolCallParser,
|
|
301
|
+
"Qwen2_5ForCausalLM": QwenToolCallParser,
|
|
302
|
+
"Qwen3ForCausalLM": QwenToolCallParser,
|
|
303
|
+
"Qwen2VLForConditionalGeneration": QwenToolCallParser,
|
|
304
|
+
"Qwen2_5_VLForConditionalGeneration": QwenToolCallParser,
|
|
305
|
+
# Llama family
|
|
306
|
+
"LlamaForCausalLM": LlamaToolCallParser,
|
|
307
|
+
"Llama3ForCausalLM": LlamaToolCallParser,
|
|
308
|
+
# Mistral family
|
|
309
|
+
"MistralForCausalLM": MistralToolCallParser,
|
|
310
|
+
"MixtralForCausalLM": MistralToolCallParser,
|
|
311
|
+
"Mistral3ForConditionalGeneration": MistralToolCallParser,
|
|
312
|
+
# Hermes/Nous (often fine-tuned Llama/Mistral with Hermes format)
|
|
313
|
+
"HermesForCausalLM": HermesToolCallParser,
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class ToolCallParserRegistry:
|
|
318
|
+
"""Registry for tool call parsers by model architecture."""
|
|
319
|
+
|
|
320
|
+
def __init__(self) -> None:
|
|
321
|
+
"""Initialize the registry with default parsers."""
|
|
322
|
+
self._parsers: dict[str, type[ToolCallParser]] = ARCHITECTURE_PARSERS.copy()
|
|
323
|
+
self._fallback = GenericToolCallParser
|
|
324
|
+
|
|
325
|
+
def register(self, architecture: str, parser_class: type[ToolCallParser]) -> None:
|
|
326
|
+
"""Register a parser for a model architecture.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
architecture: Model architecture name (e.g., "LlamaForCausalLM")
|
|
330
|
+
parser_class: Parser class to use for this architecture
|
|
331
|
+
"""
|
|
332
|
+
self._parsers[architecture] = parser_class
|
|
333
|
+
|
|
334
|
+
def get_parser(self, architectures: list[str] | None) -> ToolCallParser:
|
|
335
|
+
"""Get the appropriate parser for a model's architectures.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
architectures: List of architecture names from model config
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Instantiated parser for the model
|
|
342
|
+
"""
|
|
343
|
+
if architectures:
|
|
344
|
+
for arch in architectures:
|
|
345
|
+
if arch in self._parsers:
|
|
346
|
+
logger.debug(
|
|
347
|
+
"Using %s parser for architecture %s", self._parsers[arch].__name__, arch
|
|
348
|
+
)
|
|
349
|
+
return self._parsers[arch]()
|
|
350
|
+
|
|
351
|
+
logger.debug("No specific parser found, using generic fallback")
|
|
352
|
+
return self._fallback()
|
|
353
|
+
|
|
354
|
+
def get_parser_for_model(self, model_path: str) -> ToolCallParser:
|
|
355
|
+
"""Get parser by loading model config and detecting architecture.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
model_path: Path to model or HuggingFace Hub ID
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Instantiated parser for the model
|
|
362
|
+
"""
|
|
363
|
+
from transformers import AutoConfig # noqa: PLC0415
|
|
364
|
+
|
|
365
|
+
try:
|
|
366
|
+
config = AutoConfig.from_pretrained(model_path) # nosec
|
|
367
|
+
architectures = getattr(config, "architectures", None)
|
|
368
|
+
return self.get_parser(architectures)
|
|
369
|
+
except Exception as e:
|
|
370
|
+
logger.warning("Failed to load model config for parser detection: %s", e)
|
|
371
|
+
return self._fallback()
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
# Global registry instance
|
|
375
|
+
_registry = ToolCallParserRegistry()
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def get_parser(architectures: list[str] | None = None) -> ToolCallParser:
|
|
379
|
+
"""Get a parser for the given architectures.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
architectures: List of architecture names from model config
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
Instantiated parser
|
|
386
|
+
"""
|
|
387
|
+
return _registry.get_parser(architectures)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def get_parser_for_model(model_path: str) -> ToolCallParser:
|
|
391
|
+
"""Get a parser for a model by path.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
model_path: Path to model or HuggingFace Hub ID
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
Instantiated parser
|
|
398
|
+
"""
|
|
399
|
+
return _registry.get_parser_for_model(model_path)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def register_parser(architecture: str, parser_class: type[ToolCallParser]) -> None:
|
|
403
|
+
"""Register a custom parser for an architecture.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
architecture: Model architecture name
|
|
407
|
+
parser_class: Parser class to use
|
|
408
|
+
"""
|
|
409
|
+
_registry.register(architecture, parser_class)
|