mito-ai 0.1.56__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. mito_ai/__init__.py +17 -21
  2. mito_ai/_version.py +1 -1
  3. mito_ai/anthropic_client.py +24 -14
  4. mito_ai/chart_wizard/__init__.py +3 -0
  5. mito_ai/chart_wizard/handlers.py +113 -0
  6. mito_ai/chart_wizard/urls.py +26 -0
  7. mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
  8. mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
  9. mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
  10. mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
  11. mito_ai/completions/completion_handlers/completion_handler.py +14 -7
  12. mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
  13. mito_ai/completions/completion_handlers/scratchpad_result_handler.py +64 -0
  14. mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
  15. mito_ai/completions/completion_handlers/utils.py +3 -7
  16. mito_ai/completions/handlers.py +36 -21
  17. mito_ai/completions/message_history.py +8 -10
  18. mito_ai/completions/models.py +23 -2
  19. mito_ai/completions/prompt_builders/agent_smart_debug_prompt.py +5 -3
  20. mito_ai/completions/prompt_builders/agent_system_message.py +97 -5
  21. mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
  22. mito_ai/completions/prompt_builders/chart_conversion_prompt.py +27 -0
  23. mito_ai/completions/prompt_builders/chat_system_message.py +2 -0
  24. mito_ai/completions/prompt_builders/prompt_constants.py +28 -0
  25. mito_ai/completions/prompt_builders/scratchpad_result_prompt.py +17 -0
  26. mito_ai/constants.py +8 -1
  27. mito_ai/enterprise/__init__.py +1 -1
  28. mito_ai/enterprise/litellm_client.py +137 -0
  29. mito_ai/log/handlers.py +1 -1
  30. mito_ai/openai_client.py +10 -90
  31. mito_ai/{completions/providers.py → provider_manager.py} +157 -53
  32. mito_ai/settings/enterprise_handler.py +26 -0
  33. mito_ai/settings/urls.py +2 -0
  34. mito_ai/streamlit_conversion/agent_utils.py +2 -30
  35. mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
  36. mito_ai/streamlit_preview/handlers.py +6 -3
  37. mito_ai/streamlit_preview/urls.py +5 -3
  38. mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
  39. mito_ai/tests/providers/test_anthropic_client.py +174 -16
  40. mito_ai/tests/providers/test_azure.py +13 -13
  41. mito_ai/tests/providers/test_capabilities.py +14 -17
  42. mito_ai/tests/providers/test_gemini_client.py +14 -13
  43. mito_ai/tests/providers/test_model_resolution.py +145 -89
  44. mito_ai/tests/providers/test_openai_client.py +209 -13
  45. mito_ai/tests/providers/test_provider_limits.py +5 -5
  46. mito_ai/tests/providers/test_providers.py +229 -51
  47. mito_ai/tests/providers/test_retry_logic.py +13 -22
  48. mito_ai/tests/providers/utils.py +4 -4
  49. mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
  50. mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
  51. mito_ai/tests/test_enterprise_mode.py +162 -0
  52. mito_ai/tests/test_model_utils.py +271 -0
  53. mito_ai/utils/anthropic_utils.py +8 -6
  54. mito_ai/utils/gemini_utils.py +0 -3
  55. mito_ai/utils/litellm_utils.py +84 -0
  56. mito_ai/utils/model_utils.py +178 -0
  57. mito_ai/utils/open_ai_utils.py +0 -8
  58. mito_ai/utils/provider_utils.py +6 -21
  59. mito_ai/utils/telemetry_utils.py +14 -2
  60. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
  61. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  62. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  63. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.dfd7975de75d64db80d6.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +2992 -282
  64. mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
  65. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.1e7b5cf362385f109883.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
  66. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.1e7b5cf362385f109883.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
  67. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +7 -2
  68. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
  69. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/RECORD +94 -81
  70. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.dfd7975de75d64db80d6.js.map +0 -1
  71. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  72. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  73. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
  74. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
  75. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  76. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
  77. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
  78. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
  79. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
  80. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
  81. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
  82. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
  83. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
  84. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
  85. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
  86. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
  87. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
  88. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
  89. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
  90. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  91. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  92. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
  93. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
  94. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
  95. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import ast
6
6
  import inspect
7
7
  import requests
8
8
  from mito_ai.gemini_client import GeminiClient, get_gemini_system_prompt_and_messages
9
- from mito_ai.utils.gemini_utils import get_gemini_completion_function_params, FAST_GEMINI_MODEL
9
+ from mito_ai.utils.gemini_utils import get_gemini_completion_function_params
10
10
  from google.genai.types import Part, GenerateContentResponse, Candidate, Content
11
11
  from mito_ai.completions.models import ResponseFormatInfo, AgentResponse
12
12
  from unittest.mock import MagicMock, patch
@@ -156,19 +156,20 @@ async def test_json_response_handling_with_multiple_parts():
156
156
  assert result == 'Here is the JSON: {"key": "value"} End of response'
157
157
 
158
158
  CUSTOM_MODEL = "smart-gemini-model"
159
- @pytest.mark.parametrize("message_type, expected_model", [
160
- (MessageType.CHAT, CUSTOM_MODEL), #
161
- (MessageType.SMART_DEBUG, CUSTOM_MODEL), #
162
- (MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
163
- (MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
164
- (MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
165
- (MessageType.INLINE_COMPLETION, FAST_GEMINI_MODEL), #
166
- (MessageType.CHAT_NAME_GENERATION, FAST_GEMINI_MODEL), #
159
+ @pytest.mark.parametrize("message_type", [
160
+ MessageType.CHAT,
161
+ MessageType.SMART_DEBUG,
162
+ MessageType.CODE_EXPLAIN,
163
+ MessageType.AGENT_EXECUTION,
164
+ MessageType.AGENT_AUTO_ERROR_FIXUP,
165
+ MessageType.INLINE_COMPLETION,
166
+ MessageType.CHAT_NAME_GENERATION,
167
167
  ])
168
168
  @pytest.mark.asyncio
169
- async def test_get_completion_model_selection_based_on_message_type(message_type, expected_model):
169
+ async def test_get_completion_model_selection_uses_passed_model(message_type):
170
170
  """
171
- Tests that the correct model is selected based on the message type.
171
+ Tests that the model passed to the client is used as-is.
172
+ Model selection based on message type is now handled by ProviderManager.
172
173
  """
173
174
  with patch('google.genai.Client') as mock_genai_class:
174
175
  mock_client = MagicMock()
@@ -189,7 +190,7 @@ async def test_get_completion_model_selection_based_on_message_type(message_type
189
190
  response_format_info=None
190
191
  )
191
192
 
192
- # Verify that generate_content was called with the expected model
193
+ # Verify that generate_content was called with the model that was passed (not overridden)
193
194
  mock_models.generate_content.assert_called_once()
194
195
  call_args = mock_models.generate_content.call_args
195
- assert call_args[1]['model'] == expected_model
196
+ assert call_args[1]['model'] == CUSTOM_MODEL
@@ -6,19 +6,17 @@ These tests ensure that the correct model is chosen for each message type, for e
6
6
  """
7
7
 
8
8
  import pytest
9
- from mito_ai.utils.provider_utils import does_message_require_fast_model
9
+ from mito_ai.utils.model_utils import get_fast_model_for_selected_model, get_smartest_model_for_selected_model
10
10
  from mito_ai.completions.models import MessageType
11
11
  from unittest.mock import AsyncMock, MagicMock, patch
12
- from mito_ai.completions.providers import OpenAIProvider
13
- from mito_ai.completions.models import MessageType
14
- from mito_ai.utils.provider_utils import does_message_require_fast_model
12
+ from mito_ai.provider_manager import ProviderManager
15
13
  from traitlets.config import Config
16
14
 
17
15
  @pytest.fixture
18
16
  def provider_config() -> Config:
19
- """Create a proper Config object for the OpenAIProvider."""
17
+ """Create a proper Config object for the ProviderManager."""
20
18
  config = Config()
21
- config.OpenAIProvider = Config()
19
+ config.ProviderManager = Config()
22
20
  config.OpenAIClient = Config()
23
21
  return config
24
22
 
@@ -27,104 +25,162 @@ def mock_messages():
27
25
  """Sample messages for testing."""
28
26
  return [{"role": "user", "content": "Test message"}]
29
27
 
30
- # Test cases for different message types and their expected fast model requirement
31
- MESSAGE_TYPE_TEST_CASES = [
32
- (MessageType.CHAT, False),
33
- (MessageType.SMART_DEBUG, False),
34
- (MessageType.CODE_EXPLAIN, False),
35
- (MessageType.AGENT_EXECUTION, False),
36
- (MessageType.AGENT_AUTO_ERROR_FIXUP, False),
37
- (MessageType.INLINE_COMPLETION, True),
38
- (MessageType.CHAT_NAME_GENERATION, True),
39
- ]
40
- @pytest.mark.parametrize("message_type,expected_result", MESSAGE_TYPE_TEST_CASES)
41
- def test_does_message_require_fast_model(message_type: MessageType, expected_result: bool) -> None:
42
- """Test that does_message_require_fast_model returns the correct boolean for each message type."""
43
- assert does_message_require_fast_model(message_type) == expected_result
28
+ @pytest.mark.asyncio
29
+ async def test_request_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
30
+ """Test that request_completions uses the correct model when use_fast_model=True."""
31
+ # Set up environment variables to ensure OpenAI provider is used
32
+ monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
33
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
34
+
35
+ # Mock the OpenAI API call instead of the entire client
36
+ mock_response = MagicMock()
37
+ mock_response.choices = [MagicMock()]
38
+ mock_response.choices[0].message.content = "Test Completion"
39
+
40
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
41
+ mock_openai_client = MagicMock()
42
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
43
+ mock_openai_client.is_closed.return_value = False
44
+ mock_openai_class.return_value = mock_openai_client
45
+
46
+ # Mock the validation that happens in OpenAIClient constructor
47
+ with patch('openai.OpenAI') as mock_sync_openai_class:
48
+ mock_sync_client = MagicMock()
49
+ mock_sync_client.models.list.return_value = MagicMock()
50
+ mock_sync_openai_class.return_value = mock_sync_client
51
+
52
+ provider = ProviderManager(config=provider_config)
53
+ provider.set_selected_model("gpt-5.2")
54
+ await provider.request_completions(
55
+ message_type=MessageType.CHAT,
56
+ messages=mock_messages,
57
+ use_fast_model=True
58
+ )
59
+
60
+ # Verify the model passed to the API call is the fast model
61
+ call_args = mock_openai_client.chat.completions.create.call_args
62
+ assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
63
+
64
+ @pytest.mark.asyncio
65
+ async def test_stream_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
66
+ """Test that stream_completions uses the correct model when use_fast_model=True."""
67
+ # Set up environment variables to ensure OpenAI provider is used
68
+ monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
69
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
44
70
 
45
- def test_does_message_require_fast_model_raises_error_for_unknown_message_type():
46
- """Test that does_message_require_fast_model raises an error for an unknown message type."""
47
- with pytest.raises(ValueError):
48
- does_message_require_fast_model('unknown_message_type') # type: ignore
71
+ # Mock the OpenAI API call instead of the entire client
72
+ mock_response = MagicMock()
73
+ mock_response.choices = [MagicMock()]
74
+ mock_response.choices[0].delta.content = "Test Stream Completion"
75
+ mock_response.choices[0].finish_reason = "stop"
76
+
77
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
78
+ mock_openai_client = MagicMock()
79
+ # Create an async generator for streaming
80
+ async def mock_stream():
81
+ yield mock_response
82
+
83
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
84
+ mock_openai_client.is_closed.return_value = False
85
+ mock_openai_class.return_value = mock_openai_client
86
+
87
+ # Mock the validation that happens in OpenAIClient constructor
88
+ with patch('openai.OpenAI') as mock_sync_openai_class:
89
+ mock_sync_client = MagicMock()
90
+ mock_sync_client.models.list.return_value = MagicMock()
91
+ mock_sync_openai_class.return_value = mock_sync_client
92
+
93
+ provider = ProviderManager(config=provider_config)
94
+ provider.set_selected_model("gpt-5.2")
95
+ await provider.stream_completions(
96
+ message_type=MessageType.CHAT,
97
+ messages=mock_messages,
98
+ message_id="test_id",
99
+ thread_id="test_thread",
100
+ reply_fn=lambda x: None,
101
+ use_fast_model=True
102
+ )
103
+
104
+ # Verify the model passed to the API call is the fast model
105
+ call_args = mock_openai_client.chat.completions.create.call_args
106
+ assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
49
107
 
50
108
  @pytest.mark.asyncio
51
- async def test_request_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
52
- """Test that request_completions calls does_message_require_fast_model and uses the correct model."""
109
+ async def test_request_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
110
+ """Test that request_completions uses the correct model when use_smartest_model=True."""
53
111
  # Set up environment variables to ensure OpenAI provider is used
54
112
  monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
55
113
  monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
56
114
 
57
- with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
58
- # Mock the OpenAI API call instead of the entire client
59
- mock_response = MagicMock()
60
- mock_response.choices = [MagicMock()]
61
- mock_response.choices[0].message.content = "Test Completion"
115
+ # Mock the OpenAI API call instead of the entire client
116
+ mock_response = MagicMock()
117
+ mock_response.choices = [MagicMock()]
118
+ mock_response.choices[0].message.content = "Test Completion"
119
+
120
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
121
+ mock_openai_client = MagicMock()
122
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
123
+ mock_openai_client.is_closed.return_value = False
124
+ mock_openai_class.return_value = mock_openai_client
62
125
 
63
- with patch('openai.AsyncOpenAI') as mock_openai_class:
64
- mock_openai_client = MagicMock()
65
- mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
66
- mock_openai_client.is_closed.return_value = False
67
- mock_openai_class.return_value = mock_openai_client
126
+ # Mock the validation that happens in OpenAIClient constructor
127
+ with patch('openai.OpenAI') as mock_sync_openai_class:
128
+ mock_sync_client = MagicMock()
129
+ mock_sync_client.models.list.return_value = MagicMock()
130
+ mock_sync_openai_class.return_value = mock_sync_client
68
131
 
69
- # Mock the validation that happens in OpenAIClient constructor
70
- with patch('openai.OpenAI') as mock_sync_openai_class:
71
- mock_sync_client = MagicMock()
72
- mock_sync_client.models.list.return_value = MagicMock()
73
- mock_sync_openai_class.return_value = mock_sync_client
74
-
75
- provider = OpenAIProvider(config=provider_config)
76
- await provider.request_completions(
77
- message_type=MessageType.CHAT,
78
- messages=mock_messages,
79
- model="gpt-3.5",
80
- )
81
-
82
- mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
83
- # Verify the model passed to the API call
84
- call_args = mock_openai_client.chat.completions.create.call_args
85
- assert call_args[1]['model'] == "gpt-3.5"
132
+ provider = ProviderManager(config=provider_config)
133
+ provider.set_selected_model("gpt-4.1")
134
+ await provider.request_completions(
135
+ message_type=MessageType.CHAT,
136
+ messages=mock_messages,
137
+ use_smartest_model=True
138
+ )
139
+
140
+ # Verify the model passed to the API call is the smartest model
141
+ call_args = mock_openai_client.chat.completions.create.call_args
142
+ assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
86
143
 
87
144
  @pytest.mark.asyncio
88
- async def test_stream_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
89
- """Test that stream_completions calls does_message_require_fast_model and uses the correct model."""
145
+ async def test_stream_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
146
+ """Test that stream_completions uses the correct model when use_smartest_model=True."""
90
147
  # Set up environment variables to ensure OpenAI provider is used
91
148
  monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
92
149
  monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
93
150
 
94
- with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
95
- # Mock the OpenAI API call instead of the entire client
96
- mock_response = MagicMock()
97
- mock_response.choices = [MagicMock()]
98
- mock_response.choices[0].delta.content = "Test Stream Completion"
99
- mock_response.choices[0].finish_reason = "stop"
151
+ # Mock the OpenAI API call instead of the entire client
152
+ mock_response = MagicMock()
153
+ mock_response.choices = [MagicMock()]
154
+ mock_response.choices[0].delta.content = "Test Stream Completion"
155
+ mock_response.choices[0].finish_reason = "stop"
156
+
157
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
158
+ mock_openai_client = MagicMock()
159
+ # Create an async generator for streaming
160
+ async def mock_stream():
161
+ yield mock_response
162
+
163
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
164
+ mock_openai_client.is_closed.return_value = False
165
+ mock_openai_class.return_value = mock_openai_client
100
166
 
101
- with patch('openai.AsyncOpenAI') as mock_openai_class:
102
- mock_openai_client = MagicMock()
103
- # Create an async generator for streaming
104
- async def mock_stream():
105
- yield mock_response
167
+ # Mock the validation that happens in OpenAIClient constructor
168
+ with patch('openai.OpenAI') as mock_sync_openai_class:
169
+ mock_sync_client = MagicMock()
170
+ mock_sync_client.models.list.return_value = MagicMock()
171
+ mock_sync_openai_class.return_value = mock_sync_client
106
172
 
107
- mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
108
- mock_openai_client.is_closed.return_value = False
109
- mock_openai_class.return_value = mock_openai_client
173
+ provider = ProviderManager(config=provider_config)
174
+ provider.set_selected_model("gpt-4.1")
175
+ await provider.stream_completions(
176
+ message_type=MessageType.CHAT,
177
+ messages=mock_messages,
178
+ message_id="test_id",
179
+ thread_id="test_thread",
180
+ reply_fn=lambda x: None,
181
+ use_smartest_model=True
182
+ )
110
183
 
111
- # Mock the validation that happens in OpenAIClient constructor
112
- with patch('openai.OpenAI') as mock_sync_openai_class:
113
- mock_sync_client = MagicMock()
114
- mock_sync_client.models.list.return_value = MagicMock()
115
- mock_sync_openai_class.return_value = mock_sync_client
116
-
117
- provider = OpenAIProvider(config=provider_config)
118
- await provider.stream_completions(
119
- message_type=MessageType.CHAT,
120
- messages=mock_messages,
121
- model="gpt-3.5",
122
- message_id="test_id",
123
- thread_id="test_thread",
124
- reply_fn=lambda x: None
125
- )
126
-
127
- mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
128
- # Verify the model passed to the API call
129
- call_args = mock_openai_client.chat.completions.create.call_args
130
- assert call_args[1]['model'] == "gpt-3.5"
184
+ # Verify the model passed to the API call is the smartest model
185
+ call_args = mock_openai_client.chat.completions.create.call_args
186
+ assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
@@ -3,25 +3,25 @@
3
3
 
4
4
  import pytest
5
5
  from mito_ai.openai_client import OpenAIClient
6
- from mito_ai.utils.open_ai_utils import FAST_OPENAI_MODEL
7
6
  from mito_ai.completions.models import MessageType
8
7
  from unittest.mock import MagicMock, patch, AsyncMock
9
8
  from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
10
9
 
11
10
  CUSTOM_MODEL = "smart-openai-model"
12
- @pytest.mark.parametrize("message_type, expected_model", [
13
- (MessageType.CHAT, CUSTOM_MODEL), #
14
- (MessageType.SMART_DEBUG, CUSTOM_MODEL), #
15
- (MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
16
- (MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
17
- (MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
18
- (MessageType.INLINE_COMPLETION, FAST_OPENAI_MODEL), #
19
- (MessageType.CHAT_NAME_GENERATION, FAST_OPENAI_MODEL), #
11
+ @pytest.mark.parametrize("message_type", [
12
+ MessageType.CHAT,
13
+ MessageType.SMART_DEBUG,
14
+ MessageType.CODE_EXPLAIN,
15
+ MessageType.AGENT_EXECUTION,
16
+ MessageType.AGENT_AUTO_ERROR_FIXUP,
17
+ MessageType.INLINE_COMPLETION,
18
+ MessageType.CHAT_NAME_GENERATION,
20
19
  ])
21
20
  @pytest.mark.asyncio
22
- async def test_model_selection_based_on_message_type(message_type, expected_model):
21
+ async def test_model_selection_uses_passed_model(message_type):
23
22
  """
24
- Tests that the correct model is selected based on the message type.
23
+ Tests that the model passed to the client is used as-is.
24
+ Model selection based on message type is now handled by ProviderManager.
25
25
  """
26
26
  client = OpenAIClient(api_key="test_key") # type: ignore
27
27
 
@@ -51,7 +51,203 @@ async def test_model_selection_based_on_message_type(message_type, expected_mode
51
51
  response_format_info=None
52
52
  )
53
53
 
54
- # Verify that create was called with the expected model
54
+ # Verify that create was called with the model that was passed (not overridden)
55
55
  mock_create.assert_called_once()
56
56
  call_args = mock_create.call_args
57
- assert call_args[1]['model'] == expected_model
57
+ assert call_args[1]['model'] == CUSTOM_MODEL
58
+
59
+ @pytest.mark.asyncio
60
+ async def test_openai_client_uses_fast_model_from_provider_manager_without_override():
61
+ """Test that OpenAI client uses the fast model passed from ProviderManager without internal override."""
62
+ from mito_ai.utils.model_utils import get_fast_model_for_selected_model
63
+
64
+ client = OpenAIClient(api_key="test_key") # type: ignore
65
+
66
+ # Mock the _build_openai_client method to return our mock client
67
+ with patch.object(client, '_build_openai_client') as mock_build_client, \
68
+ patch('openai.AsyncOpenAI') as mock_openai_class:
69
+
70
+ mock_client = MagicMock()
71
+ mock_chat = MagicMock()
72
+ mock_completions = MagicMock()
73
+ mock_client.chat = mock_chat
74
+ mock_chat.completions = mock_completions
75
+ mock_openai_class.return_value = mock_client
76
+ mock_build_client.return_value = mock_client
77
+
78
+ # Create an async mock for the create method
79
+ mock_create = AsyncMock()
80
+ mock_create.return_value = MagicMock(
81
+ choices=[MagicMock(message=MagicMock(content="test"))]
82
+ )
83
+ mock_completions.create = mock_create
84
+
85
+ # Use a fast model that would be selected by ProviderManager
86
+ fast_model = get_fast_model_for_selected_model("gpt-5.2")
87
+
88
+ await client.request_completions(
89
+ message_type=MessageType.CHAT,
90
+ messages=[{"role": "user", "content": "Test message"}],
91
+ model=fast_model,
92
+ response_format_info=None
93
+ )
94
+
95
+ # Verify that create was called with the fast model that was passed (not overridden)
96
+ mock_create.assert_called_once()
97
+ call_args = mock_create.call_args
98
+ assert call_args[1]['model'] == fast_model
99
+
100
+ @pytest.mark.asyncio
101
+ async def test_openai_client_uses_smartest_model_from_provider_manager_without_override():
102
+ """Test that OpenAI client uses the smartest model passed from ProviderManager without internal override."""
103
+ from mito_ai.utils.model_utils import get_smartest_model_for_selected_model
104
+
105
+ client = OpenAIClient(api_key="test_key") # type: ignore
106
+
107
+ # Mock the _build_openai_client method to return our mock client
108
+ with patch.object(client, '_build_openai_client') as mock_build_client, \
109
+ patch('openai.AsyncOpenAI') as mock_openai_class:
110
+
111
+ mock_client = MagicMock()
112
+ mock_chat = MagicMock()
113
+ mock_completions = MagicMock()
114
+ mock_client.chat = mock_chat
115
+ mock_chat.completions = mock_completions
116
+ mock_openai_class.return_value = mock_client
117
+ mock_build_client.return_value = mock_client
118
+
119
+ # Create an async mock for the create method
120
+ mock_create = AsyncMock()
121
+ mock_create.return_value = MagicMock(
122
+ choices=[MagicMock(message=MagicMock(content="test"))]
123
+ )
124
+ mock_completions.create = mock_create
125
+
126
+ # Use a smartest model that would be selected by ProviderManager
127
+ smartest_model = get_smartest_model_for_selected_model("gpt-4.1")
128
+
129
+ await client.request_completions(
130
+ message_type=MessageType.CHAT,
131
+ messages=[{"role": "user", "content": "Test message"}],
132
+ model=smartest_model,
133
+ response_format_info=None
134
+ )
135
+
136
+ # Verify that create was called with the smartest model that was passed (not overridden)
137
+ mock_create.assert_called_once()
138
+ call_args = mock_create.call_args
139
+ assert call_args[1]['model'] == smartest_model
140
+
141
+ @pytest.mark.asyncio
142
+ async def test_openai_client_stream_uses_fast_model_from_provider_manager_without_override():
143
+ """Test that OpenAI client stream_completions uses the fast model passed from ProviderManager without internal override."""
144
+ from mito_ai.utils.model_utils import get_fast_model_for_selected_model
145
+
146
+ client = OpenAIClient(api_key="test_key") # type: ignore
147
+
148
+ # Mock the _build_openai_client method to return our mock client
149
+ with patch.object(client, '_build_openai_client') as mock_build_client, \
150
+ patch('openai.AsyncOpenAI') as mock_openai_class:
151
+
152
+ mock_client = MagicMock()
153
+ mock_chat = MagicMock()
154
+ mock_completions = MagicMock()
155
+ mock_client.chat = mock_chat
156
+ mock_chat.completions = mock_completions
157
+ mock_openai_class.return_value = mock_client
158
+ mock_build_client.return_value = mock_client
159
+
160
+ # Create an async generator for streaming
161
+ async def mock_stream():
162
+ mock_chunk = MagicMock()
163
+ mock_chunk.choices = [MagicMock()]
164
+ mock_chunk.choices[0].delta.content = "test"
165
+ mock_chunk.choices[0].finish_reason = None
166
+ yield mock_chunk
167
+ mock_final_chunk = MagicMock()
168
+ mock_final_chunk.choices = [MagicMock()]
169
+ mock_final_chunk.choices[0].delta.content = ""
170
+ mock_final_chunk.choices[0].finish_reason = "stop"
171
+ yield mock_final_chunk
172
+
173
+ mock_create = AsyncMock(return_value=mock_stream())
174
+ mock_completions.create = mock_create
175
+
176
+ # Use a fast model that would be selected by ProviderManager
177
+ fast_model = get_fast_model_for_selected_model("gpt-5.2")
178
+
179
+ reply_chunks = []
180
+ def mock_reply(chunk):
181
+ reply_chunks.append(chunk)
182
+
183
+ await client.stream_completions(
184
+ message_type=MessageType.CHAT,
185
+ messages=[{"role": "user", "content": "Test message"}],
186
+ model=fast_model,
187
+ message_id="test-id",
188
+ thread_id="test-thread",
189
+ reply_fn=mock_reply,
190
+ response_format_info=None
191
+ )
192
+
193
+ # Verify that create was called with the fast model that was passed (not overridden)
194
+ mock_create.assert_called_once()
195
+ call_args = mock_create.call_args
196
+ assert call_args[1]['model'] == fast_model
197
+
198
+ @pytest.mark.asyncio
199
+ async def test_openai_client_stream_uses_smartest_model_from_provider_manager_without_override():
200
+ """Test that OpenAI client stream_completions uses the smartest model passed from ProviderManager without internal override."""
201
+ from mito_ai.utils.model_utils import get_smartest_model_for_selected_model
202
+
203
+ client = OpenAIClient(api_key="test_key") # type: ignore
204
+
205
+ # Mock the _build_openai_client method to return our mock client
206
+ with patch.object(client, '_build_openai_client') as mock_build_client, \
207
+ patch('openai.AsyncOpenAI') as mock_openai_class:
208
+
209
+ mock_client = MagicMock()
210
+ mock_chat = MagicMock()
211
+ mock_completions = MagicMock()
212
+ mock_client.chat = mock_chat
213
+ mock_chat.completions = mock_completions
214
+ mock_openai_class.return_value = mock_client
215
+ mock_build_client.return_value = mock_client
216
+
217
+ # Create an async generator for streaming
218
+ async def mock_stream():
219
+ mock_chunk = MagicMock()
220
+ mock_chunk.choices = [MagicMock()]
221
+ mock_chunk.choices[0].delta.content = "test"
222
+ mock_chunk.choices[0].finish_reason = None
223
+ yield mock_chunk
224
+ mock_final_chunk = MagicMock()
225
+ mock_final_chunk.choices = [MagicMock()]
226
+ mock_final_chunk.choices[0].delta.content = ""
227
+ mock_final_chunk.choices[0].finish_reason = "stop"
228
+ yield mock_final_chunk
229
+
230
+ mock_create = AsyncMock(return_value=mock_stream())
231
+ mock_completions.create = mock_create
232
+
233
+ # Use a smartest model that would be selected by ProviderManager
234
+ smartest_model = get_smartest_model_for_selected_model("gpt-4.1")
235
+
236
+ reply_chunks = []
237
+ def mock_reply(chunk):
238
+ reply_chunks.append(chunk)
239
+
240
+ await client.stream_completions(
241
+ message_type=MessageType.CHAT,
242
+ messages=[{"role": "user", "content": "Test message"}],
243
+ model=smartest_model,
244
+ message_id="test-id",
245
+ thread_id="test-thread",
246
+ reply_fn=mock_reply,
247
+ response_format_info=None
248
+ )
249
+
250
+ # Verify that create was called with the smartest model that was passed (not overridden)
251
+ mock_create.assert_called_once()
252
+ call_args = mock_create.call_args
253
+ assert call_args[1]['model'] == smartest_model
@@ -2,7 +2,7 @@
2
2
  # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
3
 
4
4
  import pytest
5
- from mito_ai.completions.providers import OpenAIProvider
5
+ from mito_ai.provider_manager import ProviderManager
6
6
  from mito_ai.tests.providers.utils import mock_openai_client, patch_server_limits
7
7
  from mito_ai.utils.server_limits import OS_MONTHLY_AI_COMPLETIONS_LIMIT
8
8
  from traitlets.config import Config
@@ -11,9 +11,9 @@ FAKE_API_KEY = "sk-1234567890"
11
11
 
12
12
  @pytest.fixture
13
13
  def provider_config() -> Config:
14
- """Create a proper Config object for the OpenAIProvider."""
14
+ """Create a proper Config object for the ProviderManager."""
15
15
  config = Config()
16
- config.OpenAIProvider = Config()
16
+ config.ProviderManager = Config()
17
17
  config.OpenAIClient = Config()
18
18
  return config
19
19
 
@@ -36,7 +36,7 @@ def test_openai_provider_with_limits(
36
36
  patch_server_limits(is_pro=is_pro, completion_count=completion_count),
37
37
  mock_openai_client()
38
38
  ):
39
- llm = OpenAIProvider(config=provider_config)
39
+ llm = ProviderManager(config=provider_config)
40
40
  capabilities = llm.capabilities
41
- assert "user key" in capabilities.provider
41
+ assert "OpenAI" in capabilities.provider
42
42
  assert llm.last_error is None