mito-ai 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. mito_ai/__init__.py +19 -22
  2. mito_ai/_version.py +1 -1
  3. mito_ai/anthropic_client.py +24 -14
  4. mito_ai/chart_wizard/handlers.py +78 -17
  5. mito_ai/chart_wizard/urls.py +8 -5
  6. mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
  7. mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
  8. mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
  9. mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
  10. mito_ai/completions/completion_handlers/completion_handler.py +3 -5
  11. mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
  12. mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
  13. mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
  14. mito_ai/completions/completion_handlers/utils.py +3 -7
  15. mito_ai/completions/handlers.py +32 -22
  16. mito_ai/completions/message_history.py +8 -10
  17. mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
  18. mito_ai/completions/prompt_builders/prompt_constants.py +2 -0
  19. mito_ai/constants.py +31 -2
  20. mito_ai/enterprise/__init__.py +1 -1
  21. mito_ai/enterprise/litellm_client.py +144 -0
  22. mito_ai/enterprise/utils.py +16 -2
  23. mito_ai/log/handlers.py +1 -1
  24. mito_ai/openai_client.py +36 -96
  25. mito_ai/provider_manager.py +420 -0
  26. mito_ai/settings/enterprise_handler.py +26 -0
  27. mito_ai/settings/urls.py +2 -0
  28. mito_ai/streamlit_conversion/agent_utils.py +2 -30
  29. mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
  30. mito_ai/streamlit_preview/handlers.py +6 -3
  31. mito_ai/streamlit_preview/urls.py +5 -3
  32. mito_ai/tests/message_history/test_generate_short_chat_name.py +103 -28
  33. mito_ai/tests/open_ai_utils_test.py +34 -36
  34. mito_ai/tests/providers/test_anthropic_client.py +174 -16
  35. mito_ai/tests/providers/test_azure.py +15 -15
  36. mito_ai/tests/providers/test_capabilities.py +14 -17
  37. mito_ai/tests/providers/test_gemini_client.py +14 -13
  38. mito_ai/tests/providers/test_model_resolution.py +145 -89
  39. mito_ai/tests/providers/test_openai_client.py +209 -13
  40. mito_ai/tests/providers/test_provider_limits.py +5 -5
  41. mito_ai/tests/providers/test_providers.py +229 -51
  42. mito_ai/tests/providers/test_retry_logic.py +13 -22
  43. mito_ai/tests/providers/utils.py +4 -4
  44. mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
  45. mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
  46. mito_ai/tests/test_constants.py +90 -0
  47. mito_ai/tests/test_enterprise_mode.py +217 -0
  48. mito_ai/tests/test_model_utils.py +362 -0
  49. mito_ai/utils/anthropic_utils.py +8 -6
  50. mito_ai/utils/gemini_utils.py +0 -3
  51. mito_ai/utils/litellm_utils.py +84 -0
  52. mito_ai/utils/model_utils.py +257 -0
  53. mito_ai/utils/open_ai_utils.py +29 -41
  54. mito_ai/utils/provider_utils.py +13 -29
  55. mito_ai/utils/telemetry_utils.py +14 -2
  56. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
  57. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  58. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  59. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.44c109c7be36fb884d25.js +1059 -144
  60. mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.44c109c7be36fb884d25.js.map +1 -0
  61. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.f7decebaf69618541e0f.js +17 -17
  62. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.f7decebaf69618541e0f.js.map +1 -1
  63. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +78 -78
  64. {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/METADATA +2 -1
  65. {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/RECORD +90 -83
  66. mito_ai/completions/providers.py +0 -284
  67. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
  68. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  69. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  70. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
  71. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
  72. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  73. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
  74. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
  75. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
  76. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
  77. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
  78. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
  79. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
  80. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
  81. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
  82. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
  83. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
  84. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
  85. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
  86. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
  87. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  88. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  89. {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
  90. {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/WHEEL +0 -0
  91. {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/entry_points.txt +0 -0
  92. {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/licenses/LICENSE +0 -0
@@ -9,7 +9,7 @@ import pytest
9
9
  from traitlets.config import Config
10
10
  from openai.types.chat import ChatCompletionMessageParam
11
11
 
12
- from mito_ai.completions.providers import OpenAIProvider
12
+ from mito_ai.provider_manager import ProviderManager
13
13
  from mito_ai.completions.models import (
14
14
  MessageType,
15
15
  AICapabilities,
@@ -29,9 +29,9 @@ FAKE_AZURE_API_VERSION = "2024-12-01-preview"
29
29
 
30
30
  @pytest.fixture
31
31
  def provider_config() -> Config:
32
- """Create a proper Config object for the OpenAIProvider."""
32
+ """Create a proper Config object for the ProviderManager."""
33
33
  config = Config()
34
- config.OpenAIProvider = Config()
34
+ config.ProviderManager = Config()
35
35
  config.OpenAIClient = Config()
36
36
  return config
37
37
 
@@ -40,7 +40,7 @@ def provider_config() -> Config:
40
40
  def reset_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
41
41
  """Reset all environment variables before each test."""
42
42
  for var in [
43
- "OPENAI_API_KEY", "CLAUDE_API_KEY", "GEMINI_API_KEY", "OLLAMA_MODEL",
43
+ "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "GEMINI_API_KEY", "OLLAMA_MODEL",
44
44
  "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_MODEL",
45
45
  "AZURE_OPENAI_API_VERSION"
46
46
  ]:
@@ -176,11 +176,11 @@ class TestAzureOpenAIClientCreation:
176
176
  openai_client = OpenAIClient(config=provider_config)
177
177
 
178
178
  # Test with gpt-4.1 model
179
- resolved_model = openai_client._adjust_model_for_azure_or_ollama("gpt-4.1")
179
+ resolved_model = openai_client._adjust_model_for_provider("gpt-4.1")
180
180
  assert resolved_model == FAKE_AZURE_MODEL
181
181
 
182
182
  # Test with any other model
183
- resolved_model = openai_client._adjust_model_for_azure_or_ollama("gpt-3.5-turbo")
183
+ resolved_model = openai_client._adjust_model_for_provider("gpt-3.5-turbo")
184
184
  assert resolved_model == FAKE_AZURE_MODEL
185
185
 
186
186
 
@@ -405,7 +405,7 @@ class TestAzureOpenAIStreamCompletions:
405
405
 
406
406
 
407
407
  class TestAzureOpenAIProviderIntegration:
408
- """Test Azure OpenAI integration through the OpenAIProvider."""
408
+ """Test Azure OpenAI integration through the ProviderManager."""
409
409
 
410
410
  @pytest.mark.asyncio
411
411
  @pytest.mark.parametrize("message_type", COMPLETION_MESSAGE_TYPES)
@@ -415,7 +415,7 @@ class TestAzureOpenAIProviderIntegration:
415
415
  provider_config: Config,
416
416
  message_type: MessageType
417
417
  ) -> None:
418
- """Test that OpenAIProvider uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
418
+ """Test that ProviderManager uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
419
419
 
420
420
  # Mock the response
421
421
  mock_response = MagicMock()
@@ -428,7 +428,8 @@ class TestAzureOpenAIProviderIntegration:
428
428
  mock_azure_client.is_closed.return_value = False
429
429
  mock_azure_client_class.return_value = mock_azure_client
430
430
 
431
- provider = OpenAIProvider(config=provider_config)
431
+ provider = ProviderManager(config=provider_config)
432
+ provider.set_selected_model("gpt-4.1")
432
433
 
433
434
  messages: List[ChatCompletionMessageParam] = [
434
435
  {"role": "user", "content": "Test message"}
@@ -437,7 +438,6 @@ class TestAzureOpenAIProviderIntegration:
437
438
  completion = await provider.request_completions(
438
439
  message_type=message_type,
439
440
  messages=messages,
440
- model="gpt-4.1"
441
441
  )
442
442
 
443
443
  # Verify the completion was returned
@@ -461,7 +461,7 @@ class TestAzureOpenAIProviderIntegration:
461
461
  provider_config: Config,
462
462
  message_type: MessageType
463
463
  ) -> None:
464
- """Test that OpenAIProvider stream_completions uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
464
+ """Test that ProviderManager stream_completions uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
465
465
 
466
466
  # Mock the streaming response
467
467
  mock_chunk1 = MagicMock()
@@ -484,7 +484,8 @@ class TestAzureOpenAIProviderIntegration:
484
484
  mock_azure_client.is_closed.return_value = False
485
485
  mock_azure_client_class.return_value = mock_azure_client
486
486
 
487
- provider = OpenAIProvider(config=provider_config)
487
+ provider = ProviderManager(config=provider_config)
488
+ provider.set_selected_model("gpt-4.1")
488
489
 
489
490
  messages: List[ChatCompletionMessageParam] = [
490
491
  {"role": "user", "content": "Test message"}
@@ -497,7 +498,6 @@ class TestAzureOpenAIProviderIntegration:
497
498
  completion = await provider.stream_completions(
498
499
  message_type=message_type,
499
500
  messages=messages,
500
- model="gpt-4.1",
501
501
  message_id="test-id",
502
502
  thread_id="test-thread",
503
503
  reply_fn=mock_reply
@@ -554,8 +554,8 @@ class TestAzureOpenAIConfigurationPriority:
554
554
  """Test that Azure OpenAI is used even when Claude key is available."""
555
555
 
556
556
  # Set Claude key (this should be overridden by Azure OpenAI)
557
- monkeypatch.setenv("CLAUDE_API_KEY", "claude-key")
558
- monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", "claude-key")
557
+ monkeypatch.setenv("ANTHROPIC_API_KEY", "claude-key")
558
+ monkeypatch.setattr("mito_ai.constants.ANTHROPIC_API_KEY", "claude-key")
559
559
 
560
560
  with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
561
561
  openai_client = OpenAIClient(config=provider_config)
@@ -3,7 +3,7 @@
3
3
 
4
4
  import pytest
5
5
  from unittest.mock import MagicMock, patch
6
- from mito_ai.completions.providers import OpenAIProvider
6
+ from mito_ai.provider_manager import ProviderManager
7
7
  from mito_ai.tests.providers.utils import mock_azure_openai_client, mock_openai_client, patch_server_limits
8
8
  from traitlets.config import Config
9
9
 
@@ -11,9 +11,9 @@ FAKE_API_KEY = "sk-1234567890"
11
11
 
12
12
  @pytest.fixture
13
13
  def provider_config() -> Config:
14
- """Create a proper Config object for the OpenAIProvider."""
14
+ """Create a proper Config object for the ProviderManager."""
15
15
  config = Config()
16
- config.OpenAIProvider = Config()
16
+ config.ProviderManager = Config()
17
17
  config.OpenAIClient = Config()
18
18
  return config
19
19
 
@@ -22,7 +22,7 @@ def provider_config() -> Config:
22
22
  "name": "mito_server_fallback_no_keys",
23
23
  "setup": {
24
24
  "OPENAI_API_KEY": None,
25
- "CLAUDE_API_KEY": None,
25
+ "ANTHROPIC_API_KEY": None,
26
26
  "GEMINI_API_KEY": None,
27
27
  "is_azure_configured": False,
28
28
  },
@@ -33,45 +33,45 @@ def provider_config() -> Config:
33
33
  "name": "claude_when_only_claude_key",
34
34
  "setup": {
35
35
  "OPENAI_API_KEY": None,
36
- "CLAUDE_API_KEY": "claude-test-key",
36
+ "ANTHROPIC_API_KEY": "claude-test-key",
37
37
  "GEMINI_API_KEY": None,
38
38
  "is_azure_configured": False,
39
39
  },
40
40
  "expected_provider": "Claude",
41
- "expected_key_type": "claude"
41
+ "expected_key_type": "user_key"
42
42
  },
43
43
  {
44
44
  "name": "gemini_when_only_gemini_key",
45
45
  "setup": {
46
46
  "OPENAI_API_KEY": None,
47
- "CLAUDE_API_KEY": None,
47
+ "ANTHROPIC_API_KEY": None,
48
48
  "GEMINI_API_KEY": "gemini-test-key",
49
49
  "is_azure_configured": False,
50
50
  },
51
51
  "expected_provider": "Gemini",
52
- "expected_key_type": "gemini"
52
+ "expected_key_type": "user_key"
53
53
  },
54
54
  {
55
55
  "name": "openai_when_openai_key",
56
56
  "setup": {
57
57
  "OPENAI_API_KEY": 'openai-test-key',
58
- "CLAUDE_API_KEY": None,
58
+ "ANTHROPIC_API_KEY": None,
59
59
  "GEMINI_API_KEY": None,
60
60
  "is_azure_configured": False,
61
61
  },
62
- "expected_provider": "OpenAI (user key)",
62
+ "expected_provider": "OpenAI",
63
63
  "expected_key_type": "user_key"
64
64
  },
65
65
  {
66
66
  "name": "claude_priority_over_gemini",
67
67
  "setup": {
68
68
  "OPENAI_API_KEY": None,
69
- "CLAUDE_API_KEY": "claude-test-key",
69
+ "ANTHROPIC_API_KEY": "claude-test-key",
70
70
  "GEMINI_API_KEY": "gemini-test-key",
71
71
  "is_azure_configured": False,
72
72
  },
73
73
  "expected_provider": "Claude",
74
- "expected_key_type": "claude"
74
+ "expected_key_type": "user_key"
75
75
  },
76
76
  ])
77
77
  def test_provider_capabilities_real_logic(
@@ -79,7 +79,7 @@ def test_provider_capabilities_real_logic(
79
79
  monkeypatch: pytest.MonkeyPatch,
80
80
  provider_config: Config
81
81
  ) -> None:
82
- """Test the actual provider selection logic in OpenAIProvider.capabilities"""
82
+ """Test the actual provider selection logic in ProviderManager.capabilities"""
83
83
 
84
84
  # Set up the environment based on test case
85
85
  setup = test_case["setup"]
@@ -97,9 +97,6 @@ def test_provider_capabilities_real_logic(
97
97
  else:
98
98
  monkeypatch.setattr(f"mito_ai.constants.{key}", value)
99
99
 
100
- # Clear the provider config API key to ensure it uses constants
101
- provider_config.OpenAIProvider.api_key = None
102
-
103
100
  # Mock HTTP calls but let the real logic run
104
101
  with patch("openai.OpenAI") as mock_openai_constructor:
105
102
  with patch("openai.AsyncOpenAI") as mock_async_openai:
@@ -112,7 +109,7 @@ def test_provider_capabilities_real_logic(
112
109
  # Mock server limits for Mito server fallback
113
110
  with patch_server_limits():
114
111
  # NOW create the provider after ALL mocks are set up
115
- llm = OpenAIProvider(config=provider_config)
112
+ llm = ProviderManager(config=provider_config)
116
113
 
117
114
  # Test capabilities
118
115
  capabilities = llm.capabilities
@@ -6,7 +6,7 @@ import ast
6
6
  import inspect
7
7
  import requests
8
8
  from mito_ai.gemini_client import GeminiClient, get_gemini_system_prompt_and_messages
9
- from mito_ai.utils.gemini_utils import get_gemini_completion_function_params, FAST_GEMINI_MODEL
9
+ from mito_ai.utils.gemini_utils import get_gemini_completion_function_params
10
10
  from google.genai.types import Part, GenerateContentResponse, Candidate, Content
11
11
  from mito_ai.completions.models import ResponseFormatInfo, AgentResponse
12
12
  from unittest.mock import MagicMock, patch
@@ -156,19 +156,20 @@ async def test_json_response_handling_with_multiple_parts():
156
156
  assert result == 'Here is the JSON: {"key": "value"} End of response'
157
157
 
158
158
  CUSTOM_MODEL = "smart-gemini-model"
159
- @pytest.mark.parametrize("message_type, expected_model", [
160
- (MessageType.CHAT, CUSTOM_MODEL), #
161
- (MessageType.SMART_DEBUG, CUSTOM_MODEL), #
162
- (MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
163
- (MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
164
- (MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
165
- (MessageType.INLINE_COMPLETION, FAST_GEMINI_MODEL), #
166
- (MessageType.CHAT_NAME_GENERATION, FAST_GEMINI_MODEL), #
159
+ @pytest.mark.parametrize("message_type", [
160
+ MessageType.CHAT,
161
+ MessageType.SMART_DEBUG,
162
+ MessageType.CODE_EXPLAIN,
163
+ MessageType.AGENT_EXECUTION,
164
+ MessageType.AGENT_AUTO_ERROR_FIXUP,
165
+ MessageType.INLINE_COMPLETION,
166
+ MessageType.CHAT_NAME_GENERATION,
167
167
  ])
168
168
  @pytest.mark.asyncio
169
- async def test_get_completion_model_selection_based_on_message_type(message_type, expected_model):
169
+ async def test_get_completion_model_selection_uses_passed_model(message_type):
170
170
  """
171
- Tests that the correct model is selected based on the message type.
171
+ Tests that the model passed to the client is used as-is.
172
+ Model selection based on message type is now handled by ProviderManager.
172
173
  """
173
174
  with patch('google.genai.Client') as mock_genai_class:
174
175
  mock_client = MagicMock()
@@ -189,7 +190,7 @@ async def test_get_completion_model_selection_based_on_message_type(message_type
189
190
  response_format_info=None
190
191
  )
191
192
 
192
- # Verify that generate_content was called with the expected model
193
+ # Verify that generate_content was called with the model that was passed (not overridden)
193
194
  mock_models.generate_content.assert_called_once()
194
195
  call_args = mock_models.generate_content.call_args
195
- assert call_args[1]['model'] == expected_model
196
+ assert call_args[1]['model'] == CUSTOM_MODEL
@@ -6,19 +6,17 @@ These tests ensure that the correct model is chosen for each message type, for e
6
6
  """
7
7
 
8
8
  import pytest
9
- from mito_ai.utils.provider_utils import does_message_require_fast_model
9
+ from mito_ai.utils.model_utils import get_fast_model_for_selected_model, get_smartest_model_for_selected_model
10
10
  from mito_ai.completions.models import MessageType
11
11
  from unittest.mock import AsyncMock, MagicMock, patch
12
- from mito_ai.completions.providers import OpenAIProvider
13
- from mito_ai.completions.models import MessageType
14
- from mito_ai.utils.provider_utils import does_message_require_fast_model
12
+ from mito_ai.provider_manager import ProviderManager
15
13
  from traitlets.config import Config
16
14
 
17
15
  @pytest.fixture
18
16
  def provider_config() -> Config:
19
- """Create a proper Config object for the OpenAIProvider."""
17
+ """Create a proper Config object for the ProviderManager."""
20
18
  config = Config()
21
- config.OpenAIProvider = Config()
19
+ config.ProviderManager = Config()
22
20
  config.OpenAIClient = Config()
23
21
  return config
24
22
 
@@ -27,104 +25,162 @@ def mock_messages():
27
25
  """Sample messages for testing."""
28
26
  return [{"role": "user", "content": "Test message"}]
29
27
 
30
- # Test cases for different message types and their expected fast model requirement
31
- MESSAGE_TYPE_TEST_CASES = [
32
- (MessageType.CHAT, False),
33
- (MessageType.SMART_DEBUG, False),
34
- (MessageType.CODE_EXPLAIN, False),
35
- (MessageType.AGENT_EXECUTION, False),
36
- (MessageType.AGENT_AUTO_ERROR_FIXUP, False),
37
- (MessageType.INLINE_COMPLETION, True),
38
- (MessageType.CHAT_NAME_GENERATION, True),
39
- ]
40
- @pytest.mark.parametrize("message_type,expected_result", MESSAGE_TYPE_TEST_CASES)
41
- def test_does_message_require_fast_model(message_type: MessageType, expected_result: bool) -> None:
42
- """Test that does_message_require_fast_model returns the correct boolean for each message type."""
43
- assert does_message_require_fast_model(message_type) == expected_result
28
+ @pytest.mark.asyncio
29
+ async def test_request_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
30
+ """Test that request_completions uses the correct model when use_fast_model=True."""
31
+ # Set up environment variables to ensure OpenAI provider is used
32
+ monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
33
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
34
+
35
+ # Mock the OpenAI API call instead of the entire client
36
+ mock_response = MagicMock()
37
+ mock_response.choices = [MagicMock()]
38
+ mock_response.choices[0].message.content = "Test Completion"
39
+
40
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
41
+ mock_openai_client = MagicMock()
42
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
43
+ mock_openai_client.is_closed.return_value = False
44
+ mock_openai_class.return_value = mock_openai_client
45
+
46
+ # Mock the validation that happens in OpenAIClient constructor
47
+ with patch('openai.OpenAI') as mock_sync_openai_class:
48
+ mock_sync_client = MagicMock()
49
+ mock_sync_client.models.list.return_value = MagicMock()
50
+ mock_sync_openai_class.return_value = mock_sync_client
51
+
52
+ provider = ProviderManager(config=provider_config)
53
+ provider.set_selected_model("gpt-5.2")
54
+ await provider.request_completions(
55
+ message_type=MessageType.CHAT,
56
+ messages=mock_messages,
57
+ use_fast_model=True
58
+ )
59
+
60
+ # Verify the model passed to the API call is the fast model
61
+ call_args = mock_openai_client.chat.completions.create.call_args
62
+ assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
63
+
64
+ @pytest.mark.asyncio
65
+ async def test_stream_completions_uses_fast_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
66
+ """Test that stream_completions uses the correct model when use_fast_model=True."""
67
+ # Set up environment variables to ensure OpenAI provider is used
68
+ monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
69
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
44
70
 
45
- def test_does_message_require_fast_model_raises_error_for_unknown_message_type():
46
- """Test that does_message_require_fast_model raises an error for an unknown message type."""
47
- with pytest.raises(ValueError):
48
- does_message_require_fast_model('unknown_message_type') # type: ignore
71
+ # Mock the OpenAI API call instead of the entire client
72
+ mock_response = MagicMock()
73
+ mock_response.choices = [MagicMock()]
74
+ mock_response.choices[0].delta.content = "Test Stream Completion"
75
+ mock_response.choices[0].finish_reason = "stop"
76
+
77
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
78
+ mock_openai_client = MagicMock()
79
+ # Create an async generator for streaming
80
+ async def mock_stream():
81
+ yield mock_response
82
+
83
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
84
+ mock_openai_client.is_closed.return_value = False
85
+ mock_openai_class.return_value = mock_openai_client
86
+
87
+ # Mock the validation that happens in OpenAIClient constructor
88
+ with patch('openai.OpenAI') as mock_sync_openai_class:
89
+ mock_sync_client = MagicMock()
90
+ mock_sync_client.models.list.return_value = MagicMock()
91
+ mock_sync_openai_class.return_value = mock_sync_client
92
+
93
+ provider = ProviderManager(config=provider_config)
94
+ provider.set_selected_model("gpt-5.2")
95
+ await provider.stream_completions(
96
+ message_type=MessageType.CHAT,
97
+ messages=mock_messages,
98
+ message_id="test_id",
99
+ thread_id="test_thread",
100
+ reply_fn=lambda x: None,
101
+ use_fast_model=True
102
+ )
103
+
104
+ # Verify the model passed to the API call is the fast model
105
+ call_args = mock_openai_client.chat.completions.create.call_args
106
+ assert call_args[1]['model'] == get_fast_model_for_selected_model(provider.get_selected_model())
49
107
 
50
108
  @pytest.mark.asyncio
51
- async def test_request_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
52
- """Test that request_completions calls does_message_require_fast_model and uses the correct model."""
109
+ async def test_request_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
110
+ """Test that request_completions uses the correct model when use_smartest_model=True."""
53
111
  # Set up environment variables to ensure OpenAI provider is used
54
112
  monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
55
113
  monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
56
114
 
57
- with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
58
- # Mock the OpenAI API call instead of the entire client
59
- mock_response = MagicMock()
60
- mock_response.choices = [MagicMock()]
61
- mock_response.choices[0].message.content = "Test Completion"
115
+ # Mock the OpenAI API call instead of the entire client
116
+ mock_response = MagicMock()
117
+ mock_response.choices = [MagicMock()]
118
+ mock_response.choices[0].message.content = "Test Completion"
119
+
120
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
121
+ mock_openai_client = MagicMock()
122
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
123
+ mock_openai_client.is_closed.return_value = False
124
+ mock_openai_class.return_value = mock_openai_client
62
125
 
63
- with patch('openai.AsyncOpenAI') as mock_openai_class:
64
- mock_openai_client = MagicMock()
65
- mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
66
- mock_openai_client.is_closed.return_value = False
67
- mock_openai_class.return_value = mock_openai_client
126
+ # Mock the validation that happens in OpenAIClient constructor
127
+ with patch('openai.OpenAI') as mock_sync_openai_class:
128
+ mock_sync_client = MagicMock()
129
+ mock_sync_client.models.list.return_value = MagicMock()
130
+ mock_sync_openai_class.return_value = mock_sync_client
68
131
 
69
- # Mock the validation that happens in OpenAIClient constructor
70
- with patch('openai.OpenAI') as mock_sync_openai_class:
71
- mock_sync_client = MagicMock()
72
- mock_sync_client.models.list.return_value = MagicMock()
73
- mock_sync_openai_class.return_value = mock_sync_client
74
-
75
- provider = OpenAIProvider(config=provider_config)
76
- await provider.request_completions(
77
- message_type=MessageType.CHAT,
78
- messages=mock_messages,
79
- model="gpt-3.5",
80
- )
81
-
82
- mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
83
- # Verify the model passed to the API call
84
- call_args = mock_openai_client.chat.completions.create.call_args
85
- assert call_args[1]['model'] == "gpt-3.5"
132
+ provider = ProviderManager(config=provider_config)
133
+ provider.set_selected_model("gpt-4.1")
134
+ await provider.request_completions(
135
+ message_type=MessageType.CHAT,
136
+ messages=mock_messages,
137
+ use_smartest_model=True
138
+ )
139
+
140
+ # Verify the model passed to the API call is the smartest model
141
+ call_args = mock_openai_client.chat.completions.create.call_args
142
+ assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())
86
143
 
87
144
  @pytest.mark.asyncio
88
- async def test_stream_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
89
- """Test that stream_completions calls does_message_require_fast_model and uses the correct model."""
145
+ async def test_stream_completions_uses_smartest_model_when_requested(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
146
+ """Test that stream_completions uses the correct model when use_smartest_model=True."""
90
147
  # Set up environment variables to ensure OpenAI provider is used
91
148
  monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
92
149
  monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
93
150
 
94
- with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
95
- # Mock the OpenAI API call instead of the entire client
96
- mock_response = MagicMock()
97
- mock_response.choices = [MagicMock()]
98
- mock_response.choices[0].delta.content = "Test Stream Completion"
99
- mock_response.choices[0].finish_reason = "stop"
151
+ # Mock the OpenAI API call instead of the entire client
152
+ mock_response = MagicMock()
153
+ mock_response.choices = [MagicMock()]
154
+ mock_response.choices[0].delta.content = "Test Stream Completion"
155
+ mock_response.choices[0].finish_reason = "stop"
156
+
157
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
158
+ mock_openai_client = MagicMock()
159
+ # Create an async generator for streaming
160
+ async def mock_stream():
161
+ yield mock_response
162
+
163
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
164
+ mock_openai_client.is_closed.return_value = False
165
+ mock_openai_class.return_value = mock_openai_client
100
166
 
101
- with patch('openai.AsyncOpenAI') as mock_openai_class:
102
- mock_openai_client = MagicMock()
103
- # Create an async generator for streaming
104
- async def mock_stream():
105
- yield mock_response
167
+ # Mock the validation that happens in OpenAIClient constructor
168
+ with patch('openai.OpenAI') as mock_sync_openai_class:
169
+ mock_sync_client = MagicMock()
170
+ mock_sync_client.models.list.return_value = MagicMock()
171
+ mock_sync_openai_class.return_value = mock_sync_client
106
172
 
107
- mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
108
- mock_openai_client.is_closed.return_value = False
109
- mock_openai_class.return_value = mock_openai_client
173
+ provider = ProviderManager(config=provider_config)
174
+ provider.set_selected_model("gpt-4.1")
175
+ await provider.stream_completions(
176
+ message_type=MessageType.CHAT,
177
+ messages=mock_messages,
178
+ message_id="test_id",
179
+ thread_id="test_thread",
180
+ reply_fn=lambda x: None,
181
+ use_smartest_model=True
182
+ )
110
183
 
111
- # Mock the validation that happens in OpenAIClient constructor
112
- with patch('openai.OpenAI') as mock_sync_openai_class:
113
- mock_sync_client = MagicMock()
114
- mock_sync_client.models.list.return_value = MagicMock()
115
- mock_sync_openai_class.return_value = mock_sync_client
116
-
117
- provider = OpenAIProvider(config=provider_config)
118
- await provider.stream_completions(
119
- message_type=MessageType.CHAT,
120
- messages=mock_messages,
121
- model="gpt-3.5",
122
- message_id="test_id",
123
- thread_id="test_thread",
124
- reply_fn=lambda x: None
125
- )
126
-
127
- mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
128
- # Verify the model passed to the API call
129
- call_args = mock_openai_client.chat.completions.create.call_args
130
- assert call_args[1]['model'] == "gpt-3.5"
184
+ # Verify the model passed to the API call is the smartest model
185
+ call_args = mock_openai_client.chat.completions.create.call_args
186
+ assert call_args[1]['model'] == get_smartest_model_for_selected_model(provider.get_selected_model())