blaxel 0.2.36__py3-none-any.whl → 0.2.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. blaxel/__init__.py +2 -2
  2. blaxel/core/client/models/create_job_execution_request_env.py +3 -3
  3. blaxel/core/client/models/preview.py +48 -1
  4. blaxel/core/client/models/sandbox.py +10 -0
  5. blaxel/core/jobs/__init__.py +2 -2
  6. blaxel/core/sandbox/__init__.py +12 -0
  7. blaxel/core/sandbox/client/api/system/__init__.py +0 -0
  8. blaxel/core/sandbox/client/api/system/get_health.py +134 -0
  9. blaxel/core/sandbox/client/api/system/post_upgrade.py +196 -0
  10. blaxel/core/sandbox/client/models/__init__.py +8 -0
  11. blaxel/core/sandbox/client/models/content_search_match.py +24 -25
  12. blaxel/core/sandbox/client/models/content_search_response.py +25 -29
  13. blaxel/core/sandbox/client/models/find_match.py +13 -14
  14. blaxel/core/sandbox/client/models/find_response.py +21 -24
  15. blaxel/core/sandbox/client/models/fuzzy_search_match.py +17 -19
  16. blaxel/core/sandbox/client/models/fuzzy_search_response.py +21 -24
  17. blaxel/core/sandbox/client/models/health_response.py +159 -0
  18. blaxel/core/sandbox/client/models/process_upgrade_state.py +20 -0
  19. blaxel/core/sandbox/client/models/upgrade_request.py +71 -0
  20. blaxel/core/sandbox/client/models/upgrade_status.py +125 -0
  21. blaxel/core/sandbox/default/__init__.py +2 -0
  22. blaxel/core/sandbox/default/filesystem.py +20 -6
  23. blaxel/core/sandbox/default/preview.py +48 -1
  24. blaxel/core/sandbox/default/process.py +66 -21
  25. blaxel/core/sandbox/default/sandbox.py +36 -5
  26. blaxel/core/sandbox/default/system.py +71 -0
  27. blaxel/core/sandbox/sync/__init__.py +2 -0
  28. blaxel/core/sandbox/sync/filesystem.py +19 -2
  29. blaxel/core/sandbox/sync/preview.py +50 -3
  30. blaxel/core/sandbox/sync/process.py +38 -15
  31. blaxel/core/sandbox/sync/sandbox.py +29 -4
  32. blaxel/core/sandbox/sync/system.py +71 -0
  33. blaxel/core/sandbox/types.py +212 -5
  34. blaxel/core/tools/__init__.py +4 -0
  35. blaxel/core/volume/volume.py +10 -0
  36. blaxel/crewai/model.py +81 -44
  37. blaxel/crewai/tools.py +85 -2
  38. blaxel/googleadk/model.py +22 -3
  39. blaxel/googleadk/tools.py +25 -6
  40. blaxel/langgraph/custom/gemini.py +19 -12
  41. blaxel/langgraph/model.py +26 -18
  42. blaxel/langgraph/tools.py +6 -12
  43. blaxel/livekit/model.py +7 -2
  44. blaxel/livekit/tools.py +3 -1
  45. blaxel/llamaindex/model.py +145 -84
  46. blaxel/llamaindex/tools.py +6 -4
  47. blaxel/openai/model.py +7 -1
  48. blaxel/openai/tools.py +13 -3
  49. blaxel/pydantic/model.py +38 -24
  50. blaxel/pydantic/tools.py +37 -4
  51. blaxel-0.2.38.dist-info/METADATA +528 -0
  52. {blaxel-0.2.36.dist-info → blaxel-0.2.38.dist-info}/RECORD +54 -45
  53. blaxel-0.2.36.dist-info/METADATA +0 -228
  54. {blaxel-0.2.36.dist-info → blaxel-0.2.38.dist-info}/WHEEL +0 -0
  55. {blaxel-0.2.36.dist-info → blaxel-0.2.38.dist-info}/licenses/LICENSE +0 -0
@@ -23,16 +23,18 @@ from typing import (
23
23
 
24
24
  import httpx
25
25
  import requests
26
- from langchain_core.callbacks.manager import (
26
+ from langchain_core.callbacks.manager import ( # type: ignore[import-not-found]
27
27
  AsyncCallbackManagerForLLMRun,
28
28
  CallbackManagerForLLMRun,
29
29
  )
30
- from langchain_core.language_models import LanguageModelInput
31
- from langchain_core.language_models.chat_models import (
30
+ from langchain_core.language_models import ( # type: ignore[import-not-found]
31
+ LanguageModelInput,
32
+ )
33
+ from langchain_core.language_models.chat_models import ( # type: ignore[import-not-found]
32
34
  BaseChatModel,
33
35
  LangSmithParams,
34
36
  )
35
- from langchain_core.messages import (
37
+ from langchain_core.messages import ( # type: ignore[import-not-found]
36
38
  AIMessage,
37
39
  AIMessageChunk,
38
40
  BaseMessage,
@@ -41,25 +43,30 @@ from langchain_core.messages import (
41
43
  SystemMessage,
42
44
  ToolMessage,
43
45
  )
44
- from langchain_core.messages.ai import UsageMetadata
45
- from langchain_core.messages.tool import (
46
+ from langchain_core.messages.ai import UsageMetadata # type: ignore[import-not-found]
47
+ from langchain_core.messages.tool import ( # type: ignore[import-not-found]
46
48
  invalid_tool_call,
47
49
  tool_call,
48
50
  tool_call_chunk,
49
51
  )
50
- from langchain_core.output_parsers.openai_tools import (
52
+ from langchain_core.output_parsers.openai_tools import ( # type: ignore[import-not-found]
51
53
  JsonOutputKeyToolsParser,
52
54
  PydanticToolsParser,
53
55
  parse_tool_calls,
54
56
  )
55
- from langchain_core.outputs import (
57
+ from langchain_core.outputs import ( # type: ignore[import-not-found]
56
58
  ChatGeneration,
57
59
  ChatGenerationChunk,
58
60
  ChatResult,
59
61
  )
60
- from langchain_core.runnables import Runnable, RunnablePassthrough
61
- from langchain_core.tools import BaseTool
62
- from langchain_core.utils.function_calling import convert_to_openai_tool
62
+ from langchain_core.runnables import ( # type: ignore[import-not-found]
63
+ Runnable,
64
+ RunnablePassthrough,
65
+ )
66
+ from langchain_core.tools import BaseTool # type: ignore[import-not-found]
67
+ from langchain_core.utils.function_calling import ( # type: ignore[import-not-found]
68
+ convert_to_openai_tool,
69
+ )
63
70
  from PIL import Image
64
71
  from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
65
72
  from tenacity import (
@@ -1467,4 +1474,4 @@ def image_bytes_to_b64_string(image_bytes: bytes, image_format: str = "jpeg") ->
1467
1474
  """Convert image bytes to base64 string."""
1468
1475
  import base64
1469
1476
 
1470
- return f"data:image/{image_format};base64,{base64.b64encode(image_bytes).decode('utf-8')}"
1477
+ return f"data:image/{image_format};base64,{base64.b64encode(image_bytes).decode('utf-8')}"
blaxel/langgraph/model.py CHANGED
@@ -7,11 +7,15 @@ from blaxel.core import bl_model as bl_model_core
7
7
  from blaxel.core import settings
8
8
 
9
9
  if TYPE_CHECKING:
10
- from langchain_core.callbacks import Callbacks
11
- from langchain_core.language_models import LanguageModelInput
12
- from langchain_core.messages import BaseMessage
13
- from langchain_core.outputs import LLMResult
14
- from langchain_core.runnables import RunnableConfig
10
+ from langchain_core.callbacks import Callbacks # type: ignore[import-not-found]
11
+ from langchain_core.language_models import ( # type: ignore[import-not-found]
12
+ LanguageModelInput,
13
+ )
14
+ from langchain_core.messages import BaseMessage # type: ignore[import-not-found]
15
+ from langchain_core.outputs import LLMResult # type: ignore[import-not-found]
16
+ from langchain_core.runnables import ( # type: ignore[import-not-found]
17
+ RunnableConfig,
18
+ )
15
19
 
16
20
  logger = getLogger(__name__)
17
21
 
@@ -32,7 +36,7 @@ class TokenRefreshingWrapper:
32
36
  kwargs = config.get("kwargs", {})
33
37
 
34
38
  if model_type == "mistral":
35
- from langchain_openai import ChatOpenAI
39
+ from langchain_openai import ChatOpenAI # type: ignore[import-not-found]
36
40
 
37
41
  return ChatOpenAI(
38
42
  api_key=settings.auth.token,
@@ -41,7 +45,7 @@ class TokenRefreshingWrapper:
41
45
  **kwargs,
42
46
  )
43
47
  elif model_type == "cohere":
44
- from langchain_cohere import ChatCohere
48
+ from langchain_cohere import ChatCohere # type: ignore[import-not-found]
45
49
 
46
50
  return ChatCohere(
47
51
  cohere_api_key=settings.auth.token,
@@ -50,7 +54,7 @@ class TokenRefreshingWrapper:
50
54
  **kwargs,
51
55
  )
52
56
  elif model_type == "xai":
53
- from langchain_xai import ChatXAI
57
+ from langchain_xai import ChatXAI # type: ignore[import-not-found]
54
58
 
55
59
  return ChatXAI(
56
60
  model=model,
@@ -59,7 +63,9 @@ class TokenRefreshingWrapper:
59
63
  **kwargs,
60
64
  )
61
65
  elif model_type == "deepseek":
62
- from langchain_deepseek import ChatDeepSeek
66
+ from langchain_deepseek import ( # type: ignore[import-not-found]
67
+ ChatDeepSeek,
68
+ )
63
69
 
64
70
  return ChatDeepSeek(
65
71
  api_key=settings.auth.token,
@@ -68,7 +74,9 @@ class TokenRefreshingWrapper:
68
74
  **kwargs,
69
75
  )
70
76
  elif model_type == "anthropic":
71
- from langchain_anthropic import ChatAnthropic
77
+ from langchain_anthropic import ( # type: ignore[import-not-found]
78
+ ChatAnthropic,
79
+ )
72
80
 
73
81
  return ChatAnthropic(
74
82
  api_key=settings.auth.token,
@@ -78,7 +86,9 @@ class TokenRefreshingWrapper:
78
86
  **kwargs,
79
87
  )
80
88
  elif model_type == "gemini":
81
- from .custom.gemini import ChatGoogleGenerativeAI
89
+ from .custom.gemini import (
90
+ ChatGoogleGenerativeAI, # type: ignore[import-not-found]
91
+ )
82
92
 
83
93
  return ChatGoogleGenerativeAI(
84
94
  model=model,
@@ -88,7 +98,9 @@ class TokenRefreshingWrapper:
88
98
  **kwargs,
89
99
  )
90
100
  elif model_type == "cerebras":
91
- from langchain_cerebras import ChatCerebras
101
+ from langchain_cerebras import ( # type: ignore[import-not-found]
102
+ ChatCerebras,
103
+ )
92
104
 
93
105
  return ChatCerebras(
94
106
  api_key=settings.auth.token,
@@ -97,7 +109,7 @@ class TokenRefreshingWrapper:
97
109
  **kwargs,
98
110
  )
99
111
  else:
100
- from langchain_openai import ChatOpenAI
112
+ from langchain_openai import ChatOpenAI # type: ignore[import-not-found]
101
113
 
102
114
  if model_type != "openai":
103
115
  logger.warning(f"Model {model} is not supported by Langchain, defaulting to OpenAI")
@@ -113,10 +125,6 @@ class TokenRefreshingWrapper:
113
125
  # Only refresh if using ClientCredentials (which has get_token method)
114
126
  current_token = settings.auth.token
115
127
 
116
- if hasattr(settings.auth, "get_token"):
117
- # This will trigger token refresh if needed
118
- settings.auth.get_token()
119
-
120
128
  new_token = settings.auth.token
121
129
 
122
130
  # If token changed, recreate the model
@@ -251,4 +259,4 @@ async def bl_model(name: str, **kwargs):
251
259
  model_config = {"type": type, "model": model, "url": url, "kwargs": kwargs}
252
260
 
253
261
  # Create and return the wrapper
254
- return TokenRefreshingChatModel(model_config)
262
+ return TokenRefreshingChatModel(model_config)
blaxel/langgraph/tools.py CHANGED
@@ -4,8 +4,7 @@ from blaxel.core.tools import bl_tools as bl_tools_core
4
4
  from blaxel.core.tools.types import Tool, ToolException
5
5
 
6
6
  if TYPE_CHECKING:
7
- from langchain_core.tools import StructuredTool
8
- from mcp.types import EmbeddedResource, ImageContent
7
+ from langchain_core.tools import StructuredTool # type: ignore[import-not-found]
9
8
 
10
9
 
11
10
  def _clean_schema_for_openai(schema: Dict[str, Any]) -> Dict[str, Any]:
@@ -38,19 +37,14 @@ def _clean_schema_for_openai(schema: Dict[str, Any]) -> Dict[str, Any]:
38
37
 
39
38
 
40
39
  def get_langchain_tool(tool: Tool) -> "StructuredTool":
41
- from langchain_core.tools import StructuredTool
42
- from mcp.types import (
43
- CallToolResult,
44
- EmbeddedResource,
45
- ImageContent,
46
- TextContent,
47
- )
48
-
49
- NonTextContent = ImageContent | EmbeddedResource
40
+ from langchain_core.tools import StructuredTool # type: ignore[import-not-found]
41
+ from mcp.types import CallToolResult, EmbeddedResource, ImageContent, TextContent
50
42
 
51
43
  async def langchain_coroutine(
52
44
  **arguments: dict[str, Any],
53
- ) -> tuple[str | list[str], list[NonTextContent] | None]:
45
+ ) -> tuple[str | list[str], list[ImageContent | EmbeddedResource] | None]:
46
+ if not tool.coroutine:
47
+ raise ValueError(f"Tool {tool.name} does not have a coroutine defined")
54
48
  result: CallToolResult = await tool.coroutine(**arguments)
55
49
  text_contents: list[TextContent] = []
56
50
  non_text_contents = []
blaxel/livekit/model.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from logging import getLogger
2
2
 
3
3
  import httpx
4
- from livekit.plugins import openai
5
- from openai import AsyncOpenAI
4
+ from livekit.plugins import openai # type: ignore[import-not-found]
5
+ from openai import AsyncOpenAI # type: ignore[import-not-found]
6
6
 
7
7
  from blaxel.core import bl_model as bl_model_core
8
8
  from blaxel.core import settings
@@ -20,6 +20,11 @@ class DynamicHeadersHTTPClient(httpx.AsyncClient):
20
20
  async def send(self, request, *args, **kwargs):
21
21
  # Update headers with the latest auth headers before each request
22
22
  auth_headers = settings.auth.get_headers()
23
+ # Remove the SDK's default "Authorization: Bearer replaced" header
24
+ # when our auth uses a different header (e.g. X-Blaxel-Authorization with API keys)
25
+ if "Authorization" not in auth_headers:
26
+ request.headers.pop("Authorization", None)
27
+ request.headers.pop("authorization", None)
23
28
  for key, value in auth_headers.items():
24
29
  request.headers[key] = value
25
30
  return await super().send(request, *args, **kwargs)
blaxel/livekit/tools.py CHANGED
@@ -1,4 +1,4 @@
1
- from livekit.agents import function_tool, llm
1
+ from livekit.agents import function_tool, llm # type: ignore[import-not-found]
2
2
 
3
3
  from blaxel.core.tools import bl_tools as bl_tools_core
4
4
  from blaxel.core.tools.types import Tool
@@ -6,6 +6,8 @@ from blaxel.core.tools.types import Tool
6
6
 
7
7
  def livekit_coroutine(tool: Tool):
8
8
  async def livekit_coroutine_wrapper(raw_arguments: dict[str, object]):
9
+ if not tool.coroutine:
10
+ raise ValueError(f"Tool {tool.name} does not have a coroutine defined")
9
11
  result = await tool.coroutine(**raw_arguments)
10
12
  return result.model_dump_json()
11
13
 
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import os
4
4
  from logging import getLogger
5
- from typing import TYPE_CHECKING, Any, Sequence
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
6
6
 
7
7
  from blaxel.core import bl_model as bl_model_core
8
8
  from blaxel.core import settings
@@ -11,7 +11,7 @@ from blaxel.core import settings
11
11
  os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
12
12
 
13
13
  if TYPE_CHECKING:
14
- from llama_index.core.base.llms.types import (
14
+ from llama_index.core.base.llms.types import ( # type: ignore[import-not-found]
15
15
  ChatMessage,
16
16
  ChatResponse,
17
17
  ChatResponseAsyncGen,
@@ -20,27 +20,76 @@ if TYPE_CHECKING:
20
20
  CompletionResponseAsyncGen,
21
21
  CompletionResponseGen,
22
22
  )
23
+ from llama_index.core.llms.llm import ( # type: ignore[import-not-found]
24
+ ToolSelection,
25
+ )
26
+ from llama_index.core.tools.types import BaseTool # type: ignore[import-not-found]
27
+
28
+ # Runtime imports needed for class inheritance and construction
29
+ from llama_index.core.base.llms.types import ( # type: ignore[import-not-found]
30
+ LLMMetadata,
31
+ )
32
+ from llama_index.core.llms.function_calling import ( # type: ignore[import-not-found]
33
+ FunctionCallingLLM,
34
+ )
35
+ from pydantic import PrivateAttr # type: ignore[import-not-found]
23
36
 
24
37
  logger = getLogger(__name__)
25
38
 
39
+ DEFAULT_CONTEXT_WINDOW = 128000
40
+ DEFAULT_NUM_OUTPUT = 4096
41
+
42
+
43
+ class TokenRefreshingLLM(FunctionCallingLLM):
44
+ """Wrapper for LlamaIndex LLMs that refreshes token before each call.
26
45
 
27
- class TokenRefreshingWrapper:
28
- """Base wrapper class that refreshes token before each call."""
46
+ Inherits from FunctionCallingLLM to maintain type compatibility with
47
+ LlamaIndex's agents and components that validate isinstance(model, LLM).
48
+ """
49
+
50
+ _model_config_data: dict = PrivateAttr(default_factory=dict)
51
+ _wrapped: Any = PrivateAttr(default=None)
29
52
 
30
53
  def __init__(self, model_config: dict):
31
- self.model_config = model_config
32
- self.wrapped_model = self._create_model()
54
+ super().__init__()
55
+ self._model_config_data = model_config
56
+ self._wrapped = self._create_model()
57
+
58
+ @classmethod
59
+ def class_name(cls) -> str:
60
+ return "TokenRefreshingLLM"
61
+
62
+ @property
63
+ def wrapped_model(self) -> Any:
64
+ """Access the underlying wrapped LLM model."""
65
+ return self._wrapped
66
+
67
+ @property
68
+ def metadata(self) -> LLMMetadata:
69
+ """Get LLM metadata, with fallback for unknown model names."""
70
+ try:
71
+ return self._wrapped.metadata
72
+ except (ValueError, KeyError) as e:
73
+ logger.warning(f"Could not get metadata from wrapped model: {e}. Using defaults.")
74
+ return LLMMetadata(
75
+ context_window=DEFAULT_CONTEXT_WINDOW,
76
+ num_output=DEFAULT_NUM_OUTPUT,
77
+ is_chat_model=True,
78
+ model_name=self._model_config_data.get("model", "unknown"),
79
+ )
33
80
 
34
81
  def _create_model(self):
35
82
  """Create the model instance with current token."""
36
- config = self.model_config
83
+ config = self._model_config_data
37
84
  model_type = config["type"]
38
85
  model = config["model"]
39
86
  url = config["url"]
40
87
  kwargs = config.get("kwargs", {})
41
88
 
42
89
  if model_type == "anthropic":
43
- from llama_index.llms.anthropic import Anthropic
90
+ from llama_index.llms.anthropic import ( # type: ignore[import-not-found]
91
+ Anthropic,
92
+ )
44
93
 
45
94
  return Anthropic(
46
95
  model=model,
@@ -50,7 +99,7 @@ class TokenRefreshingWrapper:
50
99
  **kwargs,
51
100
  )
52
101
  elif model_type == "xai":
53
- from llama_index.llms.groq import Groq
102
+ from llama_index.llms.groq import Groq # type: ignore[import-not-found]
54
103
 
55
104
  return Groq(
56
105
  model=model,
@@ -60,7 +109,9 @@ class TokenRefreshingWrapper:
60
109
  )
61
110
  elif model_type == "gemini":
62
111
  from google.genai.types import HttpOptions
63
- from llama_index.llms.google_genai import GoogleGenAI
112
+ from llama_index.llms.google_genai import ( # type: ignore[import-not-found]
113
+ GoogleGenAI,
114
+ )
64
115
 
65
116
  return GoogleGenAI(
66
117
  api_key=settings.auth.token,
@@ -73,11 +124,13 @@ class TokenRefreshingWrapper:
73
124
  **kwargs,
74
125
  )
75
126
  elif model_type == "cohere":
76
- from .custom.cohere import Cohere
127
+ from .custom.cohere import Cohere # type: ignore[import-not-found]
77
128
 
78
129
  return Cohere(model=model, api_key=settings.auth.token, api_base=url, **kwargs)
79
130
  elif model_type == "deepseek":
80
- from llama_index.llms.deepseek import DeepSeek
131
+ from llama_index.llms.deepseek import ( # type: ignore[import-not-found]
132
+ DeepSeek,
133
+ )
81
134
 
82
135
  return DeepSeek(
83
136
  model=model,
@@ -86,11 +139,15 @@ class TokenRefreshingWrapper:
86
139
  **kwargs,
87
140
  )
88
141
  elif model_type == "mistral":
89
- from llama_index.llms.mistralai import MistralAI
142
+ from llama_index.llms.mistralai import ( # type: ignore[import-not-found]
143
+ MistralAI,
144
+ )
90
145
 
91
146
  return MistralAI(model=model, api_key=settings.auth.token, endpoint=url, **kwargs)
92
147
  elif model_type == "cerebras":
93
- from llama_index.llms.cerebras import Cerebras
148
+ from llama_index.llms.cerebras import ( # type: ignore[import-not-found]
149
+ Cerebras,
150
+ )
94
151
 
95
152
  return Cerebras(
96
153
  model=model,
@@ -99,7 +156,7 @@ class TokenRefreshingWrapper:
99
156
  **kwargs,
100
157
  )
101
158
  else:
102
- from llama_index.llms.openai import OpenAI
159
+ from llama_index.llms.openai import OpenAI # type: ignore[import-not-found]
103
160
 
104
161
  if model_type != "openai":
105
162
  logger.warning(
@@ -115,102 +172,106 @@ class TokenRefreshingWrapper:
115
172
 
116
173
  def _refresh_token(self):
117
174
  """Refresh the token and recreate the model if needed."""
118
- # Only refresh if using ClientCredentials (which has get_token method)
119
175
  current_token = settings.auth.token
120
176
 
121
- if hasattr(settings.auth, "get_token"):
122
- # This will trigger token refresh if needed
123
- settings.auth.get_token()
124
-
125
177
  new_token = settings.auth.token
126
178
 
127
- # If token changed, recreate the model
128
179
  if current_token != new_token:
129
- self.wrapped_model = self._create_model()
130
-
131
- def __getattr__(self, name):
132
- """Delegate attribute access to wrapped model."""
133
- return getattr(self.wrapped_model, name)
180
+ self._wrapped = self._create_model()
134
181
 
182
+ # --- Core LLM methods with token refresh ---
135
183
 
136
- class TokenRefreshingLLM(TokenRefreshingWrapper):
137
- """Wrapper for LlamaIndex LLMs that refreshes token before each call."""
184
+ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
185
+ self._refresh_token()
186
+ return self._wrapped.chat(messages, **kwargs)
138
187
 
139
- async def achat(
140
- self,
141
- messages: Sequence[ChatMessage],
142
- **kwargs: Any,
143
- ) -> ChatResponse:
144
- """Async chat with token refresh."""
188
+ async def achat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
145
189
  self._refresh_token()
146
- return await self.wrapped_model.achat(messages, **kwargs)
190
+ return await self._wrapped.achat(messages, **kwargs)
147
191
 
148
- def chat(
149
- self,
150
- messages: Sequence[ChatMessage],
151
- **kwargs: Any,
152
- ) -> ChatResponse:
153
- """Sync chat with token refresh."""
192
+ def complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
154
193
  self._refresh_token()
155
- return self.wrapped_model.chat(messages, **kwargs)
194
+ return self._wrapped.complete(prompt, formatted=formatted, **kwargs)
156
195
 
157
- async def astream_chat(
158
- self,
159
- messages: Sequence[ChatMessage],
160
- **kwargs: Any,
161
- ) -> ChatResponseAsyncGen:
162
- """Async stream chat with token refresh."""
196
+ async def acomplete(
197
+ self, prompt: str, formatted: bool = False, **kwargs: Any
198
+ ) -> CompletionResponse:
163
199
  self._refresh_token()
164
- async for chunk in self.wrapped_model.astream_chat(messages, **kwargs):
165
- yield chunk
200
+ return await self._wrapped.acomplete(prompt, formatted=formatted, **kwargs)
166
201
 
167
202
  def stream_chat(
168
- self,
169
- messages: Sequence[ChatMessage],
170
- **kwargs: Any,
203
+ self, messages: Sequence[ChatMessage], **kwargs: Any
171
204
  ) -> ChatResponseGen:
172
- """Sync stream chat with token refresh."""
173
205
  self._refresh_token()
174
- for chunk in self.wrapped_model.stream_chat(messages, **kwargs):
175
- yield chunk
206
+ return self._wrapped.stream_chat(messages, **kwargs)
176
207
 
177
- async def acomplete(
178
- self,
179
- prompt: str,
180
- **kwargs: Any,
181
- ) -> CompletionResponse:
182
- """Async complete with token refresh."""
208
+ async def astream_chat(
209
+ self, messages: Sequence[ChatMessage], **kwargs: Any
210
+ ) -> ChatResponseAsyncGen:
183
211
  self._refresh_token()
184
- return await self.wrapped_model.acomplete(prompt, **kwargs)
212
+ result = self._wrapped.astream_chat(messages, **kwargs)
213
+ # Handle both coroutine and async generator patterns
214
+ if hasattr(result, "__aiter__"):
215
+ return result
216
+ return await result
185
217
 
186
- def complete(
187
- self,
188
- prompt: str,
189
- **kwargs: Any,
190
- ) -> CompletionResponse:
191
- """Sync complete with token refresh."""
218
+ def stream_complete(
219
+ self, prompt: str, formatted: bool = False, **kwargs: Any
220
+ ) -> CompletionResponseGen:
192
221
  self._refresh_token()
193
- return self.wrapped_model.complete(prompt, **kwargs)
222
+ return self._wrapped.stream_complete(prompt, formatted=formatted, **kwargs)
194
223
 
195
224
  async def astream_complete(
196
- self,
197
- prompt: str,
198
- **kwargs: Any,
225
+ self, prompt: str, formatted: bool = False, **kwargs: Any
199
226
  ) -> CompletionResponseAsyncGen:
200
- """Async stream complete with token refresh."""
201
227
  self._refresh_token()
202
- async for chunk in self.wrapped_model.astream_complete(prompt, **kwargs):
203
- yield chunk
228
+ result = self._wrapped.astream_complete(prompt, formatted=formatted, **kwargs)
229
+ # Handle both coroutine and async generator patterns
230
+ if hasattr(result, "__aiter__"):
231
+ return result
232
+ return await result
204
233
 
205
- def stream_complete(
234
+ # --- FunctionCallingLLM methods (delegate to wrapped model) ---
235
+
236
+ def _prepare_chat_with_tools(
206
237
  self,
207
- prompt: str,
238
+ tools: Sequence[BaseTool],
239
+ user_msg: Union[str, ChatMessage, None] = None,
240
+ chat_history: List[ChatMessage] | None = None,
241
+ verbose: bool = False,
242
+ allow_parallel_tool_calls: bool = False,
243
+ tool_required: Any = None,
208
244
  **kwargs: Any,
209
- ) -> CompletionResponseGen:
210
- """Sync stream complete with token refresh."""
211
- self._refresh_token()
212
- for chunk in self.wrapped_model.stream_complete(prompt, **kwargs):
213
- yield chunk
245
+ ) -> Dict[str, Any]:
246
+ if hasattr(self._wrapped, "_prepare_chat_with_tools"):
247
+ return self._wrapped._prepare_chat_with_tools(
248
+ tools,
249
+ user_msg=user_msg,
250
+ chat_history=chat_history,
251
+ verbose=verbose,
252
+ allow_parallel_tool_calls=allow_parallel_tool_calls,
253
+ tool_required=tool_required,
254
+ **kwargs,
255
+ )
256
+ raise NotImplementedError(
257
+ f"The wrapped model ({type(self._wrapped).__name__}) does not support function calling"
258
+ )
259
+
260
+ def get_tool_calls_from_response(
261
+ self,
262
+ response: ChatResponse,
263
+ error_on_no_tool_call: bool = True,
264
+ **kwargs: Any,
265
+ ) -> List[ToolSelection]:
266
+ if hasattr(self._wrapped, "get_tool_calls_from_response"):
267
+ return self._wrapped.get_tool_calls_from_response(
268
+ response,
269
+ error_on_no_tool_call=error_on_no_tool_call,
270
+ **kwargs,
271
+ )
272
+ raise NotImplementedError(
273
+ f"The wrapped model ({type(self._wrapped).__name__}) does not support function calling"
274
+ )
214
275
 
215
276
 
216
277
  async def bl_model(name, **kwargs):
@@ -220,4 +281,4 @@ async def bl_model(name, **kwargs):
220
281
  model_config = {"type": type, "model": model, "url": url, "kwargs": kwargs}
221
282
 
222
283
  # Create and return the wrapper
223
- return TokenRefreshingLLM(model_config)
284
+ return TokenRefreshingLLM(model_config)
@@ -5,12 +5,14 @@ from blaxel.core.tools.common import create_model_from_json_schema
5
5
  from blaxel.core.tools.types import Tool
6
6
 
7
7
  if TYPE_CHECKING:
8
- from llama_index.core.tools import FunctionTool
8
+ from llama_index.core.tools import FunctionTool # type: ignore[import-not-found]
9
9
 
10
10
 
11
11
  def get_llamaindex_tool(tool: Tool) -> "FunctionTool":
12
- from llama_index.core.tools import FunctionTool
13
- from llama_index.core.tools.types import ToolMetadata
12
+ from llama_index.core.tools import FunctionTool # type: ignore[import-not-found]
13
+ from llama_index.core.tools.types import ( # type: ignore[import-not-found]
14
+ ToolMetadata,
15
+ )
14
16
 
15
17
  model_schema = create_model_from_json_schema(
16
18
  tool.input_schema, model_name=f"{tool.name}_Schema"
@@ -29,4 +31,4 @@ def get_llamaindex_tool(tool: Tool) -> "FunctionTool":
29
31
  async def bl_tools(tools_names: list[str], **kwargs) -> list["FunctionTool"]:
30
32
  tools = bl_tools_core(tools_names, **kwargs)
31
33
  await tools.initialize()
32
- return [get_llamaindex_tool(tool) for tool in tools.get_tools()]
34
+ return [get_llamaindex_tool(tool) for tool in tools.get_tools()]
blaxel/openai/model.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import httpx
2
- from agents import AsyncOpenAI, OpenAIChatCompletionsModel
2
+ from agents import OpenAIChatCompletionsModel
3
+ from openai import AsyncOpenAI
3
4
 
4
5
  from blaxel.core import bl_model as bl_model_core
5
6
  from blaxel.core import settings
@@ -14,6 +15,11 @@ class DynamicHeadersHTTPClient(httpx.AsyncClient):
14
15
  async def send(self, request, *args, **kwargs):
15
16
  # Update headers with the latest auth headers before each request
16
17
  auth_headers = settings.auth.get_headers()
18
+ # Remove the SDK's default "Authorization: Bearer replaced" header
19
+ # when our auth uses a different header (e.g. X-Blaxel-Authorization with API keys)
20
+ if "Authorization" not in auth_headers:
21
+ request.headers.pop("Authorization", None)
22
+ request.headers.pop("authorization", None)
17
23
  for key, value in auth_headers.items():
18
24
  request.headers[key] = value
19
25
  return await super().send(request, *args, **kwargs)