blaxel 0.2.37__py3-none-any.whl → 0.2.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blaxel/__init__.py +2 -2
- blaxel/core/tools/__init__.py +4 -0
- blaxel/core/volume/volume.py +4 -0
- blaxel/crewai/model.py +81 -44
- blaxel/crewai/tools.py +85 -2
- blaxel/googleadk/model.py +22 -3
- blaxel/googleadk/tools.py +25 -6
- blaxel/langgraph/custom/gemini.py +19 -12
- blaxel/langgraph/model.py +26 -18
- blaxel/langgraph/tools.py +6 -11
- blaxel/livekit/model.py +7 -2
- blaxel/livekit/tools.py +3 -1
- blaxel/llamaindex/model.py +145 -84
- blaxel/llamaindex/tools.py +6 -4
- blaxel/openai/model.py +7 -1
- blaxel/openai/tools.py +13 -3
- blaxel/pydantic/model.py +38 -24
- blaxel/pydantic/tools.py +37 -4
- {blaxel-0.2.37.dist-info → blaxel-0.2.38.dist-info}/METADATA +5 -46
- {blaxel-0.2.37.dist-info → blaxel-0.2.38.dist-info}/RECORD +22 -22
- {blaxel-0.2.37.dist-info → blaxel-0.2.38.dist-info}/WHEEL +0 -0
- {blaxel-0.2.37.dist-info → blaxel-0.2.38.dist-info}/licenses/LICENSE +0 -0
blaxel/llamaindex/model.py
CHANGED
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
4
|
from logging import getLogger
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Sequence
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
|
6
6
|
|
|
7
7
|
from blaxel.core import bl_model as bl_model_core
|
|
8
8
|
from blaxel.core import settings
|
|
@@ -11,7 +11,7 @@ from blaxel.core import settings
|
|
|
11
11
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
|
12
12
|
|
|
13
13
|
if TYPE_CHECKING:
|
|
14
|
-
from llama_index.core.base.llms.types import (
|
|
14
|
+
from llama_index.core.base.llms.types import ( # type: ignore[import-not-found]
|
|
15
15
|
ChatMessage,
|
|
16
16
|
ChatResponse,
|
|
17
17
|
ChatResponseAsyncGen,
|
|
@@ -20,27 +20,76 @@ if TYPE_CHECKING:
|
|
|
20
20
|
CompletionResponseAsyncGen,
|
|
21
21
|
CompletionResponseGen,
|
|
22
22
|
)
|
|
23
|
+
from llama_index.core.llms.llm import ( # type: ignore[import-not-found]
|
|
24
|
+
ToolSelection,
|
|
25
|
+
)
|
|
26
|
+
from llama_index.core.tools.types import BaseTool # type: ignore[import-not-found]
|
|
27
|
+
|
|
28
|
+
# Runtime imports needed for class inheritance and construction
|
|
29
|
+
from llama_index.core.base.llms.types import ( # type: ignore[import-not-found]
|
|
30
|
+
LLMMetadata,
|
|
31
|
+
)
|
|
32
|
+
from llama_index.core.llms.function_calling import ( # type: ignore[import-not-found]
|
|
33
|
+
FunctionCallingLLM,
|
|
34
|
+
)
|
|
35
|
+
from pydantic import PrivateAttr # type: ignore[import-not-found]
|
|
23
36
|
|
|
24
37
|
logger = getLogger(__name__)
|
|
25
38
|
|
|
39
|
+
DEFAULT_CONTEXT_WINDOW = 128000
|
|
40
|
+
DEFAULT_NUM_OUTPUT = 4096
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TokenRefreshingLLM(FunctionCallingLLM):
|
|
44
|
+
"""Wrapper for LlamaIndex LLMs that refreshes token before each call.
|
|
26
45
|
|
|
27
|
-
|
|
28
|
-
|
|
46
|
+
Inherits from FunctionCallingLLM to maintain type compatibility with
|
|
47
|
+
LlamaIndex's agents and components that validate isinstance(model, LLM).
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
_model_config_data: dict = PrivateAttr(default_factory=dict)
|
|
51
|
+
_wrapped: Any = PrivateAttr(default=None)
|
|
29
52
|
|
|
30
53
|
def __init__(self, model_config: dict):
|
|
31
|
-
|
|
32
|
-
self.
|
|
54
|
+
super().__init__()
|
|
55
|
+
self._model_config_data = model_config
|
|
56
|
+
self._wrapped = self._create_model()
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def class_name(cls) -> str:
|
|
60
|
+
return "TokenRefreshingLLM"
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def wrapped_model(self) -> Any:
|
|
64
|
+
"""Access the underlying wrapped LLM model."""
|
|
65
|
+
return self._wrapped
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def metadata(self) -> LLMMetadata:
|
|
69
|
+
"""Get LLM metadata, with fallback for unknown model names."""
|
|
70
|
+
try:
|
|
71
|
+
return self._wrapped.metadata
|
|
72
|
+
except (ValueError, KeyError) as e:
|
|
73
|
+
logger.warning(f"Could not get metadata from wrapped model: {e}. Using defaults.")
|
|
74
|
+
return LLMMetadata(
|
|
75
|
+
context_window=DEFAULT_CONTEXT_WINDOW,
|
|
76
|
+
num_output=DEFAULT_NUM_OUTPUT,
|
|
77
|
+
is_chat_model=True,
|
|
78
|
+
model_name=self._model_config_data.get("model", "unknown"),
|
|
79
|
+
)
|
|
33
80
|
|
|
34
81
|
def _create_model(self):
|
|
35
82
|
"""Create the model instance with current token."""
|
|
36
|
-
config = self.
|
|
83
|
+
config = self._model_config_data
|
|
37
84
|
model_type = config["type"]
|
|
38
85
|
model = config["model"]
|
|
39
86
|
url = config["url"]
|
|
40
87
|
kwargs = config.get("kwargs", {})
|
|
41
88
|
|
|
42
89
|
if model_type == "anthropic":
|
|
43
|
-
from llama_index.llms.anthropic import
|
|
90
|
+
from llama_index.llms.anthropic import ( # type: ignore[import-not-found]
|
|
91
|
+
Anthropic,
|
|
92
|
+
)
|
|
44
93
|
|
|
45
94
|
return Anthropic(
|
|
46
95
|
model=model,
|
|
@@ -50,7 +99,7 @@ class TokenRefreshingWrapper:
|
|
|
50
99
|
**kwargs,
|
|
51
100
|
)
|
|
52
101
|
elif model_type == "xai":
|
|
53
|
-
from llama_index.llms.groq import Groq
|
|
102
|
+
from llama_index.llms.groq import Groq # type: ignore[import-not-found]
|
|
54
103
|
|
|
55
104
|
return Groq(
|
|
56
105
|
model=model,
|
|
@@ -60,7 +109,9 @@ class TokenRefreshingWrapper:
|
|
|
60
109
|
)
|
|
61
110
|
elif model_type == "gemini":
|
|
62
111
|
from google.genai.types import HttpOptions
|
|
63
|
-
from llama_index.llms.google_genai import
|
|
112
|
+
from llama_index.llms.google_genai import ( # type: ignore[import-not-found]
|
|
113
|
+
GoogleGenAI,
|
|
114
|
+
)
|
|
64
115
|
|
|
65
116
|
return GoogleGenAI(
|
|
66
117
|
api_key=settings.auth.token,
|
|
@@ -73,11 +124,13 @@ class TokenRefreshingWrapper:
|
|
|
73
124
|
**kwargs,
|
|
74
125
|
)
|
|
75
126
|
elif model_type == "cohere":
|
|
76
|
-
from .custom.cohere import Cohere
|
|
127
|
+
from .custom.cohere import Cohere # type: ignore[import-not-found]
|
|
77
128
|
|
|
78
129
|
return Cohere(model=model, api_key=settings.auth.token, api_base=url, **kwargs)
|
|
79
130
|
elif model_type == "deepseek":
|
|
80
|
-
from llama_index.llms.deepseek import
|
|
131
|
+
from llama_index.llms.deepseek import ( # type: ignore[import-not-found]
|
|
132
|
+
DeepSeek,
|
|
133
|
+
)
|
|
81
134
|
|
|
82
135
|
return DeepSeek(
|
|
83
136
|
model=model,
|
|
@@ -86,11 +139,15 @@ class TokenRefreshingWrapper:
|
|
|
86
139
|
**kwargs,
|
|
87
140
|
)
|
|
88
141
|
elif model_type == "mistral":
|
|
89
|
-
from llama_index.llms.mistralai import
|
|
142
|
+
from llama_index.llms.mistralai import ( # type: ignore[import-not-found]
|
|
143
|
+
MistralAI,
|
|
144
|
+
)
|
|
90
145
|
|
|
91
146
|
return MistralAI(model=model, api_key=settings.auth.token, endpoint=url, **kwargs)
|
|
92
147
|
elif model_type == "cerebras":
|
|
93
|
-
from llama_index.llms.cerebras import
|
|
148
|
+
from llama_index.llms.cerebras import ( # type: ignore[import-not-found]
|
|
149
|
+
Cerebras,
|
|
150
|
+
)
|
|
94
151
|
|
|
95
152
|
return Cerebras(
|
|
96
153
|
model=model,
|
|
@@ -99,7 +156,7 @@ class TokenRefreshingWrapper:
|
|
|
99
156
|
**kwargs,
|
|
100
157
|
)
|
|
101
158
|
else:
|
|
102
|
-
from llama_index.llms.openai import OpenAI
|
|
159
|
+
from llama_index.llms.openai import OpenAI # type: ignore[import-not-found]
|
|
103
160
|
|
|
104
161
|
if model_type != "openai":
|
|
105
162
|
logger.warning(
|
|
@@ -115,102 +172,106 @@ class TokenRefreshingWrapper:
|
|
|
115
172
|
|
|
116
173
|
def _refresh_token(self):
|
|
117
174
|
"""Refresh the token and recreate the model if needed."""
|
|
118
|
-
# Only refresh if using ClientCredentials (which has get_token method)
|
|
119
175
|
current_token = settings.auth.token
|
|
120
176
|
|
|
121
|
-
if hasattr(settings.auth, "get_token"):
|
|
122
|
-
# This will trigger token refresh if needed
|
|
123
|
-
settings.auth.get_token()
|
|
124
|
-
|
|
125
177
|
new_token = settings.auth.token
|
|
126
178
|
|
|
127
|
-
# If token changed, recreate the model
|
|
128
179
|
if current_token != new_token:
|
|
129
|
-
self.
|
|
130
|
-
|
|
131
|
-
def __getattr__(self, name):
|
|
132
|
-
"""Delegate attribute access to wrapped model."""
|
|
133
|
-
return getattr(self.wrapped_model, name)
|
|
180
|
+
self._wrapped = self._create_model()
|
|
134
181
|
|
|
182
|
+
# --- Core LLM methods with token refresh ---
|
|
135
183
|
|
|
136
|
-
|
|
137
|
-
|
|
184
|
+
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
185
|
+
self._refresh_token()
|
|
186
|
+
return self._wrapped.chat(messages, **kwargs)
|
|
138
187
|
|
|
139
|
-
async def achat(
|
|
140
|
-
self,
|
|
141
|
-
messages: Sequence[ChatMessage],
|
|
142
|
-
**kwargs: Any,
|
|
143
|
-
) -> ChatResponse:
|
|
144
|
-
"""Async chat with token refresh."""
|
|
188
|
+
async def achat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
|
|
145
189
|
self._refresh_token()
|
|
146
|
-
return await self.
|
|
190
|
+
return await self._wrapped.achat(messages, **kwargs)
|
|
147
191
|
|
|
148
|
-
def
|
|
149
|
-
self,
|
|
150
|
-
messages: Sequence[ChatMessage],
|
|
151
|
-
**kwargs: Any,
|
|
152
|
-
) -> ChatResponse:
|
|
153
|
-
"""Sync chat with token refresh."""
|
|
192
|
+
def complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
|
154
193
|
self._refresh_token()
|
|
155
|
-
return self.
|
|
194
|
+
return self._wrapped.complete(prompt, formatted=formatted, **kwargs)
|
|
156
195
|
|
|
157
|
-
async def
|
|
158
|
-
self,
|
|
159
|
-
|
|
160
|
-
**kwargs: Any,
|
|
161
|
-
) -> ChatResponseAsyncGen:
|
|
162
|
-
"""Async stream chat with token refresh."""
|
|
196
|
+
async def acomplete(
|
|
197
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
198
|
+
) -> CompletionResponse:
|
|
163
199
|
self._refresh_token()
|
|
164
|
-
|
|
165
|
-
yield chunk
|
|
200
|
+
return await self._wrapped.acomplete(prompt, formatted=formatted, **kwargs)
|
|
166
201
|
|
|
167
202
|
def stream_chat(
|
|
168
|
-
self,
|
|
169
|
-
messages: Sequence[ChatMessage],
|
|
170
|
-
**kwargs: Any,
|
|
203
|
+
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
171
204
|
) -> ChatResponseGen:
|
|
172
|
-
"""Sync stream chat with token refresh."""
|
|
173
205
|
self._refresh_token()
|
|
174
|
-
|
|
175
|
-
yield chunk
|
|
206
|
+
return self._wrapped.stream_chat(messages, **kwargs)
|
|
176
207
|
|
|
177
|
-
async def
|
|
178
|
-
self,
|
|
179
|
-
|
|
180
|
-
**kwargs: Any,
|
|
181
|
-
) -> CompletionResponse:
|
|
182
|
-
"""Async complete with token refresh."""
|
|
208
|
+
async def astream_chat(
|
|
209
|
+
self, messages: Sequence[ChatMessage], **kwargs: Any
|
|
210
|
+
) -> ChatResponseAsyncGen:
|
|
183
211
|
self._refresh_token()
|
|
184
|
-
|
|
212
|
+
result = self._wrapped.astream_chat(messages, **kwargs)
|
|
213
|
+
# Handle both coroutine and async generator patterns
|
|
214
|
+
if hasattr(result, "__aiter__"):
|
|
215
|
+
return result
|
|
216
|
+
return await result
|
|
185
217
|
|
|
186
|
-
def
|
|
187
|
-
self,
|
|
188
|
-
|
|
189
|
-
**kwargs: Any,
|
|
190
|
-
) -> CompletionResponse:
|
|
191
|
-
"""Sync complete with token refresh."""
|
|
218
|
+
def stream_complete(
|
|
219
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
220
|
+
) -> CompletionResponseGen:
|
|
192
221
|
self._refresh_token()
|
|
193
|
-
return self.
|
|
222
|
+
return self._wrapped.stream_complete(prompt, formatted=formatted, **kwargs)
|
|
194
223
|
|
|
195
224
|
async def astream_complete(
|
|
196
|
-
self,
|
|
197
|
-
prompt: str,
|
|
198
|
-
**kwargs: Any,
|
|
225
|
+
self, prompt: str, formatted: bool = False, **kwargs: Any
|
|
199
226
|
) -> CompletionResponseAsyncGen:
|
|
200
|
-
"""Async stream complete with token refresh."""
|
|
201
227
|
self._refresh_token()
|
|
202
|
-
|
|
203
|
-
|
|
228
|
+
result = self._wrapped.astream_complete(prompt, formatted=formatted, **kwargs)
|
|
229
|
+
# Handle both coroutine and async generator patterns
|
|
230
|
+
if hasattr(result, "__aiter__"):
|
|
231
|
+
return result
|
|
232
|
+
return await result
|
|
204
233
|
|
|
205
|
-
|
|
234
|
+
# --- FunctionCallingLLM methods (delegate to wrapped model) ---
|
|
235
|
+
|
|
236
|
+
def _prepare_chat_with_tools(
|
|
206
237
|
self,
|
|
207
|
-
|
|
238
|
+
tools: Sequence[BaseTool],
|
|
239
|
+
user_msg: Union[str, ChatMessage, None] = None,
|
|
240
|
+
chat_history: List[ChatMessage] | None = None,
|
|
241
|
+
verbose: bool = False,
|
|
242
|
+
allow_parallel_tool_calls: bool = False,
|
|
243
|
+
tool_required: Any = None,
|
|
208
244
|
**kwargs: Any,
|
|
209
|
-
) ->
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
245
|
+
) -> Dict[str, Any]:
|
|
246
|
+
if hasattr(self._wrapped, "_prepare_chat_with_tools"):
|
|
247
|
+
return self._wrapped._prepare_chat_with_tools(
|
|
248
|
+
tools,
|
|
249
|
+
user_msg=user_msg,
|
|
250
|
+
chat_history=chat_history,
|
|
251
|
+
verbose=verbose,
|
|
252
|
+
allow_parallel_tool_calls=allow_parallel_tool_calls,
|
|
253
|
+
tool_required=tool_required,
|
|
254
|
+
**kwargs,
|
|
255
|
+
)
|
|
256
|
+
raise NotImplementedError(
|
|
257
|
+
f"The wrapped model ({type(self._wrapped).__name__}) does not support function calling"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def get_tool_calls_from_response(
|
|
261
|
+
self,
|
|
262
|
+
response: ChatResponse,
|
|
263
|
+
error_on_no_tool_call: bool = True,
|
|
264
|
+
**kwargs: Any,
|
|
265
|
+
) -> List[ToolSelection]:
|
|
266
|
+
if hasattr(self._wrapped, "get_tool_calls_from_response"):
|
|
267
|
+
return self._wrapped.get_tool_calls_from_response(
|
|
268
|
+
response,
|
|
269
|
+
error_on_no_tool_call=error_on_no_tool_call,
|
|
270
|
+
**kwargs,
|
|
271
|
+
)
|
|
272
|
+
raise NotImplementedError(
|
|
273
|
+
f"The wrapped model ({type(self._wrapped).__name__}) does not support function calling"
|
|
274
|
+
)
|
|
214
275
|
|
|
215
276
|
|
|
216
277
|
async def bl_model(name, **kwargs):
|
|
@@ -220,4 +281,4 @@ async def bl_model(name, **kwargs):
|
|
|
220
281
|
model_config = {"type": type, "model": model, "url": url, "kwargs": kwargs}
|
|
221
282
|
|
|
222
283
|
# Create and return the wrapper
|
|
223
|
-
return TokenRefreshingLLM(model_config)
|
|
284
|
+
return TokenRefreshingLLM(model_config)
|
blaxel/llamaindex/tools.py
CHANGED
|
@@ -5,12 +5,14 @@ from blaxel.core.tools.common import create_model_from_json_schema
|
|
|
5
5
|
from blaxel.core.tools.types import Tool
|
|
6
6
|
|
|
7
7
|
if TYPE_CHECKING:
|
|
8
|
-
from llama_index.core.tools import FunctionTool
|
|
8
|
+
from llama_index.core.tools import FunctionTool # type: ignore[import-not-found]
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def get_llamaindex_tool(tool: Tool) -> "FunctionTool":
|
|
12
|
-
from llama_index.core.tools import FunctionTool
|
|
13
|
-
from llama_index.core.tools.types import
|
|
12
|
+
from llama_index.core.tools import FunctionTool # type: ignore[import-not-found]
|
|
13
|
+
from llama_index.core.tools.types import ( # type: ignore[import-not-found]
|
|
14
|
+
ToolMetadata,
|
|
15
|
+
)
|
|
14
16
|
|
|
15
17
|
model_schema = create_model_from_json_schema(
|
|
16
18
|
tool.input_schema, model_name=f"{tool.name}_Schema"
|
|
@@ -29,4 +31,4 @@ def get_llamaindex_tool(tool: Tool) -> "FunctionTool":
|
|
|
29
31
|
async def bl_tools(tools_names: list[str], **kwargs) -> list["FunctionTool"]:
|
|
30
32
|
tools = bl_tools_core(tools_names, **kwargs)
|
|
31
33
|
await tools.initialize()
|
|
32
|
-
return [get_llamaindex_tool(tool) for tool in tools.get_tools()]
|
|
34
|
+
return [get_llamaindex_tool(tool) for tool in tools.get_tools()]
|
blaxel/openai/model.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import httpx
|
|
2
|
-
from agents import
|
|
2
|
+
from agents import OpenAIChatCompletionsModel
|
|
3
|
+
from openai import AsyncOpenAI
|
|
3
4
|
|
|
4
5
|
from blaxel.core import bl_model as bl_model_core
|
|
5
6
|
from blaxel.core import settings
|
|
@@ -14,6 +15,11 @@ class DynamicHeadersHTTPClient(httpx.AsyncClient):
|
|
|
14
15
|
async def send(self, request, *args, **kwargs):
|
|
15
16
|
# Update headers with the latest auth headers before each request
|
|
16
17
|
auth_headers = settings.auth.get_headers()
|
|
18
|
+
# Remove the SDK's default "Authorization: Bearer replaced" header
|
|
19
|
+
# when our auth uses a different header (e.g. X-Blaxel-Authorization with API keys)
|
|
20
|
+
if "Authorization" not in auth_headers:
|
|
21
|
+
request.headers.pop("Authorization", None)
|
|
22
|
+
request.headers.pop("authorization", None)
|
|
17
23
|
for key, value in auth_headers.items():
|
|
18
24
|
request.headers[key] = value
|
|
19
25
|
return await super().send(request, *args, **kwargs)
|
blaxel/openai/tools.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
-
from agents import FunctionTool
|
|
4
|
+
from agents import FunctionTool # type: ignore[import-not-found]
|
|
5
|
+
from agents.tool_context import ToolContext # type: ignore[import-not-found]
|
|
5
6
|
|
|
6
7
|
from blaxel.core.tools import bl_tools as bl_tools_core
|
|
7
8
|
from blaxel.core.tools.types import Tool
|
|
@@ -24,6 +25,13 @@ def _clean_schema_for_openai(schema: dict) -> dict:
|
|
|
24
25
|
if "additionalProperties" in cleaned_schema:
|
|
25
26
|
del cleaned_schema["additionalProperties"]
|
|
26
27
|
|
|
28
|
+
# Ensure object type schemas have properties
|
|
29
|
+
if cleaned_schema.get("type") == "object":
|
|
30
|
+
if "properties" not in cleaned_schema:
|
|
31
|
+
cleaned_schema["properties"] = {}
|
|
32
|
+
if "required" not in cleaned_schema:
|
|
33
|
+
cleaned_schema["required"] = []
|
|
34
|
+
|
|
27
35
|
# Recursively clean properties if they exist
|
|
28
36
|
if "properties" in cleaned_schema:
|
|
29
37
|
cleaned_schema["properties"] = {
|
|
@@ -39,9 +47,11 @@ def _clean_schema_for_openai(schema: dict) -> dict:
|
|
|
39
47
|
|
|
40
48
|
def get_openai_tool(tool: Tool) -> FunctionTool:
|
|
41
49
|
async def openai_coroutine(
|
|
42
|
-
_:
|
|
43
|
-
arguments:
|
|
50
|
+
_: ToolContext[Any],
|
|
51
|
+
arguments: str,
|
|
44
52
|
) -> Any:
|
|
53
|
+
if not tool.coroutine:
|
|
54
|
+
raise ValueError(f"Tool {tool.name} does not have a coroutine defined")
|
|
45
55
|
result = await tool.coroutine(**json.loads(arguments))
|
|
46
56
|
return result
|
|
47
57
|
|
blaxel/pydantic/model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
-
from pydantic_ai.models import Model
|
|
4
|
+
from pydantic_ai.models import Model # type: ignore[import-not-found]
|
|
5
5
|
|
|
6
6
|
from blaxel.core import bl_model as bl_model_core
|
|
7
7
|
from blaxel.core import settings
|
|
@@ -30,8 +30,12 @@ class TokenRefreshingModel(Model):
|
|
|
30
30
|
|
|
31
31
|
if type == "mistral":
|
|
32
32
|
from mistralai.sdk import Mistral
|
|
33
|
-
from pydantic_ai.models.mistral import
|
|
34
|
-
|
|
33
|
+
from pydantic_ai.models.mistral import ( # type: ignore[import-not-found]
|
|
34
|
+
MistralModel,
|
|
35
|
+
)
|
|
36
|
+
from pydantic_ai.providers.mistral import ( # type: ignore[import-not-found]
|
|
37
|
+
MistralProvider,
|
|
38
|
+
)
|
|
35
39
|
|
|
36
40
|
return MistralModel(
|
|
37
41
|
model_name=model,
|
|
@@ -45,8 +49,12 @@ class TokenRefreshingModel(Model):
|
|
|
45
49
|
)
|
|
46
50
|
elif type == "cohere":
|
|
47
51
|
from cohere import AsyncClientV2
|
|
48
|
-
from pydantic_ai.models.cohere import
|
|
49
|
-
|
|
52
|
+
from pydantic_ai.models.cohere import ( # type: ignore[import-not-found]
|
|
53
|
+
CohereModel,
|
|
54
|
+
)
|
|
55
|
+
from pydantic_ai.providers.cohere import ( # type: ignore[import-not-found]
|
|
56
|
+
CohereProvider,
|
|
57
|
+
)
|
|
50
58
|
|
|
51
59
|
return CohereModel(
|
|
52
60
|
model_name=model,
|
|
@@ -58,30 +66,42 @@ class TokenRefreshingModel(Model):
|
|
|
58
66
|
),
|
|
59
67
|
)
|
|
60
68
|
elif type == "xai":
|
|
61
|
-
from pydantic_ai.models.openai import
|
|
62
|
-
|
|
69
|
+
from pydantic_ai.models.openai import ( # type: ignore[import-not-found]
|
|
70
|
+
OpenAIChatModel,
|
|
71
|
+
)
|
|
72
|
+
from pydantic_ai.providers.openai import ( # type: ignore[import-not-found]
|
|
73
|
+
OpenAIProvider,
|
|
74
|
+
)
|
|
63
75
|
|
|
64
|
-
return
|
|
76
|
+
return OpenAIChatModel(
|
|
65
77
|
model_name=model,
|
|
66
78
|
provider=OpenAIProvider(
|
|
67
79
|
base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
|
|
68
80
|
),
|
|
69
81
|
)
|
|
70
82
|
elif type == "deepseek":
|
|
71
|
-
from pydantic_ai.models.openai import
|
|
72
|
-
|
|
83
|
+
from pydantic_ai.models.openai import ( # type: ignore[import-not-found]
|
|
84
|
+
OpenAIChatModel,
|
|
85
|
+
)
|
|
86
|
+
from pydantic_ai.providers.openai import ( # type: ignore[import-not-found]
|
|
87
|
+
OpenAIProvider,
|
|
88
|
+
)
|
|
73
89
|
|
|
74
|
-
return
|
|
90
|
+
return OpenAIChatModel(
|
|
75
91
|
model_name=model,
|
|
76
92
|
provider=OpenAIProvider(
|
|
77
93
|
base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
|
|
78
94
|
),
|
|
79
95
|
)
|
|
80
96
|
elif type == "cerebras":
|
|
81
|
-
from pydantic_ai.models.openai import
|
|
82
|
-
|
|
97
|
+
from pydantic_ai.models.openai import ( # type: ignore[import-not-found]
|
|
98
|
+
OpenAIChatModel,
|
|
99
|
+
)
|
|
100
|
+
from pydantic_ai.providers.openai import ( # type: ignore[import-not-found]
|
|
101
|
+
OpenAIProvider,
|
|
102
|
+
)
|
|
83
103
|
|
|
84
|
-
return
|
|
104
|
+
return OpenAIChatModel(
|
|
85
105
|
model_name=model,
|
|
86
106
|
provider=OpenAIProvider(
|
|
87
107
|
base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
|
|
@@ -116,12 +136,12 @@ class TokenRefreshingModel(Model):
|
|
|
116
136
|
),
|
|
117
137
|
)
|
|
118
138
|
else:
|
|
119
|
-
from pydantic_ai.models.openai import
|
|
139
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
120
140
|
from pydantic_ai.providers.openai import OpenAIProvider
|
|
121
141
|
|
|
122
142
|
if type != "openai":
|
|
123
143
|
logger.warning(f"Model {model} is not supported by Pydantic, defaulting to OpenAI")
|
|
124
|
-
return
|
|
144
|
+
return OpenAIChatModel(
|
|
125
145
|
model_name=model,
|
|
126
146
|
provider=OpenAIProvider(
|
|
127
147
|
base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
|
|
@@ -130,12 +150,6 @@ class TokenRefreshingModel(Model):
|
|
|
130
150
|
|
|
131
151
|
def _get_fresh_model(self) -> Model:
|
|
132
152
|
"""Get or create a model with fresh token if needed."""
|
|
133
|
-
# Only refresh if using ClientCredentials (which has get_token method)
|
|
134
|
-
if hasattr(settings.auth, "get_token"):
|
|
135
|
-
# This will trigger token refresh if needed
|
|
136
|
-
logger.debug(f"Calling get_token for {self.model_config['type']} model")
|
|
137
|
-
settings.auth.get_token()
|
|
138
|
-
|
|
139
153
|
new_token = settings.auth.token
|
|
140
154
|
|
|
141
155
|
# If token changed or no cached model, create new one
|
|
@@ -152,10 +166,10 @@ class TokenRefreshingModel(Model):
|
|
|
152
166
|
return model.model_name
|
|
153
167
|
|
|
154
168
|
@property
|
|
155
|
-
def system(self) ->
|
|
169
|
+
def system(self) -> str:
|
|
156
170
|
"""Return the system property from the wrapped model."""
|
|
157
171
|
model = self._get_fresh_model()
|
|
158
|
-
return model.system if hasattr(model, "system") else
|
|
172
|
+
return model.system if hasattr(model, "system") else ""
|
|
159
173
|
|
|
160
174
|
async def request(self, *args, **kwargs):
|
|
161
175
|
"""Make a request to the model with token refresh."""
|
blaxel/pydantic/tools.py
CHANGED
|
@@ -1,11 +1,44 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from pydantic_ai
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import RunContext # type: ignore[import-not-found]
|
|
4
|
+
from pydantic_ai.tools import Tool as PydanticTool # type: ignore[import-not-found]
|
|
5
|
+
from pydantic_ai.tools import ToolDefinition # type: ignore[import-not-found]
|
|
4
6
|
|
|
5
7
|
from blaxel.core.tools import Tool
|
|
6
8
|
from blaxel.core.tools import bl_tools as bl_tools_core
|
|
7
9
|
|
|
8
10
|
|
|
11
|
+
def _clean_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]:
|
|
12
|
+
"""Clean JSON schema to be compatible with OpenAI function calling.
|
|
13
|
+
|
|
14
|
+
OpenAI requires object schemas to have a 'properties' field, even if empty.
|
|
15
|
+
"""
|
|
16
|
+
if not isinstance(schema, dict):
|
|
17
|
+
return schema
|
|
18
|
+
|
|
19
|
+
cleaned = schema.copy()
|
|
20
|
+
|
|
21
|
+
if cleaned.get("type") == "object":
|
|
22
|
+
if "properties" not in cleaned:
|
|
23
|
+
cleaned["properties"] = {}
|
|
24
|
+
if "required" not in cleaned:
|
|
25
|
+
cleaned["required"] = []
|
|
26
|
+
|
|
27
|
+
if "additionalProperties" in cleaned:
|
|
28
|
+
del cleaned["additionalProperties"]
|
|
29
|
+
if "$schema" in cleaned:
|
|
30
|
+
del cleaned["$schema"]
|
|
31
|
+
|
|
32
|
+
if "properties" in cleaned:
|
|
33
|
+
cleaned["properties"] = {
|
|
34
|
+
k: _clean_schema_for_openai(v) for k, v in cleaned["properties"].items()
|
|
35
|
+
}
|
|
36
|
+
if "items" in cleaned and isinstance(cleaned["items"], dict):
|
|
37
|
+
cleaned["items"] = _clean_schema_for_openai(cleaned["items"])
|
|
38
|
+
|
|
39
|
+
return cleaned
|
|
40
|
+
|
|
41
|
+
|
|
9
42
|
def get_pydantic_tool(tool: Tool) -> PydanticTool:
|
|
10
43
|
"""
|
|
11
44
|
Converts a custom Tool object into a Pydantic AI Tool object.
|
|
@@ -27,7 +60,7 @@ def get_pydantic_tool(tool: Tool) -> PydanticTool:
|
|
|
27
60
|
"""Dynamically prepares the ToolDefinition using the custom Tool's attributes."""
|
|
28
61
|
tool_def.name = tool.name # Override inferred name
|
|
29
62
|
tool_def.description = tool.description # Override inferred description
|
|
30
|
-
tool_def.parameters_json_schema = tool.input_schema
|
|
63
|
+
tool_def.parameters_json_schema = _clean_schema_for_openai(tool.input_schema)
|
|
31
64
|
return tool_def
|
|
32
65
|
|
|
33
66
|
async def pydantic_function(**kwargs):
|