mito-ai 0.1.57__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. mito_ai/__init__.py +16 -22
  2. mito_ai/_version.py +1 -1
  3. mito_ai/anthropic_client.py +24 -14
  4. mito_ai/chart_wizard/handlers.py +78 -17
  5. mito_ai/chart_wizard/urls.py +8 -5
  6. mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
  7. mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
  8. mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
  9. mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
  10. mito_ai/completions/completion_handlers/completion_handler.py +3 -5
  11. mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
  12. mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
  13. mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
  14. mito_ai/completions/completion_handlers/utils.py +3 -7
  15. mito_ai/completions/handlers.py +32 -22
  16. mito_ai/completions/message_history.py +8 -10
  17. mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
  18. mito_ai/constants.py +8 -1
  19. mito_ai/enterprise/__init__.py +1 -1
  20. mito_ai/enterprise/litellm_client.py +137 -0
  21. mito_ai/log/handlers.py +1 -1
  22. mito_ai/openai_client.py +10 -90
  23. mito_ai/{completions/providers.py → provider_manager.py} +157 -53
  24. mito_ai/settings/enterprise_handler.py +26 -0
  25. mito_ai/settings/urls.py +2 -0
  26. mito_ai/streamlit_conversion/agent_utils.py +2 -30
  27. mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
  28. mito_ai/streamlit_preview/handlers.py +6 -3
  29. mito_ai/streamlit_preview/urls.py +5 -3
  30. mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
  31. mito_ai/tests/providers/test_anthropic_client.py +174 -16
  32. mito_ai/tests/providers/test_azure.py +13 -13
  33. mito_ai/tests/providers/test_capabilities.py +14 -17
  34. mito_ai/tests/providers/test_gemini_client.py +14 -13
  35. mito_ai/tests/providers/test_model_resolution.py +145 -89
  36. mito_ai/tests/providers/test_openai_client.py +209 -13
  37. mito_ai/tests/providers/test_provider_limits.py +5 -5
  38. mito_ai/tests/providers/test_providers.py +229 -51
  39. mito_ai/tests/providers/test_retry_logic.py +13 -22
  40. mito_ai/tests/providers/utils.py +4 -4
  41. mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
  42. mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
  43. mito_ai/tests/test_enterprise_mode.py +162 -0
  44. mito_ai/tests/test_model_utils.py +271 -0
  45. mito_ai/utils/anthropic_utils.py +8 -6
  46. mito_ai/utils/gemini_utils.py +0 -3
  47. mito_ai/utils/litellm_utils.py +84 -0
  48. mito_ai/utils/model_utils.py +178 -0
  49. mito_ai/utils/open_ai_utils.py +0 -8
  50. mito_ai/utils/provider_utils.py +6 -28
  51. mito_ai/utils/telemetry_utils.py +14 -2
  52. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
  53. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  54. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  55. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +671 -75
  56. mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
  57. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
  58. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
  59. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
  60. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/RECORD +86 -79
  61. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
  62. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  63. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  64. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
  65. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
  66. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  67. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
  68. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
  69. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
  70. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
  71. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
  72. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
  73. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
  74. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
  75. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
  76. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
  77. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
  78. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
  79. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
  80. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
  81. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  82. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  83. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +0 -0
  84. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
  85. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
  86. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
  87. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
@@ -1,9 +1,10 @@
1
1
  # Copyright (c) Saga Inc.
2
2
  # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
3
 
4
- from anthropic.types import MessageParam
5
- from typing import List, cast
6
- from mito_ai.streamlit_conversion.agent_utils import extract_todo_placeholders, get_response_from_agent
4
+ from typing import List
5
+ from openai.types.chat import ChatCompletionMessageParam
6
+ from mito_ai.streamlit_conversion.agent_utils import extract_todo_placeholders
7
+ from mito_ai.provider_manager import ProviderManager
7
8
  from mito_ai.streamlit_conversion.prompts.streamlit_app_creation_prompt import get_streamlit_app_creation_prompt
8
9
  from mito_ai.streamlit_conversion.prompts.streamlit_error_correction_prompt import get_streamlit_error_correction_prompt
9
10
  from mito_ai.streamlit_conversion.prompts.streamlit_finish_todo_prompt import get_finish_todo_prompt
@@ -15,22 +16,23 @@ from mito_ai.completions.models import MessageType
15
16
  from mito_ai.utils.error_classes import StreamlitConversionError
16
17
  from mito_ai.utils.telemetry_utils import log_streamlit_app_validation_retry, log_streamlit_app_conversion_success
17
18
  from mito_ai.path_utils import AbsoluteNotebookPath, AppFileName, get_absolute_notebook_dir_path, get_absolute_app_path, get_app_file_name
19
+ from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
18
20
 
19
- async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt: str) -> str:
21
+ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt: str, provider: ProviderManager) -> str:
20
22
  """Send a query to the agent, get its response and parse the code"""
21
23
 
22
24
  prompt_text = get_streamlit_app_creation_prompt(notebook, streamlit_app_prompt)
23
25
 
24
- messages: List[MessageParam] = [
25
- cast(MessageParam, {
26
- "role": "user",
27
- "content": [{
28
- "type": "text",
29
- "text": prompt_text
30
- }]
31
- })
26
+ messages: List[ChatCompletionMessageParam] = [
27
+ {"role": "system", "content": streamlit_system_prompt},
28
+ {"role": "user", "content": prompt_text}
32
29
  ]
33
- agent_response = await get_response_from_agent(messages)
30
+ agent_response = await provider.request_completions(
31
+ message_type=MessageType.STREAMLIT_CONVERSION,
32
+ messages=messages,
33
+ use_smartest_model=True,
34
+ thread_id=None
35
+ )
34
36
  converted_code = extract_code_blocks(agent_response)
35
37
 
36
38
  # Extract the TODOs from the agent's response
@@ -39,16 +41,16 @@ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt
39
41
  for todo_placeholder in todo_placeholders:
40
42
  print(f"Processing AI TODO: {todo_placeholder}")
41
43
  todo_prompt = get_finish_todo_prompt(notebook, converted_code, todo_placeholder)
42
- todo_messages: List[MessageParam] = [
43
- cast(MessageParam, {
44
- "role": "user",
45
- "content": [{
46
- "type": "text",
47
- "text": todo_prompt
48
- }]
49
- })
44
+ todo_messages: List[ChatCompletionMessageParam] = [
45
+ {"role": "system", "content": streamlit_system_prompt},
46
+ {"role": "user", "content": todo_prompt}
50
47
  ]
51
- todo_response = await get_response_from_agent(todo_messages)
48
+ todo_response = await provider.request_completions(
49
+ message_type=MessageType.STREAMLIT_CONVERSION,
50
+ messages=todo_messages,
51
+ use_smartest_model=True,
52
+ thread_id=None
53
+ )
52
54
 
53
55
  # Apply the search/replace to the streamlit app
54
56
  search_replace_pairs = extract_search_replace_blocks(todo_response)
@@ -57,21 +59,21 @@ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt
57
59
  return converted_code
58
60
 
59
61
 
60
- async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_code: str, edit_prompt: str) -> str:
62
+ async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_code: str, edit_prompt: str, provider: ProviderManager) -> str:
61
63
  """Send a query to the agent, get its response and parse the code"""
62
64
  prompt_text = get_update_existing_app_prompt(notebook, streamlit_app_code, edit_prompt)
63
65
 
64
- messages: List[MessageParam] = [
65
- cast(MessageParam, {
66
- "role": "user",
67
- "content": [{
68
- "type": "text",
69
- "text": prompt_text
70
- }]
71
- })
66
+ messages: List[ChatCompletionMessageParam] = [
67
+ {"role": "system", "content": streamlit_system_prompt},
68
+ {"role": "user", "content": prompt_text}
72
69
  ]
73
70
 
74
- agent_response = await get_response_from_agent(messages)
71
+ agent_response = await provider.request_completions(
72
+ message_type=MessageType.STREAMLIT_CONVERSION,
73
+ messages=messages,
74
+ use_smartest_model=True,
75
+ thread_id=None
76
+ )
75
77
  print(f"[Mito AI Search/Replace Tool]:\n {agent_response}")
76
78
 
77
79
  # Apply the search/replace to the streamlit app
@@ -81,18 +83,18 @@ async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_cod
81
83
  return converted_code
82
84
 
83
85
 
84
- async def correct_error_in_generation(error: str, streamlit_app_code: str) -> str:
86
+ async def correct_error_in_generation(error: str, streamlit_app_code: str, provider: ProviderManager) -> str:
85
87
  """If errors are present, send it back to the agent to get corrections in code"""
86
- messages: List[MessageParam] = [
87
- cast(MessageParam, {
88
- "role": "user",
89
- "content": [{
90
- "type": "text",
91
- "text": get_streamlit_error_correction_prompt(error, streamlit_app_code)
92
- }]
93
- })
88
+ messages: List[ChatCompletionMessageParam] = [
89
+ {"role": "system", "content": streamlit_system_prompt},
90
+ {"role": "user", "content": get_streamlit_error_correction_prompt(error, streamlit_app_code)}
94
91
  ]
95
- agent_response = await get_response_from_agent(messages)
92
+ agent_response = await provider.request_completions(
93
+ message_type=MessageType.STREAMLIT_CONVERSION,
94
+ messages=messages,
95
+ use_smartest_model=True,
96
+ thread_id=None
97
+ )
96
98
 
97
99
  # Apply the search/replace to the streamlit app
98
100
  search_replace_pairs = extract_search_replace_blocks(agent_response)
@@ -100,7 +102,7 @@ async def correct_error_in_generation(error: str, streamlit_app_code: str) -> st
100
102
 
101
103
  return streamlit_app_code
102
104
 
103
- async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNotebookPath, app_file_name: AppFileName, streamlit_app_prompt: str = "") -> None:
105
+ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNotebookPath, app_file_name: AppFileName, streamlit_app_prompt: str, provider: ProviderManager) -> None:
104
106
  """Handler function for streamlit code generation and validation"""
105
107
 
106
108
  # Convert to absolute path for consistent handling
@@ -110,7 +112,7 @@ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNoteboo
110
112
 
111
113
  if create_new_app:
112
114
  # Otherwise generate a new streamlit app
113
- streamlit_code = await generate_new_streamlit_code(notebook_code, streamlit_app_prompt)
115
+ streamlit_code = await generate_new_streamlit_code(notebook_code, streamlit_app_prompt, provider)
114
116
  else:
115
117
  # If the user is editing an existing streamlit app, use the update function
116
118
  existing_streamlit_code = get_app_code_from_file(app_path)
@@ -118,14 +120,14 @@ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNoteboo
118
120
  if existing_streamlit_code is None:
119
121
  raise StreamlitConversionError("Error updating existing streamlit app because app.py file was not found.", 404)
120
122
 
121
- streamlit_code = await update_existing_streamlit_code(notebook_code, existing_streamlit_code, streamlit_app_prompt)
123
+ streamlit_code = await update_existing_streamlit_code(notebook_code, existing_streamlit_code, streamlit_app_prompt, provider)
122
124
 
123
125
  # Then, after creating/updating the app, validate that the new code runs
124
126
  errors = validate_app(streamlit_code, notebook_path)
125
127
  tries = 0
126
128
  while len(errors) > 0 and tries < 5:
127
129
  for error in errors:
128
- streamlit_code = await correct_error_in_generation(error, streamlit_code)
130
+ streamlit_code = await correct_error_in_generation(error, streamlit_code, provider)
129
131
 
130
132
  errors = validate_app(streamlit_code, notebook_path)
131
133
 
@@ -11,15 +11,18 @@ from mito_ai.utils.telemetry_utils import log_streamlit_app_conversion_error, lo
11
11
  from mito_ai.completions.models import MessageType
12
12
  from mito_ai.utils.error_classes import StreamlitConversionError, StreamlitPreviewError
13
13
  from mito_ai.streamlit_conversion.streamlit_agent_handler import streamlit_handler
14
+ from mito_ai.provider_manager import ProviderManager
14
15
  import traceback
15
16
 
16
17
 
17
18
  class StreamlitPreviewHandler(APIHandler):
18
19
  """REST handler for streamlit preview operations."""
19
20
 
20
- def initialize(self) -> None:
21
+ def initialize(self, llm: ProviderManager) -> None:
21
22
  """Initialize the handler."""
23
+ super().initialize()
22
24
  self.preview_manager = StreamlitPreviewManager()
25
+ self._llm = llm
23
26
 
24
27
  @tornado.web.authenticated
25
28
 
@@ -45,11 +48,11 @@ class StreamlitPreviewHandler(APIHandler):
45
48
  print("[Mito AI] Force recreating streamlit app")
46
49
 
47
50
  # Create a new app
48
- await streamlit_handler(True, absolute_notebook_path, app_file_name, streamlit_app_prompt)
51
+ await streamlit_handler(True, absolute_notebook_path, app_file_name, streamlit_app_prompt, self._llm)
49
52
  elif streamlit_app_prompt != '':
50
53
  # Update an existing app if there is a prompt provided. Otherwise, the user is just
51
54
  # starting an existing app so we can skip the streamlit_handler all together
52
- await streamlit_handler(False, absolute_notebook_path, app_file_name, streamlit_app_prompt)
55
+ await streamlit_handler(False, absolute_notebook_path, app_file_name, streamlit_app_prompt, self._llm)
53
56
 
54
57
  # Start preview
55
58
  # TODO: There's a bug here where when the user rebuilds and already running app. Instead of
@@ -4,12 +4,14 @@
4
4
  from typing import Any, List, Tuple
5
5
  from jupyter_server.utils import url_path_join
6
6
  from mito_ai.streamlit_preview.handlers import StreamlitPreviewHandler
7
+ from mito_ai.provider_manager import ProviderManager
7
8
 
8
- def get_streamlit_preview_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
9
+ def get_streamlit_preview_urls(base_url: str, provider_manager: ProviderManager) -> List[Tuple[str, Any, dict]]:
9
10
  """Get all streamlit preview related URL patterns.
10
11
 
11
12
  Args:
12
13
  base_url: The base URL for the Jupyter server
14
+ provider_manager: The ProviderManager instance
13
15
 
14
16
  Returns:
15
17
  List of (url_pattern, handler_class, handler_kwargs) tuples
@@ -17,6 +19,6 @@ def get_streamlit_preview_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
17
19
  BASE_URL = base_url + "/mito-ai"
18
20
 
19
21
  return [
20
- (url_path_join(BASE_URL, "streamlit-preview"), StreamlitPreviewHandler, {}),
21
- (url_path_join(BASE_URL, "streamlit-preview/(.+)"), StreamlitPreviewHandler, {}),
22
+ (url_path_join(BASE_URL, "streamlit-preview"), StreamlitPreviewHandler, {"llm": provider_manager}),
23
+ (url_path_join(BASE_URL, "streamlit-preview/(.+)"), StreamlitPreviewHandler, {"llm": provider_manager}),
22
24
  ]
@@ -5,24 +5,25 @@ import pytest
5
5
  from unittest.mock import AsyncMock, MagicMock, patch
6
6
  from traitlets.config import Config
7
7
  from mito_ai.completions.message_history import generate_short_chat_name
8
- from mito_ai.completions.providers import OpenAIProvider
8
+ from mito_ai.provider_manager import ProviderManager
9
9
 
10
10
 
11
11
  @pytest.fixture
12
12
  def provider_config() -> Config:
13
- """Create a proper Config object for the OpenAIProvider."""
13
+ """Create a proper Config object for the ProviderManager."""
14
14
  config = Config()
15
- config.OpenAIProvider = Config()
15
+ config.ProviderManager = Config()
16
16
  config.OpenAIClient = Config()
17
17
  return config
18
18
 
19
19
 
20
20
  # Test cases for different models and their expected providers/fast models
21
21
  PROVIDER_TEST_CASES = [
22
- # (model, client_patch_path)
23
- ("gpt-4.1", "mito_ai.completions.providers.OpenAIClient"),
24
- ("claude-3-5-sonnet-20241022", "mito_ai.completions.providers.AnthropicClient"),
25
- ("gemini-2.0-flash-exp", "mito_ai.completions.providers.GeminiClient")
22
+ # (model, client_patch_path) - patch where the classes are used (in provider_manager)
23
+ ("gpt-4.1", "mito_ai.provider_manager.OpenAIClient"),
24
+ ("claude-sonnet-4-5-20250929", "mito_ai.provider_manager.AnthropicClient"),
25
+ ("gemini-3-flash-preview", "mito_ai.provider_manager.GeminiClient"),
26
+ ("openai/gpt-4o", "mito_ai.provider_manager.LiteLLMClient"), # LiteLLM test case
26
27
  ]
27
28
 
28
29
  @pytest.mark.parametrize("selected_model,client_patch_path", PROVIDER_TEST_CASES)
@@ -37,31 +38,77 @@ async def test_generate_short_chat_name_uses_correct_provider_and_fast_model(
37
38
 
38
39
  # Set up environment variables for all providers
39
40
  monkeypatch.setenv("OPENAI_API_KEY", "fake-openai-key")
40
- monkeypatch.setenv("CLAUDE_API_KEY", "fake-claude-key")
41
+ monkeypatch.setenv("ANTHROPIC_API_KEY", "fake-claude-key")
41
42
  monkeypatch.setenv("GEMINI_API_KEY", "fake-gemini-key")
42
43
  monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-openai-key")
43
- monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", "fake-claude-key")
44
+ monkeypatch.setattr("mito_ai.constants.ANTHROPIC_API_KEY", "fake-claude-key")
44
45
  monkeypatch.setattr("mito_ai.constants.GEMINI_API_KEY", "fake-gemini-key")
45
46
 
47
+ # Set up LiteLLM constants if testing LiteLLM
48
+ if "LiteLLMClient" in client_patch_path:
49
+ # Patch constants both at the source and where they're imported in model_utils
50
+ monkeypatch.setattr("mito_ai.constants.LITELLM_BASE_URL", "https://litellm-server.com")
51
+ monkeypatch.setattr("mito_ai.constants.LITELLM_API_KEY", "fake-litellm-key")
52
+ monkeypatch.setattr("mito_ai.constants.LITELLM_MODELS", ["openai/gpt-4o", "anthropic/claude-3-5-sonnet"])
53
+ # Also patch where constants is imported in model_utils (where get_available_models uses it)
54
+ monkeypatch.setattr("mito_ai.utils.model_utils.constants.LITELLM_BASE_URL", "https://litellm-server.com")
55
+ monkeypatch.setattr("mito_ai.utils.model_utils.constants.LITELLM_MODELS", ["openai/gpt-4o", "anthropic/claude-3-5-sonnet"])
56
+ # Mock is_enterprise to return True so LiteLLM models are available
57
+ monkeypatch.setattr("mito_ai.utils.version_utils.is_enterprise", lambda: True)
58
+
46
59
  # Create mock client for the specific provider being tested
47
60
  mock_client = MagicMock()
48
61
  mock_client.request_completions = AsyncMock(return_value="Test Chat Name")
49
62
 
63
+ # Create the ProviderManager first
64
+ llm_provider = ProviderManager(config=provider_config)
65
+
66
+ # Set the selected model (this is required for the ProviderManager to use the correct model)
67
+ llm_provider.set_selected_model(selected_model)
68
+
50
69
  # Patch the specific client class that should be used based on the model
51
- # We need to patch before creating the OpenAIProvider since OpenAI client is created in constructor
52
- with patch(client_patch_path, return_value=mock_client):
53
- # Create the OpenAIProvider after patching so the mock client is used
54
- llm_provider = OpenAIProvider(config=provider_config)
55
-
56
- # Test the function
57
- result = await generate_short_chat_name(
58
- user_message="What is the capital of France?",
59
- assistant_message="The capital of France is Paris.",
60
- model=selected_model,
61
- llm_provider=llm_provider
62
- )
63
-
64
- # Verify that the correct client's request_completions was called
70
+ # For Anthropic, Gemini, and LiteLLM, new instances are created in request_completions, so we patch the class
71
+ # For OpenAI, the instance is created in __init__, so we patch the instance's method
72
+ if "AnthropicClient" in client_patch_path:
73
+ with patch(client_patch_path, return_value=mock_client):
74
+ result = await generate_short_chat_name(
75
+ user_message="What is the capital of France?",
76
+ assistant_message="The capital of France is Paris.",
77
+ llm_provider=llm_provider
78
+ )
79
+ elif "GeminiClient" in client_patch_path:
80
+ with patch(client_patch_path, return_value=mock_client):
81
+ result = await generate_short_chat_name(
82
+ user_message="What is the capital of France?",
83
+ assistant_message="The capital of France is Paris.",
84
+ llm_provider=llm_provider
85
+ )
86
+ elif "LiteLLMClient" in client_patch_path:
87
+ # Patch LiteLLMClient where it's defined (it's imported inside request_completions)
88
+ # Also patch get_available_models to return LiteLLM models
89
+ with patch("mito_ai.enterprise.litellm_client.LiteLLMClient", return_value=mock_client), \
90
+ patch("mito_ai.provider_manager.get_available_models", return_value=["openai/gpt-4o", "anthropic/claude-3-5-sonnet"]):
91
+ result = await generate_short_chat_name(
92
+ user_message="What is the capital of France?",
93
+ assistant_message="The capital of France is Paris.",
94
+ llm_provider=llm_provider
95
+ )
96
+ else: # OpenAI
97
+ # For OpenAI, patch the instance's method since the client is created in __init__
98
+ assert llm_provider._openai_client is not None, "OpenAI client should be initialized"
99
+ with patch.object(llm_provider._openai_client, 'request_completions', new_callable=AsyncMock, return_value="Test Chat Name") as mock_openai_request:
100
+ result = await generate_short_chat_name(
101
+ user_message="What is the capital of France?",
102
+ assistant_message="The capital of France is Paris.",
103
+ llm_provider=llm_provider
104
+ )
105
+ # Verify that the OpenAI client's request_completions was called
106
+ mock_openai_request.assert_called_once() # type: ignore
107
+ # As a double check, if we have used the correct client, then we must get the correct result
108
+ assert result == "Test Chat Name"
109
+ return
110
+
111
+ # Verify that the correct client's request_completions was called (for Anthropic, Gemini, and LiteLLM)
65
112
  mock_client.request_completions.assert_called_once()
66
113
 
67
114
  # As a double check, if we have used the correct client, then we must get the correct result
@@ -74,13 +121,12 @@ async def test_generate_short_chat_name_cleans_gemini_response() -> None:
74
121
  """Test that generate_short_chat_name properly cleans Gemini-style responses with quotes and newlines."""
75
122
 
76
123
  # Create mock llm_provider that returns a response with quotes and newlines
77
- mock_llm_provider = MagicMock(spec=OpenAIProvider)
124
+ mock_llm_provider = MagicMock(spec=ProviderManager)
78
125
  mock_llm_provider.request_completions = AsyncMock(return_value='"France Geography Discussion\n"')
79
126
 
80
127
  result = await generate_short_chat_name(
81
128
  user_message="What is the capital of France?",
82
129
  assistant_message="The capital of France is Paris.",
83
- model="gemini-2.0-flash-exp",
84
130
  llm_provider=mock_llm_provider
85
131
  )
86
132
 
@@ -95,13 +141,12 @@ async def test_generate_short_chat_name_handles_empty_response() -> None:
95
141
  """Test that generate_short_chat_name handles empty or None responses gracefully."""
96
142
 
97
143
  # Test with empty string response
98
- mock_llm_provider = MagicMock(spec=OpenAIProvider)
144
+ mock_llm_provider = MagicMock(spec=ProviderManager)
99
145
  mock_llm_provider.request_completions = AsyncMock(return_value="")
100
146
 
101
147
  result = await generate_short_chat_name(
102
148
  user_message="Test message",
103
149
  assistant_message="Test response",
104
- model="gpt-4.1",
105
150
  llm_provider=mock_llm_provider
106
151
  )
107
152
 
@@ -113,7 +158,6 @@ async def test_generate_short_chat_name_handles_empty_response() -> None:
113
158
  result = await generate_short_chat_name(
114
159
  user_message="Test message",
115
160
  assistant_message="Test response",
116
- model="gpt-4.1",
117
161
  llm_provider=mock_llm_provider
118
162
  )
119
163
 
@@ -3,11 +3,10 @@
3
3
 
4
4
  import pytest
5
5
  from mito_ai.anthropic_client import get_anthropic_system_prompt_and_messages, get_anthropic_system_prompt_and_messages_with_caching, add_cache_control_to_message, extract_and_parse_anthropic_json_response, AnthropicClient
6
- from mito_ai.utils.anthropic_utils import FAST_ANTHROPIC_MODEL
7
6
  from anthropic.types import Message, TextBlock, ToolUseBlock, Usage, ToolUseBlock, Message, Usage, TextBlock
8
7
  from openai.types.chat import ChatCompletionMessageParam, ChatCompletionUserMessageParam, ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam
9
8
  from mito_ai.completions.models import MessageType
10
- from unittest.mock import patch
9
+ from unittest.mock import MagicMock, patch
11
10
  import anthropic
12
11
  from typing import List, Dict, cast
13
12
 
@@ -233,24 +232,25 @@ def test_tool_use_without_agent_response():
233
232
  assert "No valid AgentResponse format found" in str(exc_info.value)
234
233
 
235
234
  CUSTOM_MODEL = "smart-anthropic-model"
236
- @pytest.mark.parametrize("message_type, expected_model", [
237
- (MessageType.CHAT, CUSTOM_MODEL), #
238
- (MessageType.SMART_DEBUG, CUSTOM_MODEL), #
239
- (MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
240
- (MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
241
- (MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
242
- (MessageType.INLINE_COMPLETION, FAST_ANTHROPIC_MODEL), #
243
- (MessageType.CHAT_NAME_GENERATION, FAST_ANTHROPIC_MODEL), #
235
+ @pytest.mark.parametrize("message_type", [
236
+ MessageType.CHAT,
237
+ MessageType.SMART_DEBUG,
238
+ MessageType.CODE_EXPLAIN,
239
+ MessageType.AGENT_EXECUTION,
240
+ MessageType.AGENT_AUTO_ERROR_FIXUP,
241
+ MessageType.INLINE_COMPLETION,
242
+ MessageType.CHAT_NAME_GENERATION,
244
243
  ])
245
244
  @pytest.mark.asyncio
246
- async def test_model_selection_based_on_message_type(message_type, expected_model):
245
+ async def test_model_selection_uses_passed_model(message_type):
247
246
  """
248
- Tests that the correct model is selected based on the message type.
247
+ Tests that the model passed to the client is used as-is.
248
+ Model selection based on message type is now handled by ProviderManager.
249
249
  """
250
250
  client = AnthropicClient(api_key="test_key")
251
251
 
252
- # Mock the messages.create method directly
253
- with patch.object(client.client.messages, 'create') as mock_create: # type: ignore
252
+ # Mock the beta.messages.create method directly (we now use beta API)
253
+ with patch.object(client.client.beta.messages, 'create') as mock_create: # type: ignore
254
254
  # Create a mock response
255
255
  mock_response = Message(
256
256
  id="test_id",
@@ -269,10 +269,168 @@ async def test_model_selection_based_on_message_type(message_type, expected_mode
269
269
  response_format_info=None
270
270
  )
271
271
 
272
- # Verify that create was called with the expected model
272
+ # Verify that create was called with the model that was passed (not overridden)
273
273
  mock_create.assert_called_once()
274
274
  call_args = mock_create.call_args
275
- assert call_args[1]['model'] == expected_model
275
+ assert call_args[1]['model'] == CUSTOM_MODEL
276
+
277
+ @pytest.mark.asyncio
278
+ async def test_anthropic_client_uses_fast_model_from_provider_manager_without_override():
279
+ """Test that Anthropic client uses the fast model passed from ProviderManager without internal override."""
280
+ from mito_ai.utils.model_utils import get_fast_model_for_selected_model
281
+
282
+ client = AnthropicClient(api_key="test_key")
283
+
284
+ # Mock the beta.messages.create method directly (we now use beta API)
285
+ with patch.object(client.client.beta.messages, 'create') as mock_create: # type: ignore
286
+ # Create a mock response
287
+ mock_response = Message(
288
+ id="test_id",
289
+ role="assistant",
290
+ content=[TextBlock(type="text", text="test")],
291
+ model='anthropic-model-we-do-not-check',
292
+ type="message",
293
+ usage=Usage(input_tokens=0, output_tokens=0)
294
+ )
295
+ mock_create.return_value = mock_response
296
+
297
+ # Use a fast model that would be selected by ProviderManager
298
+ fast_model = get_fast_model_for_selected_model("claude-sonnet-4-5-20250929")
299
+
300
+ await client.request_completions(
301
+ messages=[{"role": "user", "content": "Test message"}],
302
+ model=fast_model,
303
+ message_type=MessageType.CHAT,
304
+ response_format_info=None
305
+ )
306
+
307
+ # Verify that create was called with the fast model that was passed (not overridden)
308
+ mock_create.assert_called_once()
309
+ call_args = mock_create.call_args
310
+ assert call_args[1]['model'] == fast_model
311
+
312
+ @pytest.mark.asyncio
313
+ async def test_anthropic_client_uses_smartest_model_from_provider_manager_without_override():
314
+ """Test that Anthropic client uses the smartest model passed from ProviderManager without internal override."""
315
+ from mito_ai.utils.model_utils import get_smartest_model_for_selected_model
316
+
317
+ client = AnthropicClient(api_key="test_key")
318
+
319
+ # Mock the beta.messages.create method directly (we now use beta API)
320
+ with patch.object(client.client.beta.messages, 'create') as mock_create: # type: ignore
321
+ # Create a mock response
322
+ mock_response = Message(
323
+ id="test_id",
324
+ role="assistant",
325
+ content=[TextBlock(type="text", text="test")],
326
+ model='anthropic-model-we-do-not-check',
327
+ type="message",
328
+ usage=Usage(input_tokens=0, output_tokens=0)
329
+ )
330
+ mock_create.return_value = mock_response
331
+
332
+ # Use a smartest model that would be selected by ProviderManager
333
+ smartest_model = get_smartest_model_for_selected_model("claude-haiku-4-5-20251001")
334
+
335
+ await client.request_completions(
336
+ messages=[{"role": "user", "content": "Test message"}],
337
+ model=smartest_model,
338
+ message_type=MessageType.CHAT,
339
+ response_format_info=None
340
+ )
341
+
342
+ # Verify that create was called with the smartest model that was passed (not overridden)
343
+ mock_create.assert_called_once()
344
+ call_args = mock_create.call_args
345
+ assert call_args[1]['model'] == smartest_model
346
+
347
+ @pytest.mark.asyncio
348
+ async def test_anthropic_client_stream_uses_fast_model_from_provider_manager_without_override():
349
+ """Test that Anthropic client stream_completions uses the fast model passed from ProviderManager without internal override."""
350
+ from mito_ai.utils.model_utils import get_fast_model_for_selected_model
351
+
352
+ client = AnthropicClient(api_key="test_key")
353
+
354
+ # Mock the beta.messages.create method for streaming
355
+ with patch.object(client.client.beta.messages, 'create') as mock_create: # type: ignore
356
+ # Create a mock stream response
357
+ class MockStreamChunk:
358
+ def __init__(self, chunk_type, text=""):
359
+ self.type = chunk_type
360
+ if chunk_type == "content_block_delta":
361
+ self.delta = MagicMock()
362
+ self.delta.type = "text_delta"
363
+ self.delta.text = text
364
+
365
+ mock_stream = [
366
+ MockStreamChunk("content_block_delta", "test"),
367
+ MockStreamChunk("message_stop")
368
+ ]
369
+ mock_create.return_value = iter(mock_stream)
370
+
371
+ # Use a fast model that would be selected by ProviderManager
372
+ fast_model = get_fast_model_for_selected_model("claude-sonnet-4-5-20250929")
373
+
374
+ reply_chunks = []
375
+ def mock_reply(chunk):
376
+ reply_chunks.append(chunk)
377
+
378
+ await client.stream_completions(
379
+ messages=[{"role": "user", "content": "Test message"}],
380
+ model=fast_model,
381
+ message_id="test-id",
382
+ message_type=MessageType.CHAT,
383
+ reply_fn=mock_reply
384
+ )
385
+
386
+ # Verify that create was called with the fast model that was passed (not overridden)
387
+ mock_create.assert_called_once()
388
+ call_args = mock_create.call_args
389
+ assert call_args[1]['model'] == fast_model
390
+
391
+ @pytest.mark.asyncio
392
+ async def test_anthropic_client_stream_uses_smartest_model_from_provider_manager_without_override():
393
+ """Test that Anthropic client stream_completions uses the smartest model passed from ProviderManager without internal override."""
394
+ from mito_ai.utils.model_utils import get_smartest_model_for_selected_model
395
+
396
+ client = AnthropicClient(api_key="test_key")
397
+
398
+ # Mock the beta.messages.create method for streaming
399
+ with patch.object(client.client.beta.messages, 'create') as mock_create: # type: ignore
400
+ # Create a mock stream response
401
+ class MockStreamChunk:
402
+ def __init__(self, chunk_type, text=""):
403
+ self.type = chunk_type
404
+ if chunk_type == "content_block_delta":
405
+ self.delta = MagicMock()
406
+ self.delta.type = "text_delta"
407
+ self.delta.text = text
408
+
409
+ mock_stream = [
410
+ MockStreamChunk("content_block_delta", "test"),
411
+ MockStreamChunk("message_stop")
412
+ ]
413
+ mock_create.return_value = iter(mock_stream)
414
+
415
+ # Use a smartest model that would be selected by ProviderManager
416
+ smartest_model = get_smartest_model_for_selected_model("claude-haiku-4-5-20251001")
417
+
418
+ reply_chunks = []
419
+ def mock_reply(chunk):
420
+ reply_chunks.append(chunk)
421
+
422
+ await client.stream_completions(
423
+ messages=[{"role": "user", "content": "Test message"}],
424
+ model=smartest_model,
425
+ message_id="test-id",
426
+ message_type=MessageType.CHAT,
427
+ reply_fn=mock_reply
428
+ )
429
+
430
+ # Verify that create was called with the smartest model that was passed (not overridden)
431
+ mock_create.assert_called_once()
432
+ call_args = mock_create.call_args
433
+ assert call_args[1]['model'] == smartest_model
276
434
 
277
435
 
278
436
  # Caching Tests