mito-ai 0.1.57__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mito_ai/__init__.py +16 -22
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +24 -14
- mito_ai/chart_wizard/handlers.py +78 -17
- mito_ai/chart_wizard/urls.py +8 -5
- mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
- mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
- mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
- mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
- mito_ai/completions/completion_handlers/completion_handler.py +3 -5
- mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
- mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
- mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
- mito_ai/completions/completion_handlers/utils.py +3 -7
- mito_ai/completions/handlers.py +32 -22
- mito_ai/completions/message_history.py +8 -10
- mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
- mito_ai/constants.py +8 -1
- mito_ai/enterprise/__init__.py +1 -1
- mito_ai/enterprise/litellm_client.py +137 -0
- mito_ai/log/handlers.py +1 -1
- mito_ai/openai_client.py +10 -90
- mito_ai/{completions/providers.py → provider_manager.py} +157 -53
- mito_ai/settings/enterprise_handler.py +26 -0
- mito_ai/settings/urls.py +2 -0
- mito_ai/streamlit_conversion/agent_utils.py +2 -30
- mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
- mito_ai/streamlit_preview/handlers.py +6 -3
- mito_ai/streamlit_preview/urls.py +5 -3
- mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
- mito_ai/tests/providers/test_anthropic_client.py +174 -16
- mito_ai/tests/providers/test_azure.py +13 -13
- mito_ai/tests/providers/test_capabilities.py +14 -17
- mito_ai/tests/providers/test_gemini_client.py +14 -13
- mito_ai/tests/providers/test_model_resolution.py +145 -89
- mito_ai/tests/providers/test_openai_client.py +209 -13
- mito_ai/tests/providers/test_provider_limits.py +5 -5
- mito_ai/tests/providers/test_providers.py +229 -51
- mito_ai/tests/providers/test_retry_logic.py +13 -22
- mito_ai/tests/providers/utils.py +4 -4
- mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
- mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
- mito_ai/tests/test_enterprise_mode.py +162 -0
- mito_ai/tests/test_model_utils.py +271 -0
- mito_ai/utils/anthropic_utils.py +8 -6
- mito_ai/utils/gemini_utils.py +0 -3
- mito_ai/utils/litellm_utils.py +84 -0
- mito_ai/utils/model_utils.py +178 -0
- mito_ai/utils/open_ai_utils.py +0 -8
- mito_ai/utils/provider_utils.py +6 -28
- mito_ai/utils/telemetry_utils.py +14 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +671 -75
- mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/RECORD +86 -79
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,7 +9,7 @@ import pytest
|
|
|
9
9
|
from traitlets.config import Config
|
|
10
10
|
from openai.types.chat import ChatCompletionMessageParam
|
|
11
11
|
|
|
12
|
-
from mito_ai.
|
|
12
|
+
from mito_ai.provider_manager import ProviderManager
|
|
13
13
|
from mito_ai.completions.models import (
|
|
14
14
|
MessageType,
|
|
15
15
|
AICapabilities,
|
|
@@ -29,9 +29,9 @@ FAKE_AZURE_API_VERSION = "2024-12-01-preview"
|
|
|
29
29
|
|
|
30
30
|
@pytest.fixture
|
|
31
31
|
def provider_config() -> Config:
|
|
32
|
-
"""Create a proper Config object for the
|
|
32
|
+
"""Create a proper Config object for the ProviderManager."""
|
|
33
33
|
config = Config()
|
|
34
|
-
config.
|
|
34
|
+
config.ProviderManager = Config()
|
|
35
35
|
config.OpenAIClient = Config()
|
|
36
36
|
return config
|
|
37
37
|
|
|
@@ -40,7 +40,7 @@ def provider_config() -> Config:
|
|
|
40
40
|
def reset_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
41
41
|
"""Reset all environment variables before each test."""
|
|
42
42
|
for var in [
|
|
43
|
-
"OPENAI_API_KEY", "
|
|
43
|
+
"OPENAI_API_KEY", "ANTHROPIC_API_KEY", "GEMINI_API_KEY", "OLLAMA_MODEL",
|
|
44
44
|
"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_MODEL",
|
|
45
45
|
"AZURE_OPENAI_API_VERSION"
|
|
46
46
|
]:
|
|
@@ -405,7 +405,7 @@ class TestAzureOpenAIStreamCompletions:
|
|
|
405
405
|
|
|
406
406
|
|
|
407
407
|
class TestAzureOpenAIProviderIntegration:
|
|
408
|
-
"""Test Azure OpenAI integration through the
|
|
408
|
+
"""Test Azure OpenAI integration through the ProviderManager."""
|
|
409
409
|
|
|
410
410
|
@pytest.mark.asyncio
|
|
411
411
|
@pytest.mark.parametrize("message_type", COMPLETION_MESSAGE_TYPES)
|
|
@@ -415,7 +415,7 @@ class TestAzureOpenAIProviderIntegration:
|
|
|
415
415
|
provider_config: Config,
|
|
416
416
|
message_type: MessageType
|
|
417
417
|
) -> None:
|
|
418
|
-
"""Test that
|
|
418
|
+
"""Test that ProviderManager uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
|
|
419
419
|
|
|
420
420
|
# Mock the response
|
|
421
421
|
mock_response = MagicMock()
|
|
@@ -428,7 +428,8 @@ class TestAzureOpenAIProviderIntegration:
|
|
|
428
428
|
mock_azure_client.is_closed.return_value = False
|
|
429
429
|
mock_azure_client_class.return_value = mock_azure_client
|
|
430
430
|
|
|
431
|
-
provider =
|
|
431
|
+
provider = ProviderManager(config=provider_config)
|
|
432
|
+
provider.set_selected_model("gpt-4.1")
|
|
432
433
|
|
|
433
434
|
messages: List[ChatCompletionMessageParam] = [
|
|
434
435
|
{"role": "user", "content": "Test message"}
|
|
@@ -437,7 +438,6 @@ class TestAzureOpenAIProviderIntegration:
|
|
|
437
438
|
completion = await provider.request_completions(
|
|
438
439
|
message_type=message_type,
|
|
439
440
|
messages=messages,
|
|
440
|
-
model="gpt-4.1"
|
|
441
441
|
)
|
|
442
442
|
|
|
443
443
|
# Verify the completion was returned
|
|
@@ -461,7 +461,7 @@ class TestAzureOpenAIProviderIntegration:
|
|
|
461
461
|
provider_config: Config,
|
|
462
462
|
message_type: MessageType
|
|
463
463
|
) -> None:
|
|
464
|
-
"""Test that
|
|
464
|
+
"""Test that ProviderManager stream_completions uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
|
|
465
465
|
|
|
466
466
|
# Mock the streaming response
|
|
467
467
|
mock_chunk1 = MagicMock()
|
|
@@ -484,7 +484,8 @@ class TestAzureOpenAIProviderIntegration:
|
|
|
484
484
|
mock_azure_client.is_closed.return_value = False
|
|
485
485
|
mock_azure_client_class.return_value = mock_azure_client
|
|
486
486
|
|
|
487
|
-
provider =
|
|
487
|
+
provider = ProviderManager(config=provider_config)
|
|
488
|
+
provider.set_selected_model("gpt-4.1")
|
|
488
489
|
|
|
489
490
|
messages: List[ChatCompletionMessageParam] = [
|
|
490
491
|
{"role": "user", "content": "Test message"}
|
|
@@ -497,7 +498,6 @@ class TestAzureOpenAIProviderIntegration:
|
|
|
497
498
|
completion = await provider.stream_completions(
|
|
498
499
|
message_type=message_type,
|
|
499
500
|
messages=messages,
|
|
500
|
-
model="gpt-4.1",
|
|
501
501
|
message_id="test-id",
|
|
502
502
|
thread_id="test-thread",
|
|
503
503
|
reply_fn=mock_reply
|
|
@@ -554,8 +554,8 @@ class TestAzureOpenAIConfigurationPriority:
|
|
|
554
554
|
"""Test that Azure OpenAI is used even when Claude key is available."""
|
|
555
555
|
|
|
556
556
|
# Set Claude key (this should be overridden by Azure OpenAI)
|
|
557
|
-
monkeypatch.setenv("
|
|
558
|
-
monkeypatch.setattr("mito_ai.constants.
|
|
557
|
+
monkeypatch.setenv("ANTHROPIC_API_KEY", "claude-key")
|
|
558
|
+
monkeypatch.setattr("mito_ai.constants.ANTHROPIC_API_KEY", "claude-key")
|
|
559
559
|
|
|
560
560
|
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
561
561
|
openai_client = OpenAIClient(config=provider_config)
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from unittest.mock import MagicMock, patch
|
|
6
|
-
from mito_ai.
|
|
6
|
+
from mito_ai.provider_manager import ProviderManager
|
|
7
7
|
from mito_ai.tests.providers.utils import mock_azure_openai_client, mock_openai_client, patch_server_limits
|
|
8
8
|
from traitlets.config import Config
|
|
9
9
|
|
|
@@ -11,9 +11,9 @@ FAKE_API_KEY = "sk-1234567890"
|
|
|
11
11
|
|
|
12
12
|
@pytest.fixture
|
|
13
13
|
def provider_config() -> Config:
|
|
14
|
-
"""Create a proper Config object for the
|
|
14
|
+
"""Create a proper Config object for the ProviderManager."""
|
|
15
15
|
config = Config()
|
|
16
|
-
config.
|
|
16
|
+
config.ProviderManager = Config()
|
|
17
17
|
config.OpenAIClient = Config()
|
|
18
18
|
return config
|
|
19
19
|
|
|
@@ -22,7 +22,7 @@ def provider_config() -> Config:
|
|
|
22
22
|
"name": "mito_server_fallback_no_keys",
|
|
23
23
|
"setup": {
|
|
24
24
|
"OPENAI_API_KEY": None,
|
|
25
|
-
"
|
|
25
|
+
"ANTHROPIC_API_KEY": None,
|
|
26
26
|
"GEMINI_API_KEY": None,
|
|
27
27
|
"is_azure_configured": False,
|
|
28
28
|
},
|
|
@@ -33,45 +33,45 @@ def provider_config() -> Config:
|
|
|
33
33
|
"name": "claude_when_only_claude_key",
|
|
34
34
|
"setup": {
|
|
35
35
|
"OPENAI_API_KEY": None,
|
|
36
|
-
"
|
|
36
|
+
"ANTHROPIC_API_KEY": "claude-test-key",
|
|
37
37
|
"GEMINI_API_KEY": None,
|
|
38
38
|
"is_azure_configured": False,
|
|
39
39
|
},
|
|
40
40
|
"expected_provider": "Claude",
|
|
41
|
-
"expected_key_type": "
|
|
41
|
+
"expected_key_type": "user_key"
|
|
42
42
|
},
|
|
43
43
|
{
|
|
44
44
|
"name": "gemini_when_only_gemini_key",
|
|
45
45
|
"setup": {
|
|
46
46
|
"OPENAI_API_KEY": None,
|
|
47
|
-
"
|
|
47
|
+
"ANTHROPIC_API_KEY": None,
|
|
48
48
|
"GEMINI_API_KEY": "gemini-test-key",
|
|
49
49
|
"is_azure_configured": False,
|
|
50
50
|
},
|
|
51
51
|
"expected_provider": "Gemini",
|
|
52
|
-
"expected_key_type": "
|
|
52
|
+
"expected_key_type": "user_key"
|
|
53
53
|
},
|
|
54
54
|
{
|
|
55
55
|
"name": "openai_when_openai_key",
|
|
56
56
|
"setup": {
|
|
57
57
|
"OPENAI_API_KEY": 'openai-test-key',
|
|
58
|
-
"
|
|
58
|
+
"ANTHROPIC_API_KEY": None,
|
|
59
59
|
"GEMINI_API_KEY": None,
|
|
60
60
|
"is_azure_configured": False,
|
|
61
61
|
},
|
|
62
|
-
"expected_provider": "OpenAI
|
|
62
|
+
"expected_provider": "OpenAI",
|
|
63
63
|
"expected_key_type": "user_key"
|
|
64
64
|
},
|
|
65
65
|
{
|
|
66
66
|
"name": "claude_priority_over_gemini",
|
|
67
67
|
"setup": {
|
|
68
68
|
"OPENAI_API_KEY": None,
|
|
69
|
-
"
|
|
69
|
+
"ANTHROPIC_API_KEY": "claude-test-key",
|
|
70
70
|
"GEMINI_API_KEY": "gemini-test-key",
|
|
71
71
|
"is_azure_configured": False,
|
|
72
72
|
},
|
|
73
73
|
"expected_provider": "Claude",
|
|
74
|
-
"expected_key_type": "
|
|
74
|
+
"expected_key_type": "user_key"
|
|
75
75
|
},
|
|
76
76
|
])
|
|
77
77
|
def test_provider_capabilities_real_logic(
|
|
@@ -79,7 +79,7 @@ def test_provider_capabilities_real_logic(
|
|
|
79
79
|
monkeypatch: pytest.MonkeyPatch,
|
|
80
80
|
provider_config: Config
|
|
81
81
|
) -> None:
|
|
82
|
-
"""Test the actual provider selection logic in
|
|
82
|
+
"""Test the actual provider selection logic in ProviderManager.capabilities"""
|
|
83
83
|
|
|
84
84
|
# Set up the environment based on test case
|
|
85
85
|
setup = test_case["setup"]
|
|
@@ -97,9 +97,6 @@ def test_provider_capabilities_real_logic(
|
|
|
97
97
|
else:
|
|
98
98
|
monkeypatch.setattr(f"mito_ai.constants.{key}", value)
|
|
99
99
|
|
|
100
|
-
# Clear the provider config API key to ensure it uses constants
|
|
101
|
-
provider_config.OpenAIProvider.api_key = None
|
|
102
|
-
|
|
103
100
|
# Mock HTTP calls but let the real logic run
|
|
104
101
|
with patch("openai.OpenAI") as mock_openai_constructor:
|
|
105
102
|
with patch("openai.AsyncOpenAI") as mock_async_openai:
|
|
@@ -112,7 +109,7 @@ def test_provider_capabilities_real_logic(
|
|
|
112
109
|
# Mock server limits for Mito server fallback
|
|
113
110
|
with patch_server_limits():
|
|
114
111
|
# NOW create the provider after ALL mocks are set up
|
|
115
|
-
llm =
|
|
112
|
+
llm = ProviderManager(config=provider_config)
|
|
116
113
|
|
|
117
114
|
# Test capabilities
|
|
118
115
|
capabilities = llm.capabilities
|
|
@@ -6,7 +6,7 @@ import ast
|
|
|
6
6
|
import inspect
|
|
7
7
|
import requests
|
|
8
8
|
from mito_ai.gemini_client import GeminiClient, get_gemini_system_prompt_and_messages
|
|
9
|
-
from mito_ai.utils.gemini_utils import get_gemini_completion_function_params
|
|
9
|
+
from mito_ai.utils.gemini_utils import get_gemini_completion_function_params
|
|
10
10
|
from google.genai.types import Part, GenerateContentResponse, Candidate, Content
|
|
11
11
|
from mito_ai.completions.models import ResponseFormatInfo, AgentResponse
|
|
12
12
|
from unittest.mock import MagicMock, patch
|
|
@@ -156,19 +156,20 @@ async def test_json_response_handling_with_multiple_parts():
|
|
|
156
156
|
assert result == 'Here is the JSON: {"key": "value"} End of response'
|
|
157
157
|
|
|
158
158
|
CUSTOM_MODEL = "smart-gemini-model"
|
|
159
|
-
@pytest.mark.parametrize("message_type
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
159
|
+
@pytest.mark.parametrize("message_type", [
|
|
160
|
+
MessageType.CHAT,
|
|
161
|
+
MessageType.SMART_DEBUG,
|
|
162
|
+
MessageType.CODE_EXPLAIN,
|
|
163
|
+
MessageType.AGENT_EXECUTION,
|
|
164
|
+
MessageType.AGENT_AUTO_ERROR_FIXUP,
|
|
165
|
+
MessageType.INLINE_COMPLETION,
|
|
166
|
+
MessageType.CHAT_NAME_GENERATION,
|
|
167
167
|
])
|
|
168
168
|
@pytest.mark.asyncio
|
|
169
|
-
async def
|
|
169
|
+
async def test_get_completion_model_selection_uses_passed_model(message_type):
|
|
170
170
|
"""
|
|
171
|
-
Tests that the
|
|
171
|
+
Tests that the model passed to the client is used as-is.
|
|
172
|
+
Model selection based on message type is now handled by ProviderManager.
|
|
172
173
|
"""
|
|
173
174
|
with patch('google.genai.Client') as mock_genai_class:
|
|
174
175
|
mock_client = MagicMock()
|
|
@@ -189,7 +190,7 @@ async def test_get_completion_model_selection_based_on_message_type(message_type
|
|
|
189
190
|
response_format_info=None
|
|
190
191
|
)
|
|
191
192
|
|
|
192
|
-
# Verify that generate_content was called with the
|
|
193
|
+
# Verify that generate_content was called with the model that was passed (not overridden)
|
|
193
194
|
mock_models.generate_content.assert_called_once()
|
|
194
195
|
call_args = mock_models.generate_content.call_args
|
|
195
|
-
assert call_args[1]['model'] ==
|
|
196
|
+
assert call_args[1]['model'] == CUSTOM_MODEL
|
|
@@ -6,19 +6,17 @@ These tests ensure that the correct model is chosen for each message type, for e
|
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
8
|
import pytest
|
|
9
|
-
from mito_ai.utils.
|
|
9
|
+
from mito_ai.utils.model_utils import get_fast_model_for_selected_model, get_smartest_model_for_selected_model
|
|
10
10
|
from mito_ai.completions.models import MessageType
|
|
11
11
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
12
|
-
from mito_ai.
|
|
13
|
-
from mito_ai.completions.models import MessageType
|
|
14
|
-
from mito_ai.utils.provider_utils import does_message_require_fast_model
|
|
12
|
+
from mito_ai.provider_manager import ProviderManager
|
|
15
13
|
from traitlets.config import Config
|
|
16
14
|
|
|
17
15
|
@pytest.fixture
|
|
18
16
|
def provider_config() -> Config:
|
|
19
|
-
"""Create a proper Config object for the
|
|
17
|
+
"""Create a proper Config object for the ProviderManager."""
|
|
20
18
|
config = Config()
|
|
21
|
-
config.
|
|
19
|
+
config.ProviderManager = Config()
|
|
22
20
|
config.OpenAIClient = Config()
|
|
23
21
|
return config
|
|
24
22
|
|
|
@@ -27,104 +25,162 @@ def mock_messages():
|
|
|
27
25
|
"""Sample messages for testing."""
|
|
28
26
|
return [{"role": "user", "content": "Test message"}]
|
|
29
27
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
(
|
|
35
|
-
(
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
(
|
|
39
|
-
]
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
28
|
+
@pytest.mark.asyncio
|
|
29
|
+
async def test_request_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
30
|
+
"""Test that request_completions uses the correct model when use_fast_model=True."""
|
|
31
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
32
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
33
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
34
|
+
|
|
35
|
+
# Mock the OpenAI API call instead of the entire client
|
|
36
|
+
mock_response = MagicMock()
|
|
37
|
+
mock_response.choices = [MagicMock()]
|
|
38
|
+
mock_response.choices[0].message.content = "Test Completion"
|
|
39
|
+
|
|
40
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
41
|
+
mock_openai_client = MagicMock()
|
|
42
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
43
|
+
mock_openai_client.is_closed.return_value = False
|
|
44
|
+
mock_openai_class.return_value = mock_openai_client
|
|
45
|
+
|
|
46
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
47
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
48
|
+
mock_sync_client = MagicMock()
|
|
49
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
50
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
51
|
+
|
|
52
|
+
provider = ProviderManager(config=provider_config)
|
|
53
|
+
provider.set_selected_model("gpt-5.2")
|
|
54
|
+
await provider.request_completions(
|
|
55
|
+
message_type=MessageType.CHAT,
|
|
56
|
+
messages=mock_messages,
|
|
57
|
+
use_fast_model=True
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Verify the model passed to the API call is the fast model
|
|
61
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
62
|
+
assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
|
|
63
|
+
|
|
64
|
+
@pytest.mark.asyncio
|
|
65
|
+
async def test_stream_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
66
|
+
"""Test that stream_completions uses the correct model when use_fast_model=True."""
|
|
67
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
68
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
69
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
44
70
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
71
|
+
# Mock the OpenAI API call instead of the entire client
|
|
72
|
+
mock_response = MagicMock()
|
|
73
|
+
mock_response.choices = [MagicMock()]
|
|
74
|
+
mock_response.choices[0].delta.content = "Test Stream Completion"
|
|
75
|
+
mock_response.choices[0].finish_reason = "stop"
|
|
76
|
+
|
|
77
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
78
|
+
mock_openai_client = MagicMock()
|
|
79
|
+
# Create an async generator for streaming
|
|
80
|
+
async def mock_stream():
|
|
81
|
+
yield mock_response
|
|
82
|
+
|
|
83
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
84
|
+
mock_openai_client.is_closed.return_value = False
|
|
85
|
+
mock_openai_class.return_value = mock_openai_client
|
|
86
|
+
|
|
87
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
88
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
89
|
+
mock_sync_client = MagicMock()
|
|
90
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
91
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
92
|
+
|
|
93
|
+
provider = ProviderManager(config=provider_config)
|
|
94
|
+
provider.set_selected_model("gpt-5.2")
|
|
95
|
+
await provider.stream_completions(
|
|
96
|
+
message_type=MessageType.CHAT,
|
|
97
|
+
messages=mock_messages,
|
|
98
|
+
message_id="test_id",
|
|
99
|
+
thread_id="test_thread",
|
|
100
|
+
reply_fn=lambda x: None,
|
|
101
|
+
use_fast_model=True
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Verify the model passed to the API call is the fast model
|
|
105
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
106
|
+
assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
|
|
49
107
|
|
|
50
108
|
@pytest.mark.asyncio
|
|
51
|
-
async def
|
|
52
|
-
"""Test that request_completions
|
|
109
|
+
async def test_request_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
110
|
+
"""Test that request_completions uses the correct model when use_smartest_model=True."""
|
|
53
111
|
# Set up environment variables to ensure OpenAI provider is used
|
|
54
112
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
55
113
|
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
56
114
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
115
|
+
# Mock the OpenAI API call instead of the entire client
|
|
116
|
+
mock_response = MagicMock()
|
|
117
|
+
mock_response.choices = [MagicMock()]
|
|
118
|
+
mock_response.choices[0].message.content = "Test Completion"
|
|
119
|
+
|
|
120
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
121
|
+
mock_openai_client = MagicMock()
|
|
122
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
123
|
+
mock_openai_client.is_closed.return_value = False
|
|
124
|
+
mock_openai_class.return_value = mock_openai_client
|
|
62
125
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
126
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
127
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
128
|
+
mock_sync_client = MagicMock()
|
|
129
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
130
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
68
131
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
83
|
-
# Verify the model passed to the API call
|
|
84
|
-
call_args = mock_openai_client.chat.completions.create.call_args
|
|
85
|
-
assert call_args[1]['model'] == "gpt-3.5"
|
|
132
|
+
provider = ProviderManager(config=provider_config)
|
|
133
|
+
provider.set_selected_model("gpt-4.1")
|
|
134
|
+
await provider.request_completions(
|
|
135
|
+
message_type=MessageType.CHAT,
|
|
136
|
+
messages=mock_messages,
|
|
137
|
+
use_smartest_model=True
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Verify the model passed to the API call is the smartest model
|
|
141
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
142
|
+
assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
|
|
86
143
|
|
|
87
144
|
@pytest.mark.asyncio
|
|
88
|
-
async def
|
|
89
|
-
"""Test that stream_completions
|
|
145
|
+
async def test_stream_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
146
|
+
"""Test that stream_completions uses the correct model when use_smartest_model=True."""
|
|
90
147
|
# Set up environment variables to ensure OpenAI provider is used
|
|
91
148
|
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
92
149
|
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
93
150
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
151
|
+
# Mock the OpenAI API call instead of the entire client
|
|
152
|
+
mock_response = MagicMock()
|
|
153
|
+
mock_response.choices = [MagicMock()]
|
|
154
|
+
mock_response.choices[0].delta.content = "Test Stream Completion"
|
|
155
|
+
mock_response.choices[0].finish_reason = "stop"
|
|
156
|
+
|
|
157
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
158
|
+
mock_openai_client = MagicMock()
|
|
159
|
+
# Create an async generator for streaming
|
|
160
|
+
async def mock_stream():
|
|
161
|
+
yield mock_response
|
|
162
|
+
|
|
163
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
164
|
+
mock_openai_client.is_closed.return_value = False
|
|
165
|
+
mock_openai_class.return_value = mock_openai_client
|
|
100
166
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
167
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
168
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
169
|
+
mock_sync_client = MagicMock()
|
|
170
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
171
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
106
172
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
173
|
+
provider = ProviderManager(config=provider_config)
|
|
174
|
+
provider.set_selected_model("gpt-4.1")
|
|
175
|
+
await provider.stream_completions(
|
|
176
|
+
message_type=MessageType.CHAT,
|
|
177
|
+
messages=mock_messages,
|
|
178
|
+
message_id="test_id",
|
|
179
|
+
thread_id="test_thread",
|
|
180
|
+
reply_fn=lambda x: None,
|
|
181
|
+
use_smartest_model=True
|
|
182
|
+
)
|
|
110
183
|
|
|
111
|
-
#
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
mock_sync_client.models.list.return_value = MagicMock()
|
|
115
|
-
mock_sync_openai_class.return_value = mock_sync_client
|
|
116
|
-
|
|
117
|
-
provider = OpenAIProvider(config=provider_config)
|
|
118
|
-
await provider.stream_completions(
|
|
119
|
-
message_type=MessageType.CHAT,
|
|
120
|
-
messages=mock_messages,
|
|
121
|
-
model="gpt-3.5",
|
|
122
|
-
message_id="test_id",
|
|
123
|
-
thread_id="test_thread",
|
|
124
|
-
reply_fn=lambda x: None
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
128
|
-
# Verify the model passed to the API call
|
|
129
|
-
call_args = mock_openai_client.chat.completions.create.call_args
|
|
130
|
-
assert call_args[1]['model'] == "gpt-3.5"
|
|
184
|
+
# Verify the model passed to the API call is the smartest model
|
|
185
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
186
|
+
assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
|