mito-ai 0.1.32__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mito-ai might be problematic. Click here for more details.

Files changed (58) hide show
  1. mito_ai/_version.py +1 -1
  2. mito_ai/anthropic_client.py +52 -54
  3. mito_ai/app_builder/handlers.py +2 -4
  4. mito_ai/completions/models.py +15 -1
  5. mito_ai/completions/prompt_builders/agent_system_message.py +10 -2
  6. mito_ai/completions/providers.py +79 -39
  7. mito_ai/constants.py +11 -24
  8. mito_ai/gemini_client.py +44 -48
  9. mito_ai/openai_client.py +30 -44
  10. mito_ai/tests/message_history/test_generate_short_chat_name.py +0 -4
  11. mito_ai/tests/open_ai_utils_test.py +18 -22
  12. mito_ai/tests/{test_anthropic_client.py → providers/test_anthropic_client.py} +37 -32
  13. mito_ai/tests/providers/test_azure.py +2 -6
  14. mito_ai/tests/providers/test_capabilities.py +120 -0
  15. mito_ai/tests/{test_gemini_client.py → providers/test_gemini_client.py} +40 -36
  16. mito_ai/tests/providers/test_mito_server_utils.py +448 -0
  17. mito_ai/tests/providers/test_model_resolution.py +130 -0
  18. mito_ai/tests/providers/test_openai_client.py +57 -0
  19. mito_ai/tests/providers/test_provider_completion_exception.py +66 -0
  20. mito_ai/tests/providers/test_provider_limits.py +42 -0
  21. mito_ai/tests/providers/test_providers.py +382 -0
  22. mito_ai/tests/providers/test_retry_logic.py +389 -0
  23. mito_ai/tests/providers/utils.py +85 -0
  24. mito_ai/tests/test_constants.py +15 -2
  25. mito_ai/tests/test_telemetry.py +12 -0
  26. mito_ai/utils/anthropic_utils.py +21 -29
  27. mito_ai/utils/gemini_utils.py +18 -22
  28. mito_ai/utils/mito_server_utils.py +92 -0
  29. mito_ai/utils/open_ai_utils.py +22 -46
  30. mito_ai/utils/provider_utils.py +49 -0
  31. mito_ai/utils/telemetry_utils.py +11 -1
  32. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +1 -1
  33. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  34. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  35. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.42b54cf8f038cc526980.js → mito_ai-0.1.34.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js +785 -351
  36. mito_ai-0.1.34.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js.map +1 -0
  37. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.a711c58b58423173bd24.js → mito_ai-0.1.34.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.51d07439b02aaa830975.js +13 -16
  38. mito_ai-0.1.34.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.51d07439b02aaa830975.js.map +1 -0
  39. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js → mito_ai-0.1.34.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js +6 -2
  40. mito_ai-0.1.34.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js.map +1 -0
  41. {mito_ai-0.1.32.dist-info → mito_ai-0.1.34.dist-info}/METADATA +1 -1
  42. {mito_ai-0.1.32.dist-info → mito_ai-0.1.34.dist-info}/RECORD +52 -43
  43. mito_ai/tests/providers_test.py +0 -438
  44. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.42b54cf8f038cc526980.js.map +0 -1
  45. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.a711c58b58423173bd24.js.map +0 -1
  46. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js.map +0 -1
  47. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js +0 -7842
  48. mito_ai-0.1.32.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js.map +0 -1
  49. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  50. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  51. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  52. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js +0 -0
  53. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js.map +0 -0
  54. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  55. {mito_ai-0.1.32.data → mito_ai-0.1.34.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  56. {mito_ai-0.1.32.dist-info → mito_ai-0.1.34.dist-info}/WHEEL +0 -0
  57. {mito_ai-0.1.32.dist-info → mito_ai-0.1.34.dist-info}/entry_points.txt +0 -0
  58. {mito_ai-0.1.32.dist-info → mito_ai-0.1.34.dist-info}/licenses/LICENSE +0 -0
mito_ai/gemini_client.py CHANGED
@@ -5,10 +5,9 @@ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5
5
  from google import genai
6
6
  from google.genai import types
7
7
  from google.genai.types import GenerateContentConfig, Part, Content, GenerateContentResponse
8
- from mito_ai.completions.models import CompletionItem, CompletionReply, CompletionStreamChunk, MessageType, ResponseFormatInfo
8
+ from mito_ai.completions.models import CompletionError, CompletionItem, CompletionReply, CompletionStreamChunk, MessageType, ResponseFormatInfo
9
9
  from mito_ai.utils.gemini_utils import get_gemini_completion_from_mito_server, stream_gemini_completion_from_mito_server, get_gemini_completion_function_params
10
-
11
- GEMINI_FAST_MODEL = "gemini-2.0-flash-lite"
10
+ from mito_ai.utils.mito_server_utils import ProviderCompletionException
12
11
 
13
12
  def extract_and_parse_gemini_json_response(response: GenerateContentResponse) -> Optional[str]:
14
13
  """
@@ -100,65 +99,62 @@ def get_gemini_system_prompt_and_messages(messages: List[Dict[str, Any]]) -> Tup
100
99
 
101
100
 
102
101
  class GeminiClient:
103
- def __init__(self, api_key: Optional[str], model: str):
102
+ def __init__(self, api_key: Optional[str]):
104
103
  self.api_key = api_key
105
- self.model = model
106
104
  if api_key:
107
105
  self.client = genai.Client(api_key=api_key)
108
106
 
109
107
  async def request_completions(
110
108
  self,
111
109
  messages: List[Dict[str, Any]],
110
+ model: str,
112
111
  response_format_info: Optional[ResponseFormatInfo] = None,
113
112
  message_type: MessageType = MessageType.CHAT
114
113
  ) -> str:
115
- try:
116
- # Extract system instructions and contents
117
- system_instructions, contents = get_gemini_system_prompt_and_messages(messages)
118
-
119
- # Get provider data for Gemini completion
120
- provider_data = get_gemini_completion_function_params(
121
- model=self.model if response_format_info else GEMINI_FAST_MODEL,
122
- contents=contents,
114
+ # Extract system instructions and contents
115
+ system_instructions, contents = get_gemini_system_prompt_and_messages(messages)
116
+
117
+ # Get provider data for Gemini completion
118
+ provider_data = get_gemini_completion_function_params(
119
+ model=model,
120
+ contents=contents,
121
+ message_type=message_type,
122
+ response_format_info=response_format_info
123
+ )
124
+
125
+ if self.api_key:
126
+ # Generate content using the Gemini client
127
+ response_config = GenerateContentConfig(
128
+ system_instruction=system_instructions,
129
+ response_mime_type=provider_data.get("config", {}).get("response_mime_type"),
130
+ response_schema=provider_data.get("config", {}).get("response_schema")
131
+ )
132
+ response = self.client.models.generate_content(
133
+ model=provider_data["model"],
134
+ contents=contents, # type: ignore
135
+ config=response_config
136
+ )
137
+
138
+ result = extract_and_parse_gemini_json_response(response)
139
+
140
+ if not result:
141
+ return "No response received from Gemini API"
142
+
143
+ return result
144
+ else:
145
+ # Fallback to Mito server for completion
146
+ return await get_gemini_completion_from_mito_server(
147
+ model=provider_data["model"],
148
+ contents=messages, # Use the extracted contents instead of converted messages to avoid serialization issues
123
149
  message_type=message_type,
124
- response_format_info=response_format_info
150
+ config=provider_data.get("config", None),
151
+ response_format_info=response_format_info,
125
152
  )
126
153
 
127
- if self.api_key:
128
- # Generate content using the Gemini client
129
- response_config = GenerateContentConfig(
130
- system_instruction=system_instructions,
131
- response_mime_type=provider_data.get("config", {}).get("response_mime_type"),
132
- response_schema=provider_data.get("config", {}).get("response_schema")
133
- )
134
- response = self.client.models.generate_content(
135
- model=provider_data["model"],
136
- contents=contents, # type: ignore
137
- config=response_config
138
- )
139
-
140
- result = extract_and_parse_gemini_json_response(response)
141
-
142
- if not result:
143
- return "No response received from Gemini API"
144
-
145
- return result
146
- else:
147
- # Fallback to Mito server for completion
148
- return await get_gemini_completion_from_mito_server(
149
- model=provider_data["model"],
150
- contents=messages, # Use the extracted contents instead of converted messages to avoid serialization issues
151
- message_type=message_type,
152
- config=provider_data.get("config", None),
153
- response_format_info=response_format_info,
154
- )
155
-
156
- except Exception as e:
157
- return f"Error generating content: {str(e)}"
158
-
159
154
  async def stream_completions(
160
155
  self,
161
156
  messages: List[Dict[str, Any]],
157
+ model: str,
162
158
  message_id: str,
163
159
  reply_fn: Callable[[Union[CompletionReply, CompletionStreamChunk]], None],
164
160
  message_type: MessageType = MessageType.CHAT
@@ -169,7 +165,7 @@ class GeminiClient:
169
165
  system_instructions, contents = get_gemini_system_prompt_and_messages(messages)
170
166
  if self.api_key:
171
167
  for chunk in self.client.models.generate_content_stream(
172
- model=self.model,
168
+ model=model,
173
169
  contents=contents, # type: ignore
174
170
  config=GenerateContentConfig(
175
171
  system_instruction=system_instructions
@@ -208,7 +204,7 @@ class GeminiClient:
208
204
  return accumulated_response
209
205
  else:
210
206
  async for chunk_text in stream_gemini_completion_from_mito_server(
211
- model=self.model,
207
+ model=model,
212
208
  contents=messages, # Use the extracted contents instead of converted messages to avoid serialization issues
213
209
  message_type=message_type,
214
210
  message_id=message_id,
mito_ai/openai_client.py CHANGED
@@ -4,6 +4,7 @@
4
4
  from __future__ import annotations
5
5
  from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
6
6
 
7
+ from mito_ai.utils.mito_server_utils import ProviderCompletionException
7
8
  import openai
8
9
  from openai.types.chat import ChatCompletionMessageParam
9
10
  from traitlets import Instance, Unicode, default, validate
@@ -36,8 +37,6 @@ from mito_ai.utils.telemetry_utils import (
36
37
 
37
38
  OPENAI_MODEL_FALLBACK = "gpt-4.1"
38
39
 
39
- OPENAI_FAST_MODEL = "gpt-4.1-nano"
40
-
41
40
  class OpenAIClient(LoggingConfigurable):
42
41
  """Provide AI feature through OpenAI services."""
43
42
 
@@ -222,26 +221,20 @@ This attribute is observed by the websocket provider to push the error to the cl
222
221
  )
223
222
  return client
224
223
 
225
- def _resolve_model(self, model: Optional[str] = None, response_format_info: Optional[ResponseFormatInfo] = None) -> str:
224
+ def _adjust_model_for_azure_or_ollama(self, model: str) -> str:
226
225
 
227
226
  # If they have set an Azure OpenAI model, then we always use it
228
227
  if is_azure_openai_configured() and constants.AZURE_OPENAI_MODEL is not None:
229
228
  self.log.debug(f"Resolving to Azure OpenAI model: {constants.AZURE_OPENAI_MODEL}")
230
229
  return constants.AZURE_OPENAI_MODEL
231
230
 
232
- # Otherwise, we use the fast model for anything other than the agent mode
233
- if response_format_info:
234
- return OPENAI_FAST_MODEL
235
-
236
231
  # If they have set an Ollama model, then we use it
237
232
  if constants.OLLAMA_MODEL is not None:
238
233
  return constants.OLLAMA_MODEL
239
234
 
240
- # If they have set a model, then we use it
241
- if model:
242
- return model
235
+ # Otherwise, we use the model they provided
236
+ return model
243
237
 
244
- return OPENAI_MODEL_FALLBACK
245
238
 
246
239
  async def request_completions(
247
240
  self,
@@ -263,39 +256,33 @@ This attribute is observed by the websocket provider to push the error to the cl
263
256
  # Reset the last error
264
257
  self.last_error = None
265
258
  completion = None
259
+
260
+ # Note: We don't catch exceptions here because we want them to bubble up
261
+ # to the providers file so we can handle all client exceptions in one place.
266
262
 
267
- try:
268
-
269
- # Make sure we are using the correct model
270
- # TODO: If we bring back inline completions or another action that needs to
271
- # respond fast, we must require the user to configure a fast model with Azure as well.
272
- model = self._resolve_model(model, response_format_info)
273
-
274
- # Handle other providers as before
275
- completion_function_params = get_open_ai_completion_function_params(
276
- model, messages, False, response_format_info
263
+ # Handle other providers as before
264
+ completion_function_params = get_open_ai_completion_function_params(
265
+ message_type, model, messages, False, response_format_info
266
+ )
267
+
268
+ # If they have set an Azure OpenAI or Ollama model, then we use it
269
+ completion_function_params["model"] = self._adjust_model_for_azure_or_ollama(completion_function_params["model"])
270
+
271
+ if self._active_async_client is not None:
272
+ response = await self._active_async_client.chat.completions.create(**completion_function_params)
273
+ completion = response.choices[0].message.content or ""
274
+ else:
275
+ last_message_content = str(messages[-1].get("content", "")) if messages else None
276
+ completion = await get_ai_completion_from_mito_server(
277
+ last_message_content,
278
+ completion_function_params,
279
+ self.timeout,
280
+ self.max_retries,
281
+ message_type,
277
282
  )
278
283
 
279
- if self._active_async_client is not None:
280
- response = await self._active_async_client.chat.completions.create(**completion_function_params)
281
- completion = response.choices[0].message.content or ""
282
- else:
283
- last_message_content = str(messages[-1].get("content", "")) if messages else None
284
- completion = await get_ai_completion_from_mito_server(
285
- last_message_content,
286
- completion_function_params,
287
- self.timeout,
288
- self.max_retries,
289
- message_type,
290
- )
291
-
292
- update_mito_server_quota(message_type)
284
+ return completion
293
285
 
294
- return completion
295
-
296
- except BaseException as e:
297
- self.last_error = CompletionError.from_exception(e)
298
- raise
299
286
 
300
287
  async def stream_completions(
301
288
  self,
@@ -315,9 +302,6 @@ This attribute is observed by the websocket provider to push the error to the cl
315
302
  # Reset the last error
316
303
  self.last_error = None
317
304
  accumulated_response = ""
318
-
319
- # Validate that the model is supported.
320
- model = self._resolve_model(model, response_format_info)
321
305
 
322
306
  # Send initial acknowledgment
323
307
  reply_fn(CompletionReply(
@@ -329,8 +313,10 @@ This attribute is observed by the websocket provider to push the error to the cl
329
313
 
330
314
  # Handle other providers as before
331
315
  completion_function_params = get_open_ai_completion_function_params(
332
- model, messages, True, response_format_info
316
+ message_type, model, messages, True, response_format_info
333
317
  )
318
+
319
+ completion_function_params["model"] = self._adjust_model_for_azure_or_ollama(completion_function_params["model"])
334
320
 
335
321
  try:
336
322
  if self._active_async_client is not None:
@@ -6,10 +6,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
6
6
  from traitlets.config import Config
7
7
  from mito_ai.completions.message_history import generate_short_chat_name
8
8
  from mito_ai.completions.providers import OpenAIProvider
9
- from mito_ai.completions.models import MessageType
10
- from mito_ai.openai_client import OPENAI_FAST_MODEL
11
- from mito_ai.anthropic_client import ANTHROPIC_FAST_MODEL
12
- from mito_ai.gemini_client import GEMINI_FAST_MODEL
13
9
 
14
10
 
15
11
  @pytest.fixture
@@ -80,28 +80,24 @@ def test_prepare_request_data_and_headers_basic() -> None:
80
80
  mock_get_user_field.side_effect = ["test@example.com", "user123"]
81
81
 
82
82
  # Mock the quota check
83
- with patch("mito_ai.utils.open_ai_utils.check_mito_server_quota") as mock_check_quota:
84
- data, headers = _prepare_request_data_and_headers(
85
- last_message_content="test message",
86
- ai_completion_data={"key": "value"},
87
- timeout=30,
88
- max_retries=3,
89
- message_type=MessageType.CHAT
90
- )
91
-
92
- # Verify quota check was called
93
- mock_check_quota.assert_called_once_with(MessageType.CHAT)
94
-
95
- # Verify data structure
96
- assert data["timeout"] == 30
97
- assert data["max_retries"] == 3
98
- assert data["email"] == "test@example.com"
99
- assert data["user_id"] == "user123"
100
- assert data["data"] == {"key": "value"}
101
- assert data["user_input"] == "test message"
102
-
103
- # Verify headers
104
- assert headers == {"Content-Type": "application/json"}
83
+ data, headers = _prepare_request_data_and_headers(
84
+ last_message_content="test message",
85
+ ai_completion_data={"key": "value"},
86
+ timeout=30,
87
+ max_retries=3,
88
+ message_type=MessageType.CHAT
89
+ )
90
+
91
+ # Verify data structure
92
+ assert data["timeout"] == 30
93
+ assert data["max_retries"] == 3
94
+ assert data["email"] == "test@example.com"
95
+ assert data["user_id"] == "user123"
96
+ assert data["data"] == {"key": "value"}
97
+ assert data["user_input"] == "test message"
98
+
99
+ # Verify headers
100
+ assert headers == {"Content-Type": "application/json"}
105
101
 
106
102
  def test_prepare_request_data_and_headers_null_message() -> None:
107
103
  """Test handling of null message content"""
@@ -2,15 +2,16 @@
2
2
  # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
3
 
4
4
  import pytest
5
- from mito_ai.anthropic_client import get_anthropic_system_prompt_and_messages, extract_and_parse_anthropic_json_response, AnthropicClient, ANTHROPIC_FAST_MODEL
6
- from mito_ai.utils.anthropic_utils import get_anthropic_completion_function_params
7
- from anthropic.types import MessageParam, Message, ContentBlock, TextBlock, ToolUseBlock, Usage
5
+ from mito_ai.anthropic_client import get_anthropic_system_prompt_and_messages, extract_and_parse_anthropic_json_response, AnthropicClient
6
+ from mito_ai.utils.anthropic_utils import get_anthropic_completion_function_params, FAST_ANTHROPIC_MODEL
7
+ from anthropic.types import Message, TextBlock, ToolUseBlock, Usage, ToolUseBlock, Message, Usage, TextBlock
8
8
  from openai.types.chat import ChatCompletionMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam
9
- from mito_ai.completions.models import ResponseFormatInfo, AgentResponse
9
+ from mito_ai.completions.models import MessageType, ResponseFormatInfo, AgentResponse
10
10
  from unittest.mock import MagicMock, patch
11
11
  import anthropic
12
12
  from typing import List, Dict, Any, cast, Union
13
13
 
14
+
14
15
  # Dummy base64 image (1x1 PNG)
15
16
  DUMMY_IMAGE_DATA_URL = (
16
17
  "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/wcAAgMBAp9l9AAAAABJRU5ErkJggg=="
@@ -231,40 +232,44 @@ def test_tool_use_without_agent_response():
231
232
  extract_and_parse_anthropic_json_response(response)
232
233
  assert "No valid AgentResponse format found" in str(exc_info.value)
233
234
 
234
- CUSTOM_MODEL = "claude-3-5-sonnet-latest"
235
- @pytest.mark.parametrize("response_format_info, expected_model", [
236
- (ResponseFormatInfo(name="agent_response", format=AgentResponse), CUSTOM_MODEL), # With response_format_info - should use self.model
237
- (None, ANTHROPIC_FAST_MODEL), # Without response_format_info - should use ANTHROPIC_FAST_MODEL
235
+ CUSTOM_MODEL = "smart-anthropic-model"
236
+ @pytest.mark.parametrize("message_type, expected_model", [
237
+ (MessageType.CHAT, CUSTOM_MODEL), #
238
+ (MessageType.SMART_DEBUG, CUSTOM_MODEL), #
239
+ (MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
240
+ (MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
241
+ (MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
242
+ (MessageType.INLINE_COMPLETION, FAST_ANTHROPIC_MODEL), #
243
+ (MessageType.CHAT_NAME_GENERATION, FAST_ANTHROPIC_MODEL), #
238
244
  ])
239
- @pytest.mark.asyncio
240
- async def test_model_selection_based_on_response_format_info(response_format_info, expected_model):
245
+ @pytest.mark.asyncio
246
+ async def test_model_selection_based_on_message_type(message_type, expected_model):
241
247
  """
242
- Tests that the correct model is selected based on whether response_format_info is provided.
248
+ Tests that the correct model is selected based on the message type.
243
249
  """
250
+ client = AnthropicClient(api_key="test_key")
244
251
 
245
- # Create an AnthropicClient with a specific model
246
- custom_model = CUSTOM_MODEL
247
- client = AnthropicClient(api_key="test_key", model=custom_model)
248
-
249
- # Mock the messages.create method to avoid actual API calls
250
- client.client = MagicMock()
251
- mock_response = Message(
252
- id="test_id",
253
- role="assistant",
254
- content=[TextBlock(type="text", text="Test response")],
255
- model=custom_model,
256
- type="message",
257
- usage=Usage(input_tokens=0, output_tokens=0)
258
- )
259
- client.client.messages.create.return_value = mock_response
260
-
261
- with patch('mito_ai.anthropic_client.get_anthropic_completion_function_params', wraps=get_anthropic_completion_function_params) as mock_get_params:
252
+ # Mock the messages.create method directly
253
+ with patch.object(client.client.messages, 'create') as mock_create: # type: ignore
254
+ # Create a mock response
255
+ mock_response = Message(
256
+ id="test_id",
257
+ role="assistant",
258
+ content=[TextBlock(type="text", text="test")],
259
+ model='anthropic-model-we-do-not-check',
260
+ type="message",
261
+ usage=Usage(input_tokens=0, output_tokens=0)
262
+ )
263
+ mock_create.return_value = mock_response
264
+
262
265
  await client.request_completions(
263
266
  messages=[{"role": "user", "content": "Test message"}],
264
- response_format_info=response_format_info
267
+ model=CUSTOM_MODEL,
268
+ message_type=message_type,
269
+ response_format_info=None
265
270
  )
266
271
 
267
- # Verify that get_anthropic_completion_function_params was called with the expected model
268
- mock_get_params.assert_called_once()
269
- call_args = mock_get_params.call_args
272
+ # Verify that create was called with the expected model
273
+ mock_create.assert_called_once()
274
+ call_args = mock_create.call_args
270
275
  assert call_args[1]['model'] == expected_model
@@ -176,15 +176,11 @@ class TestAzureOpenAIClientCreation:
176
176
  openai_client = OpenAIClient(config=provider_config)
177
177
 
178
178
  # Test with gpt-4.1 model
179
- resolved_model = openai_client._resolve_model("gpt-4.1")
179
+ resolved_model = openai_client._adjust_model_for_azure_or_ollama("gpt-4.1")
180
180
  assert resolved_model == FAKE_AZURE_MODEL
181
181
 
182
182
  # Test with any other model
183
- resolved_model = openai_client._resolve_model("gpt-3.5-turbo")
184
- assert resolved_model == FAKE_AZURE_MODEL
185
-
186
- # Test with no model specified
187
- resolved_model = openai_client._resolve_model()
183
+ resolved_model = openai_client._adjust_model_for_azure_or_ollama("gpt-3.5-turbo")
188
184
  assert resolved_model == FAKE_AZURE_MODEL
189
185
 
190
186
 
@@ -0,0 +1,120 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ import pytest
5
+ from unittest.mock import MagicMock, patch
6
+ from mito_ai.completions.providers import OpenAIProvider
7
+ from mito_ai.tests.providers.utils import mock_azure_openai_client, mock_openai_client, patch_server_limits
8
+ from traitlets.config import Config
9
+
10
+ FAKE_API_KEY = "sk-1234567890"
11
+
12
+ @pytest.fixture
13
+ def provider_config() -> Config:
14
+ """Create a proper Config object for the OpenAIProvider."""
15
+ config = Config()
16
+ config.OpenAIProvider = Config()
17
+ config.OpenAIClient = Config()
18
+ return config
19
+
20
+ @pytest.mark.parametrize("test_case", [
21
+ {
22
+ "name": "mito_server_fallback_no_keys",
23
+ "setup": {
24
+ "OPENAI_API_KEY": None,
25
+ "CLAUDE_API_KEY": None,
26
+ "GEMINI_API_KEY": None,
27
+ "is_azure_configured": False,
28
+ },
29
+ "expected_provider": "Mito server",
30
+ "expected_key_type": "mito_server_key"
31
+ },
32
+ {
33
+ "name": "claude_when_only_claude_key",
34
+ "setup": {
35
+ "OPENAI_API_KEY": None,
36
+ "CLAUDE_API_KEY": "claude-test-key",
37
+ "GEMINI_API_KEY": None,
38
+ "is_azure_configured": False,
39
+ },
40
+ "expected_provider": "Claude",
41
+ "expected_key_type": "claude"
42
+ },
43
+ {
44
+ "name": "gemini_when_only_gemini_key",
45
+ "setup": {
46
+ "OPENAI_API_KEY": None,
47
+ "CLAUDE_API_KEY": None,
48
+ "GEMINI_API_KEY": "gemini-test-key",
49
+ "is_azure_configured": False,
50
+ },
51
+ "expected_provider": "Gemini",
52
+ "expected_key_type": "gemini"
53
+ },
54
+ {
55
+ "name": "openai_when_openai_key",
56
+ "setup": {
57
+ "OPENAI_API_KEY": 'openai-test-key',
58
+ "CLAUDE_API_KEY": None,
59
+ "GEMINI_API_KEY": None,
60
+ "is_azure_configured": False,
61
+ },
62
+ "expected_provider": "OpenAI (user key)",
63
+ "expected_key_type": "user_key"
64
+ },
65
+ {
66
+ "name": "claude_priority_over_gemini",
67
+ "setup": {
68
+ "OPENAI_API_KEY": None,
69
+ "CLAUDE_API_KEY": "claude-test-key",
70
+ "GEMINI_API_KEY": "gemini-test-key",
71
+ "is_azure_configured": False,
72
+ },
73
+ "expected_provider": "Claude",
74
+ "expected_key_type": "claude"
75
+ },
76
+ ])
77
+ def test_provider_capabilities_real_logic(
78
+ test_case: dict,
79
+ monkeypatch: pytest.MonkeyPatch,
80
+ provider_config: Config
81
+ ) -> None:
82
+ """Test the actual provider selection logic in OpenAIProvider.capabilities"""
83
+
84
+ # Set up the environment based on test case
85
+ setup = test_case["setup"]
86
+
87
+ # CRITICAL: Set up ALL mocks BEFORE creating any clients
88
+ for key, value in setup.items():
89
+ if key == "is_azure_configured":
90
+ if value:
91
+ # For Azure case, mock to return True and set required constants
92
+ monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: True)
93
+ monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_MODEL", "gpt-4o")
94
+ else:
95
+ # For non-Azure case, mock to return False
96
+ monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
97
+ else:
98
+ monkeypatch.setattr(f"mito_ai.constants.{key}", value)
99
+
100
+ # Clear the provider config API key to ensure it uses constants
101
+ provider_config.OpenAIProvider.api_key = None
102
+
103
+ # Mock HTTP calls but let the real logic run
104
+ with patch("openai.OpenAI") as mock_openai_constructor:
105
+ with patch("openai.AsyncOpenAI") as mock_async_openai:
106
+ with patch("openai.AsyncAzureOpenAI") as mock_async_azure_openai:
107
+ # Mock successful API key validation for OpenAI
108
+ mock_openai_instance = MagicMock()
109
+ mock_openai_instance.models.list.return_value = [MagicMock(id="gpt-4o-mini")]
110
+ mock_openai_constructor.return_value = mock_openai_instance
111
+
112
+ # Mock server limits for Mito server fallback
113
+ with patch_server_limits():
114
+ # NOW create the provider after ALL mocks are set up
115
+ llm = OpenAIProvider(config=provider_config)
116
+
117
+ # Test capabilities
118
+ capabilities = llm.capabilities
119
+ assert capabilities.provider == test_case["expected_provider"], f"Test case: {test_case['name']}"
120
+ assert llm.key_type == test_case["expected_key_type"], f"Test case: {test_case['name']}"