mito-ai 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mito_ai/__init__.py +19 -22
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +24 -14
- mito_ai/chart_wizard/handlers.py +78 -17
- mito_ai/chart_wizard/urls.py +8 -5
- mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
- mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
- mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
- mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
- mito_ai/completions/completion_handlers/completion_handler.py +3 -5
- mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
- mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
- mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
- mito_ai/completions/completion_handlers/utils.py +3 -7
- mito_ai/completions/handlers.py +32 -22
- mito_ai/completions/message_history.py +8 -10
- mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
- mito_ai/completions/prompt_builders/prompt_constants.py +2 -0
- mito_ai/constants.py +31 -2
- mito_ai/enterprise/__init__.py +1 -1
- mito_ai/enterprise/litellm_client.py +144 -0
- mito_ai/enterprise/utils.py +16 -2
- mito_ai/log/handlers.py +1 -1
- mito_ai/openai_client.py +36 -96
- mito_ai/provider_manager.py +420 -0
- mito_ai/settings/enterprise_handler.py +26 -0
- mito_ai/settings/urls.py +2 -0
- mito_ai/streamlit_conversion/agent_utils.py +2 -30
- mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
- mito_ai/streamlit_preview/handlers.py +6 -3
- mito_ai/streamlit_preview/urls.py +5 -3
- mito_ai/tests/message_history/test_generate_short_chat_name.py +103 -28
- mito_ai/tests/open_ai_utils_test.py +34 -36
- mito_ai/tests/providers/test_anthropic_client.py +174 -16
- mito_ai/tests/providers/test_azure.py +15 -15
- mito_ai/tests/providers/test_capabilities.py +14 -17
- mito_ai/tests/providers/test_gemini_client.py +14 -13
- mito_ai/tests/providers/test_model_resolution.py +145 -89
- mito_ai/tests/providers/test_openai_client.py +209 -13
- mito_ai/tests/providers/test_provider_limits.py +5 -5
- mito_ai/tests/providers/test_providers.py +229 -51
- mito_ai/tests/providers/test_retry_logic.py +13 -22
- mito_ai/tests/providers/utils.py +4 -4
- mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
- mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
- mito_ai/tests/test_constants.py +90 -0
- mito_ai/tests/test_enterprise_mode.py +217 -0
- mito_ai/tests/test_model_utils.py +362 -0
- mito_ai/utils/anthropic_utils.py +8 -6
- mito_ai/utils/gemini_utils.py +0 -3
- mito_ai/utils/litellm_utils.py +84 -0
- mito_ai/utils/model_utils.py +257 -0
- mito_ai/utils/open_ai_utils.py +29 -41
- mito_ai/utils/provider_utils.py +13 -29
- mito_ai/utils/telemetry_utils.py +14 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.44c109c7be36fb884d25.js +1059 -144
- mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.44c109c7be36fb884d25.js.map +1 -0
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.f7decebaf69618541e0f.js +17 -17
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.f7decebaf69618541e0f.js.map +1 -1
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +78 -78
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/METADATA +2 -1
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/RECORD +90 -83
- mito_ai/completions/providers.py +0 -284
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/licenses/LICENSE +0 -0
mito_ai/tests/providers/utils.py
CHANGED
|
@@ -38,7 +38,7 @@ def mock_openai_client() -> Any:
|
|
|
38
38
|
mock_client.key_type = "user"
|
|
39
39
|
mock_client.request_completions = AsyncMock(return_value="Test completion")
|
|
40
40
|
mock_client.stream_completions = AsyncMock(return_value="Test completion")
|
|
41
|
-
return patch("mito_ai.
|
|
41
|
+
return patch("mito_ai.provider_manager.OpenAIClient", return_value=mock_client)
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
def mock_gemini_client() -> Any:
|
|
@@ -52,7 +52,7 @@ def mock_gemini_client() -> Any:
|
|
|
52
52
|
mock_client.key_type = "gemini"
|
|
53
53
|
mock_client.request_completions = AsyncMock(return_value="Test completion")
|
|
54
54
|
mock_client.stream_completions = AsyncMock(return_value="Test completion")
|
|
55
|
-
return patch("mito_ai.
|
|
55
|
+
return patch("mito_ai.provider_manager.GeminiClient", return_value=mock_client)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def mock_azure_openai_client() -> Any:
|
|
@@ -66,7 +66,7 @@ def mock_azure_openai_client() -> Any:
|
|
|
66
66
|
mock_client.key_type = "azure"
|
|
67
67
|
mock_client.request_completions = AsyncMock(return_value="Test completion")
|
|
68
68
|
mock_client.stream_completions = AsyncMock(return_value="Test completion")
|
|
69
|
-
return patch("mito_ai.
|
|
69
|
+
return patch("mito_ai.provider_manager.OpenAIClient", return_value=mock_client)
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
|
|
@@ -82,4 +82,4 @@ def mock_claude_client() -> Any:
|
|
|
82
82
|
mock_client.request_completions = AsyncMock(return_value="Test completion")
|
|
83
83
|
mock_client.stream_completions = AsyncMock(return_value="Test completion")
|
|
84
84
|
mock_client.stream_response = AsyncMock(return_value="Test completion")
|
|
85
|
-
return patch("mito_ai.
|
|
85
|
+
return patch("mito_ai.provider_manager.AnthropicClient", return_value=mock_client)
|
|
@@ -2,94 +2,38 @@
|
|
|
2
2
|
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
3
|
|
|
4
4
|
from typing import List
|
|
5
|
-
from anthropic.types import MessageParam
|
|
6
5
|
import pytest
|
|
7
6
|
import os
|
|
8
7
|
from unittest.mock import patch, AsyncMock, MagicMock
|
|
8
|
+
from traitlets.config import Config
|
|
9
9
|
from mito_ai.streamlit_conversion.streamlit_agent_handler import (
|
|
10
|
-
get_response_from_agent,
|
|
11
10
|
generate_new_streamlit_code,
|
|
12
11
|
correct_error_in_generation,
|
|
13
12
|
streamlit_handler
|
|
14
13
|
)
|
|
15
14
|
from mito_ai.path_utils import AbsoluteNotebookPath, AppFileName, get_absolute_app_path, get_absolute_notebook_dir_path, get_absolute_notebook_path
|
|
15
|
+
from mito_ai.provider_manager import ProviderManager
|
|
16
16
|
|
|
17
17
|
# Add this line to enable async support
|
|
18
18
|
pytest_plugins = ('pytest_asyncio',)
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class TestGetResponseFromAgent:
|
|
22
|
-
"""Test cases for get_response_from_agent function"""
|
|
23
|
-
|
|
24
|
-
@pytest.mark.asyncio
|
|
25
|
-
@patch('mito_ai.streamlit_conversion.agent_utils.stream_anthropic_completion_from_mito_server')
|
|
26
|
-
async def test_get_response_from_agent_success(self, mock_stream):
|
|
27
|
-
"""Test get_response_from_agent with successful response"""
|
|
28
|
-
# Mock the async generator
|
|
29
|
-
async def mock_async_gen():
|
|
30
|
-
yield "Here's your code:\n```python\nimport streamlit\nst.title('Test')\n```"
|
|
31
|
-
|
|
32
|
-
mock_stream.return_value = mock_async_gen()
|
|
33
|
-
|
|
34
|
-
messages: List[MessageParam] = [{"role": "user", "content": [{"type": "text", "text": "test"}]}]
|
|
35
|
-
response = await get_response_from_agent(messages)
|
|
36
|
-
|
|
37
|
-
assert response is not None
|
|
38
|
-
assert len(response) > 0
|
|
39
|
-
assert "import streamlit" in response
|
|
40
|
-
|
|
41
|
-
@pytest.mark.asyncio
|
|
42
|
-
@patch('mito_ai.streamlit_conversion.agent_utils.stream_anthropic_completion_from_mito_server')
|
|
43
|
-
@pytest.mark.parametrize("mock_items,expected_result", [
|
|
44
|
-
(["Hello", " World", "!"], "Hello World!"),
|
|
45
|
-
([], ""),
|
|
46
|
-
(["Here's your code: import streamlit"], "Here's your code: import streamlit")
|
|
47
|
-
])
|
|
48
|
-
async def test_get_response_from_agent_parametrized(self, mock_stream, mock_items, expected_result):
|
|
49
|
-
"""Test response from agent with different scenarios"""
|
|
50
|
-
# Mock the async generator
|
|
51
|
-
async def mock_async_gen():
|
|
52
|
-
for item in mock_items:
|
|
53
|
-
yield item
|
|
54
|
-
|
|
55
|
-
mock_stream.return_value = mock_async_gen()
|
|
56
|
-
|
|
57
|
-
messages: List[MessageParam] = [{"role": "user", "content": [{"type": "text", "text": "test"}]}]
|
|
58
|
-
result = await get_response_from_agent(messages)
|
|
59
|
-
|
|
60
|
-
assert result == expected_result
|
|
61
|
-
mock_stream.assert_called_once()
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
@pytest.mark.asyncio
|
|
65
|
-
@patch('mito_ai.streamlit_conversion.agent_utils.stream_anthropic_completion_from_mito_server')
|
|
66
|
-
async def test_get_response_from_agent_exception(self, mock_stream):
|
|
67
|
-
"""Test exception handling in get_response_from_agent"""
|
|
68
|
-
mock_stream.side_effect = Exception("API Error")
|
|
69
|
-
|
|
70
|
-
messages: List[MessageParam] = [{"role": "user", "content": [{"type": "text", "text": "test"}]}]
|
|
71
|
-
|
|
72
|
-
with pytest.raises(Exception, match="API Error"):
|
|
73
|
-
await get_response_from_agent(messages)
|
|
74
|
-
|
|
75
|
-
|
|
76
21
|
class TestGenerateStreamlitCode:
|
|
77
22
|
"""Test cases for generate_new_streamlit_code function"""
|
|
78
23
|
|
|
79
24
|
@pytest.mark.asyncio
|
|
80
|
-
|
|
81
|
-
async def test_generate_new_streamlit_code_success(self, mock_stream):
|
|
25
|
+
async def test_generate_new_streamlit_code_success(self):
|
|
82
26
|
"""Test successful streamlit code generation"""
|
|
83
27
|
mock_response = "Here's your code:\n```python\nimport streamlit\nst.title('Hello')\n```"
|
|
84
28
|
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
mock_stream.return_value = mock_async_gen()
|
|
29
|
+
provider_config = Config()
|
|
30
|
+
provider_config.ProviderManager = Config()
|
|
31
|
+
provider_config.OpenAIClient = Config()
|
|
32
|
+
provider = ProviderManager(config=provider_config)
|
|
90
33
|
|
|
91
|
-
|
|
92
|
-
|
|
34
|
+
with patch.object(provider, 'request_completions', new_callable=AsyncMock, return_value=mock_response):
|
|
35
|
+
notebook_data: List[dict] = [{"cells": []}]
|
|
36
|
+
result = await generate_new_streamlit_code(notebook_data, '', provider)
|
|
93
37
|
|
|
94
38
|
expected_code = "import streamlit\nst.title('Hello')\n"
|
|
95
39
|
assert result == expected_code
|
|
@@ -99,8 +43,7 @@ class TestCorrectErrorInGeneration:
|
|
|
99
43
|
"""Test cases for correct_error_in_generation function"""
|
|
100
44
|
|
|
101
45
|
@pytest.mark.asyncio
|
|
102
|
-
|
|
103
|
-
async def test_correct_error_in_generation_success(self, mock_stream):
|
|
46
|
+
async def test_correct_error_in_generation_success(self):
|
|
104
47
|
"""Test successful error correction"""
|
|
105
48
|
mock_response = """```search_replace
|
|
106
49
|
>>>>>>> SEARCH
|
|
@@ -109,25 +52,28 @@ st.title('Test')
|
|
|
109
52
|
st.title('Fixed')
|
|
110
53
|
<<<<<<< REPLACE
|
|
111
54
|
```"""
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
mock_stream.return_value = mock_async_gen()
|
|
55
|
+
provider_config = Config()
|
|
56
|
+
provider_config.ProviderManager = Config()
|
|
57
|
+
provider_config.OpenAIClient = Config()
|
|
58
|
+
provider = ProviderManager(config=provider_config)
|
|
117
59
|
|
|
118
|
-
|
|
60
|
+
with patch.object(provider, 'request_completions', new_callable=AsyncMock, return_value=mock_response):
|
|
61
|
+
result = await correct_error_in_generation("ImportError: No module named 'pandas'", "import streamlit\nst.title('Test')\n", provider)
|
|
119
62
|
|
|
120
63
|
expected_code = "import streamlit\nst.title('Fixed')\n"
|
|
121
64
|
assert result == expected_code
|
|
122
65
|
|
|
123
66
|
@pytest.mark.asyncio
|
|
124
|
-
|
|
125
|
-
async def test_correct_error_in_generation_exception(self, mock_stream):
|
|
67
|
+
async def test_correct_error_in_generation_exception(self):
|
|
126
68
|
"""Test exception handling in error correction"""
|
|
127
|
-
|
|
69
|
+
provider_config = Config()
|
|
70
|
+
provider_config.ProviderManager = Config()
|
|
71
|
+
provider_config.OpenAIClient = Config()
|
|
72
|
+
provider = ProviderManager(config=provider_config)
|
|
128
73
|
|
|
129
|
-
with
|
|
130
|
-
|
|
74
|
+
with patch.object(provider, 'request_completions', new_callable=AsyncMock, side_effect=Exception("API Error")):
|
|
75
|
+
with pytest.raises(Exception, match="API Error"):
|
|
76
|
+
await correct_error_in_generation("Some error", "import streamlit\nst.title('Test')", provider)
|
|
131
77
|
|
|
132
78
|
|
|
133
79
|
class TestStreamlitHandler:
|
|
@@ -158,11 +104,17 @@ class TestStreamlitHandler:
|
|
|
158
104
|
# Construct the expected app path using the same method as the production code
|
|
159
105
|
app_directory = get_absolute_notebook_dir_path(notebook_path)
|
|
160
106
|
expected_app_path = get_absolute_app_path(app_directory, app_file_name)
|
|
161
|
-
|
|
107
|
+
|
|
108
|
+
provider_config = Config()
|
|
109
|
+
provider_config.ProviderManager = Config()
|
|
110
|
+
provider_config.OpenAIClient = Config()
|
|
111
|
+
provider = ProviderManager(config=provider_config)
|
|
112
|
+
|
|
113
|
+
await streamlit_handler(True, notebook_path, app_file_name, '', provider)
|
|
162
114
|
|
|
163
115
|
# Verify calls
|
|
164
116
|
mock_parse.assert_called_once_with(notebook_path)
|
|
165
|
-
mock_generate_code.assert_called_once_with(mock_notebook_data, '')
|
|
117
|
+
mock_generate_code.assert_called_once_with(mock_notebook_data, '', provider)
|
|
166
118
|
mock_validator.assert_called_once_with("import streamlit\nst.title('Test')", notebook_path)
|
|
167
119
|
mock_create_file.assert_called_once_with(expected_app_path, "import streamlit\nst.title('Test')")
|
|
168
120
|
|
|
@@ -185,9 +137,14 @@ class TestStreamlitHandler:
|
|
|
185
137
|
# Mock validation (always errors) - validate_app returns List[str]
|
|
186
138
|
mock_validator.return_value = ["Persistent error"]
|
|
187
139
|
|
|
140
|
+
provider_config = Config()
|
|
141
|
+
provider_config.ProviderManager = Config()
|
|
142
|
+
provider_config.OpenAIClient = Config()
|
|
143
|
+
provider = ProviderManager(config=provider_config)
|
|
144
|
+
|
|
188
145
|
# Now it should raise an exception instead of returning a tuple
|
|
189
146
|
with pytest.raises(Exception):
|
|
190
|
-
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '')
|
|
147
|
+
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '', provider)
|
|
191
148
|
|
|
192
149
|
# Verify that error correction was called 5 times (once per error, 5 retries)
|
|
193
150
|
# Each retry processes 1 error, so 5 retries = 5 calls
|
|
@@ -213,9 +170,14 @@ class TestStreamlitHandler:
|
|
|
213
170
|
# Mock file creation failure - now it should raise an exception
|
|
214
171
|
mock_create_file.side_effect = Exception("Permission denied")
|
|
215
172
|
|
|
173
|
+
provider_config = Config()
|
|
174
|
+
provider_config.ProviderManager = Config()
|
|
175
|
+
provider_config.OpenAIClient = Config()
|
|
176
|
+
provider = ProviderManager(config=provider_config)
|
|
177
|
+
|
|
216
178
|
# Now it should raise an exception instead of returning a tuple
|
|
217
179
|
with pytest.raises(Exception):
|
|
218
|
-
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '')
|
|
180
|
+
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '', provider)
|
|
219
181
|
|
|
220
182
|
@pytest.mark.asyncio
|
|
221
183
|
@patch('mito_ai.streamlit_conversion.streamlit_agent_handler.parse_jupyter_notebook_to_extract_required_content')
|
|
@@ -224,8 +186,13 @@ class TestStreamlitHandler:
|
|
|
224
186
|
|
|
225
187
|
mock_parse.side_effect = FileNotFoundError("Notebook not found")
|
|
226
188
|
|
|
189
|
+
provider_config = Config()
|
|
190
|
+
provider_config.ProviderManager = Config()
|
|
191
|
+
provider_config.OpenAIClient = Config()
|
|
192
|
+
provider = ProviderManager(config=provider_config)
|
|
193
|
+
|
|
227
194
|
with pytest.raises(FileNotFoundError, match="Notebook not found"):
|
|
228
|
-
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '')
|
|
195
|
+
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '', provider)
|
|
229
196
|
|
|
230
197
|
@pytest.mark.asyncio
|
|
231
198
|
@patch('mito_ai.streamlit_conversion.streamlit_agent_handler.parse_jupyter_notebook_to_extract_required_content')
|
|
@@ -239,8 +206,13 @@ class TestStreamlitHandler:
|
|
|
239
206
|
# Mock code generation failure
|
|
240
207
|
mock_generate_code.side_effect = Exception("Generation failed")
|
|
241
208
|
|
|
209
|
+
provider_config = Config()
|
|
210
|
+
provider_config.ProviderManager = Config()
|
|
211
|
+
provider_config.OpenAIClient = Config()
|
|
212
|
+
provider = ProviderManager(config=provider_config)
|
|
213
|
+
|
|
242
214
|
with pytest.raises(Exception, match="Generation failed"):
|
|
243
|
-
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '')
|
|
215
|
+
await streamlit_handler(True, AbsoluteNotebookPath("notebook.ipynb"), AppFileName('test-app-file-name.py'), '', provider)
|
|
244
216
|
|
|
245
217
|
|
|
246
218
|
|
|
@@ -63,12 +63,15 @@ class TestStreamlitPreviewHandler:
|
|
|
63
63
|
mock_request.connection = MagicMock()
|
|
64
64
|
mock_request.connection.context = MagicMock()
|
|
65
65
|
|
|
66
|
+
# Create a mock ProviderManager for the llm parameter
|
|
67
|
+
mock_llm = MagicMock()
|
|
68
|
+
|
|
66
69
|
# Create handler instance
|
|
67
70
|
handler = StreamlitPreviewHandler(
|
|
68
71
|
application=mock_application,
|
|
69
72
|
request=mock_request,
|
|
73
|
+
llm=mock_llm
|
|
70
74
|
)
|
|
71
|
-
handler.initialize()
|
|
72
75
|
|
|
73
76
|
# Mock authentication - set current_user to bypass @tornado.web.authenticated
|
|
74
77
|
handler.current_user = "test_user" # type: ignore
|
mito_ai/tests/test_constants.py
CHANGED
|
@@ -7,6 +7,7 @@ from mito_ai.constants import (
|
|
|
7
7
|
ACTIVE_BASE_URL, MITO_PROD_BASE_URL, MITO_DEV_BASE_URL,
|
|
8
8
|
MITO_STREAMLIT_DEV_BASE_URL, MITO_STREAMLIT_TEST_BASE_URL, ACTIVE_STREAMLIT_BASE_URL,
|
|
9
9
|
COGNITO_CONFIG_DEV, ACTIVE_COGNITO_CONFIG,
|
|
10
|
+
parse_comma_separated_models,
|
|
10
11
|
)
|
|
11
12
|
|
|
12
13
|
|
|
@@ -45,3 +46,92 @@ def test_cognito_config() -> Any:
|
|
|
45
46
|
|
|
46
47
|
assert COGNITO_CONFIG_DEV == expected_config
|
|
47
48
|
assert ACTIVE_COGNITO_CONFIG == COGNITO_CONFIG_DEV
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TestParseCommaSeparatedModels:
|
|
52
|
+
"""Tests for parse_comma_separated_models helper function."""
|
|
53
|
+
|
|
54
|
+
def test_parse_models_no_quotes(self) -> None:
|
|
55
|
+
"""Test parsing models without quotes."""
|
|
56
|
+
models_str = "litellm/openai/gpt-4o,litellm/anthropic/claude-3-5-sonnet"
|
|
57
|
+
result = parse_comma_separated_models(models_str)
|
|
58
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
59
|
+
|
|
60
|
+
def test_parse_models_double_quotes(self) -> None:
|
|
61
|
+
"""Test parsing models with double quotes."""
|
|
62
|
+
# Entire string quoted
|
|
63
|
+
models_str = '"litellm/openai/gpt-4o,litellm/anthropic/claude-3-5-sonnet"'
|
|
64
|
+
result = parse_comma_separated_models(models_str)
|
|
65
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
66
|
+
|
|
67
|
+
# Individual models quoted
|
|
68
|
+
models_str = '"litellm/openai/gpt-4o","litellm/anthropic/claude-3-5-sonnet"'
|
|
69
|
+
result = parse_comma_separated_models(models_str)
|
|
70
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
71
|
+
|
|
72
|
+
def test_parse_models_single_quotes(self) -> None:
|
|
73
|
+
"""Test parsing models with single quotes."""
|
|
74
|
+
# Entire string quoted
|
|
75
|
+
models_str = "'litellm/openai/gpt-4o,litellm/anthropic/claude-3-5-sonnet'"
|
|
76
|
+
result = parse_comma_separated_models(models_str)
|
|
77
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
78
|
+
|
|
79
|
+
# Individual models quoted
|
|
80
|
+
models_str = "'litellm/openai/gpt-4o','litellm/anthropic/claude-3-5-sonnet'"
|
|
81
|
+
result = parse_comma_separated_models(models_str)
|
|
82
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
83
|
+
|
|
84
|
+
def test_parse_models_mixed_quotes(self) -> None:
|
|
85
|
+
"""Test parsing models where some have single quotes and some have double quotes."""
|
|
86
|
+
# Some models with single quotes, some with double quotes
|
|
87
|
+
models_str = "'litellm/openai/gpt-4o',\"litellm/anthropic/claude-3-5-sonnet\""
|
|
88
|
+
result = parse_comma_separated_models(models_str)
|
|
89
|
+
# Should strip both types of quotes
|
|
90
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
91
|
+
|
|
92
|
+
def test_parse_models_with_whitespace(self) -> None:
|
|
93
|
+
"""Test parsing models with whitespace around commas and model names."""
|
|
94
|
+
models_str = " litellm/openai/gpt-4o , litellm/anthropic/claude-3-5-sonnet "
|
|
95
|
+
result = parse_comma_separated_models(models_str)
|
|
96
|
+
assert result == ["litellm/openai/gpt-4o", "litellm/anthropic/claude-3-5-sonnet"]
|
|
97
|
+
|
|
98
|
+
def test_parse_models_empty_string(self) -> None:
|
|
99
|
+
"""Test parsing empty string."""
|
|
100
|
+
result = parse_comma_separated_models("")
|
|
101
|
+
assert result == []
|
|
102
|
+
|
|
103
|
+
def test_parse_models_single_model(self) -> None:
|
|
104
|
+
"""Test parsing single model."""
|
|
105
|
+
models_str = "litellm/openai/gpt-4o"
|
|
106
|
+
result = parse_comma_separated_models(models_str)
|
|
107
|
+
assert result == ["litellm/openai/gpt-4o"]
|
|
108
|
+
|
|
109
|
+
# With quotes
|
|
110
|
+
models_str = '"litellm/openai/gpt-4o"'
|
|
111
|
+
result = parse_comma_separated_models(models_str)
|
|
112
|
+
assert result == ["litellm/openai/gpt-4o"]
|
|
113
|
+
|
|
114
|
+
def test_parse_models_abacus_format(self) -> None:
|
|
115
|
+
"""Test parsing Abacus model format."""
|
|
116
|
+
models_str = "Abacus/gpt-4.1,Abacus/claude-haiku-4-5-20251001"
|
|
117
|
+
result = parse_comma_separated_models(models_str)
|
|
118
|
+
assert result == ["Abacus/gpt-4.1", "Abacus/claude-haiku-4-5-20251001"]
|
|
119
|
+
|
|
120
|
+
# With quotes
|
|
121
|
+
models_str = '"Abacus/gpt-4.1","Abacus/claude-haiku-4-5-20251001"'
|
|
122
|
+
result = parse_comma_separated_models(models_str)
|
|
123
|
+
assert result == ["Abacus/gpt-4.1", "Abacus/claude-haiku-4-5-20251001"]
|
|
124
|
+
|
|
125
|
+
@pytest.mark.parametrize("models_str,description", [
|
|
126
|
+
('"model1,model2"', 'Double quotes, no space after comma'),
|
|
127
|
+
("'model1,model2'", 'Single quotes, no space after comma'),
|
|
128
|
+
("model1,model2", 'No quotes, no space after comma'),
|
|
129
|
+
('"model1, model2"', 'Double quotes, space after comma'),
|
|
130
|
+
("'model1, model2'", 'Single quotes, space after comma'),
|
|
131
|
+
("model1, model2", 'No quotes, space after comma'),
|
|
132
|
+
])
|
|
133
|
+
def test_parse_models_all_scenarios(self, models_str: str, description: str) -> None:
|
|
134
|
+
"""Test all specific scenarios: quotes with and without spaces after commas."""
|
|
135
|
+
expected = ["model1", "model2"]
|
|
136
|
+
result = parse_comma_separated_models(models_str)
|
|
137
|
+
assert result == expected, f"Failed for {description}: {repr(models_str)}"
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from unittest.mock import patch, MagicMock
|
|
6
|
+
from traitlets.config import Config
|
|
7
|
+
from mito_ai.utils.telemetry_utils import telemetry_turned_on, identify, log
|
|
8
|
+
from mito_ai.utils.model_utils import get_available_models
|
|
9
|
+
from mito_ai.provider_manager import ProviderManager
|
|
10
|
+
from mito_ai.completions.models import MessageType
|
|
11
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture
|
|
15
|
+
def provider_config() -> Config:
|
|
16
|
+
"""Create a proper Config object for the ProviderManager."""
|
|
17
|
+
config = Config()
|
|
18
|
+
config.ProviderManager = Config()
|
|
19
|
+
config.OpenAIClient = Config()
|
|
20
|
+
return config
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestEnterpriseModeDetection:
|
|
24
|
+
"""Tests for enterprise mode detection."""
|
|
25
|
+
|
|
26
|
+
@patch('mito_ai.utils.version_utils.is_enterprise')
|
|
27
|
+
def test_telemetry_disabled_when_enterprise(self, mock_is_enterprise):
|
|
28
|
+
"""Test that telemetry is disabled when enterprise mode is enabled."""
|
|
29
|
+
mock_is_enterprise.return_value = True
|
|
30
|
+
|
|
31
|
+
result = telemetry_turned_on()
|
|
32
|
+
|
|
33
|
+
assert result is False
|
|
34
|
+
|
|
35
|
+
@patch('mito_ai.utils.telemetry_utils.is_enterprise')
|
|
36
|
+
def test_telemetry_enabled_when_not_enterprise(self, mock_is_enterprise):
|
|
37
|
+
"""Test that telemetry can be enabled when enterprise mode is not enabled."""
|
|
38
|
+
mock_is_enterprise.return_value = False
|
|
39
|
+
|
|
40
|
+
# Mock other conditions that might disable telemetry
|
|
41
|
+
with patch('mito_ai.utils.telemetry_utils.MITOSHEET_HELPER_PRIVATE', False), \
|
|
42
|
+
patch('mito_ai.utils.telemetry_utils.is_pro', return_value=False), \
|
|
43
|
+
patch('mito_ai.utils.telemetry_utils.get_user_field', return_value=True):
|
|
44
|
+
result = telemetry_turned_on()
|
|
45
|
+
# Result depends on other conditions, but enterprise check should pass
|
|
46
|
+
# We just verify enterprise check doesn't block it
|
|
47
|
+
mock_is_enterprise.assert_called_once()
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TestTelemetryDisabling:
|
|
51
|
+
"""Tests for telemetry disabling in enterprise mode."""
|
|
52
|
+
|
|
53
|
+
@patch('mito_ai.utils.version_utils.is_enterprise')
|
|
54
|
+
@patch('mito_ai.utils.telemetry_utils.analytics')
|
|
55
|
+
def test_identify_skips_when_enterprise(self, mock_analytics, mock_is_enterprise):
|
|
56
|
+
"""Test that identify() skips analytics calls when enterprise mode is enabled."""
|
|
57
|
+
mock_is_enterprise.return_value = True
|
|
58
|
+
|
|
59
|
+
identify()
|
|
60
|
+
|
|
61
|
+
# Should not call analytics.identify
|
|
62
|
+
mock_analytics.identify.assert_not_called()
|
|
63
|
+
|
|
64
|
+
@patch('mito_ai.utils.version_utils.is_enterprise')
|
|
65
|
+
@patch('mito_ai.utils.telemetry_utils.analytics')
|
|
66
|
+
def test_log_skips_when_enterprise(self, mock_analytics, mock_is_enterprise):
|
|
67
|
+
"""Test that log() skips analytics calls when enterprise mode is enabled."""
|
|
68
|
+
mock_is_enterprise.return_value = True
|
|
69
|
+
|
|
70
|
+
log("test_event", {"param": "value"})
|
|
71
|
+
|
|
72
|
+
# Should not call analytics.track
|
|
73
|
+
mock_analytics.track.assert_not_called()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TestModelValidation:
|
|
77
|
+
"""Tests for model validation in enterprise mode."""
|
|
78
|
+
|
|
79
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
80
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
81
|
+
def test_provider_manager_validates_model(self, mock_constants, mock_is_enterprise, provider_config: Config):
|
|
82
|
+
"""Test that ProviderManager validates models against available models."""
|
|
83
|
+
mock_is_enterprise.return_value = True
|
|
84
|
+
mock_constants.LITELLM_BASE_URL = "https://litellm-server.com"
|
|
85
|
+
mock_constants.LITELLM_MODELS = ["openai/gpt-4o", "openai/gpt-4o-mini"]
|
|
86
|
+
|
|
87
|
+
provider_manager = ProviderManager(config=provider_config)
|
|
88
|
+
provider_manager.set_selected_model("openai/gpt-4o")
|
|
89
|
+
|
|
90
|
+
# Should not raise an error for valid model
|
|
91
|
+
available_models = get_available_models()
|
|
92
|
+
assert "openai/gpt-4o" in available_models
|
|
93
|
+
|
|
94
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
95
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
96
|
+
@pytest.mark.asyncio
|
|
97
|
+
async def test_provider_manager_rejects_invalid_model(self, mock_constants, mock_is_enterprise, provider_config: Config):
|
|
98
|
+
"""Test that ProviderManager rejects invalid models."""
|
|
99
|
+
mock_is_enterprise.return_value = True
|
|
100
|
+
mock_constants.LITELLM_BASE_URL = "https://litellm-server.com"
|
|
101
|
+
mock_constants.LITELLM_MODELS = ["openai/gpt-4o"]
|
|
102
|
+
mock_constants.LITELLM_API_KEY = "test-key"
|
|
103
|
+
|
|
104
|
+
provider_manager = ProviderManager(config=provider_config)
|
|
105
|
+
provider_manager.set_selected_model("invalid-model")
|
|
106
|
+
|
|
107
|
+
messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": "test"}]
|
|
108
|
+
|
|
109
|
+
# Should raise ValueError for invalid model
|
|
110
|
+
with pytest.raises(ValueError, match="is not in the allowed model list"):
|
|
111
|
+
await provider_manager.request_completions(
|
|
112
|
+
message_type=MessageType.CHAT,
|
|
113
|
+
messages=messages
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
117
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
118
|
+
def test_available_models_endpoint_returns_litellm_models(self, mock_constants, mock_is_enterprise):
|
|
119
|
+
"""Test that /available-models endpoint returns LiteLLM models when configured."""
|
|
120
|
+
mock_is_enterprise.return_value = True
|
|
121
|
+
mock_constants.LITELLM_BASE_URL = "https://litellm-server.com"
|
|
122
|
+
mock_constants.LITELLM_MODELS = ["openai/gpt-4o", "anthropic/claude-3-5-sonnet"]
|
|
123
|
+
|
|
124
|
+
result = get_available_models()
|
|
125
|
+
|
|
126
|
+
assert result == ["openai/gpt-4o", "anthropic/claude-3-5-sonnet"]
|
|
127
|
+
|
|
128
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
129
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
130
|
+
def test_available_models_endpoint_returns_standard_models_when_not_configured(self, mock_constants, mock_is_enterprise):
|
|
131
|
+
"""Test that /available-models endpoint returns standard models when LiteLLM is not configured."""
|
|
132
|
+
mock_is_enterprise.return_value = True
|
|
133
|
+
mock_constants.LITELLM_BASE_URL = None
|
|
134
|
+
mock_constants.LITELLM_MODELS = []
|
|
135
|
+
|
|
136
|
+
result = get_available_models()
|
|
137
|
+
|
|
138
|
+
from mito_ai.utils.model_utils import STANDARD_MODELS
|
|
139
|
+
assert result == STANDARD_MODELS
|
|
140
|
+
|
|
141
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
142
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
143
|
+
@patch('mito_ai.utils.model_utils.is_abacus_configured')
|
|
144
|
+
def test_provider_manager_validates_abacus_model(self, mock_is_abacus_configured, mock_constants, mock_is_enterprise, provider_config: Config):
|
|
145
|
+
"""Test that ProviderManager validates Abacus models against available models."""
|
|
146
|
+
mock_is_abacus_configured.return_value = True
|
|
147
|
+
mock_is_enterprise.return_value = True
|
|
148
|
+
mock_constants.ABACUS_BASE_URL = "https://routellm.abacus.ai/v1"
|
|
149
|
+
mock_constants.ABACUS_MODELS = ["Abacus/gpt-4.1", "Abacus/gpt-5.2"]
|
|
150
|
+
|
|
151
|
+
provider_manager = ProviderManager(config=provider_config)
|
|
152
|
+
provider_manager.set_selected_model("Abacus/gpt-4.1")
|
|
153
|
+
|
|
154
|
+
# Should not raise an error for valid model
|
|
155
|
+
available_models = get_available_models()
|
|
156
|
+
assert "Abacus/gpt-4.1" in available_models
|
|
157
|
+
|
|
158
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
159
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
160
|
+
@patch('mito_ai.utils.model_utils.is_abacus_configured')
|
|
161
|
+
@pytest.mark.asyncio
|
|
162
|
+
async def test_provider_manager_rejects_invalid_abacus_model(self, mock_is_abacus_configured, mock_constants, mock_is_enterprise, provider_config: Config):
|
|
163
|
+
"""Test that ProviderManager rejects invalid Abacus models."""
|
|
164
|
+
mock_is_abacus_configured.return_value = True
|
|
165
|
+
mock_is_enterprise.return_value = True
|
|
166
|
+
mock_constants.ABACUS_BASE_URL = "https://routellm.abacus.ai/v1"
|
|
167
|
+
mock_constants.ABACUS_MODELS = ["Abacus/gpt-4.1"]
|
|
168
|
+
mock_constants.ABACUS_API_KEY = "test-key"
|
|
169
|
+
|
|
170
|
+
provider_manager = ProviderManager(config=provider_config)
|
|
171
|
+
provider_manager.set_selected_model("invalid-model")
|
|
172
|
+
|
|
173
|
+
messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": "test"}]
|
|
174
|
+
|
|
175
|
+
# Should raise ValueError for invalid model
|
|
176
|
+
with pytest.raises(ValueError, match="is not in the allowed model list"):
|
|
177
|
+
await provider_manager.request_completions(
|
|
178
|
+
message_type=MessageType.CHAT,
|
|
179
|
+
messages=messages
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
@patch('mito_ai.utils.model_utils.is_enterprise')
|
|
183
|
+
@patch('mito_ai.utils.model_utils.constants')
|
|
184
|
+
@patch('mito_ai.utils.model_utils.is_abacus_configured')
|
|
185
|
+
def test_available_models_endpoint_returns_abacus_models(self, mock_is_abacus_configured, mock_constants, mock_is_enterprise):
|
|
186
|
+
"""Test that /available-models endpoint returns Abacus models when configured."""
|
|
187
|
+
mock_is_abacus_configured.return_value = True
|
|
188
|
+
mock_is_enterprise.return_value = True
|
|
189
|
+
mock_constants.ABACUS_BASE_URL = "https://routellm.abacus.ai/v1"
|
|
190
|
+
mock_constants.ABACUS_MODELS = ["Abacus/gpt-4.1", "Abacus/claude-haiku-4-5-20251001"]
|
|
191
|
+
|
|
192
|
+
result = get_available_models()
|
|
193
|
+
|
|
194
|
+
assert result == ["Abacus/gpt-4.1", "Abacus/claude-haiku-4-5-20251001"]
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
class TestModelStorage:
|
|
198
|
+
"""Tests for model storage in ProviderManager."""
|
|
199
|
+
|
|
200
|
+
def test_provider_manager_stores_model(self, provider_config: Config):
|
|
201
|
+
"""Test that ProviderManager can store and retrieve selected model."""
|
|
202
|
+
provider_manager = ProviderManager(config=provider_config)
|
|
203
|
+
|
|
204
|
+
provider_manager.set_selected_model("gpt-4.1")
|
|
205
|
+
assert provider_manager.get_selected_model() == "gpt-4.1"
|
|
206
|
+
|
|
207
|
+
provider_manager.set_selected_model("claude-sonnet-4-5-20250929")
|
|
208
|
+
assert provider_manager.get_selected_model() == "claude-sonnet-4-5-20250929"
|
|
209
|
+
|
|
210
|
+
def test_provider_manager_default_model(self, provider_config: Config):
|
|
211
|
+
"""Test that ProviderManager has a default model."""
|
|
212
|
+
provider_manager = ProviderManager(config=provider_config)
|
|
213
|
+
|
|
214
|
+
# Should have default model
|
|
215
|
+
default_model = provider_manager.get_selected_model()
|
|
216
|
+
assert default_model is not None
|
|
217
|
+
assert isinstance(default_model, str)
|