mito-ai 0.1.56__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mito_ai/__init__.py +17 -21
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +24 -14
- mito_ai/chart_wizard/__init__.py +3 -0
- mito_ai/chart_wizard/handlers.py +113 -0
- mito_ai/chart_wizard/urls.py +26 -0
- mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
- mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
- mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
- mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
- mito_ai/completions/completion_handlers/completion_handler.py +14 -7
- mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
- mito_ai/completions/completion_handlers/scratchpad_result_handler.py +64 -0
- mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
- mito_ai/completions/completion_handlers/utils.py +3 -7
- mito_ai/completions/handlers.py +36 -21
- mito_ai/completions/message_history.py +8 -10
- mito_ai/completions/models.py +23 -2
- mito_ai/completions/prompt_builders/agent_smart_debug_prompt.py +5 -3
- mito_ai/completions/prompt_builders/agent_system_message.py +97 -5
- mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
- mito_ai/completions/prompt_builders/chart_conversion_prompt.py +27 -0
- mito_ai/completions/prompt_builders/chat_system_message.py +2 -0
- mito_ai/completions/prompt_builders/prompt_constants.py +28 -0
- mito_ai/completions/prompt_builders/scratchpad_result_prompt.py +17 -0
- mito_ai/constants.py +8 -1
- mito_ai/enterprise/__init__.py +1 -1
- mito_ai/enterprise/litellm_client.py +137 -0
- mito_ai/log/handlers.py +1 -1
- mito_ai/openai_client.py +10 -90
- mito_ai/{completions/providers.py → provider_manager.py} +157 -53
- mito_ai/settings/enterprise_handler.py +26 -0
- mito_ai/settings/urls.py +2 -0
- mito_ai/streamlit_conversion/agent_utils.py +2 -30
- mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
- mito_ai/streamlit_preview/handlers.py +6 -3
- mito_ai/streamlit_preview/urls.py +5 -3
- mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
- mito_ai/tests/providers/test_anthropic_client.py +174 -16
- mito_ai/tests/providers/test_azure.py +13 -13
- mito_ai/tests/providers/test_capabilities.py +14 -17
- mito_ai/tests/providers/test_gemini_client.py +14 -13
- mito_ai/tests/providers/test_model_resolution.py +145 -89
- mito_ai/tests/providers/test_openai_client.py +209 -13
- mito_ai/tests/providers/test_provider_limits.py +5 -5
- mito_ai/tests/providers/test_providers.py +229 -51
- mito_ai/tests/providers/test_retry_logic.py +13 -22
- mito_ai/tests/providers/utils.py +4 -4
- mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
- mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
- mito_ai/tests/test_enterprise_mode.py +162 -0
- mito_ai/tests/test_model_utils.py +271 -0
- mito_ai/utils/anthropic_utils.py +8 -6
- mito_ai/utils/gemini_utils.py +0 -3
- mito_ai/utils/litellm_utils.py +84 -0
- mito_ai/utils/model_utils.py +178 -0
- mito_ai/utils/open_ai_utils.py +0 -8
- mito_ai/utils/provider_utils.py +6 -21
- mito_ai/utils/telemetry_utils.py +14 -2
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.dfd7975de75d64db80d6.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +2992 -282
- mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
- mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.1e7b5cf362385f109883.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
- mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.1e7b5cf362385f109883.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +7 -2
- {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
- {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/RECORD +94 -81
- mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.dfd7975de75d64db80d6.js.map +0 -1
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
- {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
|
@@ -6,7 +6,7 @@ import ast
|
|
|
6
6
|
import inspect
|
|
7
7
|
import requests
|
|
8
8
|
from mito_ai.gemini_client import GeminiClient, get_gemini_system_prompt_and_messages
|
|
9
|
-
from mito_ai.utils.gemini_utils import get_gemini_completion_function_params
|
|
9
|
+
from mito_ai.utils.gemini_utils import get_gemini_completion_function_params
|
|
10
10
|
from google.genai.types import Part, GenerateContentResponse, Candidate, Content
|
|
11
11
|
from mito_ai.completions.models import ResponseFormatInfo, AgentResponse
|
|
12
12
|
from unittest.mock import MagicMock, patch
|
|
@@ -156,19 +156,20 @@ async def test_json_response_handling_with_multiple_parts():
|
|
|
156
156
|
assert result == 'Here is the JSON: {"key": "value"} End of response'
|
|
157
157
|
|
|
158
158
|
CUSTOM_MODEL = "smart-gemini-model"
|
|
159
|
-
@pytest.mark.parametrize("message_type
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
159
|
+
@pytest.mark.parametrize("message_type", [
|
|
160
|
+
MessageType.CHAT,
|
|
161
|
+
MessageType.SMART_DEBUG,
|
|
162
|
+
MessageType.CODE_EXPLAIN,
|
|
163
|
+
MessageType.AGENT_EXECUTION,
|
|
164
|
+
MessageType.AGENT_AUTO_ERROR_FIXUP,
|
|
165
|
+
MessageType.INLINE_COMPLETION,
|
|
166
|
+
MessageType.CHAT_NAME_GENERATION,
|
|
167
167
|
])
|
|
168
168
|
@pytest.mark.asyncio
|
|
169
|
-
async def
|
|
169
|
+
async def test_get_completion_model_selection_uses_passed_model(message_type):
|
|
170
170
|
"""
|
|
171
|
-
Tests that the
|
|
171
|
+
Tests that the model passed to the client is used as-is.
|
|
172
|
+
Model selection based on message type is now handled by ProviderManager.
|
|
172
173
|
"""
|
|
173
174
|
with patch('google.genai.Client') as mock_genai_class:
|
|
174
175
|
mock_client = MagicMock()
|
|
@@ -189,7 +190,7 @@ async def test_get_completion_model_selection_based_on_message_type(message_type
|
|
|
189
190
|
response_format_info=None
|
|
190
191
|
)
|
|
191
192
|
|
|
192
|
-
# Verify that generate_content was called with the
|
|
193
|
+
# Verify that generate_content was called with the model that was passed (not overridden)
|
|
193
194
|
mock_models.generate_content.assert_called_once()
|
|
194
195
|
call_args = mock_models.generate_content.call_args
|
|
195
|
-
assert call_args[1]['model'] ==
|
|
196
|
+
assert call_args[1]['model'] == CUSTOM_MODEL
|
|
@@ -6,19 +6,17 @@ These tests ensure that the correct model is chosen for each message type, for e
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import pytest
|
|
9
|
-
from mito_ai.utils.
|
|
9
|
+
from mito_ai.utils.model_utils import get_fast_model_for_selected_model, get_smartest_model_for_selected_model
|
|
10
10
|
from mito_ai.completions.models import MessageType
|
|
11
11
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
12
|
-
from mito_ai.
|
|
13
|
-
from mito_ai.completions.models import MessageType
|
|
14
|
-
from mito_ai.utils.provider_utils import does_message_require_fast_model
|
|
12
|
+
from mito_ai.provider_manager import ProviderManager
|
|
15
13
|
from traitlets.config import Config
|
|
16
14
|
|
|
17
15
|
@pytest.fixture
|
|
18
16
|
def provider_config() -> Config:
|
|
19
|
-
"""Create a proper Config object for the
|
|
17
|
+
"""Create a proper Config object for the ProviderManager."""
|
|
20
18
|
config = Config()
|
|
21
|
-
config.
|
|
19
|
+
config.ProviderManager = Config()
|
|
22
20
|
config.OpenAIClient = Config()
|
|
23
21
|
return config
|
|
24
22
|
|
|
@@ -27,104 +25,162 @@ def mock_messages():
|
|
|
27
25
|
"""Sample messages for testing."""
|
|
28
26
|
return [{"role": "user", "content": "Test message"}]
|
|
29
27
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
(
|
|
35
|
-
(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
(
|
|
39
|
-
]
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
28
|
+
@pytest.mark.asyncio
|
|
29
|
+
async def test_request_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
30
|
+
"""Test that request_completions uses the correct model when use_fast_model=True."""
|
|
31
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
32
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
33
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
34
|
+
|
|
35
|
+
# Mock the OpenAI API call instead of the entire client
|
|
36
|
+
mock_response = MagicMock()
|
|
37
|
+
mock_response.choices = [MagicMock()]
|
|
38
|
+
mock_response.choices[0].message.content = "Test Completion"
|
|
39
|
+
|
|
40
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
41
|
+
mock_openai_client = MagicMock()
|
|
42
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
43
|
+
mock_openai_client.is_closed.return_value = False
|
|
44
|
+
mock_openai_class.return_value = mock_openai_client
|
|
45
|
+
|
|
46
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
47
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
48
|
+
mock_sync_client = MagicMock()
|
|
49
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
50
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
51
|
+
|
|
52
|
+
provider = ProviderManager(config=provider_config)
|
|
53
|
+
provider.set_selected_model("gpt-5.2")
|
|
54
|
+
await provider.request_completions(
|
|
55
|
+
message_type=MessageType.CHAT,
|
|
56
|
+
messages=mock_messages,
|
|
57
|
+
use_fast_model=True
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Verify the model passed to the API call is the fast model
|
|
61
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
62
|
+
assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
|
|
63
|
+
|
|
64
|
+
@pytest.mark.asyncio
|
|
65
|
+
async def test_stream_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
66
|
+
"""Test that stream_completions uses the correct model when use_fast_model=True."""
|
|
67
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
68
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
69
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
44
70
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
71
|
+
# Mock the OpenAI API call instead of the entire client
|
|
72
|
+
mock_response = MagicMock()
|
|
73
|
+
mock_response.choices = [MagicMock()]
|
|
74
|
+
mock_response.choices[0].delta.content = "Test Stream Completion"
|
|
75
|
+
mock_response.choices[0].finish_reason = "stop"
|
|
76
|
+
|
|
77
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
78
|
+
mock_openai_client = MagicMock()
|
|
79
|
+
# Create an async generator for streaming
|
|
80
|
+
async def mock_stream():
|
|
81
|
+
yield mock_response
|
|
82
|
+
|
|
83
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
84
|
+
mock_openai_client.is_closed.return_value = False
|
|
85
|
+
mock_openai_class.return_value = mock_openai_client
|
|
86
|
+
|
|
87
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
88
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
89
|
+
mock_sync_client = MagicMock()
|
|
90
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
91
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
92
|
+
|
|
93
|
+
provider = ProviderManager(config=provider_config)
|
|
94
|
+
provider.set_selected_model("gpt-5.2")
|
|
95
|
+
await provider.stream_completions(
|
|
96
|
+
message_type=MessageType.CHAT,
|
|
97
|
+
messages=mock_messages,
|
|
98
|
+
message_id="test_id",
|
|
99
|
+
thread_id="test_thread",
|
|
100
|
+
reply_fn=lambda x: None,
|
|
101
|
+
use_fast_model=True
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Verify the model passed to the API call is the fast model
|
|
105
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
106
|
+
assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
|
|
49
107
|
|
|
50
108
|
@pytest.mark.asyncio
|
|
51
|
-
async def
|
|
52
|
-
"""Test that request_completions
|
|
109
|
+
async def test_request_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
110
|
+
"""Test that request_completions uses the correct model when use_smartest_model=True."""
|
|
53
111
|
# Set up environment variables to ensure OpenAI provider is used
|
|
54
112
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
55
113
|
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
56
114
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
115
|
+
# Mock the OpenAI API call instead of the entire client
|
|
116
|
+
mock_response = MagicMock()
|
|
117
|
+
mock_response.choices = [MagicMock()]
|
|
118
|
+
mock_response.choices[0].message.content = "Test Completion"
|
|
119
|
+
|
|
120
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
121
|
+
mock_openai_client = MagicMock()
|
|
122
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
123
|
+
mock_openai_client.is_closed.return_value = False
|
|
124
|
+
mock_openai_class.return_value = mock_openai_client
|
|
62
125
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
126
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
127
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
128
|
+
mock_sync_client = MagicMock()
|
|
129
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
130
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
68
131
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
83
|
-
# Verify the model passed to the API call
|
|
84
|
-
call_args = mock_openai_client.chat.completions.create.call_args
|
|
85
|
-
assert call_args[1]['model'] == "gpt-3.5"
|
|
132
|
+
provider = ProviderManager(config=provider_config)
|
|
133
|
+
provider.set_selected_model("gpt-4.1")
|
|
134
|
+
await provider.request_completions(
|
|
135
|
+
message_type=MessageType.CHAT,
|
|
136
|
+
messages=mock_messages,
|
|
137
|
+
use_smartest_model=True
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Verify the model passed to the API call is the smartest model
|
|
141
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
142
|
+
assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
|
|
86
143
|
|
|
87
144
|
@pytest.mark.asyncio
|
|
88
|
-
async def
|
|
89
|
-
"""Test that stream_completions
|
|
145
|
+
async def test_stream_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
146
|
+
"""Test that stream_completions uses the correct model when use_smartest_model=True."""
|
|
90
147
|
# Set up environment variables to ensure OpenAI provider is used
|
|
91
148
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
92
149
|
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
93
150
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
151
|
+
# Mock the OpenAI API call instead of the entire client
|
|
152
|
+
mock_response = MagicMock()
|
|
153
|
+
mock_response.choices = [MagicMock()]
|
|
154
|
+
mock_response.choices[0].delta.content = "Test Stream Completion"
|
|
155
|
+
mock_response.choices[0].finish_reason = "stop"
|
|
156
|
+
|
|
157
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
158
|
+
mock_openai_client = MagicMock()
|
|
159
|
+
# Create an async generator for streaming
|
|
160
|
+
async def mock_stream():
|
|
161
|
+
yield mock_response
|
|
162
|
+
|
|
163
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
164
|
+
mock_openai_client.is_closed.return_value = False
|
|
165
|
+
mock_openai_class.return_value = mock_openai_client
|
|
100
166
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
167
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
168
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
169
|
+
mock_sync_client = MagicMock()
|
|
170
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
171
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
106
172
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
173
|
+
provider = ProviderManager(config=provider_config)
|
|
174
|
+
provider.set_selected_model("gpt-4.1")
|
|
175
|
+
await provider.stream_completions(
|
|
176
|
+
message_type=MessageType.CHAT,
|
|
177
|
+
messages=mock_messages,
|
|
178
|
+
message_id="test_id",
|
|
179
|
+
thread_id="test_thread",
|
|
180
|
+
reply_fn=lambda x: None,
|
|
181
|
+
use_smartest_model=True
|
|
182
|
+
)
|
|
110
183
|
|
|
111
|
-
#
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
mock_sync_client.models.list.return_value = MagicMock()
|
|
115
|
-
mock_sync_openai_class.return_value = mock_sync_client
|
|
116
|
-
|
|
117
|
-
provider = OpenAIProvider(config=provider_config)
|
|
118
|
-
await provider.stream_completions(
|
|
119
|
-
message_type=MessageType.CHAT,
|
|
120
|
-
messages=mock_messages,
|
|
121
|
-
model="gpt-3.5",
|
|
122
|
-
message_id="test_id",
|
|
123
|
-
thread_id="test_thread",
|
|
124
|
-
reply_fn=lambda x: None
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
128
|
-
# Verify the model passed to the API call
|
|
129
|
-
call_args = mock_openai_client.chat.completions.create.call_args
|
|
130
|
-
assert call_args[1]['model'] == "gpt-3.5"
|
|
184
|
+
# Verify the model passed to the API call is the smartest model
|
|
185
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
186
|
+
assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
|
|
@@ -3,25 +3,25 @@
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from mito_ai.openai_client import OpenAIClient
|
|
6
|
-
from mito_ai.utils.open_ai_utils import FAST_OPENAI_MODEL
|
|
7
6
|
from mito_ai.completions.models import MessageType
|
|
8
7
|
from unittest.mock import MagicMock, patch, AsyncMock
|
|
9
8
|
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
|
|
10
9
|
|
|
11
10
|
CUSTOM_MODEL = "smart-openai-model"
|
|
12
|
-
@pytest.mark.parametrize("message_type
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
11
|
+
@pytest.mark.parametrize("message_type", [
|
|
12
|
+
MessageType.CHAT,
|
|
13
|
+
MessageType.SMART_DEBUG,
|
|
14
|
+
MessageType.CODE_EXPLAIN,
|
|
15
|
+
MessageType.AGENT_EXECUTION,
|
|
16
|
+
MessageType.AGENT_AUTO_ERROR_FIXUP,
|
|
17
|
+
MessageType.INLINE_COMPLETION,
|
|
18
|
+
MessageType.CHAT_NAME_GENERATION,
|
|
20
19
|
])
|
|
21
20
|
@pytest.mark.asyncio
|
|
22
|
-
async def
|
|
21
|
+
async def test_model_selection_uses_passed_model(message_type):
|
|
23
22
|
"""
|
|
24
|
-
Tests that the
|
|
23
|
+
Tests that the model passed to the client is used as-is.
|
|
24
|
+
Model selection based on message type is now handled by ProviderManager.
|
|
25
25
|
"""
|
|
26
26
|
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
27
27
|
|
|
@@ -51,7 +51,203 @@ async def test_model_selection_based_on_message_type(message_type, expected_mode
|
|
|
51
51
|
response_format_info=None
|
|
52
52
|
)
|
|
53
53
|
|
|
54
|
-
# Verify that create was called with the
|
|
54
|
+
# Verify that create was called with the model that was passed (not overridden)
|
|
55
55
|
mock_create.assert_called_once()
|
|
56
56
|
call_args = mock_create.call_args
|
|
57
|
-
assert call_args[1]['model'] ==
|
|
57
|
+
assert call_args[1]['model'] == CUSTOM_MODEL
|
|
58
|
+
|
|
59
|
+
@pytest.mark.asyncio
|
|
60
|
+
async def test_openai_client_uses_fast_model_from_provider_manager_without_override():
|
|
61
|
+
"""Test that OpenAI client uses the fast model passed from ProviderManager without internal override."""
|
|
62
|
+
from mito_ai.utils.model_utils import get_fast_model_for_selected_model
|
|
63
|
+
|
|
64
|
+
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
65
|
+
|
|
66
|
+
# Mock the _build_openai_client method to return our mock client
|
|
67
|
+
with patch.object(client, '_build_openai_client') as mock_build_client, \
|
|
68
|
+
patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
69
|
+
|
|
70
|
+
mock_client = MagicMock()
|
|
71
|
+
mock_chat = MagicMock()
|
|
72
|
+
mock_completions = MagicMock()
|
|
73
|
+
mock_client.chat = mock_chat
|
|
74
|
+
mock_chat.completions = mock_completions
|
|
75
|
+
mock_openai_class.return_value = mock_client
|
|
76
|
+
mock_build_client.return_value = mock_client
|
|
77
|
+
|
|
78
|
+
# Create an async mock for the create method
|
|
79
|
+
mock_create = AsyncMock()
|
|
80
|
+
mock_create.return_value = MagicMock(
|
|
81
|
+
choices=[MagicMock(message=MagicMock(content="test"))]
|
|
82
|
+
)
|
|
83
|
+
mock_completions.create = mock_create
|
|
84
|
+
|
|
85
|
+
# Use a fast model that would be selected by ProviderManager
|
|
86
|
+
fast_model = get_fast_model_for_selected_model("gpt-5.2")
|
|
87
|
+
|
|
88
|
+
await client.request_completions(
|
|
89
|
+
message_type=MessageType.CHAT,
|
|
90
|
+
messages=[{"role": "user", "content": "Test message"}],
|
|
91
|
+
model=fast_model,
|
|
92
|
+
response_format_info=None
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Verify that create was called with the fast model that was passed (not overridden)
|
|
96
|
+
mock_create.assert_called_once()
|
|
97
|
+
call_args = mock_create.call_args
|
|
98
|
+
assert call_args[1]['model'] == fast_model
|
|
99
|
+
|
|
100
|
+
@pytest.mark.asyncio
|
|
101
|
+
async def test_openai_client_uses_smartest_model_from_provider_manager_without_override():
|
|
102
|
+
"""Test that OpenAI client uses the smartest model passed from ProviderManager without internal override."""
|
|
103
|
+
from mito_ai.utils.model_utils import get_smartest_model_for_selected_model
|
|
104
|
+
|
|
105
|
+
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
106
|
+
|
|
107
|
+
# Mock the _build_openai_client method to return our mock client
|
|
108
|
+
with patch.object(client, '_build_openai_client') as mock_build_client, \
|
|
109
|
+
patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
110
|
+
|
|
111
|
+
mock_client = MagicMock()
|
|
112
|
+
mock_chat = MagicMock()
|
|
113
|
+
mock_completions = MagicMock()
|
|
114
|
+
mock_client.chat = mock_chat
|
|
115
|
+
mock_chat.completions = mock_completions
|
|
116
|
+
mock_openai_class.return_value = mock_client
|
|
117
|
+
mock_build_client.return_value = mock_client
|
|
118
|
+
|
|
119
|
+
# Create an async mock for the create method
|
|
120
|
+
mock_create = AsyncMock()
|
|
121
|
+
mock_create.return_value = MagicMock(
|
|
122
|
+
choices=[MagicMock(message=MagicMock(content="test"))]
|
|
123
|
+
)
|
|
124
|
+
mock_completions.create = mock_create
|
|
125
|
+
|
|
126
|
+
# Use a smartest model that would be selected by ProviderManager
|
|
127
|
+
smartest_model = get_smartest_model_for_selected_model("gpt-4.1")
|
|
128
|
+
|
|
129
|
+
await client.request_completions(
|
|
130
|
+
message_type=MessageType.CHAT,
|
|
131
|
+
messages=[{"role": "user", "content": "Test message"}],
|
|
132
|
+
model=smartest_model,
|
|
133
|
+
response_format_info=None
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Verify that create was called with the smartest model that was passed (not overridden)
|
|
137
|
+
mock_create.assert_called_once()
|
|
138
|
+
call_args = mock_create.call_args
|
|
139
|
+
assert call_args[1]['model'] == smartest_model
|
|
140
|
+
|
|
141
|
+
@pytest.mark.asyncio
|
|
142
|
+
async def test_openai_client_stream_uses_fast_model_from_provider_manager_without_override():
|
|
143
|
+
"""Test that OpenAI client stream_completions uses the fast model passed from ProviderManager without internal override."""
|
|
144
|
+
from mito_ai.utils.model_utils import get_fast_model_for_selected_model
|
|
145
|
+
|
|
146
|
+
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
147
|
+
|
|
148
|
+
# Mock the _build_openai_client method to return our mock client
|
|
149
|
+
with patch.object(client, '_build_openai_client') as mock_build_client, \
|
|
150
|
+
patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
151
|
+
|
|
152
|
+
mock_client = MagicMock()
|
|
153
|
+
mock_chat = MagicMock()
|
|
154
|
+
mock_completions = MagicMock()
|
|
155
|
+
mock_client.chat = mock_chat
|
|
156
|
+
mock_chat.completions = mock_completions
|
|
157
|
+
mock_openai_class.return_value = mock_client
|
|
158
|
+
mock_build_client.return_value = mock_client
|
|
159
|
+
|
|
160
|
+
# Create an async generator for streaming
|
|
161
|
+
async def mock_stream():
|
|
162
|
+
mock_chunk = MagicMock()
|
|
163
|
+
mock_chunk.choices = [MagicMock()]
|
|
164
|
+
mock_chunk.choices[0].delta.content = "test"
|
|
165
|
+
mock_chunk.choices[0].finish_reason = None
|
|
166
|
+
yield mock_chunk
|
|
167
|
+
mock_final_chunk = MagicMock()
|
|
168
|
+
mock_final_chunk.choices = [MagicMock()]
|
|
169
|
+
mock_final_chunk.choices[0].delta.content = ""
|
|
170
|
+
mock_final_chunk.choices[0].finish_reason = "stop"
|
|
171
|
+
yield mock_final_chunk
|
|
172
|
+
|
|
173
|
+
mock_create = AsyncMock(return_value=mock_stream())
|
|
174
|
+
mock_completions.create = mock_create
|
|
175
|
+
|
|
176
|
+
# Use a fast model that would be selected by ProviderManager
|
|
177
|
+
fast_model = get_fast_model_for_selected_model("gpt-5.2")
|
|
178
|
+
|
|
179
|
+
reply_chunks = []
|
|
180
|
+
def mock_reply(chunk):
|
|
181
|
+
reply_chunks.append(chunk)
|
|
182
|
+
|
|
183
|
+
await client.stream_completions(
|
|
184
|
+
message_type=MessageType.CHAT,
|
|
185
|
+
messages=[{"role": "user", "content": "Test message"}],
|
|
186
|
+
model=fast_model,
|
|
187
|
+
message_id="test-id",
|
|
188
|
+
thread_id="test-thread",
|
|
189
|
+
reply_fn=mock_reply,
|
|
190
|
+
response_format_info=None
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# Verify that create was called with the fast model that was passed (not overridden)
|
|
194
|
+
mock_create.assert_called_once()
|
|
195
|
+
call_args = mock_create.call_args
|
|
196
|
+
assert call_args[1]['model'] == fast_model
|
|
197
|
+
|
|
198
|
+
@pytest.mark.asyncio
|
|
199
|
+
async def test_openai_client_stream_uses_smartest_model_from_provider_manager_without_override():
|
|
200
|
+
"""Test that OpenAI client stream_completions uses the smartest model passed from ProviderManager without internal override."""
|
|
201
|
+
from mito_ai.utils.model_utils import get_smartest_model_for_selected_model
|
|
202
|
+
|
|
203
|
+
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
204
|
+
|
|
205
|
+
# Mock the _build_openai_client method to return our mock client
|
|
206
|
+
with patch.object(client, '_build_openai_client') as mock_build_client, \
|
|
207
|
+
patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
208
|
+
|
|
209
|
+
mock_client = MagicMock()
|
|
210
|
+
mock_chat = MagicMock()
|
|
211
|
+
mock_completions = MagicMock()
|
|
212
|
+
mock_client.chat = mock_chat
|
|
213
|
+
mock_chat.completions = mock_completions
|
|
214
|
+
mock_openai_class.return_value = mock_client
|
|
215
|
+
mock_build_client.return_value = mock_client
|
|
216
|
+
|
|
217
|
+
# Create an async generator for streaming
|
|
218
|
+
async def mock_stream():
|
|
219
|
+
mock_chunk = MagicMock()
|
|
220
|
+
mock_chunk.choices = [MagicMock()]
|
|
221
|
+
mock_chunk.choices[0].delta.content = "test"
|
|
222
|
+
mock_chunk.choices[0].finish_reason = None
|
|
223
|
+
yield mock_chunk
|
|
224
|
+
mock_final_chunk = MagicMock()
|
|
225
|
+
mock_final_chunk.choices = [MagicMock()]
|
|
226
|
+
mock_final_chunk.choices[0].delta.content = ""
|
|
227
|
+
mock_final_chunk.choices[0].finish_reason = "stop"
|
|
228
|
+
yield mock_final_chunk
|
|
229
|
+
|
|
230
|
+
mock_create = AsyncMock(return_value=mock_stream())
|
|
231
|
+
mock_completions.create = mock_create
|
|
232
|
+
|
|
233
|
+
# Use a smartest model that would be selected by ProviderManager
|
|
234
|
+
smartest_model = get_smartest_model_for_selected_model("gpt-4.1")
|
|
235
|
+
|
|
236
|
+
reply_chunks = []
|
|
237
|
+
def mock_reply(chunk):
|
|
238
|
+
reply_chunks.append(chunk)
|
|
239
|
+
|
|
240
|
+
await client.stream_completions(
|
|
241
|
+
message_type=MessageType.CHAT,
|
|
242
|
+
messages=[{"role": "user", "content": "Test message"}],
|
|
243
|
+
model=smartest_model,
|
|
244
|
+
message_id="test-id",
|
|
245
|
+
thread_id="test-thread",
|
|
246
|
+
reply_fn=mock_reply,
|
|
247
|
+
response_format_info=None
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
# Verify that create was called with the smartest model that was passed (not overridden)
|
|
251
|
+
mock_create.assert_called_once()
|
|
252
|
+
call_args = mock_create.call_args
|
|
253
|
+
assert call_args[1]['model'] == smartest_model
|
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
|
-
from mito_ai.
|
|
5
|
+
from mito_ai.provider_manager import ProviderManager
|
|
6
6
|
from mito_ai.tests.providers.utils import mock_openai_client, patch_server_limits
|
|
7
7
|
from mito_ai.utils.server_limits import OS_MONTHLY_AI_COMPLETIONS_LIMIT
|
|
8
8
|
from traitlets.config import Config
|
|
@@ -11,9 +11,9 @@ FAKE_API_KEY = "sk-1234567890"
|
|
|
11
11
|
|
|
12
12
|
@pytest.fixture
|
|
13
13
|
def provider_config() -> Config:
|
|
14
|
-
"""Create a proper Config object for the
|
|
14
|
+
"""Create a proper Config object for the ProviderManager."""
|
|
15
15
|
config = Config()
|
|
16
|
-
config.
|
|
16
|
+
config.ProviderManager = Config()
|
|
17
17
|
config.OpenAIClient = Config()
|
|
18
18
|
return config
|
|
19
19
|
|
|
@@ -36,7 +36,7 @@ def test_openai_provider_with_limits(
|
|
|
36
36
|
patch_server_limits(is_pro=is_pro, completion_count=completion_count),
|
|
37
37
|
mock_openai_client()
|
|
38
38
|
):
|
|
39
|
-
llm =
|
|
39
|
+
llm = ProviderManager(config=provider_config)
|
|
40
40
|
capabilities = llm.capabilities
|
|
41
|
-
assert "
|
|
41
|
+
assert "OpenAI" in capabilities.provider
|
|
42
42
|
assert llm.last_error is None
|