blaxel 0.2.37__py3-none-any.whl → 0.2.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  from logging import getLogger
5
- from typing import TYPE_CHECKING, Any, Sequence
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
6
6
 
7
7
  from blaxel.core import bl_model as bl_model_core
8
8
  from blaxel.core import settings
@@ -11,7 +11,7 @@ from blaxel.core import settings
11
11
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
12
12
 
13
13
  if TYPE_CHECKING:
14
- from llama_index.core.base.llms.types import (
14
+ from llama_index.core.base.llms.types import ( # type: ignore[import-not-found]
15
15
  ChatMessage,
16
16
  ChatResponse,
17
17
  ChatResponseAsyncGen,
@@ -20,27 +20,76 @@ if TYPE_CHECKING:
20
20
  CompletionResponseAsyncGen,
21
21
  CompletionResponseGen,
22
22
  )
23
+ from llama_index.core.llms.llm import ( # type: ignore[import-not-found]
24
+ ToolSelection,
25
+ )
26
+ from llama_index.core.tools.types import BaseTool # type: ignore[import-not-found]
27
+
28
+ # Runtime imports needed for class inheritance and construction
29
+ from llama_index.core.base.llms.types import ( # type: ignore[import-not-found]
30
+ LLMMetadata,
31
+ )
32
+ from llama_index.core.llms.function_calling import ( # type: ignore[import-not-found]
33
+ FunctionCallingLLM,
34
+ )
35
+ from pydantic import PrivateAttr # type: ignore[import-not-found]
23
36
 
24
37
  logger = getLogger(__name__)
25
38
 
39
+ DEFAULT_CONTEXT_WINDOW = 128000
40
+ DEFAULT_NUM_OUTPUT = 4096
41
+
42
+
43
+ class TokenRefreshingLLM(FunctionCallingLLM):
44
+ """Wrapper for LlamaIndex LLMs that refreshes token before each call.
26
45
 
27
- class TokenRefreshingWrapper:
28
- """Base wrapper class that refreshes token before each call."""
46
+ Inherits from FunctionCallingLLM to maintain type compatibility with
47
+ LlamaIndex's agents and components that validate isinstance(model, LLM).
48
+ """
49
+
50
+ _model_config_data: dict = PrivateAttr(default_factory=dict)
51
+ _wrapped: Any = PrivateAttr(default=None)
29
52
 
30
53
  def __init__(self, model_config: dict):
31
- self.model_config = model_config
32
- self.wrapped_model = self._create_model()
54
+ super().__init__()
55
+ self._model_config_data = model_config
56
+ self._wrapped = self._create_model()
57
+
58
+ @classmethod
59
+ def class_name(cls) -> str:
60
+ return "TokenRefreshingLLM"
61
+
62
+ @property
63
+ def wrapped_model(self) -> Any:
64
+ """Access the underlying wrapped LLM model."""
65
+ return self._wrapped
66
+
67
+ @property
68
+ def metadata(self) -> LLMMetadata:
69
+ """Get LLM metadata, with fallback for unknown model names."""
70
+ try:
71
+ return self._wrapped.metadata
72
+ except (ValueError, KeyError) as e:
73
+ logger.warning(f"Could not get metadata from wrapped model: {e}. Using defaults.")
74
+ return LLMMetadata(
75
+ context_window=DEFAULT_CONTEXT_WINDOW,
76
+ num_output=DEFAULT_NUM_OUTPUT,
77
+ is_chat_model=True,
78
+ model_name=self._model_config_data.get("model", "unknown"),
79
+ )
33
80
 
34
81
  def _create_model(self):
35
82
  """Create the model instance with current token."""
36
- config = self.model_config
83
+ config = self._model_config_data
37
84
  model_type = config["type"]
38
85
  model = config["model"]
39
86
  url = config["url"]
40
87
  kwargs = config.get("kwargs", {})
41
88
 
42
89
  if model_type == "anthropic":
43
- from llama_index.llms.anthropic import Anthropic
90
+ from llama_index.llms.anthropic import ( # type: ignore[import-not-found]
91
+ Anthropic,
92
+ )
44
93
 
45
94
  return Anthropic(
46
95
  model=model,
@@ -50,7 +99,7 @@ class TokenRefreshingWrapper:
50
99
  **kwargs,
51
100
  )
52
101
  elif model_type == "xai":
53
- from llama_index.llms.groq import Groq
102
+ from llama_index.llms.groq import Groq # type: ignore[import-not-found]
54
103
 
55
104
  return Groq(
56
105
  model=model,
@@ -60,7 +109,9 @@ class TokenRefreshingWrapper:
60
109
  )
61
110
  elif model_type == "gemini":
62
111
  from google.genai.types import HttpOptions
63
- from llama_index.llms.google_genai import GoogleGenAI
112
+ from llama_index.llms.google_genai import ( # type: ignore[import-not-found]
113
+ GoogleGenAI,
114
+ )
64
115
 
65
116
  return GoogleGenAI(
66
117
  api_key=settings.auth.token,
@@ -73,11 +124,13 @@ class TokenRefreshingWrapper:
73
124
  **kwargs,
74
125
  )
75
126
  elif model_type == "cohere":
76
- from .custom.cohere import Cohere
127
+ from .custom.cohere import Cohere # type: ignore[import-not-found]
77
128
 
78
129
  return Cohere(model=model, api_key=settings.auth.token, api_base=url, **kwargs)
79
130
  elif model_type == "deepseek":
80
- from llama_index.llms.deepseek import DeepSeek
131
+ from llama_index.llms.deepseek import ( # type: ignore[import-not-found]
132
+ DeepSeek,
133
+ )
81
134
 
82
135
  return DeepSeek(
83
136
  model=model,
@@ -86,11 +139,15 @@ class TokenRefreshingWrapper:
86
139
  **kwargs,
87
140
  )
88
141
  elif model_type == "mistral":
89
- from llama_index.llms.mistralai import MistralAI
142
+ from llama_index.llms.mistralai import ( # type: ignore[import-not-found]
143
+ MistralAI,
144
+ )
90
145
 
91
146
  return MistralAI(model=model, api_key=settings.auth.token, endpoint=url, **kwargs)
92
147
  elif model_type == "cerebras":
93
- from llama_index.llms.cerebras import Cerebras
148
+ from llama_index.llms.cerebras import ( # type: ignore[import-not-found]
149
+ Cerebras,
150
+ )
94
151
 
95
152
  return Cerebras(
96
153
  model=model,
@@ -99,7 +156,7 @@ class TokenRefreshingWrapper:
99
156
  **kwargs,
100
157
  )
101
158
  else:
102
- from llama_index.llms.openai import OpenAI
159
+ from llama_index.llms.openai import OpenAI # type: ignore[import-not-found]
103
160
 
104
161
  if model_type != "openai":
105
162
  logger.warning(
@@ -115,102 +172,106 @@ class TokenRefreshingWrapper:
115
172
 
116
173
  def _refresh_token(self):
117
174
  """Refresh the token and recreate the model if needed."""
118
- # Only refresh if using ClientCredentials (which has get_token method)
119
175
  current_token = settings.auth.token
120
176
 
121
- if hasattr(settings.auth, "get_token"):
122
- # This will trigger token refresh if needed
123
- settings.auth.get_token()
124
-
125
177
  new_token = settings.auth.token
126
178
 
127
- # If token changed, recreate the model
128
179
  if current_token != new_token:
129
- self.wrapped_model = self._create_model()
130
-
131
- def __getattr__(self, name):
132
- """Delegate attribute access to wrapped model."""
133
- return getattr(self.wrapped_model, name)
180
+ self._wrapped = self._create_model()
134
181
 
182
+ # --- Core LLM methods with token refresh ---
135
183
 
136
- class TokenRefreshingLLM(TokenRefreshingWrapper):
137
- """Wrapper for LlamaIndex LLMs that refreshes token before each call."""
184
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
185
+ self._refresh_token()
186
+ return self._wrapped.chat(messages, **kwargs)
138
187
 
139
- async def achat(
140
- self,
141
- messages: Sequence[ChatMessage],
142
- **kwargs: Any,
143
- ) -> ChatResponse:
144
- """Async chat with token refresh."""
188
+ async def achat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
145
189
  self._refresh_token()
146
- return await self.wrapped_model.achat(messages, **kwargs)
190
+ return await self._wrapped.achat(messages, **kwargs)
147
191
 
148
- def chat(
149
- self,
150
- messages: Sequence[ChatMessage],
151
- **kwargs: Any,
152
- ) -> ChatResponse:
153
- """Sync chat with token refresh."""
192
+ def complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
154
193
  self._refresh_token()
155
- return self.wrapped_model.chat(messages, **kwargs)
194
+ return self._wrapped.complete(prompt, formatted=formatted, **kwargs)
156
195
 
157
- async def astream_chat(
158
- self,
159
- messages: Sequence[ChatMessage],
160
- **kwargs: Any,
161
- ) -> ChatResponseAsyncGen:
162
- """Async stream chat with token refresh."""
196
+ async def acomplete(
197
+ self, prompt: str, formatted: bool = False, **kwargs: Any
198
+ ) -> CompletionResponse:
163
199
  self._refresh_token()
164
- async for chunk in self.wrapped_model.astream_chat(messages, **kwargs):
165
- yield chunk
200
+ return await self._wrapped.acomplete(prompt, formatted=formatted, **kwargs)
166
201
 
167
202
  def stream_chat(
168
- self,
169
- messages: Sequence[ChatMessage],
170
- **kwargs: Any,
203
+ self, messages: Sequence[ChatMessage], **kwargs: Any
171
204
  ) -> ChatResponseGen:
172
- """Sync stream chat with token refresh."""
173
205
  self._refresh_token()
174
- for chunk in self.wrapped_model.stream_chat(messages, **kwargs):
175
- yield chunk
206
+ return self._wrapped.stream_chat(messages, **kwargs)
176
207
 
177
- async def acomplete(
178
- self,
179
- prompt: str,
180
- **kwargs: Any,
181
- ) -> CompletionResponse:
182
- """Async complete with token refresh."""
208
+ async def astream_chat(
209
+ self, messages: Sequence[ChatMessage], **kwargs: Any
210
+ ) -> ChatResponseAsyncGen:
183
211
  self._refresh_token()
184
- return await self.wrapped_model.acomplete(prompt, **kwargs)
212
+ result = self._wrapped.astream_chat(messages, **kwargs)
213
+ # Handle both coroutine and async generator patterns
214
+ if hasattr(result, "__aiter__"):
215
+ return result
216
+ return await result
185
217
 
186
- def complete(
187
- self,
188
- prompt: str,
189
- **kwargs: Any,
190
- ) -> CompletionResponse:
191
- """Sync complete with token refresh."""
218
+ def stream_complete(
219
+ self, prompt: str, formatted: bool = False, **kwargs: Any
220
+ ) -> CompletionResponseGen:
192
221
  self._refresh_token()
193
- return self.wrapped_model.complete(prompt, **kwargs)
222
+ return self._wrapped.stream_complete(prompt, formatted=formatted, **kwargs)
194
223
 
195
224
  async def astream_complete(
196
- self,
197
- prompt: str,
198
- **kwargs: Any,
225
+ self, prompt: str, formatted: bool = False, **kwargs: Any
199
226
  ) -> CompletionResponseAsyncGen:
200
- """Async stream complete with token refresh."""
201
227
  self._refresh_token()
202
- async for chunk in self.wrapped_model.astream_complete(prompt, **kwargs):
203
- yield chunk
228
+ result = self._wrapped.astream_complete(prompt, formatted=formatted, **kwargs)
229
+ # Handle both coroutine and async generator patterns
230
+ if hasattr(result, "__aiter__"):
231
+ return result
232
+ return await result
204
233
 
205
- def stream_complete(
234
+ # --- FunctionCallingLLM methods (delegate to wrapped model) ---
235
+
236
+ def _prepare_chat_with_tools(
206
237
  self,
207
- prompt: str,
238
+ tools: Sequence[BaseTool],
239
+ user_msg: Union[str, ChatMessage, None] = None,
240
+ chat_history: List[ChatMessage] | None = None,
241
+ verbose: bool = False,
242
+ allow_parallel_tool_calls: bool = False,
243
+ tool_required: Any = None,
208
244
  **kwargs: Any,
209
- ) -> CompletionResponseGen:
210
- """Sync stream complete with token refresh."""
211
- self._refresh_token()
212
- for chunk in self.wrapped_model.stream_complete(prompt, **kwargs):
213
- yield chunk
245
+ ) -> Dict[str, Any]:
246
+ if hasattr(self._wrapped, "_prepare_chat_with_tools"):
247
+ return self._wrapped._prepare_chat_with_tools(
248
+ tools,
249
+ user_msg=user_msg,
250
+ chat_history=chat_history,
251
+ verbose=verbose,
252
+ allow_parallel_tool_calls=allow_parallel_tool_calls,
253
+ tool_required=tool_required,
254
+ **kwargs,
255
+ )
256
+ raise NotImplementedError(
257
+ f"The wrapped model ({type(self._wrapped).__name__}) does not support function calling"
258
+ )
259
+
260
+ def get_tool_calls_from_response(
261
+ self,
262
+ response: ChatResponse,
263
+ error_on_no_tool_call: bool = True,
264
+ **kwargs: Any,
265
+ ) -> List[ToolSelection]:
266
+ if hasattr(self._wrapped, "get_tool_calls_from_response"):
267
+ return self._wrapped.get_tool_calls_from_response(
268
+ response,
269
+ error_on_no_tool_call=error_on_no_tool_call,
270
+ **kwargs,
271
+ )
272
+ raise NotImplementedError(
273
+ f"The wrapped model ({type(self._wrapped).__name__}) does not support function calling"
274
+ )
214
275
 
215
276
 
216
277
  async def bl_model(name, **kwargs):
@@ -220,4 +281,4 @@ async def bl_model(name, **kwargs):
220
281
  model_config = {"type": type, "model": model, "url": url, "kwargs": kwargs}
221
282
 
222
283
  # Create and return the wrapper
223
- return TokenRefreshingLLM(model_config)
284
+ return TokenRefreshingLLM(model_config)
@@ -5,12 +5,14 @@ from blaxel.core.tools.common import create_model_from_json_schema
5
5
  from blaxel.core.tools.types import Tool
6
6
 
7
7
  if TYPE_CHECKING:
8
- from llama_index.core.tools import FunctionTool
8
+ from llama_index.core.tools import FunctionTool # type: ignore[import-not-found]
9
9
 
10
10
 
11
11
  def get_llamaindex_tool(tool: Tool) -> "FunctionTool":
12
- from llama_index.core.tools import FunctionTool
13
- from llama_index.core.tools.types import ToolMetadata
12
+ from llama_index.core.tools import FunctionTool # type: ignore[import-not-found]
13
+ from llama_index.core.tools.types import ( # type: ignore[import-not-found]
14
+ ToolMetadata,
15
+ )
14
16
 
15
17
  model_schema = create_model_from_json_schema(
16
18
  tool.input_schema, model_name=f"{tool.name}_Schema"
@@ -29,4 +31,4 @@ def get_llamaindex_tool(tool: Tool) -> "FunctionTool":
29
31
  async def bl_tools(tools_names: list[str], **kwargs) -> list["FunctionTool"]:
30
32
  tools = bl_tools_core(tools_names, **kwargs)
31
33
  await tools.initialize()
32
- return [get_llamaindex_tool(tool) for tool in tools.get_tools()]
34
+ return [get_llamaindex_tool(tool) for tool in tools.get_tools()]
blaxel/openai/model.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import httpx
2
- from agents import AsyncOpenAI, OpenAIChatCompletionsModel
2
+ from agents import OpenAIChatCompletionsModel
3
+ from openai import AsyncOpenAI
3
4
 
4
5
  from blaxel.core import bl_model as bl_model_core
5
6
  from blaxel.core import settings
@@ -14,6 +15,11 @@ class DynamicHeadersHTTPClient(httpx.AsyncClient):
14
15
  async def send(self, request, *args, **kwargs):
15
16
  # Update headers with the latest auth headers before each request
16
17
  auth_headers = settings.auth.get_headers()
18
+ # Remove the SDK's default "Authorization: Bearer replaced" header
19
+ # when our auth uses a different header (e.g. X-Blaxel-Authorization with API keys)
20
+ if "Authorization" not in auth_headers:
21
+ request.headers.pop("Authorization", None)
22
+ request.headers.pop("authorization", None)
17
23
  for key, value in auth_headers.items():
18
24
  request.headers[key] = value
19
25
  return await super().send(request, *args, **kwargs)
blaxel/openai/tools.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  from typing import Any
3
3
 
4
- from agents import FunctionTool, RunContextWrapper
4
+ from agents import FunctionTool # type: ignore[import-not-found]
5
+ from agents.tool_context import ToolContext # type: ignore[import-not-found]
5
6
 
6
7
  from blaxel.core.tools import bl_tools as bl_tools_core
7
8
  from blaxel.core.tools.types import Tool
@@ -24,6 +25,13 @@ def _clean_schema_for_openai(schema: dict) -> dict:
24
25
  if "additionalProperties" in cleaned_schema:
25
26
  del cleaned_schema["additionalProperties"]
26
27
 
28
+ # Ensure object type schemas have properties
29
+ if cleaned_schema.get("type") == "object":
30
+ if "properties" not in cleaned_schema:
31
+ cleaned_schema["properties"] = {}
32
+ if "required" not in cleaned_schema:
33
+ cleaned_schema["required"] = []
34
+
27
35
  # Recursively clean properties if they exist
28
36
  if "properties" in cleaned_schema:
29
37
  cleaned_schema["properties"] = {
@@ -39,9 +47,11 @@ def _clean_schema_for_openai(schema: dict) -> dict:
39
47
 
40
48
  def get_openai_tool(tool: Tool) -> FunctionTool:
41
49
  async def openai_coroutine(
42
- _: RunContextWrapper,
43
- arguments: dict[str, Any],
50
+ _: ToolContext[Any],
51
+ arguments: str,
44
52
  ) -> Any:
53
+ if not tool.coroutine:
54
+ raise ValueError(f"Tool {tool.name} does not have a coroutine defined")
45
55
  result = await tool.coroutine(**json.loads(arguments))
46
56
  return result
47
57
 
blaxel/pydantic/model.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from typing import Any
3
3
 
4
- from pydantic_ai.models import Model
4
+ from pydantic_ai.models import Model # type: ignore[import-not-found]
5
5
 
6
6
  from blaxel.core import bl_model as bl_model_core
7
7
  from blaxel.core import settings
@@ -30,8 +30,12 @@ class TokenRefreshingModel(Model):
30
30
 
31
31
  if type == "mistral":
32
32
  from mistralai.sdk import Mistral
33
- from pydantic_ai.models.mistral import MistralModel
34
- from pydantic_ai.providers.mistral import MistralProvider
33
+ from pydantic_ai.models.mistral import ( # type: ignore[import-not-found]
34
+ MistralModel,
35
+ )
36
+ from pydantic_ai.providers.mistral import ( # type: ignore[import-not-found]
37
+ MistralProvider,
38
+ )
35
39
 
36
40
  return MistralModel(
37
41
  model_name=model,
@@ -45,8 +49,12 @@ class TokenRefreshingModel(Model):
45
49
  )
46
50
  elif type == "cohere":
47
51
  from cohere import AsyncClientV2
48
- from pydantic_ai.models.cohere import CohereModel
49
- from pydantic_ai.providers.cohere import CohereProvider
52
+ from pydantic_ai.models.cohere import ( # type: ignore[import-not-found]
53
+ CohereModel,
54
+ )
55
+ from pydantic_ai.providers.cohere import ( # type: ignore[import-not-found]
56
+ CohereProvider,
57
+ )
50
58
 
51
59
  return CohereModel(
52
60
  model_name=model,
@@ -58,30 +66,42 @@ class TokenRefreshingModel(Model):
58
66
  ),
59
67
  )
60
68
  elif type == "xai":
61
- from pydantic_ai.models.openai import OpenAIModel
62
- from pydantic_ai.providers.openai import OpenAIProvider
69
+ from pydantic_ai.models.openai import ( # type: ignore[import-not-found]
70
+ OpenAIChatModel,
71
+ )
72
+ from pydantic_ai.providers.openai import ( # type: ignore[import-not-found]
73
+ OpenAIProvider,
74
+ )
63
75
 
64
- return OpenAIModel(
76
+ return OpenAIChatModel(
65
77
  model_name=model,
66
78
  provider=OpenAIProvider(
67
79
  base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
68
80
  ),
69
81
  )
70
82
  elif type == "deepseek":
71
- from pydantic_ai.models.openai import OpenAIModel
72
- from pydantic_ai.providers.openai import OpenAIProvider
83
+ from pydantic_ai.models.openai import ( # type: ignore[import-not-found]
84
+ OpenAIChatModel,
85
+ )
86
+ from pydantic_ai.providers.openai import ( # type: ignore[import-not-found]
87
+ OpenAIProvider,
88
+ )
73
89
 
74
- return OpenAIModel(
90
+ return OpenAIChatModel(
75
91
  model_name=model,
76
92
  provider=OpenAIProvider(
77
93
  base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
78
94
  ),
79
95
  )
80
96
  elif type == "cerebras":
81
- from pydantic_ai.models.openai import OpenAIModel
82
- from pydantic_ai.providers.openai import OpenAIProvider
97
+ from pydantic_ai.models.openai import ( # type: ignore[import-not-found]
98
+ OpenAIChatModel,
99
+ )
100
+ from pydantic_ai.providers.openai import ( # type: ignore[import-not-found]
101
+ OpenAIProvider,
102
+ )
83
103
 
84
- return OpenAIModel(
104
+ return OpenAIChatModel(
85
105
  model_name=model,
86
106
  provider=OpenAIProvider(
87
107
  base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
@@ -116,12 +136,12 @@ class TokenRefreshingModel(Model):
116
136
  ),
117
137
  )
118
138
  else:
119
- from pydantic_ai.models.openai import OpenAIModel
139
+ from pydantic_ai.models.openai import OpenAIChatModel
120
140
  from pydantic_ai.providers.openai import OpenAIProvider
121
141
 
122
142
  if type != "openai":
123
143
  logger.warning(f"Model {model} is not supported by Pydantic, defaulting to OpenAI")
124
- return OpenAIModel(
144
+ return OpenAIChatModel(
125
145
  model_name=model,
126
146
  provider=OpenAIProvider(
127
147
  base_url=f"{url}/v1", api_key=settings.auth.token, **kwargs
@@ -130,12 +150,6 @@ class TokenRefreshingModel(Model):
130
150
 
131
151
  def _get_fresh_model(self) -> Model:
132
152
  """Get or create a model with fresh token if needed."""
133
- # Only refresh if using ClientCredentials (which has get_token method)
134
- if hasattr(settings.auth, "get_token"):
135
- # This will trigger token refresh if needed
136
- logger.debug(f"Calling get_token for {self.model_config['type']} model")
137
- settings.auth.get_token()
138
-
139
153
  new_token = settings.auth.token
140
154
 
141
155
  # If token changed or no cached model, create new one
@@ -152,10 +166,10 @@ class TokenRefreshingModel(Model):
152
166
  return model.model_name
153
167
 
154
168
  @property
155
- def system(self) -> Any | None:
169
+ def system(self) -> str:
156
170
  """Return the system property from the wrapped model."""
157
171
  model = self._get_fresh_model()
158
- return model.system if hasattr(model, "system") else None
172
+ return model.system if hasattr(model, "system") else ""
159
173
 
160
174
  async def request(self, *args, **kwargs):
161
175
  """Make a request to the model with token refresh."""
blaxel/pydantic/tools.py CHANGED
@@ -1,11 +1,44 @@
1
- from pydantic_ai import RunContext
2
- from pydantic_ai.tools import Tool as PydanticTool
3
- from pydantic_ai.tools import ToolDefinition
1
+ from typing import Any
2
+
3
+ from pydantic_ai import RunContext # type: ignore[import-not-found]
4
+ from pydantic_ai.tools import Tool as PydanticTool # type: ignore[import-not-found]
5
+ from pydantic_ai.tools import ToolDefinition # type: ignore[import-not-found]
4
6
 
5
7
  from blaxel.core.tools import Tool
6
8
  from blaxel.core.tools import bl_tools as bl_tools_core
7
9
 
8
10
 
11
+ def _clean_schema_for_openai(schema: dict[str, Any]) -> dict[str, Any]:
12
+ """Clean JSON schema to be compatible with OpenAI function calling.
13
+
14
+ OpenAI requires object schemas to have a 'properties' field, even if empty.
15
+ """
16
+ if not isinstance(schema, dict):
17
+ return schema
18
+
19
+ cleaned = schema.copy()
20
+
21
+ if cleaned.get("type") == "object":
22
+ if "properties" not in cleaned:
23
+ cleaned["properties"] = {}
24
+ if "required" not in cleaned:
25
+ cleaned["required"] = []
26
+
27
+ if "additionalProperties" in cleaned:
28
+ del cleaned["additionalProperties"]
29
+ if "$schema" in cleaned:
30
+ del cleaned["$schema"]
31
+
32
+ if "properties" in cleaned:
33
+ cleaned["properties"] = {
34
+ k: _clean_schema_for_openai(v) for k, v in cleaned["properties"].items()
35
+ }
36
+ if "items" in cleaned and isinstance(cleaned["items"], dict):
37
+ cleaned["items"] = _clean_schema_for_openai(cleaned["items"])
38
+
39
+ return cleaned
40
+
41
+
9
42
  def get_pydantic_tool(tool: Tool) -> PydanticTool:
10
43
  """
11
44
  Converts a custom Tool object into a Pydantic AI Tool object.
@@ -27,7 +60,7 @@ def get_pydantic_tool(tool: Tool) -> PydanticTool:
27
60
  """Dynamically prepares the ToolDefinition using the custom Tool's attributes."""
28
61
  tool_def.name = tool.name # Override inferred name
29
62
  tool_def.description = tool.description # Override inferred description
30
- tool_def.parameters_json_schema = tool.input_schema
63
+ tool_def.parameters_json_schema = _clean_schema_for_openai(tool.input_schema)
31
64
  return tool_def
32
65
 
33
66
  async def pydantic_function(**kwargs):