mito-ai 0.1.56__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. mito_ai/__init__.py +17 -21
  2. mito_ai/_version.py +1 -1
  3. mito_ai/anthropic_client.py +24 -14
  4. mito_ai/chart_wizard/__init__.py +3 -0
  5. mito_ai/chart_wizard/handlers.py +113 -0
  6. mito_ai/chart_wizard/urls.py +26 -0
  7. mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
  8. mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
  9. mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
  10. mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
  11. mito_ai/completions/completion_handlers/completion_handler.py +14 -7
  12. mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
  13. mito_ai/completions/completion_handlers/scratchpad_result_handler.py +64 -0
  14. mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
  15. mito_ai/completions/completion_handlers/utils.py +3 -7
  16. mito_ai/completions/handlers.py +36 -21
  17. mito_ai/completions/message_history.py +8 -10
  18. mito_ai/completions/models.py +23 -2
  19. mito_ai/completions/prompt_builders/agent_smart_debug_prompt.py +5 -3
  20. mito_ai/completions/prompt_builders/agent_system_message.py +97 -5
  21. mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
  22. mito_ai/completions/prompt_builders/chart_conversion_prompt.py +27 -0
  23. mito_ai/completions/prompt_builders/chat_system_message.py +2 -0
  24. mito_ai/completions/prompt_builders/prompt_constants.py +28 -0
  25. mito_ai/completions/prompt_builders/scratchpad_result_prompt.py +17 -0
  26. mito_ai/constants.py +8 -1
  27. mito_ai/enterprise/__init__.py +1 -1
  28. mito_ai/enterprise/litellm_client.py +137 -0
  29. mito_ai/log/handlers.py +1 -1
  30. mito_ai/openai_client.py +10 -90
  31. mito_ai/{completions/providers.py → provider_manager.py} +157 -53
  32. mito_ai/settings/enterprise_handler.py +26 -0
  33. mito_ai/settings/urls.py +2 -0
  34. mito_ai/streamlit_conversion/agent_utils.py +2 -30
  35. mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
  36. mito_ai/streamlit_preview/handlers.py +6 -3
  37. mito_ai/streamlit_preview/urls.py +5 -3
  38. mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
  39. mito_ai/tests/providers/test_anthropic_client.py +174 -16
  40. mito_ai/tests/providers/test_azure.py +13 -13
  41. mito_ai/tests/providers/test_capabilities.py +14 -17
  42. mito_ai/tests/providers/test_gemini_client.py +14 -13
  43. mito_ai/tests/providers/test_model_resolution.py +145 -89
  44. mito_ai/tests/providers/test_openai_client.py +209 -13
  45. mito_ai/tests/providers/test_provider_limits.py +5 -5
  46. mito_ai/tests/providers/test_providers.py +229 -51
  47. mito_ai/tests/providers/test_retry_logic.py +13 -22
  48. mito_ai/tests/providers/utils.py +4 -4
  49. mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
  50. mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
  51. mito_ai/tests/test_enterprise_mode.py +162 -0
  52. mito_ai/tests/test_model_utils.py +271 -0
  53. mito_ai/utils/anthropic_utils.py +8 -6
  54. mito_ai/utils/gemini_utils.py +0 -3
  55. mito_ai/utils/litellm_utils.py +84 -0
  56. mito_ai/utils/model_utils.py +178 -0
  57. mito_ai/utils/open_ai_utils.py +0 -8
  58. mito_ai/utils/provider_utils.py +6 -21
  59. mito_ai/utils/telemetry_utils.py +14 -2
  60. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
  61. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  62. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  63. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.dfd7975de75d64db80d6.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +2992 -282
  64. mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
  65. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.1e7b5cf362385f109883.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
  66. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.1e7b5cf362385f109883.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
  67. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +7 -2
  68. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
  69. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/RECORD +94 -81
  70. mito_ai-0.1.56.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.dfd7975de75d64db80d6.js.map +0 -1
  71. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  72. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  73. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
  74. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
  75. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  76. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
  77. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
  78. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
  79. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
  80. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
  81. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
  82. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
  83. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
  84. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
  85. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
  86. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
  87. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
  88. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
  89. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
  90. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  91. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  92. {mito_ai-0.1.56.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
  93. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
  94. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
  95. {mito_ai-0.1.56.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
@@ -6,7 +6,7 @@ import asyncio
6
6
  from typing import Any, Callable, Dict, List, Optional, Union, cast
7
7
  from mito_ai import constants
8
8
  from openai.types.chat import ChatCompletionMessageParam
9
- from traitlets import Instance, Unicode, default, validate
9
+ from traitlets import Instance
10
10
  from traitlets.config import LoggingConfigurable
11
11
  from openai.types.chat import ChatCompletionMessageParam
12
12
 
@@ -24,32 +24,23 @@ from mito_ai.completions.models import (
24
24
  CompletionReply,
25
25
  CompletionStreamChunk,
26
26
  MessageType,
27
- ResponseFormatInfo, CompletionItemError,
27
+ ResponseFormatInfo,
28
28
  )
29
+ from mito_ai.utils.litellm_utils import is_litellm_configured
29
30
  from mito_ai.utils.telemetry_utils import (
30
- KEY_TYPE_PARAM,
31
- MITO_AI_COMPLETION_ERROR,
32
- MITO_AI_COMPLETION_RETRY,
33
31
  MITO_SERVER_KEY,
34
32
  USER_KEY,
35
- log,
36
33
  log_ai_completion_error,
37
34
  log_ai_completion_retry,
38
35
  log_ai_completion_success,
39
36
  )
40
37
  from mito_ai.utils.provider_utils import get_model_provider
41
- from mito_ai.utils.mito_server_utils import ProviderCompletionException
38
+ from mito_ai.utils.model_utils import get_available_models, get_fast_model_for_selected_model, get_smartest_model_for_selected_model
42
39
 
43
- __all__ = ["OpenAIProvider"]
40
+ __all__ = ["ProviderManager"]
44
41
 
45
- class OpenAIProvider(LoggingConfigurable):
46
- """Provide AI feature through OpenAI services."""
47
-
48
- api_key = Unicode(
49
- config=True,
50
- allow_none=True,
51
- help="OpenAI API key. Default value is read from the OPENAI_API_KEY environment variable.",
52
- )
42
+ class ProviderManager(LoggingConfigurable):
43
+ """Manage AI providers (Claude, Gemini, OpenAI) and route requests to the appropriate client."""
53
44
 
54
45
  last_error = Instance(
55
46
  CompletionError,
@@ -61,29 +52,57 @@ This attribute is observed by the websocket provider to push the error to the cl
61
52
 
62
53
  def __init__(self, **kwargs: Dict[str, Any]) -> None:
63
54
  config = kwargs.get('config', {})
64
- if 'api_key' in kwargs:
65
- config['OpenAIClient'] = {'api_key': kwargs['api_key']}
66
55
  kwargs['config'] = config
67
56
 
68
57
  super().__init__(log=get_logger(), **kwargs)
69
58
  self.last_error = None
70
59
  self._openai_client: Optional[OpenAIClient] = OpenAIClient(**config)
60
+ # Initialize with the first available model to ensure it's always valid
61
+ # This respects LiteLLM configuration: if LiteLLM is configured, uses first LiteLLM model
62
+ # Otherwise, uses first standard model
63
+ available_models = get_available_models()
64
+ self._selected_model: str = available_models[0] if available_models else "gpt-4.1"
65
+
66
+ def get_selected_model(self) -> str:
67
+ """Get the currently selected model."""
68
+ return self._selected_model
69
+
70
+ def set_selected_model(self, model: str) -> None:
71
+ """Set the selected model."""
72
+ self._selected_model = model
71
73
 
72
74
  @property
73
75
  def capabilities(self) -> AICapabilities:
74
76
  """
75
77
  Returns the capabilities of the AI provider.
76
78
  """
77
- if constants.CLAUDE_API_KEY and not self.api_key:
79
+ # TODO: We should validate that these keys are actually valid for the provider
80
+ # otherwise it will look like we are using the user_key when actually falling back
81
+ # to the mito server because the key is invalid.
82
+ if is_litellm_configured():
83
+ return AICapabilities(
84
+ configuration={"model": "<dynamic>"},
85
+ provider="LiteLLM",
86
+ )
87
+
88
+ if constants.OPENAI_API_KEY:
89
+ return AICapabilities(
90
+ configuration={"model": "<dynamic>"},
91
+ provider="OpenAI",
92
+ )
93
+
94
+ if constants.ANTHROPIC_API_KEY:
78
95
  return AICapabilities(
79
96
  configuration={"model": "<dynamic>"},
80
97
  provider="Claude",
81
98
  )
82
- if constants.GEMINI_API_KEY and not self.api_key:
99
+
100
+ if constants.GEMINI_API_KEY:
83
101
  return AICapabilities(
84
102
  configuration={"model": "<dynamic>"},
85
103
  provider="Gemini",
86
104
  )
105
+
87
106
  if self._openai_client:
88
107
  return self._openai_client.capabilities
89
108
 
@@ -94,65 +113,106 @@ This attribute is observed by the websocket provider to push the error to the cl
94
113
 
95
114
  @property
96
115
  def key_type(self) -> str:
97
- if constants.CLAUDE_API_KEY and not self.api_key:
98
- return "claude"
99
- if constants.GEMINI_API_KEY and not self.api_key:
100
- return "gemini"
101
- if self._openai_client:
102
- return self._openai_client.key_type
116
+ # TODO: We should validate that these keys are actually valid for the provider
117
+ # otherwise it will look like we are using the user_key when actually falling back
118
+ # to the mito server because the key is invalid.
119
+ if is_litellm_configured():
120
+ return USER_KEY
121
+
122
+ if constants.ANTHROPIC_API_KEY or constants.GEMINI_API_KEY or constants.OPENAI_API_KEY or constants.OLLAMA_MODEL:
123
+ return USER_KEY
124
+
103
125
  return MITO_SERVER_KEY
104
126
 
105
127
  async def request_completions(
106
128
  self,
107
129
  message_type: MessageType,
108
130
  messages: List[ChatCompletionMessageParam],
109
- model: str,
110
131
  response_format_info: Optional[ResponseFormatInfo] = None,
111
132
  user_input: Optional[str] = None,
112
133
  thread_id: Optional[str] = None,
113
- max_retries: int = 3
134
+ max_retries: int = 3,
135
+ use_fast_model: bool = False,
136
+ use_smartest_model: bool = False
114
137
  ) -> str:
115
138
  """
116
139
  Request completions from the AI provider.
140
+
141
+ Args:
142
+ message_type: Type of message
143
+ messages: List of chat messages
144
+ response_format_info: Optional response format specification
145
+ user_input: Optional user input for logging
146
+ thread_id: Optional thread ID for logging
147
+ max_retries: Maximum number of retries
148
+ use_fast_model: If True, use the fastest model from the selected provider
149
+ use_smartest_model: If True, use the smartest model from the selected provider
117
150
  """
118
151
  self.last_error = None
119
152
  completion = None
120
153
  last_message_content = str(messages[-1].get('content', '')) if messages else ""
121
- model_type = get_model_provider(model)
154
+
155
+ # Get the model to use (selected model, fast model, or smartest model if requested)
156
+ selected_model = self.get_selected_model()
157
+ if use_smartest_model:
158
+ resolved_model = get_smartest_model_for_selected_model(selected_model)
159
+ elif use_fast_model:
160
+ resolved_model = get_fast_model_for_selected_model(selected_model)
161
+ else:
162
+ resolved_model = selected_model
163
+
164
+ # Validate model is in allowed list (uses same function as endpoint)
165
+ available_models = get_available_models()
166
+ if resolved_model not in available_models:
167
+ raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
168
+
169
+ # Get model provider type
170
+ model_type = get_model_provider(resolved_model)
122
171
 
123
172
  # Retry loop
124
173
  for attempt in range(max_retries + 1):
125
174
  try:
126
- if model_type == "claude":
127
- api_key = constants.CLAUDE_API_KEY
175
+ if model_type == "litellm":
176
+ from mito_ai.enterprise.litellm_client import LiteLLMClient
177
+ if not constants.LITELLM_BASE_URL:
178
+ raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
179
+ litellm_client = LiteLLMClient(api_key=constants.LITELLM_API_KEY, base_url=constants.LITELLM_BASE_URL)
180
+ completion = await litellm_client.request_completions(
181
+ messages=messages,
182
+ model=resolved_model,
183
+ response_format_info=response_format_info,
184
+ message_type=message_type
185
+ )
186
+ elif model_type == "claude":
187
+ api_key = constants.ANTHROPIC_API_KEY
128
188
  anthropic_client = AnthropicClient(api_key=api_key)
129
- completion = await anthropic_client.request_completions(messages, model, response_format_info, message_type)
189
+ completion = await anthropic_client.request_completions(messages, resolved_model, response_format_info, message_type)
130
190
  elif model_type == "gemini":
131
191
  api_key = constants.GEMINI_API_KEY
132
192
  gemini_client = GeminiClient(api_key=api_key)
133
193
  messages_for_gemini = [dict(m) for m in messages]
134
- completion = await gemini_client.request_completions(messages_for_gemini, model, response_format_info, message_type)
194
+ completion = await gemini_client.request_completions(messages_for_gemini, resolved_model, response_format_info, message_type)
135
195
  elif model_type == "openai":
136
196
  if not self._openai_client:
137
197
  raise RuntimeError("OpenAI client is not initialized.")
138
198
  completion = await self._openai_client.request_completions(
139
199
  message_type=message_type,
140
200
  messages=messages,
141
- model=model,
201
+ model=resolved_model,
142
202
  response_format_info=response_format_info
143
203
  )
144
204
  else:
145
- raise ValueError(f"No AI provider configured for model: {model}")
205
+ raise ValueError(f"No AI provider configured for model: {resolved_model}")
146
206
 
147
207
  # Success! Log and return
148
208
  log_ai_completion_success(
149
- key_type=USER_KEY if self.key_type == "user" else MITO_SERVER_KEY,
209
+ key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
150
210
  message_type=message_type,
151
211
  last_message_content=last_message_content,
152
212
  response={"completion": completion},
153
213
  user_input=user_input or "",
154
214
  thread_id=thread_id or "",
155
- model=model
215
+ model=resolved_model
156
216
  )
157
217
  return completion # type: ignore
158
218
 
@@ -160,7 +220,7 @@ This attribute is observed by the websocket provider to push the error to the cl
160
220
  # If we hit a free tier limit, then raise an exception right away without retrying.
161
221
  self.log.exception(f"Error during request_completions: {e}")
162
222
  self.last_error = CompletionError.from_exception(e)
163
- log_ai_completion_error('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id or "", message_type, e)
223
+ log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
164
224
  raise
165
225
 
166
226
  except BaseException as e:
@@ -169,14 +229,14 @@ This attribute is observed by the websocket provider to push the error to the cl
169
229
  # Exponential backoff: wait 2^attempt seconds
170
230
  wait_time = 2 ** attempt
171
231
  self.log.info(f"Retrying request_completions after {wait_time}s (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
172
- log_ai_completion_retry('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id or "", message_type, e)
232
+ log_ai_completion_retry(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
173
233
  await asyncio.sleep(wait_time)
174
234
  continue
175
235
  else:
176
236
  # Final failure after all retries - set error state and raise
177
237
  self.log.exception(f"Error during request_completions after {attempt + 1} attempts: {e}")
178
238
  self.last_error = CompletionError.from_exception(e)
179
- log_ai_completion_error('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id or "", message_type, e)
239
+ log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
180
240
  raise
181
241
 
182
242
  # This should never be reached due to the raise in the except block,
@@ -187,21 +247,50 @@ This attribute is observed by the websocket provider to push the error to the cl
187
247
  self,
188
248
  message_type: MessageType,
189
249
  messages: List[ChatCompletionMessageParam],
190
- model: str,
191
250
  message_id: str,
192
251
  thread_id: str,
193
252
  reply_fn: Callable[[Union[CompletionReply, CompletionStreamChunk]], None],
194
253
  user_input: Optional[str] = None,
195
- response_format_info: Optional[ResponseFormatInfo] = None
254
+ response_format_info: Optional[ResponseFormatInfo] = None,
255
+ use_fast_model: bool = False,
256
+ use_smartest_model: bool = False
196
257
  ) -> str:
197
258
  """
198
259
  Stream completions from the AI provider and return the accumulated response.
260
+
261
+ Args:
262
+ message_type: Type of message
263
+ messages: List of chat messages
264
+ message_id: ID of the message being processed
265
+ thread_id: Thread ID for logging
266
+ reply_fn: Function to call with each chunk for streaming replies
267
+ user_input: Optional user input for logging
268
+ response_format_info: Optional response format specification
269
+ use_fast_model: If True, use the fastest model from the selected provider
270
+ use_smartest_model: If True, use the smartest model from the selected provider
271
+
199
272
  Returns: The accumulated response string.
200
273
  """
201
274
  self.last_error = None
202
275
  accumulated_response = ""
203
276
  last_message_content = str(messages[-1].get('content', '')) if messages else ""
204
- model_type = get_model_provider(model)
277
+
278
+ # Get the model to use (selected model, fast model, or smartest model if requested)
279
+ selected_model = self.get_selected_model()
280
+ if use_smartest_model:
281
+ resolved_model = get_smartest_model_for_selected_model(selected_model)
282
+ elif use_fast_model:
283
+ resolved_model = get_fast_model_for_selected_model(selected_model)
284
+ else:
285
+ resolved_model = selected_model
286
+
287
+ # Validate model is in allowed list (uses same function as endpoint)
288
+ available_models = get_available_models()
289
+ if resolved_model not in available_models:
290
+ raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
291
+
292
+ # Get model provider type
293
+ model_type = get_model_provider(resolved_model)
205
294
  reply_fn(CompletionReply(
206
295
  items=[
207
296
  CompletionItem(content="", isIncomplete=True, token=message_id)
@@ -210,12 +299,28 @@ This attribute is observed by the websocket provider to push the error to the cl
210
299
  ))
211
300
 
212
301
  try:
213
- if model_type == "claude":
214
- api_key = constants.CLAUDE_API_KEY
302
+ if model_type == "litellm":
303
+ from mito_ai.enterprise.litellm_client import LiteLLMClient
304
+ if not constants.LITELLM_BASE_URL:
305
+ raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
306
+ litellm_client = LiteLLMClient(
307
+ api_key=constants.LITELLM_API_KEY,
308
+ base_url=constants.LITELLM_BASE_URL
309
+ )
310
+ accumulated_response = await litellm_client.stream_completions(
311
+ messages=messages,
312
+ model=resolved_model,
313
+ message_type=message_type,
314
+ message_id=message_id,
315
+ reply_fn=reply_fn,
316
+ response_format_info=response_format_info
317
+ )
318
+ elif model_type == "claude":
319
+ api_key = constants.ANTHROPIC_API_KEY
215
320
  anthropic_client = AnthropicClient(api_key=api_key)
216
321
  accumulated_response = await anthropic_client.stream_completions(
217
322
  messages=messages,
218
- model=model,
323
+ model=resolved_model,
219
324
  message_type=message_type,
220
325
  message_id=message_id,
221
326
  reply_fn=reply_fn
@@ -228,7 +333,7 @@ This attribute is observed by the websocket provider to push the error to the cl
228
333
  messages_for_gemini = [dict(m) for m in messages]
229
334
  accumulated_response = await gemini_client.stream_completions(
230
335
  messages=messages_for_gemini,
231
- model=model,
336
+ model=resolved_model,
232
337
  message_id=message_id,
233
338
  reply_fn=reply_fn,
234
339
  message_type=message_type
@@ -239,7 +344,7 @@ This attribute is observed by the websocket provider to push the error to the cl
239
344
  accumulated_response = await self._openai_client.stream_completions(
240
345
  message_type=message_type,
241
346
  messages=messages,
242
- model=model,
347
+ model=resolved_model,
243
348
  message_id=message_id,
244
349
  thread_id=thread_id,
245
350
  reply_fn=reply_fn,
@@ -247,24 +352,24 @@ This attribute is observed by the websocket provider to push the error to the cl
247
352
  response_format_info=response_format_info
248
353
  )
249
354
  else:
250
- raise ValueError(f"No AI provider configured for model: {model}")
355
+ raise ValueError(f"No AI provider configured for model: {resolved_model}")
251
356
 
252
357
  # Log the successful completion
253
358
  log_ai_completion_success(
254
- key_type=USER_KEY if self.key_type == "user" else MITO_SERVER_KEY,
359
+ key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
255
360
  message_type=message_type,
256
361
  last_message_content=last_message_content,
257
362
  response={"completion": accumulated_response},
258
363
  user_input=user_input or "",
259
364
  thread_id=thread_id,
260
- model=model
365
+ model=resolved_model
261
366
  )
262
367
  return accumulated_response
263
368
 
264
369
  except BaseException as e:
265
370
  self.log.exception(f"Error during stream_completions: {e}")
266
371
  self.last_error = CompletionError.from_exception(e)
267
- log_ai_completion_error('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id, message_type, e)
372
+ log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id, message_type, e)
268
373
 
269
374
  # Send error message to client before raising
270
375
  reply_fn(CompletionStreamChunk(
@@ -281,4 +386,3 @@ This attribute is observed by the websocket provider to push the error to the cl
281
386
  error=CompletionError.from_exception(e),
282
387
  ))
283
388
  raise
284
-
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ import json
5
+ import tornado
6
+ from jupyter_server.base.handlers import APIHandler
7
+ from mito_ai.utils.model_utils import get_available_models
8
+
9
+
10
+ class AvailableModelsHandler(APIHandler):
11
+ """REST handler for returning available models to the frontend."""
12
+
13
+ @tornado.web.authenticated
14
+ async def get(self) -> None:
15
+ """GET endpoint that returns the list of available models."""
16
+ try:
17
+ available_models = get_available_models()
18
+
19
+ self.write({
20
+ "models": available_models
21
+ })
22
+ self.finish()
23
+ except Exception as e:
24
+ self.set_status(500)
25
+ self.write({"error": str(e)})
26
+ self.finish()
mito_ai/settings/urls.py CHANGED
@@ -4,6 +4,7 @@
4
4
  from typing import Any, List, Tuple
5
5
  from jupyter_server.utils import url_path_join
6
6
  from mito_ai.settings.handlers import SettingsHandler
7
+ from mito_ai.settings.enterprise_handler import AvailableModelsHandler
7
8
 
8
9
  def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
9
10
  """Get all settings related URL patterns.
@@ -17,4 +18,5 @@ def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
17
18
  BASE_URL = base_url + "/mito-ai"
18
19
  return [
19
20
  (url_path_join(BASE_URL, "settings/(.*)"), SettingsHandler, {}),
21
+ (url_path_join(BASE_URL, "available-models"), AvailableModelsHandler, {}),
20
22
  ]
@@ -1,37 +1,9 @@
1
1
  # Copyright (c) Saga Inc.
2
2
  # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
3
 
4
- from typing import List, Tuple
5
- import re
6
- from anthropic.types import MessageParam
7
- from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
8
- from mito_ai.utils.anthropic_utils import stream_anthropic_completion_from_mito_server
4
+ from typing import List
9
5
  from mito_ai.streamlit_conversion.prompts.prompt_constants import MITO_TODO_PLACEHOLDER
10
- from mito_ai.completions.models import MessageType
11
-
12
- STREAMLIT_AI_MODEL = "claude-sonnet-4-5-20250929"
13
6
 
14
7
  def extract_todo_placeholders(agent_response: str) -> List[str]:
15
8
  """Extract TODO placeholders from the agent's response"""
16
- return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
17
-
18
- async def get_response_from_agent(message_to_agent: List[MessageParam]) -> str:
19
- """Gets the streaming response from the agent using the mito server"""
20
- model = STREAMLIT_AI_MODEL
21
- max_tokens = 64000 # TODO: If we move to haiku, we must reset this to 8192
22
- temperature = 0.2
23
-
24
- accumulated_response = ""
25
- async for stream_chunk in stream_anthropic_completion_from_mito_server(
26
- model = model,
27
- max_tokens = max_tokens,
28
- temperature = temperature,
29
- system = streamlit_system_prompt,
30
- messages = message_to_agent,
31
- stream=True,
32
- message_type=MessageType.STREAMLIT_CONVERSION,
33
- reply_fn=None,
34
- message_id=""
35
- ):
36
- accumulated_response += stream_chunk
37
- return accumulated_response
9
+ return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
@@ -1,9 +1,10 @@
1
1
  # Copyright (c) Saga Inc.
2
2
  # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
3
 
4
- from anthropic.types import MessageParam
5
- from typing import List, cast
6
- from mito_ai.streamlit_conversion.agent_utils import extract_todo_placeholders, get_response_from_agent
4
+ from typing import List
5
+ from openai.types.chat import ChatCompletionMessageParam
6
+ from mito_ai.streamlit_conversion.agent_utils import extract_todo_placeholders
7
+ from mito_ai.provider_manager import ProviderManager
7
8
  from mito_ai.streamlit_conversion.prompts.streamlit_app_creation_prompt import get_streamlit_app_creation_prompt
8
9
  from mito_ai.streamlit_conversion.prompts.streamlit_error_correction_prompt import get_streamlit_error_correction_prompt
9
10
  from mito_ai.streamlit_conversion.prompts.streamlit_finish_todo_prompt import get_finish_todo_prompt
@@ -15,22 +16,23 @@ from mito_ai.completions.models import MessageType
15
16
  from mito_ai.utils.error_classes import StreamlitConversionError
16
17
  from mito_ai.utils.telemetry_utils import log_streamlit_app_validation_retry, log_streamlit_app_conversion_success
17
18
  from mito_ai.path_utils import AbsoluteNotebookPath, AppFileName, get_absolute_notebook_dir_path, get_absolute_app_path, get_app_file_name
19
+ from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
18
20
 
19
- async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt: str) -> str:
21
+ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt: str, provider: ProviderManager) -> str:
20
22
  """Send a query to the agent, get its response and parse the code"""
21
23
 
22
24
  prompt_text = get_streamlit_app_creation_prompt(notebook, streamlit_app_prompt)
23
25
 
24
- messages: List[MessageParam] = [
25
- cast(MessageParam, {
26
- "role": "user",
27
- "content": [{
28
- "type": "text",
29
- "text": prompt_text
30
- }]
31
- })
26
+ messages: List[ChatCompletionMessageParam] = [
27
+ {"role": "system", "content": streamlit_system_prompt},
28
+ {"role": "user", "content": prompt_text}
32
29
  ]
33
- agent_response = await get_response_from_agent(messages)
30
+ agent_response = await provider.request_completions(
31
+ message_type=MessageType.STREAMLIT_CONVERSION,
32
+ messages=messages,
33
+ use_smartest_model=True,
34
+ thread_id=None
35
+ )
34
36
  converted_code = extract_code_blocks(agent_response)
35
37
 
36
38
  # Extract the TODOs from the agent's response
@@ -39,16 +41,16 @@ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt
39
41
  for todo_placeholder in todo_placeholders:
40
42
  print(f"Processing AI TODO: {todo_placeholder}")
41
43
  todo_prompt = get_finish_todo_prompt(notebook, converted_code, todo_placeholder)
42
- todo_messages: List[MessageParam] = [
43
- cast(MessageParam, {
44
- "role": "user",
45
- "content": [{
46
- "type": "text",
47
- "text": todo_prompt
48
- }]
49
- })
44
+ todo_messages: List[ChatCompletionMessageParam] = [
45
+ {"role": "system", "content": streamlit_system_prompt},
46
+ {"role": "user", "content": todo_prompt}
50
47
  ]
51
- todo_response = await get_response_from_agent(todo_messages)
48
+ todo_response = await provider.request_completions(
49
+ message_type=MessageType.STREAMLIT_CONVERSION,
50
+ messages=todo_messages,
51
+ use_smartest_model=True,
52
+ thread_id=None
53
+ )
52
54
 
53
55
  # Apply the search/replace to the streamlit app
54
56
  search_replace_pairs = extract_search_replace_blocks(todo_response)
@@ -57,21 +59,21 @@ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt
57
59
  return converted_code
58
60
 
59
61
 
60
- async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_code: str, edit_prompt: str) -> str:
62
+ async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_code: str, edit_prompt: str, provider: ProviderManager) -> str:
61
63
  """Send a query to the agent, get its response and parse the code"""
62
64
  prompt_text = get_update_existing_app_prompt(notebook, streamlit_app_code, edit_prompt)
63
65
 
64
- messages: List[MessageParam] = [
65
- cast(MessageParam, {
66
- "role": "user",
67
- "content": [{
68
- "type": "text",
69
- "text": prompt_text
70
- }]
71
- })
66
+ messages: List[ChatCompletionMessageParam] = [
67
+ {"role": "system", "content": streamlit_system_prompt},
68
+ {"role": "user", "content": prompt_text}
72
69
  ]
73
70
 
74
- agent_response = await get_response_from_agent(messages)
71
+ agent_response = await provider.request_completions(
72
+ message_type=MessageType.STREAMLIT_CONVERSION,
73
+ messages=messages,
74
+ use_smartest_model=True,
75
+ thread_id=None
76
+ )
75
77
  print(f"[Mito AI Search/Replace Tool]:\n {agent_response}")
76
78
 
77
79
  # Apply the search/replace to the streamlit app
@@ -81,18 +83,18 @@ async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_cod
81
83
  return converted_code
82
84
 
83
85
 
84
- async def correct_error_in_generation(error: str, streamlit_app_code: str) -> str:
86
+ async def correct_error_in_generation(error: str, streamlit_app_code: str, provider: ProviderManager) -> str:
85
87
  """If errors are present, send it back to the agent to get corrections in code"""
86
- messages: List[MessageParam] = [
87
- cast(MessageParam, {
88
- "role": "user",
89
- "content": [{
90
- "type": "text",
91
- "text": get_streamlit_error_correction_prompt(error, streamlit_app_code)
92
- }]
93
- })
88
+ messages: List[ChatCompletionMessageParam] = [
89
+ {"role": "system", "content": streamlit_system_prompt},
90
+ {"role": "user", "content": get_streamlit_error_correction_prompt(error, streamlit_app_code)}
94
91
  ]
95
- agent_response = await get_response_from_agent(messages)
92
+ agent_response = await provider.request_completions(
93
+ message_type=MessageType.STREAMLIT_CONVERSION,
94
+ messages=messages,
95
+ use_smartest_model=True,
96
+ thread_id=None
97
+ )
96
98
 
97
99
  # Apply the search/replace to the streamlit app
98
100
  search_replace_pairs = extract_search_replace_blocks(agent_response)
@@ -100,7 +102,7 @@ async def correct_error_in_generation(error: str, streamlit_app_code: str) -> st
100
102
 
101
103
  return streamlit_app_code
102
104
 
103
- async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNotebookPath, app_file_name: AppFileName, streamlit_app_prompt: str = "") -> None:
105
+ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNotebookPath, app_file_name: AppFileName, streamlit_app_prompt: str, provider: ProviderManager) -> None:
104
106
  """Handler function for streamlit code generation and validation"""
105
107
 
106
108
  # Convert to absolute path for consistent handling
@@ -110,7 +112,7 @@ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNoteboo
110
112
 
111
113
  if create_new_app:
112
114
  # Otherwise generate a new streamlit app
113
- streamlit_code = await generate_new_streamlit_code(notebook_code, streamlit_app_prompt)
115
+ streamlit_code = await generate_new_streamlit_code(notebook_code, streamlit_app_prompt, provider)
114
116
  else:
115
117
  # If the user is editing an existing streamlit app, use the update function
116
118
  existing_streamlit_code = get_app_code_from_file(app_path)
@@ -118,14 +120,14 @@ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNoteboo
118
120
  if existing_streamlit_code is None:
119
121
  raise StreamlitConversionError("Error updating existing streamlit app because app.py file was not found.", 404)
120
122
 
121
- streamlit_code = await update_existing_streamlit_code(notebook_code, existing_streamlit_code, streamlit_app_prompt)
123
+ streamlit_code = await update_existing_streamlit_code(notebook_code, existing_streamlit_code, streamlit_app_prompt, provider)
122
124
 
123
125
  # Then, after creating/updating the app, validate that the new code runs
124
126
  errors = validate_app(streamlit_code, notebook_path)
125
127
  tries = 0
126
128
  while len(errors) > 0 and tries < 5:
127
129
  for error in errors:
128
- streamlit_code = await correct_error_in_generation(error, streamlit_code)
130
+ streamlit_code = await correct_error_in_generation(error, streamlit_code, provider)
129
131
 
130
132
  errors = validate_app(streamlit_code, notebook_path)
131
133