mito-ai 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mito-ai might be problematic. Click here for more details.
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +52 -54
- mito_ai/app_builder/handlers.py +2 -4
- mito_ai/completions/models.py +15 -1
- mito_ai/completions/prompt_builders/agent_system_message.py +10 -2
- mito_ai/completions/providers.py +79 -39
- mito_ai/constants.py +11 -24
- mito_ai/gemini_client.py +44 -48
- mito_ai/openai_client.py +30 -44
- mito_ai/tests/message_history/test_generate_short_chat_name.py +0 -4
- mito_ai/tests/open_ai_utils_test.py +18 -22
- mito_ai/tests/{test_anthropic_client.py → providers/test_anthropic_client.py} +37 -32
- mito_ai/tests/providers/test_azure.py +2 -6
- mito_ai/tests/providers/test_capabilities.py +120 -0
- mito_ai/tests/{test_gemini_client.py → providers/test_gemini_client.py} +40 -36
- mito_ai/tests/providers/test_mito_server_utils.py +448 -0
- mito_ai/tests/providers/test_model_resolution.py +130 -0
- mito_ai/tests/providers/test_openai_client.py +57 -0
- mito_ai/tests/providers/test_provider_completion_exception.py +66 -0
- mito_ai/tests/providers/test_provider_limits.py +42 -0
- mito_ai/tests/providers/test_providers.py +382 -0
- mito_ai/tests/providers/test_retry_logic.py +389 -0
- mito_ai/tests/providers/utils.py +85 -0
- mito_ai/tests/test_constants.py +15 -2
- mito_ai/tests/test_telemetry.py +12 -0
- mito_ai/utils/anthropic_utils.py +21 -29
- mito_ai/utils/gemini_utils.py +18 -22
- mito_ai/utils/mito_server_utils.py +92 -0
- mito_ai/utils/open_ai_utils.py +22 -46
- mito_ai/utils/provider_utils.py +49 -0
- mito_ai/utils/telemetry_utils.py +11 -1
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +1 -1
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.281f4b9af60d620c6fb1.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js +737 -319
- mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js.map +1 -0
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.4f1d00fd0c58fcc05d8d.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.d2eea6519fa332d79efb.js +13 -16
- mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.d2eea6519fa332d79efb.js.map +1 -0
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js +6 -2
- mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js.map +1 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/METADATA +1 -1
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/RECORD +52 -43
- mito_ai/tests/providers_test.py +0 -438
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.281f4b9af60d620c6fb1.js.map +0 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.4f1d00fd0c58fcc05d8d.js.map +0 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js.map +0 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js +0 -7842
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js.map +0 -1
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js.map +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/licenses/LICENSE +0 -0
mito_ai/gemini_client.py
CHANGED
|
@@ -5,10 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
|
|
5
5
|
from google import genai
|
|
6
6
|
from google.genai import types
|
|
7
7
|
from google.genai.types import GenerateContentConfig, Part, Content, GenerateContentResponse
|
|
8
|
-
from mito_ai.completions.models import CompletionItem, CompletionReply, CompletionStreamChunk, MessageType, ResponseFormatInfo
|
|
8
|
+
from mito_ai.completions.models import CompletionError, CompletionItem, CompletionReply, CompletionStreamChunk, MessageType, ResponseFormatInfo
|
|
9
9
|
from mito_ai.utils.gemini_utils import get_gemini_completion_from_mito_server, stream_gemini_completion_from_mito_server, get_gemini_completion_function_params
|
|
10
|
-
|
|
11
|
-
GEMINI_FAST_MODEL = "gemini-2.0-flash-lite"
|
|
10
|
+
from mito_ai.utils.mito_server_utils import ProviderCompletionException
|
|
12
11
|
|
|
13
12
|
def extract_and_parse_gemini_json_response(response: GenerateContentResponse) -> Optional[str]:
|
|
14
13
|
"""
|
|
@@ -100,65 +99,62 @@ def get_gemini_system_prompt_and_messages(messages: List[Dict[str, Any]]) -> Tup
|
|
|
100
99
|
|
|
101
100
|
|
|
102
101
|
class GeminiClient:
|
|
103
|
-
def __init__(self, api_key: Optional[str]
|
|
102
|
+
def __init__(self, api_key: Optional[str]):
|
|
104
103
|
self.api_key = api_key
|
|
105
|
-
self.model = model
|
|
106
104
|
if api_key:
|
|
107
105
|
self.client = genai.Client(api_key=api_key)
|
|
108
106
|
|
|
109
107
|
async def request_completions(
|
|
110
108
|
self,
|
|
111
109
|
messages: List[Dict[str, Any]],
|
|
110
|
+
model: str,
|
|
112
111
|
response_format_info: Optional[ResponseFormatInfo] = None,
|
|
113
112
|
message_type: MessageType = MessageType.CHAT
|
|
114
113
|
) -> str:
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
114
|
+
# Extract system instructions and contents
|
|
115
|
+
system_instructions, contents = get_gemini_system_prompt_and_messages(messages)
|
|
116
|
+
|
|
117
|
+
# Get provider data for Gemini completion
|
|
118
|
+
provider_data = get_gemini_completion_function_params(
|
|
119
|
+
model=model,
|
|
120
|
+
contents=contents,
|
|
121
|
+
message_type=message_type,
|
|
122
|
+
response_format_info=response_format_info
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if self.api_key:
|
|
126
|
+
# Generate content using the Gemini client
|
|
127
|
+
response_config = GenerateContentConfig(
|
|
128
|
+
system_instruction=system_instructions,
|
|
129
|
+
response_mime_type=provider_data.get("config", {}).get("response_mime_type"),
|
|
130
|
+
response_schema=provider_data.get("config", {}).get("response_schema")
|
|
131
|
+
)
|
|
132
|
+
response = self.client.models.generate_content(
|
|
133
|
+
model=provider_data["model"],
|
|
134
|
+
contents=contents, # type: ignore
|
|
135
|
+
config=response_config
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
result = extract_and_parse_gemini_json_response(response)
|
|
139
|
+
|
|
140
|
+
if not result:
|
|
141
|
+
return "No response received from Gemini API"
|
|
142
|
+
|
|
143
|
+
return result
|
|
144
|
+
else:
|
|
145
|
+
# Fallback to Mito server for completion
|
|
146
|
+
return await get_gemini_completion_from_mito_server(
|
|
147
|
+
model=provider_data["model"],
|
|
148
|
+
contents=messages, # Use the extracted contents instead of converted messages to avoid serialization issues
|
|
123
149
|
message_type=message_type,
|
|
124
|
-
|
|
150
|
+
config=provider_data.get("config", None),
|
|
151
|
+
response_format_info=response_format_info,
|
|
125
152
|
)
|
|
126
153
|
|
|
127
|
-
if self.api_key:
|
|
128
|
-
# Generate content using the Gemini client
|
|
129
|
-
response_config = GenerateContentConfig(
|
|
130
|
-
system_instruction=system_instructions,
|
|
131
|
-
response_mime_type=provider_data.get("config", {}).get("response_mime_type"),
|
|
132
|
-
response_schema=provider_data.get("config", {}).get("response_schema")
|
|
133
|
-
)
|
|
134
|
-
response = self.client.models.generate_content(
|
|
135
|
-
model=provider_data["model"],
|
|
136
|
-
contents=contents, # type: ignore
|
|
137
|
-
config=response_config
|
|
138
|
-
)
|
|
139
|
-
|
|
140
|
-
result = extract_and_parse_gemini_json_response(response)
|
|
141
|
-
|
|
142
|
-
if not result:
|
|
143
|
-
return "No response received from Gemini API"
|
|
144
|
-
|
|
145
|
-
return result
|
|
146
|
-
else:
|
|
147
|
-
# Fallback to Mito server for completion
|
|
148
|
-
return await get_gemini_completion_from_mito_server(
|
|
149
|
-
model=provider_data["model"],
|
|
150
|
-
contents=messages, # Use the extracted contents instead of converted messages to avoid serialization issues
|
|
151
|
-
message_type=message_type,
|
|
152
|
-
config=provider_data.get("config", None),
|
|
153
|
-
response_format_info=response_format_info,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
except Exception as e:
|
|
157
|
-
return f"Error generating content: {str(e)}"
|
|
158
|
-
|
|
159
154
|
async def stream_completions(
|
|
160
155
|
self,
|
|
161
156
|
messages: List[Dict[str, Any]],
|
|
157
|
+
model: str,
|
|
162
158
|
message_id: str,
|
|
163
159
|
reply_fn: Callable[[Union[CompletionReply, CompletionStreamChunk]], None],
|
|
164
160
|
message_type: MessageType = MessageType.CHAT
|
|
@@ -169,7 +165,7 @@ class GeminiClient:
|
|
|
169
165
|
system_instructions, contents = get_gemini_system_prompt_and_messages(messages)
|
|
170
166
|
if self.api_key:
|
|
171
167
|
for chunk in self.client.models.generate_content_stream(
|
|
172
|
-
model=
|
|
168
|
+
model=model,
|
|
173
169
|
contents=contents, # type: ignore
|
|
174
170
|
config=GenerateContentConfig(
|
|
175
171
|
system_instruction=system_instructions
|
|
@@ -208,7 +204,7 @@ class GeminiClient:
|
|
|
208
204
|
return accumulated_response
|
|
209
205
|
else:
|
|
210
206
|
async for chunk_text in stream_gemini_completion_from_mito_server(
|
|
211
|
-
model=
|
|
207
|
+
model=model,
|
|
212
208
|
contents=messages, # Use the extracted contents instead of converted messages to avoid serialization issues
|
|
213
209
|
message_type=message_type,
|
|
214
210
|
message_id=message_id,
|
mito_ai/openai_client.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
from __future__ import annotations
|
|
5
5
|
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
|
6
6
|
|
|
7
|
+
from mito_ai.utils.mito_server_utils import ProviderCompletionException
|
|
7
8
|
import openai
|
|
8
9
|
from openai.types.chat import ChatCompletionMessageParam
|
|
9
10
|
from traitlets import Instance, Unicode, default, validate
|
|
@@ -36,8 +37,6 @@ from mito_ai.utils.telemetry_utils import (
|
|
|
36
37
|
|
|
37
38
|
OPENAI_MODEL_FALLBACK = "gpt-4.1"
|
|
38
39
|
|
|
39
|
-
OPENAI_FAST_MODEL = "gpt-4.1-nano"
|
|
40
|
-
|
|
41
40
|
class OpenAIClient(LoggingConfigurable):
|
|
42
41
|
"""Provide AI feature through OpenAI services."""
|
|
43
42
|
|
|
@@ -222,26 +221,20 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
222
221
|
)
|
|
223
222
|
return client
|
|
224
223
|
|
|
225
|
-
def
|
|
224
|
+
def _adjust_model_for_azure_or_ollama(self, model: str) -> str:
|
|
226
225
|
|
|
227
226
|
# If they have set an Azure OpenAI model, then we always use it
|
|
228
227
|
if is_azure_openai_configured() and constants.AZURE_OPENAI_MODEL is not None:
|
|
229
228
|
self.log.debug(f"Resolving to Azure OpenAI model: {constants.AZURE_OPENAI_MODEL}")
|
|
230
229
|
return constants.AZURE_OPENAI_MODEL
|
|
231
230
|
|
|
232
|
-
# Otherwise, we use the fast model for anything other than the agent mode
|
|
233
|
-
if response_format_info:
|
|
234
|
-
return OPENAI_FAST_MODEL
|
|
235
|
-
|
|
236
231
|
# If they have set an Ollama model, then we use it
|
|
237
232
|
if constants.OLLAMA_MODEL is not None:
|
|
238
233
|
return constants.OLLAMA_MODEL
|
|
239
234
|
|
|
240
|
-
#
|
|
241
|
-
|
|
242
|
-
return model
|
|
235
|
+
# Otherwise, we use the model they provided
|
|
236
|
+
return model
|
|
243
237
|
|
|
244
|
-
return OPENAI_MODEL_FALLBACK
|
|
245
238
|
|
|
246
239
|
async def request_completions(
|
|
247
240
|
self,
|
|
@@ -263,39 +256,33 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
263
256
|
# Reset the last error
|
|
264
257
|
self.last_error = None
|
|
265
258
|
completion = None
|
|
259
|
+
|
|
260
|
+
# Note: We don't catch exceptions here because we want them to bubble up
|
|
261
|
+
# to the providers file so we can handle all client exceptions in one place.
|
|
266
262
|
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
263
|
+
# Handle other providers as before
|
|
264
|
+
completion_function_params = get_open_ai_completion_function_params(
|
|
265
|
+
message_type, model, messages, False, response_format_info
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# If they have set an Azure OpenAI or Ollama model, then we use it
|
|
269
|
+
completion_function_params["model"] = self._adjust_model_for_azure_or_ollama(completion_function_params["model"])
|
|
270
|
+
|
|
271
|
+
if self._active_async_client is not None:
|
|
272
|
+
response = await self._active_async_client.chat.completions.create(**completion_function_params)
|
|
273
|
+
completion = response.choices[0].message.content or ""
|
|
274
|
+
else:
|
|
275
|
+
last_message_content = str(messages[-1].get("content", "")) if messages else None
|
|
276
|
+
completion = await get_ai_completion_from_mito_server(
|
|
277
|
+
last_message_content,
|
|
278
|
+
completion_function_params,
|
|
279
|
+
self.timeout,
|
|
280
|
+
self.max_retries,
|
|
281
|
+
message_type,
|
|
277
282
|
)
|
|
278
283
|
|
|
279
|
-
|
|
280
|
-
response = await self._active_async_client.chat.completions.create(**completion_function_params)
|
|
281
|
-
completion = response.choices[0].message.content or ""
|
|
282
|
-
else:
|
|
283
|
-
last_message_content = str(messages[-1].get("content", "")) if messages else None
|
|
284
|
-
completion = await get_ai_completion_from_mito_server(
|
|
285
|
-
last_message_content,
|
|
286
|
-
completion_function_params,
|
|
287
|
-
self.timeout,
|
|
288
|
-
self.max_retries,
|
|
289
|
-
message_type,
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
update_mito_server_quota(message_type)
|
|
284
|
+
return completion
|
|
293
285
|
|
|
294
|
-
return completion
|
|
295
|
-
|
|
296
|
-
except BaseException as e:
|
|
297
|
-
self.last_error = CompletionError.from_exception(e)
|
|
298
|
-
raise
|
|
299
286
|
|
|
300
287
|
async def stream_completions(
|
|
301
288
|
self,
|
|
@@ -315,9 +302,6 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
315
302
|
# Reset the last error
|
|
316
303
|
self.last_error = None
|
|
317
304
|
accumulated_response = ""
|
|
318
|
-
|
|
319
|
-
# Validate that the model is supported.
|
|
320
|
-
model = self._resolve_model(model, response_format_info)
|
|
321
305
|
|
|
322
306
|
# Send initial acknowledgment
|
|
323
307
|
reply_fn(CompletionReply(
|
|
@@ -329,8 +313,10 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
329
313
|
|
|
330
314
|
# Handle other providers as before
|
|
331
315
|
completion_function_params = get_open_ai_completion_function_params(
|
|
332
|
-
model, messages, True, response_format_info
|
|
316
|
+
message_type, model, messages, True, response_format_info
|
|
333
317
|
)
|
|
318
|
+
|
|
319
|
+
completion_function_params["model"] = self._adjust_model_for_azure_or_ollama(completion_function_params["model"])
|
|
334
320
|
|
|
335
321
|
try:
|
|
336
322
|
if self._active_async_client is not None:
|
|
@@ -6,10 +6,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
6
6
|
from traitlets.config import Config
|
|
7
7
|
from mito_ai.completions.message_history import generate_short_chat_name
|
|
8
8
|
from mito_ai.completions.providers import OpenAIProvider
|
|
9
|
-
from mito_ai.completions.models import MessageType
|
|
10
|
-
from mito_ai.openai_client import OPENAI_FAST_MODEL
|
|
11
|
-
from mito_ai.anthropic_client import ANTHROPIC_FAST_MODEL
|
|
12
|
-
from mito_ai.gemini_client import GEMINI_FAST_MODEL
|
|
13
9
|
|
|
14
10
|
|
|
15
11
|
@pytest.fixture
|
|
@@ -80,28 +80,24 @@ def test_prepare_request_data_and_headers_basic() -> None:
|
|
|
80
80
|
mock_get_user_field.side_effect = ["test@example.com", "user123"]
|
|
81
81
|
|
|
82
82
|
# Mock the quota check
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
assert data["user_input"] == "test message"
|
|
102
|
-
|
|
103
|
-
# Verify headers
|
|
104
|
-
assert headers == {"Content-Type": "application/json"}
|
|
83
|
+
data, headers = _prepare_request_data_and_headers(
|
|
84
|
+
last_message_content="test message",
|
|
85
|
+
ai_completion_data={"key": "value"},
|
|
86
|
+
timeout=30,
|
|
87
|
+
max_retries=3,
|
|
88
|
+
message_type=MessageType.CHAT
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Verify data structure
|
|
92
|
+
assert data["timeout"] == 30
|
|
93
|
+
assert data["max_retries"] == 3
|
|
94
|
+
assert data["email"] == "test@example.com"
|
|
95
|
+
assert data["user_id"] == "user123"
|
|
96
|
+
assert data["data"] == {"key": "value"}
|
|
97
|
+
assert data["user_input"] == "test message"
|
|
98
|
+
|
|
99
|
+
# Verify headers
|
|
100
|
+
assert headers == {"Content-Type": "application/json"}
|
|
105
101
|
|
|
106
102
|
def test_prepare_request_data_and_headers_null_message() -> None:
|
|
107
103
|
"""Test handling of null message content"""
|
|
@@ -2,15 +2,16 @@
|
|
|
2
2
|
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
|
-
from mito_ai.anthropic_client import get_anthropic_system_prompt_and_messages, extract_and_parse_anthropic_json_response, AnthropicClient
|
|
6
|
-
from mito_ai.utils.anthropic_utils import get_anthropic_completion_function_params
|
|
7
|
-
from anthropic.types import
|
|
5
|
+
from mito_ai.anthropic_client import get_anthropic_system_prompt_and_messages, extract_and_parse_anthropic_json_response, AnthropicClient
|
|
6
|
+
from mito_ai.utils.anthropic_utils import get_anthropic_completion_function_params, FAST_ANTHROPIC_MODEL
|
|
7
|
+
from anthropic.types import Message, TextBlock, ToolUseBlock, Usage, ToolUseBlock, Message, Usage, TextBlock
|
|
8
8
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam
|
|
9
|
-
from mito_ai.completions.models import ResponseFormatInfo, AgentResponse
|
|
9
|
+
from mito_ai.completions.models import MessageType, ResponseFormatInfo, AgentResponse
|
|
10
10
|
from unittest.mock import MagicMock, patch
|
|
11
11
|
import anthropic
|
|
12
12
|
from typing import List, Dict, Any, cast, Union
|
|
13
13
|
|
|
14
|
+
|
|
14
15
|
# Dummy base64 image (1x1 PNG)
|
|
15
16
|
DUMMY_IMAGE_DATA_URL = (
|
|
16
17
|
""
|
|
@@ -231,40 +232,44 @@ def test_tool_use_without_agent_response():
|
|
|
231
232
|
extract_and_parse_anthropic_json_response(response)
|
|
232
233
|
assert "No valid AgentResponse format found" in str(exc_info.value)
|
|
233
234
|
|
|
234
|
-
CUSTOM_MODEL = "
|
|
235
|
-
@pytest.mark.parametrize("
|
|
236
|
-
(
|
|
237
|
-
(
|
|
235
|
+
CUSTOM_MODEL = "smart-anthropic-model"
|
|
236
|
+
@pytest.mark.parametrize("message_type, expected_model", [
|
|
237
|
+
(MessageType.CHAT, CUSTOM_MODEL), #
|
|
238
|
+
(MessageType.SMART_DEBUG, CUSTOM_MODEL), #
|
|
239
|
+
(MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
|
|
240
|
+
(MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
|
|
241
|
+
(MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
|
|
242
|
+
(MessageType.INLINE_COMPLETION, FAST_ANTHROPIC_MODEL), #
|
|
243
|
+
(MessageType.CHAT_NAME_GENERATION, FAST_ANTHROPIC_MODEL), #
|
|
238
244
|
])
|
|
239
|
-
@pytest.mark.asyncio
|
|
240
|
-
async def
|
|
245
|
+
@pytest.mark.asyncio
|
|
246
|
+
async def test_model_selection_based_on_message_type(message_type, expected_model):
|
|
241
247
|
"""
|
|
242
|
-
Tests that the correct model is selected based on
|
|
248
|
+
Tests that the correct model is selected based on the message type.
|
|
243
249
|
"""
|
|
250
|
+
client = AnthropicClient(api_key="test_key")
|
|
244
251
|
|
|
245
|
-
#
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
)
|
|
259
|
-
client.client.messages.create.return_value = mock_response
|
|
260
|
-
|
|
261
|
-
with patch('mito_ai.anthropic_client.get_anthropic_completion_function_params', wraps=get_anthropic_completion_function_params) as mock_get_params:
|
|
252
|
+
# Mock the messages.create method directly
|
|
253
|
+
with patch.object(client.client.messages, 'create') as mock_create: # type: ignore
|
|
254
|
+
# Create a mock response
|
|
255
|
+
mock_response = Message(
|
|
256
|
+
id="test_id",
|
|
257
|
+
role="assistant",
|
|
258
|
+
content=[TextBlock(type="text", text="test")],
|
|
259
|
+
model='anthropic-model-we-do-not-check',
|
|
260
|
+
type="message",
|
|
261
|
+
usage=Usage(input_tokens=0, output_tokens=0)
|
|
262
|
+
)
|
|
263
|
+
mock_create.return_value = mock_response
|
|
264
|
+
|
|
262
265
|
await client.request_completions(
|
|
263
266
|
messages=[{"role": "user", "content": "Test message"}],
|
|
264
|
-
|
|
267
|
+
model=CUSTOM_MODEL,
|
|
268
|
+
message_type=message_type,
|
|
269
|
+
response_format_info=None
|
|
265
270
|
)
|
|
266
271
|
|
|
267
|
-
# Verify that
|
|
268
|
-
|
|
269
|
-
call_args =
|
|
272
|
+
# Verify that create was called with the expected model
|
|
273
|
+
mock_create.assert_called_once()
|
|
274
|
+
call_args = mock_create.call_args
|
|
270
275
|
assert call_args[1]['model'] == expected_model
|
|
@@ -176,15 +176,11 @@ class TestAzureOpenAIClientCreation:
|
|
|
176
176
|
openai_client = OpenAIClient(config=provider_config)
|
|
177
177
|
|
|
178
178
|
# Test with gpt-4.1 model
|
|
179
|
-
resolved_model = openai_client.
|
|
179
|
+
resolved_model = openai_client._adjust_model_for_azure_or_ollama("gpt-4.1")
|
|
180
180
|
assert resolved_model == FAKE_AZURE_MODEL
|
|
181
181
|
|
|
182
182
|
# Test with any other model
|
|
183
|
-
resolved_model = openai_client.
|
|
184
|
-
assert resolved_model == FAKE_AZURE_MODEL
|
|
185
|
-
|
|
186
|
-
# Test with no model specified
|
|
187
|
-
resolved_model = openai_client._resolve_model()
|
|
183
|
+
resolved_model = openai_client._adjust_model_for_azure_or_ollama("gpt-3.5-turbo")
|
|
188
184
|
assert resolved_model == FAKE_AZURE_MODEL
|
|
189
185
|
|
|
190
186
|
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from unittest.mock import MagicMock, patch
|
|
6
|
+
from mito_ai.completions.providers import OpenAIProvider
|
|
7
|
+
from mito_ai.tests.providers.utils import mock_azure_openai_client, mock_openai_client, patch_server_limits
|
|
8
|
+
from traitlets.config import Config
|
|
9
|
+
|
|
10
|
+
FAKE_API_KEY = "sk-1234567890"
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def provider_config() -> Config:
|
|
14
|
+
"""Create a proper Config object for the OpenAIProvider."""
|
|
15
|
+
config = Config()
|
|
16
|
+
config.OpenAIProvider = Config()
|
|
17
|
+
config.OpenAIClient = Config()
|
|
18
|
+
return config
|
|
19
|
+
|
|
20
|
+
@pytest.mark.parametrize("test_case", [
|
|
21
|
+
{
|
|
22
|
+
"name": "mito_server_fallback_no_keys",
|
|
23
|
+
"setup": {
|
|
24
|
+
"OPENAI_API_KEY": None,
|
|
25
|
+
"CLAUDE_API_KEY": None,
|
|
26
|
+
"GEMINI_API_KEY": None,
|
|
27
|
+
"is_azure_configured": False,
|
|
28
|
+
},
|
|
29
|
+
"expected_provider": "Mito server",
|
|
30
|
+
"expected_key_type": "mito_server_key"
|
|
31
|
+
},
|
|
32
|
+
{
|
|
33
|
+
"name": "claude_when_only_claude_key",
|
|
34
|
+
"setup": {
|
|
35
|
+
"OPENAI_API_KEY": None,
|
|
36
|
+
"CLAUDE_API_KEY": "claude-test-key",
|
|
37
|
+
"GEMINI_API_KEY": None,
|
|
38
|
+
"is_azure_configured": False,
|
|
39
|
+
},
|
|
40
|
+
"expected_provider": "Claude",
|
|
41
|
+
"expected_key_type": "claude"
|
|
42
|
+
},
|
|
43
|
+
{
|
|
44
|
+
"name": "gemini_when_only_gemini_key",
|
|
45
|
+
"setup": {
|
|
46
|
+
"OPENAI_API_KEY": None,
|
|
47
|
+
"CLAUDE_API_KEY": None,
|
|
48
|
+
"GEMINI_API_KEY": "gemini-test-key",
|
|
49
|
+
"is_azure_configured": False,
|
|
50
|
+
},
|
|
51
|
+
"expected_provider": "Gemini",
|
|
52
|
+
"expected_key_type": "gemini"
|
|
53
|
+
},
|
|
54
|
+
{
|
|
55
|
+
"name": "openai_when_openai_key",
|
|
56
|
+
"setup": {
|
|
57
|
+
"OPENAI_API_KEY": 'openai-test-key',
|
|
58
|
+
"CLAUDE_API_KEY": None,
|
|
59
|
+
"GEMINI_API_KEY": None,
|
|
60
|
+
"is_azure_configured": False,
|
|
61
|
+
},
|
|
62
|
+
"expected_provider": "OpenAI (user key)",
|
|
63
|
+
"expected_key_type": "user_key"
|
|
64
|
+
},
|
|
65
|
+
{
|
|
66
|
+
"name": "claude_priority_over_gemini",
|
|
67
|
+
"setup": {
|
|
68
|
+
"OPENAI_API_KEY": None,
|
|
69
|
+
"CLAUDE_API_KEY": "claude-test-key",
|
|
70
|
+
"GEMINI_API_KEY": "gemini-test-key",
|
|
71
|
+
"is_azure_configured": False,
|
|
72
|
+
},
|
|
73
|
+
"expected_provider": "Claude",
|
|
74
|
+
"expected_key_type": "claude"
|
|
75
|
+
},
|
|
76
|
+
])
|
|
77
|
+
def test_provider_capabilities_real_logic(
|
|
78
|
+
test_case: dict,
|
|
79
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
80
|
+
provider_config: Config
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Test the actual provider selection logic in OpenAIProvider.capabilities"""
|
|
83
|
+
|
|
84
|
+
# Set up the environment based on test case
|
|
85
|
+
setup = test_case["setup"]
|
|
86
|
+
|
|
87
|
+
# CRITICAL: Set up ALL mocks BEFORE creating any clients
|
|
88
|
+
for key, value in setup.items():
|
|
89
|
+
if key == "is_azure_configured":
|
|
90
|
+
if value:
|
|
91
|
+
# For Azure case, mock to return True and set required constants
|
|
92
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: True)
|
|
93
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_MODEL", "gpt-4o")
|
|
94
|
+
else:
|
|
95
|
+
# For non-Azure case, mock to return False
|
|
96
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
|
|
97
|
+
else:
|
|
98
|
+
monkeypatch.setattr(f"mito_ai.constants.{key}", value)
|
|
99
|
+
|
|
100
|
+
# Clear the provider config API key to ensure it uses constants
|
|
101
|
+
provider_config.OpenAIProvider.api_key = None
|
|
102
|
+
|
|
103
|
+
# Mock HTTP calls but let the real logic run
|
|
104
|
+
with patch("openai.OpenAI") as mock_openai_constructor:
|
|
105
|
+
with patch("openai.AsyncOpenAI") as mock_async_openai:
|
|
106
|
+
with patch("openai.AsyncAzureOpenAI") as mock_async_azure_openai:
|
|
107
|
+
# Mock successful API key validation for OpenAI
|
|
108
|
+
mock_openai_instance = MagicMock()
|
|
109
|
+
mock_openai_instance.models.list.return_value = [MagicMock(id="gpt-4o-mini")]
|
|
110
|
+
mock_openai_constructor.return_value = mock_openai_instance
|
|
111
|
+
|
|
112
|
+
# Mock server limits for Mito server fallback
|
|
113
|
+
with patch_server_limits():
|
|
114
|
+
# NOW create the provider after ALL mocks are set up
|
|
115
|
+
llm = OpenAIProvider(config=provider_config)
|
|
116
|
+
|
|
117
|
+
# Test capabilities
|
|
118
|
+
capabilities = llm.capabilities
|
|
119
|
+
assert capabilities.provider == test_case["expected_provider"], f"Test case: {test_case['name']}"
|
|
120
|
+
assert llm.key_type == test_case["expected_key_type"], f"Test case: {test_case['name']}"
|