mito-ai 0.1.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mito_ai/__init__.py +114 -0
- mito_ai/_version.py +4 -0
- mito_ai/anthropic_client.py +334 -0
- mito_ai/app_deploy/__init__.py +6 -0
- mito_ai/app_deploy/app_deploy_utils.py +44 -0
- mito_ai/app_deploy/handlers.py +345 -0
- mito_ai/app_deploy/models.py +98 -0
- mito_ai/app_manager/__init__.py +4 -0
- mito_ai/app_manager/handlers.py +167 -0
- mito_ai/app_manager/models.py +71 -0
- mito_ai/app_manager/utils.py +24 -0
- mito_ai/auth/README.md +18 -0
- mito_ai/auth/__init__.py +6 -0
- mito_ai/auth/handlers.py +96 -0
- mito_ai/auth/urls.py +13 -0
- mito_ai/chat_history/handlers.py +63 -0
- mito_ai/chat_history/urls.py +32 -0
- mito_ai/completions/completion_handlers/__init__.py +3 -0
- mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +59 -0
- mito_ai/completions/completion_handlers/agent_execution_handler.py +66 -0
- mito_ai/completions/completion_handlers/chat_completion_handler.py +141 -0
- mito_ai/completions/completion_handlers/code_explain_handler.py +113 -0
- mito_ai/completions/completion_handlers/completion_handler.py +42 -0
- mito_ai/completions/completion_handlers/inline_completer_handler.py +48 -0
- mito_ai/completions/completion_handlers/smart_debug_handler.py +160 -0
- mito_ai/completions/completion_handlers/utils.py +147 -0
- mito_ai/completions/handlers.py +415 -0
- mito_ai/completions/message_history.py +401 -0
- mito_ai/completions/models.py +404 -0
- mito_ai/completions/prompt_builders/__init__.py +3 -0
- mito_ai/completions/prompt_builders/agent_execution_prompt.py +57 -0
- mito_ai/completions/prompt_builders/agent_smart_debug_prompt.py +160 -0
- mito_ai/completions/prompt_builders/agent_system_message.py +472 -0
- mito_ai/completions/prompt_builders/chat_name_prompt.py +15 -0
- mito_ai/completions/prompt_builders/chat_prompt.py +116 -0
- mito_ai/completions/prompt_builders/chat_system_message.py +92 -0
- mito_ai/completions/prompt_builders/explain_code_prompt.py +32 -0
- mito_ai/completions/prompt_builders/inline_completer_prompt.py +197 -0
- mito_ai/completions/prompt_builders/prompt_constants.py +170 -0
- mito_ai/completions/prompt_builders/smart_debug_prompt.py +199 -0
- mito_ai/completions/prompt_builders/utils.py +84 -0
- mito_ai/completions/providers.py +284 -0
- mito_ai/constants.py +63 -0
- mito_ai/db/__init__.py +3 -0
- mito_ai/db/crawlers/__init__.py +6 -0
- mito_ai/db/crawlers/base_crawler.py +61 -0
- mito_ai/db/crawlers/constants.py +43 -0
- mito_ai/db/crawlers/snowflake.py +71 -0
- mito_ai/db/handlers.py +168 -0
- mito_ai/db/models.py +31 -0
- mito_ai/db/urls.py +34 -0
- mito_ai/db/utils.py +185 -0
- mito_ai/docker/mssql/compose.yml +37 -0
- mito_ai/docker/mssql/init/setup.sql +21 -0
- mito_ai/docker/mysql/compose.yml +18 -0
- mito_ai/docker/mysql/init/setup.sql +13 -0
- mito_ai/docker/oracle/compose.yml +17 -0
- mito_ai/docker/oracle/init/setup.sql +20 -0
- mito_ai/docker/postgres/compose.yml +17 -0
- mito_ai/docker/postgres/init/setup.sql +13 -0
- mito_ai/enterprise/__init__.py +3 -0
- mito_ai/enterprise/utils.py +15 -0
- mito_ai/file_uploads/__init__.py +3 -0
- mito_ai/file_uploads/handlers.py +248 -0
- mito_ai/file_uploads/urls.py +21 -0
- mito_ai/gemini_client.py +232 -0
- mito_ai/log/handlers.py +38 -0
- mito_ai/log/urls.py +21 -0
- mito_ai/logger.py +37 -0
- mito_ai/openai_client.py +382 -0
- mito_ai/path_utils.py +70 -0
- mito_ai/rules/handlers.py +44 -0
- mito_ai/rules/urls.py +22 -0
- mito_ai/rules/utils.py +56 -0
- mito_ai/settings/handlers.py +41 -0
- mito_ai/settings/urls.py +20 -0
- mito_ai/settings/utils.py +42 -0
- mito_ai/streamlit_conversion/agent_utils.py +37 -0
- mito_ai/streamlit_conversion/prompts/prompt_constants.py +172 -0
- mito_ai/streamlit_conversion/prompts/prompt_utils.py +10 -0
- mito_ai/streamlit_conversion/prompts/streamlit_app_creation_prompt.py +46 -0
- mito_ai/streamlit_conversion/prompts/streamlit_error_correction_prompt.py +28 -0
- mito_ai/streamlit_conversion/prompts/streamlit_finish_todo_prompt.py +45 -0
- mito_ai/streamlit_conversion/prompts/streamlit_system_prompt.py +56 -0
- mito_ai/streamlit_conversion/prompts/update_existing_app_prompt.py +50 -0
- mito_ai/streamlit_conversion/search_replace_utils.py +94 -0
- mito_ai/streamlit_conversion/streamlit_agent_handler.py +144 -0
- mito_ai/streamlit_conversion/streamlit_utils.py +85 -0
- mito_ai/streamlit_conversion/validate_streamlit_app.py +105 -0
- mito_ai/streamlit_preview/__init__.py +6 -0
- mito_ai/streamlit_preview/handlers.py +111 -0
- mito_ai/streamlit_preview/manager.py +152 -0
- mito_ai/streamlit_preview/urls.py +22 -0
- mito_ai/streamlit_preview/utils.py +29 -0
- mito_ai/tests/__init__.py +3 -0
- mito_ai/tests/chat_history/test_chat_history.py +211 -0
- mito_ai/tests/completions/completion_handlers_utils_test.py +190 -0
- mito_ai/tests/conftest.py +53 -0
- mito_ai/tests/create_agent_system_message_prompt_test.py +22 -0
- mito_ai/tests/data/prompt_lg.py +69 -0
- mito_ai/tests/data/prompt_sm.py +6 -0
- mito_ai/tests/data/prompt_xl.py +13 -0
- mito_ai/tests/data/stock_data.sqlite3 +0 -0
- mito_ai/tests/db/conftest.py +39 -0
- mito_ai/tests/db/connections_test.py +102 -0
- mito_ai/tests/db/mssql_test.py +29 -0
- mito_ai/tests/db/mysql_test.py +29 -0
- mito_ai/tests/db/oracle_test.py +29 -0
- mito_ai/tests/db/postgres_test.py +29 -0
- mito_ai/tests/db/schema_test.py +93 -0
- mito_ai/tests/db/sqlite_test.py +31 -0
- mito_ai/tests/db/test_db_constants.py +61 -0
- mito_ai/tests/deploy_app/test_app_deploy_utils.py +89 -0
- mito_ai/tests/file_uploads/__init__.py +2 -0
- mito_ai/tests/file_uploads/test_handlers.py +282 -0
- mito_ai/tests/message_history/test_generate_short_chat_name.py +120 -0
- mito_ai/tests/message_history/test_message_history_utils.py +469 -0
- mito_ai/tests/open_ai_utils_test.py +152 -0
- mito_ai/tests/performance_test.py +329 -0
- mito_ai/tests/providers/test_anthropic_client.py +447 -0
- mito_ai/tests/providers/test_azure.py +631 -0
- mito_ai/tests/providers/test_capabilities.py +120 -0
- mito_ai/tests/providers/test_gemini_client.py +195 -0
- mito_ai/tests/providers/test_mito_server_utils.py +448 -0
- mito_ai/tests/providers/test_model_resolution.py +130 -0
- mito_ai/tests/providers/test_openai_client.py +57 -0
- mito_ai/tests/providers/test_provider_completion_exception.py +66 -0
- mito_ai/tests/providers/test_provider_limits.py +42 -0
- mito_ai/tests/providers/test_providers.py +382 -0
- mito_ai/tests/providers/test_retry_logic.py +389 -0
- mito_ai/tests/providers/test_stream_mito_server_utils.py +140 -0
- mito_ai/tests/providers/utils.py +85 -0
- mito_ai/tests/rules/conftest.py +26 -0
- mito_ai/tests/rules/rules_test.py +117 -0
- mito_ai/tests/server_limits_test.py +406 -0
- mito_ai/tests/settings/conftest.py +26 -0
- mito_ai/tests/settings/settings_test.py +70 -0
- mito_ai/tests/settings/test_settings_constants.py +9 -0
- mito_ai/tests/streamlit_conversion/__init__.py +3 -0
- mito_ai/tests/streamlit_conversion/test_apply_search_replace.py +240 -0
- mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +246 -0
- mito_ai/tests/streamlit_conversion/test_streamlit_utils.py +193 -0
- mito_ai/tests/streamlit_conversion/test_validate_streamlit_app.py +112 -0
- mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +118 -0
- mito_ai/tests/streamlit_preview/test_streamlit_preview_manager.py +292 -0
- mito_ai/tests/test_constants.py +47 -0
- mito_ai/tests/test_telemetry.py +12 -0
- mito_ai/tests/user/__init__.py +2 -0
- mito_ai/tests/user/test_user.py +120 -0
- mito_ai/tests/utils/__init__.py +3 -0
- mito_ai/tests/utils/test_anthropic_utils.py +162 -0
- mito_ai/tests/utils/test_gemini_utils.py +98 -0
- mito_ai/tests/version_check_test.py +169 -0
- mito_ai/user/handlers.py +45 -0
- mito_ai/user/urls.py +21 -0
- mito_ai/utils/__init__.py +3 -0
- mito_ai/utils/anthropic_utils.py +168 -0
- mito_ai/utils/create.py +94 -0
- mito_ai/utils/db.py +74 -0
- mito_ai/utils/error_classes.py +42 -0
- mito_ai/utils/gemini_utils.py +133 -0
- mito_ai/utils/message_history_utils.py +87 -0
- mito_ai/utils/mito_server_utils.py +242 -0
- mito_ai/utils/open_ai_utils.py +200 -0
- mito_ai/utils/provider_utils.py +49 -0
- mito_ai/utils/schema.py +86 -0
- mito_ai/utils/server_limits.py +152 -0
- mito_ai/utils/telemetry_utils.py +480 -0
- mito_ai/utils/utils.py +89 -0
- mito_ai/utils/version_utils.py +94 -0
- mito_ai/utils/websocket_base.py +88 -0
- mito_ai/version_check.py +60 -0
- mito_ai-0.1.50.data/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +7 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/build_log.json +728 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/package.json +243 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +238 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +37 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.8f1845da6bf2b128c049.js +21602 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.8f1845da6bf2b128c049.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +198 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.78d3ccb73e7ca1da3aae.js +619 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.78d3ccb73e7ca1da3aae.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/style.js +4 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.5876024bb17dbd6a3ee6.js +712 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.5876024bb17dbd6a3ee6.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +533 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +6941 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +1021 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +59698 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +7440 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +2792 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +1 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +4859 -0
- mito_ai-0.1.50.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +1 -0
- mito_ai-0.1.50.dist-info/METADATA +221 -0
- mito_ai-0.1.50.dist-info/RECORD +205 -0
- mito_ai-0.1.50.dist-info/WHEEL +4 -0
- mito_ai-0.1.50.dist-info/entry_points.txt +2 -0
- mito_ai-0.1.50.dist-info/licenses/LICENSE +3 -0
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import json
|
|
6
|
+
import time
|
|
7
|
+
from unittest.mock import MagicMock, patch, AsyncMock
|
|
8
|
+
from tornado.httpclient import HTTPResponse
|
|
9
|
+
|
|
10
|
+
from mito_ai.utils.mito_server_utils import (
|
|
11
|
+
ProviderCompletionException,
|
|
12
|
+
get_response_from_mito_server
|
|
13
|
+
)
|
|
14
|
+
from mito_ai.completions.models import MessageType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.fixture
|
|
18
|
+
def mock_request_params():
|
|
19
|
+
"""Standard request parameters for testing."""
|
|
20
|
+
return {
|
|
21
|
+
"url": "https://api.example.com",
|
|
22
|
+
"headers": {"Content-Type": "application/json"},
|
|
23
|
+
"data": {"query": "test query"},
|
|
24
|
+
"timeout": 30,
|
|
25
|
+
"max_retries": 3,
|
|
26
|
+
"message_type": MessageType.CHAT,
|
|
27
|
+
"provider_name": "Test Provider"
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@pytest.fixture
|
|
32
|
+
def mock_http_dependencies():
|
|
33
|
+
"""Mock the HTTP client and related dependencies."""
|
|
34
|
+
with patch('mito_ai.utils.mito_server_utils._create_http_client') as mock_create_client, \
|
|
35
|
+
patch('mito_ai.utils.mito_server_utils.update_mito_server_quota') as mock_update_quota, \
|
|
36
|
+
patch('mito_ai.utils.mito_server_utils.check_mito_server_quota') as mock_check_quota, \
|
|
37
|
+
patch('mito_ai.utils.mito_server_utils.time.time') as mock_time:
|
|
38
|
+
|
|
39
|
+
# Setup mock HTTP client
|
|
40
|
+
mock_http_client = MagicMock()
|
|
41
|
+
mock_create_client.return_value = (mock_http_client, 30)
|
|
42
|
+
|
|
43
|
+
# Setup mock time
|
|
44
|
+
mock_time.side_effect = [0.0, 1.5] # start_time, end_time
|
|
45
|
+
|
|
46
|
+
yield {
|
|
47
|
+
'mock_check_quota': mock_check_quota,
|
|
48
|
+
'mock_create_client': mock_create_client,
|
|
49
|
+
'mock_http_client': mock_http_client,
|
|
50
|
+
'mock_update_quota': mock_update_quota,
|
|
51
|
+
'mock_time': mock_time
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def create_mock_response(body_content: dict):
|
|
56
|
+
"""Helper to create mock HTTP response."""
|
|
57
|
+
mock_response = MagicMock(spec=HTTPResponse)
|
|
58
|
+
mock_response.body.decode.return_value = json.dumps(body_content)
|
|
59
|
+
return mock_response
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TestProviderCompletionException:
|
|
63
|
+
"""Test the ProviderCompletionException class."""
|
|
64
|
+
|
|
65
|
+
@pytest.mark.parametrize("error_message,provider_name,error_type,expected_title,expected_hint_contains", [
|
|
66
|
+
(
|
|
67
|
+
"Something went wrong",
|
|
68
|
+
"LLM Provider",
|
|
69
|
+
"LLMProviderError",
|
|
70
|
+
"LLM Provider Error: Something went wrong",
|
|
71
|
+
"LLM Provider"
|
|
72
|
+
),
|
|
73
|
+
(
|
|
74
|
+
"API key is invalid",
|
|
75
|
+
"OpenAI",
|
|
76
|
+
"AuthenticationError",
|
|
77
|
+
"OpenAI Error: API key is invalid",
|
|
78
|
+
"OpenAI"
|
|
79
|
+
),
|
|
80
|
+
(
|
|
81
|
+
"There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
|
|
82
|
+
"Anthropic",
|
|
83
|
+
"LLMProviderError",
|
|
84
|
+
"Anthropic Error: There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
|
|
85
|
+
"Anthropic"
|
|
86
|
+
),
|
|
87
|
+
])
|
|
88
|
+
def test_exception_initialization(
|
|
89
|
+
self,
|
|
90
|
+
error_message: str,
|
|
91
|
+
provider_name: str,
|
|
92
|
+
error_type: str,
|
|
93
|
+
expected_title: str,
|
|
94
|
+
expected_hint_contains: str
|
|
95
|
+
):
|
|
96
|
+
"""Test exception initialization with various parameter combinations."""
|
|
97
|
+
exception = ProviderCompletionException(
|
|
98
|
+
error_message,
|
|
99
|
+
provider_name=provider_name,
|
|
100
|
+
error_type=error_type
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
assert exception.error_message == error_message
|
|
104
|
+
assert exception.provider_name == provider_name
|
|
105
|
+
assert exception.error_type == error_type
|
|
106
|
+
assert exception.user_friendly_title == expected_title
|
|
107
|
+
assert expected_hint_contains in exception.user_friendly_hint
|
|
108
|
+
assert str(exception) == expected_title
|
|
109
|
+
assert exception.args[0] == expected_title
|
|
110
|
+
|
|
111
|
+
def test_default_initialization(self):
|
|
112
|
+
"""Test exception initialization with default values."""
|
|
113
|
+
error_msg = "Something went wrong"
|
|
114
|
+
exception = ProviderCompletionException(error_msg)
|
|
115
|
+
|
|
116
|
+
assert exception.error_message == error_msg
|
|
117
|
+
assert exception.provider_name == "LLM Provider"
|
|
118
|
+
assert exception.error_type == "LLMProviderError"
|
|
119
|
+
assert exception.user_friendly_title == "LLM Provider Error: Something went wrong"
|
|
120
|
+
assert "LLM Provider" in exception.user_friendly_hint
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class TestGetResponseFromMitoServer:
|
|
124
|
+
"""Test the get_response_from_mito_server function."""
|
|
125
|
+
|
|
126
|
+
@pytest.mark.parametrize("completion_value,message_type", [
|
|
127
|
+
("This is the AI response", MessageType.CHAT),
|
|
128
|
+
("Code completion here", MessageType.INLINE_COMPLETION),
|
|
129
|
+
("", MessageType.CHAT), # Empty string
|
|
130
|
+
(None, MessageType.INLINE_COMPLETION), # None value
|
|
131
|
+
("Multi-line\nresponse\nhere", MessageType.CHAT), # Multi-line response
|
|
132
|
+
])
|
|
133
|
+
@pytest.mark.asyncio
|
|
134
|
+
async def test_successful_completion_responses(
|
|
135
|
+
self,
|
|
136
|
+
completion_value,
|
|
137
|
+
message_type: MessageType,
|
|
138
|
+
mock_request_params,
|
|
139
|
+
mock_http_dependencies
|
|
140
|
+
):
|
|
141
|
+
"""Test successful responses with various completion values."""
|
|
142
|
+
# Setup
|
|
143
|
+
response_body = {"completion": completion_value}
|
|
144
|
+
mock_response = create_mock_response(response_body)
|
|
145
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
146
|
+
|
|
147
|
+
# Update request params
|
|
148
|
+
mock_request_params["message_type"] = message_type
|
|
149
|
+
|
|
150
|
+
# Execute
|
|
151
|
+
result = await get_response_from_mito_server(**mock_request_params)
|
|
152
|
+
|
|
153
|
+
# Verify
|
|
154
|
+
assert result == completion_value
|
|
155
|
+
mock_http_dependencies['mock_check_quota'].assert_called_once_with(message_type)
|
|
156
|
+
mock_http_dependencies['mock_update_quota'].assert_called_once_with(message_type)
|
|
157
|
+
mock_http_dependencies['mock_http_client'].close.assert_called_once()
|
|
158
|
+
|
|
159
|
+
# Verify HTTP request was made correctly
|
|
160
|
+
mock_http_dependencies['mock_http_client'].fetch.assert_called_once_with(
|
|
161
|
+
mock_request_params["url"],
|
|
162
|
+
method="POST",
|
|
163
|
+
headers=mock_request_params["headers"],
|
|
164
|
+
body=json.dumps(mock_request_params["data"]),
|
|
165
|
+
request_timeout=30
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
@pytest.mark.parametrize("error_message,provider_name,expected_exception_provider", [
|
|
169
|
+
(
|
|
170
|
+
"There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
|
|
171
|
+
"Anthropic",
|
|
172
|
+
"Anthropic"
|
|
173
|
+
),
|
|
174
|
+
(
|
|
175
|
+
"Rate limit exceeded",
|
|
176
|
+
"OpenAI",
|
|
177
|
+
"OpenAI"
|
|
178
|
+
),
|
|
179
|
+
(
|
|
180
|
+
"Invalid API key",
|
|
181
|
+
"Custom Provider",
|
|
182
|
+
"Custom Provider"
|
|
183
|
+
),
|
|
184
|
+
(
|
|
185
|
+
"Server timeout",
|
|
186
|
+
"Mito Server",
|
|
187
|
+
"Mito Server"
|
|
188
|
+
),
|
|
189
|
+
])
|
|
190
|
+
@pytest.mark.asyncio
|
|
191
|
+
async def test_error_responses_from_server(
|
|
192
|
+
self,
|
|
193
|
+
error_message: str,
|
|
194
|
+
provider_name: str,
|
|
195
|
+
expected_exception_provider: str,
|
|
196
|
+
mock_request_params,
|
|
197
|
+
mock_http_dependencies
|
|
198
|
+
):
|
|
199
|
+
"""Test server returns error response with various error messages and providers."""
|
|
200
|
+
# Setup
|
|
201
|
+
response_body = {"error": error_message}
|
|
202
|
+
mock_response = create_mock_response(response_body)
|
|
203
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
204
|
+
|
|
205
|
+
# Update request params
|
|
206
|
+
mock_request_params["provider_name"] = provider_name
|
|
207
|
+
|
|
208
|
+
# Execute and verify exception
|
|
209
|
+
with pytest.raises(ProviderCompletionException) as exc_info:
|
|
210
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
211
|
+
|
|
212
|
+
# Verify exception details
|
|
213
|
+
exception = exc_info.value
|
|
214
|
+
assert exception.error_message == error_message
|
|
215
|
+
assert exception.provider_name == expected_exception_provider
|
|
216
|
+
assert f"{expected_exception_provider} Error" in str(exception)
|
|
217
|
+
|
|
218
|
+
# Verify quota was updated and client was closed
|
|
219
|
+
mock_http_dependencies['mock_update_quota'].assert_called_once_with(mock_request_params["message_type"])
|
|
220
|
+
mock_http_dependencies['mock_http_client'].close.assert_called_once()
|
|
221
|
+
|
|
222
|
+
@pytest.mark.parametrize("response_body,expected_error_contains", [
|
|
223
|
+
({"some_other_field": "value"}, "No completion found in response"),
|
|
224
|
+
({"data": "value", "status": "ok"}, "No completion found in response"),
|
|
225
|
+
({}, "No completion found in response"),
|
|
226
|
+
({"completion": None, "error": "also present"}, None), # completion takes precedence
|
|
227
|
+
])
|
|
228
|
+
@pytest.mark.asyncio
|
|
229
|
+
async def test_invalid_response_formats(
|
|
230
|
+
self,
|
|
231
|
+
response_body: dict,
|
|
232
|
+
expected_error_contains: str,
|
|
233
|
+
mock_request_params,
|
|
234
|
+
mock_http_dependencies
|
|
235
|
+
):
|
|
236
|
+
"""Test responses with invalid formats."""
|
|
237
|
+
# Setup
|
|
238
|
+
mock_response = create_mock_response(response_body)
|
|
239
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
240
|
+
|
|
241
|
+
if "completion" in response_body:
|
|
242
|
+
# This should succeed because completion field exists
|
|
243
|
+
result = await get_response_from_mito_server(**mock_request_params)
|
|
244
|
+
assert result == response_body["completion"]
|
|
245
|
+
mock_http_dependencies['mock_update_quota'].assert_called_once()
|
|
246
|
+
else:
|
|
247
|
+
# Execute and verify exception
|
|
248
|
+
with pytest.raises(ProviderCompletionException) as exc_info:
|
|
249
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
250
|
+
|
|
251
|
+
# Verify exception details
|
|
252
|
+
exception = exc_info.value
|
|
253
|
+
assert expected_error_contains in exception.error_message
|
|
254
|
+
assert str(response_body) in exception.error_message
|
|
255
|
+
assert exception.provider_name == mock_request_params["provider_name"]
|
|
256
|
+
|
|
257
|
+
# Verify quota was NOT updated
|
|
258
|
+
mock_http_dependencies['mock_update_quota'].assert_called_once_with(mock_request_params["message_type"])
|
|
259
|
+
|
|
260
|
+
# Client should always be closed
|
|
261
|
+
mock_http_dependencies['mock_http_client'].close.assert_called_once()
|
|
262
|
+
|
|
263
|
+
@pytest.mark.parametrize("invalid_json_content,expected_error_contains", [
|
|
264
|
+
("invalid json content", "Error parsing response"),
|
|
265
|
+
('{"incomplete": json', "Error parsing response"),
|
|
266
|
+
("", "Error parsing response"),
|
|
267
|
+
('{"malformed":', "Error parsing response"),
|
|
268
|
+
])
|
|
269
|
+
@pytest.mark.asyncio
|
|
270
|
+
async def test_json_parsing_errors(
|
|
271
|
+
self,
|
|
272
|
+
invalid_json_content: str,
|
|
273
|
+
expected_error_contains: str,
|
|
274
|
+
mock_request_params,
|
|
275
|
+
mock_http_dependencies
|
|
276
|
+
):
|
|
277
|
+
"""Test response with invalid or malformed JSON."""
|
|
278
|
+
# Setup
|
|
279
|
+
mock_response = MagicMock(spec=HTTPResponse)
|
|
280
|
+
mock_response.body.decode.return_value = invalid_json_content
|
|
281
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
282
|
+
|
|
283
|
+
# Execute and verify exception
|
|
284
|
+
with pytest.raises(ProviderCompletionException) as exc_info:
|
|
285
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
286
|
+
|
|
287
|
+
# Verify exception details
|
|
288
|
+
exception = exc_info.value
|
|
289
|
+
assert expected_error_contains in exception.error_message
|
|
290
|
+
assert exception.provider_name == mock_request_params["provider_name"]
|
|
291
|
+
|
|
292
|
+
# Verify quota was updated and client was closed
|
|
293
|
+
mock_http_dependencies['mock_update_quota'].assert_called_once_with(mock_request_params["message_type"])
|
|
294
|
+
mock_http_dependencies['mock_http_client'].close.assert_called_once()
|
|
295
|
+
|
|
296
|
+
@pytest.mark.parametrize("timeout,max_retries", [
|
|
297
|
+
(30, 3),
|
|
298
|
+
(45, 5),
|
|
299
|
+
(60, 1),
|
|
300
|
+
(15, 0),
|
|
301
|
+
])
|
|
302
|
+
@pytest.mark.asyncio
|
|
303
|
+
async def test_http_client_creation_parameters(
|
|
304
|
+
self,
|
|
305
|
+
timeout: int,
|
|
306
|
+
max_retries: int,
|
|
307
|
+
mock_request_params,
|
|
308
|
+
mock_http_dependencies
|
|
309
|
+
):
|
|
310
|
+
"""Test that HTTP client is created with correct parameters."""
|
|
311
|
+
# Setup
|
|
312
|
+
response_body = {"completion": "test response"}
|
|
313
|
+
mock_response = create_mock_response(response_body)
|
|
314
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
315
|
+
|
|
316
|
+
# Update request params
|
|
317
|
+
mock_request_params["timeout"] = timeout
|
|
318
|
+
mock_request_params["max_retries"] = max_retries
|
|
319
|
+
|
|
320
|
+
# Execute
|
|
321
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
322
|
+
|
|
323
|
+
# Verify HTTP client creation
|
|
324
|
+
mock_http_dependencies['mock_create_client'].assert_called_once_with(timeout, max_retries)
|
|
325
|
+
|
|
326
|
+
@pytest.mark.parametrize("exception_type,exception_message", [
|
|
327
|
+
(Exception, "Network error"),
|
|
328
|
+
(ConnectionError, "Connection failed"),
|
|
329
|
+
(TimeoutError, "Request timed out"),
|
|
330
|
+
(RuntimeError, "Runtime error occurred"),
|
|
331
|
+
])
|
|
332
|
+
@pytest.mark.asyncio
|
|
333
|
+
async def test_http_client_always_closed_on_exception(
|
|
334
|
+
self,
|
|
335
|
+
exception_type,
|
|
336
|
+
exception_message: str,
|
|
337
|
+
mock_request_params,
|
|
338
|
+
mock_http_dependencies
|
|
339
|
+
):
|
|
340
|
+
"""Test that HTTP client is always closed even when exceptions occur."""
|
|
341
|
+
# Setup - make fetch raise an exception
|
|
342
|
+
test_exception = exception_type(exception_message)
|
|
343
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(side_effect=test_exception)
|
|
344
|
+
|
|
345
|
+
# Execute and expect exception to bubble up
|
|
346
|
+
with pytest.raises(exception_type, match=exception_message):
|
|
347
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
348
|
+
|
|
349
|
+
# Verify client was still closed despite the exception
|
|
350
|
+
mock_http_dependencies['mock_http_client'].close.assert_called_once()
|
|
351
|
+
|
|
352
|
+
@pytest.mark.asyncio
|
|
353
|
+
async def test_default_provider_name(self, mock_http_dependencies):
|
|
354
|
+
"""Test that default provider name is used when not specified."""
|
|
355
|
+
# Setup
|
|
356
|
+
error_message = "Test error"
|
|
357
|
+
response_body = {"error": error_message}
|
|
358
|
+
mock_response = create_mock_response(response_body)
|
|
359
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
360
|
+
|
|
361
|
+
# Test data without provider_name parameter
|
|
362
|
+
request_params = {
|
|
363
|
+
"url": "https://api.example.com",
|
|
364
|
+
"headers": {"Content-Type": "application/json"},
|
|
365
|
+
"data": {"query": "test query"},
|
|
366
|
+
"timeout": 30,
|
|
367
|
+
"max_retries": 3,
|
|
368
|
+
"message_type": MessageType.CHAT,
|
|
369
|
+
# Note: not providing provider_name parameter
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
# Execute and verify exception
|
|
373
|
+
with pytest.raises(ProviderCompletionException) as exc_info:
|
|
374
|
+
await get_response_from_mito_server(**request_params) # type: ignore
|
|
375
|
+
|
|
376
|
+
# Verify default provider name is used
|
|
377
|
+
exception = exc_info.value
|
|
378
|
+
assert exception.provider_name == "Mito Server"
|
|
379
|
+
assert "Mito Server Error" in str(exception)
|
|
380
|
+
|
|
381
|
+
@pytest.mark.asyncio
|
|
382
|
+
async def test_provider_completion_exception_reraised(self, mock_request_params, mock_http_dependencies):
|
|
383
|
+
"""Test that ProviderCompletionException is re-raised correctly during JSON parsing."""
|
|
384
|
+
# Setup - simulate ProviderCompletionException during JSON parsing
|
|
385
|
+
mock_response = MagicMock(spec=HTTPResponse)
|
|
386
|
+
mock_response.body.decode.return_value = "some json content" # This will trigger json.loads
|
|
387
|
+
|
|
388
|
+
def mock_json_loads(content, **kwargs):
|
|
389
|
+
raise ProviderCompletionException("Custom parsing error", "Custom Provider")
|
|
390
|
+
|
|
391
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
392
|
+
|
|
393
|
+
with patch('mito_ai.utils.mito_server_utils.json.loads', side_effect=mock_json_loads), \
|
|
394
|
+
patch('mito_ai.utils.mito_server_utils.check_mito_server_quota') as mock_check_quota:
|
|
395
|
+
|
|
396
|
+
# Execute and verify exception
|
|
397
|
+
with pytest.raises(ProviderCompletionException) as exc_info:
|
|
398
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
399
|
+
|
|
400
|
+
# Verify the original exception is preserved
|
|
401
|
+
exception = exc_info.value
|
|
402
|
+
assert exception.error_message == "Custom parsing error"
|
|
403
|
+
assert exception.provider_name == "Custom Provider"
|
|
404
|
+
|
|
405
|
+
# Verify quota check was called
|
|
406
|
+
mock_check_quota.assert_called_once_with(mock_request_params["message_type"])
|
|
407
|
+
|
|
408
|
+
# Verify client was closed
|
|
409
|
+
mock_http_dependencies['mock_http_client'].close.assert_called_once()
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
@pytest.mark.parametrize("scenario,response_setup,main_exception,quota_exception", [
|
|
413
|
+
("successful_with_quota_error", {"completion": "Success"}, None, Exception("Quota update failed")),
|
|
414
|
+
("server_error_with_quota_error", {"error": "Server error"}, ProviderCompletionException, Exception("Quota update failed")),
|
|
415
|
+
("invalid_format_with_quota_error", {"invalid": "format"}, ProviderCompletionException, RuntimeError("Quota system down")),
|
|
416
|
+
("success_with_quota_timeout", {"completion": "Success"}, None, TimeoutError("Quota service timeout")),
|
|
417
|
+
])
|
|
418
|
+
@pytest.mark.asyncio
|
|
419
|
+
async def test_quota_update_exceptions_do_not_interfere(
|
|
420
|
+
self,
|
|
421
|
+
scenario: str,
|
|
422
|
+
response_setup: dict,
|
|
423
|
+
main_exception,
|
|
424
|
+
quota_exception,
|
|
425
|
+
mock_request_params,
|
|
426
|
+
mock_http_dependencies
|
|
427
|
+
):
|
|
428
|
+
"""Test that quota update exceptions don't interfere with main function logic."""
|
|
429
|
+
# Setup
|
|
430
|
+
mock_response = create_mock_response(response_setup)
|
|
431
|
+
mock_http_dependencies['mock_http_client'].fetch = AsyncMock(return_value=mock_response)
|
|
432
|
+
mock_http_dependencies['mock_update_quota'].side_effect = quota_exception
|
|
433
|
+
|
|
434
|
+
# Execute
|
|
435
|
+
if main_exception:
|
|
436
|
+
with pytest.raises(main_exception) as exc_info:
|
|
437
|
+
await get_response_from_mito_server(**mock_request_params)
|
|
438
|
+
|
|
439
|
+
# Verify the original error is preserved, not the quota error
|
|
440
|
+
if "error" in response_setup:
|
|
441
|
+
assert exc_info.value.error_message == response_setup["error"]
|
|
442
|
+
else:
|
|
443
|
+
# Should still succeed despite quota update failure
|
|
444
|
+
result = await get_response_from_mito_server(**mock_request_params)
|
|
445
|
+
assert result == response_setup["completion"]
|
|
446
|
+
|
|
447
|
+
# Verify quota update was attempted
|
|
448
|
+
mock_http_dependencies['mock_update_quota'].assert_called_once_with(mock_request_params["message_type"])
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
These tests ensure that the correct model is chosen for each message type, for each provider.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from mito_ai.utils.provider_utils import does_message_require_fast_model
|
|
10
|
+
from mito_ai.completions.models import MessageType
|
|
11
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
12
|
+
from mito_ai.completions.providers import OpenAIProvider
|
|
13
|
+
from mito_ai.completions.models import MessageType
|
|
14
|
+
from mito_ai.utils.provider_utils import does_message_require_fast_model
|
|
15
|
+
from traitlets.config import Config
|
|
16
|
+
|
|
17
|
+
@pytest.fixture
|
|
18
|
+
def provider_config() -> Config:
|
|
19
|
+
"""Create a proper Config object for the OpenAIProvider."""
|
|
20
|
+
config = Config()
|
|
21
|
+
config.OpenAIProvider = Config()
|
|
22
|
+
config.OpenAIClient = Config()
|
|
23
|
+
return config
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def mock_messages():
|
|
27
|
+
"""Sample messages for testing."""
|
|
28
|
+
return [{"role": "user", "content": "Test message"}]
|
|
29
|
+
|
|
30
|
+
# Test cases for different message types and their expected fast model requirement
|
|
31
|
+
MESSAGE_TYPE_TEST_CASES = [
|
|
32
|
+
(MessageType.CHAT, False),
|
|
33
|
+
(MessageType.SMART_DEBUG, False),
|
|
34
|
+
(MessageType.CODE_EXPLAIN, False),
|
|
35
|
+
(MessageType.AGENT_EXECUTION, False),
|
|
36
|
+
(MessageType.AGENT_AUTO_ERROR_FIXUP, False),
|
|
37
|
+
(MessageType.INLINE_COMPLETION, True),
|
|
38
|
+
(MessageType.CHAT_NAME_GENERATION, True),
|
|
39
|
+
]
|
|
40
|
+
@pytest.mark.parametrize("message_type,expected_result", MESSAGE_TYPE_TEST_CASES)
|
|
41
|
+
def test_does_message_require_fast_model(message_type: MessageType, expected_result: bool) -> None:
|
|
42
|
+
"""Test that does_message_require_fast_model returns the correct boolean for each message type."""
|
|
43
|
+
assert does_message_require_fast_model(message_type) == expected_result
|
|
44
|
+
|
|
45
|
+
def test_does_message_require_fast_model_raises_error_for_unknown_message_type():
|
|
46
|
+
"""Test that does_message_require_fast_model raises an error for an unknown message type."""
|
|
47
|
+
with pytest.raises(ValueError):
|
|
48
|
+
does_message_require_fast_model('unknown_message_type') # type: ignore
|
|
49
|
+
|
|
50
|
+
@pytest.mark.asyncio
|
|
51
|
+
async def test_request_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
52
|
+
"""Test that request_completions calls does_message_require_fast_model and uses the correct model."""
|
|
53
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
54
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
55
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
56
|
+
|
|
57
|
+
with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
|
|
58
|
+
# Mock the OpenAI API call instead of the entire client
|
|
59
|
+
mock_response = MagicMock()
|
|
60
|
+
mock_response.choices = [MagicMock()]
|
|
61
|
+
mock_response.choices[0].message.content = "Test Completion"
|
|
62
|
+
|
|
63
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
64
|
+
mock_openai_client = MagicMock()
|
|
65
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
66
|
+
mock_openai_client.is_closed.return_value = False
|
|
67
|
+
mock_openai_class.return_value = mock_openai_client
|
|
68
|
+
|
|
69
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
70
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
71
|
+
mock_sync_client = MagicMock()
|
|
72
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
73
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
74
|
+
|
|
75
|
+
provider = OpenAIProvider(config=provider_config)
|
|
76
|
+
await provider.request_completions(
|
|
77
|
+
message_type=MessageType.CHAT,
|
|
78
|
+
messages=mock_messages,
|
|
79
|
+
model="gpt-3.5",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
83
|
+
# Verify the model passed to the API call
|
|
84
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
85
|
+
assert call_args[1]['model'] == "gpt-3.5"
|
|
86
|
+
|
|
87
|
+
@pytest.mark.asyncio
|
|
88
|
+
async def test_stream_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
89
|
+
"""Test that stream_completions calls does_message_require_fast_model and uses the correct model."""
|
|
90
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
91
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
92
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
93
|
+
|
|
94
|
+
with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
|
|
95
|
+
# Mock the OpenAI API call instead of the entire client
|
|
96
|
+
mock_response = MagicMock()
|
|
97
|
+
mock_response.choices = [MagicMock()]
|
|
98
|
+
mock_response.choices[0].delta.content = "Test Stream Completion"
|
|
99
|
+
mock_response.choices[0].finish_reason = "stop"
|
|
100
|
+
|
|
101
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
102
|
+
mock_openai_client = MagicMock()
|
|
103
|
+
# Create an async generator for streaming
|
|
104
|
+
async def mock_stream():
|
|
105
|
+
yield mock_response
|
|
106
|
+
|
|
107
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
108
|
+
mock_openai_client.is_closed.return_value = False
|
|
109
|
+
mock_openai_class.return_value = mock_openai_client
|
|
110
|
+
|
|
111
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
112
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
113
|
+
mock_sync_client = MagicMock()
|
|
114
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
115
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
116
|
+
|
|
117
|
+
provider = OpenAIProvider(config=provider_config)
|
|
118
|
+
await provider.stream_completions(
|
|
119
|
+
message_type=MessageType.CHAT,
|
|
120
|
+
messages=mock_messages,
|
|
121
|
+
model="gpt-3.5",
|
|
122
|
+
message_id="test_id",
|
|
123
|
+
thread_id="test_thread",
|
|
124
|
+
reply_fn=lambda x: None
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
128
|
+
# Verify the model passed to the API call
|
|
129
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
130
|
+
assert call_args[1]['model'] == "gpt-3.5"
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from mito_ai.openai_client import OpenAIClient
|
|
6
|
+
from mito_ai.utils.open_ai_utils import FAST_OPENAI_MODEL
|
|
7
|
+
from mito_ai.completions.models import MessageType
|
|
8
|
+
from unittest.mock import MagicMock, patch, AsyncMock
|
|
9
|
+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
|
|
10
|
+
|
|
11
|
+
CUSTOM_MODEL = "smart-openai-model"
|
|
12
|
+
@pytest.mark.parametrize("message_type, expected_model", [
|
|
13
|
+
(MessageType.CHAT, CUSTOM_MODEL), #
|
|
14
|
+
(MessageType.SMART_DEBUG, CUSTOM_MODEL), #
|
|
15
|
+
(MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
|
|
16
|
+
(MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
|
|
17
|
+
(MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
|
|
18
|
+
(MessageType.INLINE_COMPLETION, FAST_OPENAI_MODEL), #
|
|
19
|
+
(MessageType.CHAT_NAME_GENERATION, FAST_OPENAI_MODEL), #
|
|
20
|
+
])
|
|
21
|
+
@pytest.mark.asyncio
|
|
22
|
+
async def test_model_selection_based_on_message_type(message_type, expected_model):
|
|
23
|
+
"""
|
|
24
|
+
Tests that the correct model is selected based on the message type.
|
|
25
|
+
"""
|
|
26
|
+
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
27
|
+
|
|
28
|
+
# Mock the _build_openai_client method to return our mock client
|
|
29
|
+
with patch.object(client, '_build_openai_client') as mock_build_client, \
|
|
30
|
+
patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
31
|
+
|
|
32
|
+
mock_client = MagicMock()
|
|
33
|
+
mock_chat = MagicMock()
|
|
34
|
+
mock_completions = MagicMock()
|
|
35
|
+
mock_client.chat = mock_chat
|
|
36
|
+
mock_chat.completions = mock_completions
|
|
37
|
+
mock_openai_class.return_value = mock_client
|
|
38
|
+
mock_build_client.return_value = mock_client
|
|
39
|
+
|
|
40
|
+
# Create an async mock for the create method
|
|
41
|
+
mock_create = AsyncMock()
|
|
42
|
+
mock_create.return_value = MagicMock(
|
|
43
|
+
choices=[MagicMock(message=MagicMock(content="test"))]
|
|
44
|
+
)
|
|
45
|
+
mock_completions.create = mock_create
|
|
46
|
+
|
|
47
|
+
await client.request_completions(
|
|
48
|
+
message_type=message_type,
|
|
49
|
+
messages=[{"role": "user", "content": "Test message"}],
|
|
50
|
+
model=CUSTOM_MODEL,
|
|
51
|
+
response_format_info=None
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Verify that create was called with the expected model
|
|
55
|
+
mock_create.assert_called_once()
|
|
56
|
+
call_args = mock_create.call_args
|
|
57
|
+
assert call_args[1]['model'] == expected_model
|