mito-ai 0.1.57__py3-none-any.whl → 0.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. mito_ai/__init__.py +16 -22
  2. mito_ai/_version.py +1 -1
  3. mito_ai/anthropic_client.py +24 -14
  4. mito_ai/chart_wizard/handlers.py +78 -17
  5. mito_ai/chart_wizard/urls.py +8 -5
  6. mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
  7. mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
  8. mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
  9. mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
  10. mito_ai/completions/completion_handlers/completion_handler.py +3 -5
  11. mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
  12. mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
  13. mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
  14. mito_ai/completions/completion_handlers/utils.py +3 -7
  15. mito_ai/completions/handlers.py +32 -22
  16. mito_ai/completions/message_history.py +8 -10
  17. mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
  18. mito_ai/constants.py +8 -1
  19. mito_ai/enterprise/__init__.py +1 -1
  20. mito_ai/enterprise/litellm_client.py +137 -0
  21. mito_ai/log/handlers.py +1 -1
  22. mito_ai/openai_client.py +10 -90
  23. mito_ai/{completions/providers.py → provider_manager.py} +157 -53
  24. mito_ai/settings/enterprise_handler.py +26 -0
  25. mito_ai/settings/urls.py +2 -0
  26. mito_ai/streamlit_conversion/agent_utils.py +2 -30
  27. mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
  28. mito_ai/streamlit_preview/handlers.py +6 -3
  29. mito_ai/streamlit_preview/urls.py +5 -3
  30. mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
  31. mito_ai/tests/providers/test_anthropic_client.py +174 -16
  32. mito_ai/tests/providers/test_azure.py +13 -13
  33. mito_ai/tests/providers/test_capabilities.py +14 -17
  34. mito_ai/tests/providers/test_gemini_client.py +14 -13
  35. mito_ai/tests/providers/test_model_resolution.py +145 -89
  36. mito_ai/tests/providers/test_openai_client.py +209 -13
  37. mito_ai/tests/providers/test_provider_limits.py +5 -5
  38. mito_ai/tests/providers/test_providers.py +229 -51
  39. mito_ai/tests/providers/test_retry_logic.py +13 -22
  40. mito_ai/tests/providers/utils.py +4 -4
  41. mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
  42. mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
  43. mito_ai/tests/test_enterprise_mode.py +162 -0
  44. mito_ai/tests/test_model_utils.py +271 -0
  45. mito_ai/utils/anthropic_utils.py +8 -6
  46. mito_ai/utils/gemini_utils.py +0 -3
  47. mito_ai/utils/litellm_utils.py +84 -0
  48. mito_ai/utils/model_utils.py +178 -0
  49. mito_ai/utils/open_ai_utils.py +0 -8
  50. mito_ai/utils/provider_utils.py +6 -28
  51. mito_ai/utils/telemetry_utils.py +14 -2
  52. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
  53. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  54. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  55. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +671 -75
  56. mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
  57. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
  58. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
  59. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
  60. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/RECORD +86 -79
  61. mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
  62. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  63. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  64. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
  65. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
  66. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  67. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
  68. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
  69. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
  70. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
  71. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
  72. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
  73. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
  74. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
  75. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
  76. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
  77. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
  78. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
  79. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
  80. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
  81. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  82. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  83. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +0 -0
  84. {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
  85. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
  86. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
  87. {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
mito_ai/openai_client.py CHANGED
@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
7
7
  from mito_ai.utils.mito_server_utils import ProviderCompletionException
8
8
  import openai
9
9
  from openai.types.chat import ChatCompletionMessageParam
10
- from traitlets import Instance, Unicode, default, validate
10
+ from traitlets import Instance, default, validate
11
11
  from traitlets.config import LoggingConfigurable
12
12
 
13
13
  from mito_ai import constants
@@ -30,22 +30,12 @@ from mito_ai.utils.open_ai_utils import (
30
30
  stream_ai_completion_from_mito_server,
31
31
  )
32
32
  from mito_ai.utils.server_limits import update_mito_server_quota
33
- from mito_ai.utils.telemetry_utils import (
34
- MITO_SERVER_KEY,
35
- USER_KEY,
36
- )
37
33
 
38
34
  OPENAI_MODEL_FALLBACK = "gpt-4.1"
39
35
 
40
36
  class OpenAIClient(LoggingConfigurable):
41
37
  """Provide AI feature through OpenAI services."""
42
38
 
43
- api_key = Unicode(
44
- config=True,
45
- allow_none=True,
46
- help="OpenAI API key. Default value is read from the OPENAI_API_KEY environment variable.",
47
- )
48
-
49
39
  last_error = Instance(
50
40
  CompletionError,
51
41
  allow_none=True,
@@ -65,61 +55,6 @@ This attribute is observed by the websocket provider to push the error to the cl
65
55
  super().__init__(log=get_logger(), **kwargs)
66
56
  self.last_error = None
67
57
  self._async_client: Optional[openai.AsyncOpenAI] = None
68
-
69
- @default("api_key")
70
- def _api_key_default(self) -> Optional[str]:
71
- default_key = constants.OPENAI_API_KEY
72
- return self._validate_api_key(default_key)
73
-
74
- @validate("api_key")
75
- def _validate_api_key(self, api_key: Optional[str]) -> Optional[str]:
76
- if not api_key:
77
- self.log.debug(
78
- "No OpenAI API key provided; following back to Mito server API."
79
- )
80
- return None
81
-
82
- client = openai.OpenAI(api_key=api_key)
83
- try:
84
- # Make an http request to OpenAI to make sure it works
85
- client.models.list()
86
- except openai.AuthenticationError as e:
87
- self.log.warning(
88
- "Invalid OpenAI API key provided.",
89
- exc_info=e,
90
- )
91
- self.last_error = CompletionError.from_exception(
92
- e,
93
- hint="You're missing the OPENAI_API_KEY environment variable. Run the following code in your terminal to set the environment variable and then relaunch the jupyter server `export OPENAI_API_KEY=<your-api-key>`",
94
- )
95
- return None
96
- except openai.PermissionDeniedError as e:
97
- self.log.warning(
98
- "Invalid OpenAI API key provided.",
99
- exc_info=e,
100
- )
101
- self.last_error = CompletionError.from_exception(e)
102
- return None
103
- except openai.InternalServerError as e:
104
- self.log.debug(
105
- "Unable to get OpenAI models due to OpenAI error.", exc_info=e
106
- )
107
- return api_key
108
- except openai.RateLimitError as e:
109
- self.log.debug(
110
- "Unable to get OpenAI models due to rate limit error.", exc_info=e
111
- )
112
- return api_key
113
- except openai.APIConnectionError as e:
114
- self.log.warning(
115
- "Unable to connect to OpenAI API.",
116
- exec_info=e,
117
- )
118
- self.last_error = CompletionError.from_exception(e)
119
- return None
120
- else:
121
- self.log.debug("User OpenAI API key validated.")
122
- return api_key
123
58
 
124
59
  @property
125
60
  def capabilities(self) -> AICapabilities:
@@ -133,7 +68,7 @@ This attribute is observed by the websocket provider to push the error to the cl
133
68
  provider="Azure OpenAI",
134
69
  )
135
70
 
136
- if constants.OLLAMA_MODEL and not self.api_key:
71
+ if constants.OLLAMA_MODEL:
137
72
  return AICapabilities(
138
73
  configuration={
139
74
  "model": constants.OLLAMA_MODEL
@@ -141,14 +76,12 @@ This attribute is observed by the websocket provider to push the error to the cl
141
76
  provider="Ollama",
142
77
  )
143
78
 
144
- if self.api_key:
145
- self._validate_api_key(self.api_key)
146
-
79
+ if constants.OPENAI_API_KEY:
147
80
  return AICapabilities(
148
81
  configuration={
149
- "model": OPENAI_MODEL_FALLBACK,
82
+ "model": "<dynamic>"
150
83
  },
151
- provider="OpenAI (user key)",
84
+ provider="OpenAI",
152
85
  )
153
86
 
154
87
  try:
@@ -169,19 +102,6 @@ This attribute is observed by the websocket provider to push the error to the cl
169
102
  if not self._async_client or self._async_client.is_closed():
170
103
  self._async_client = self._build_openai_client()
171
104
  return self._async_client
172
-
173
-
174
- @property
175
- def key_type(self) -> str:
176
- """Returns the authentication key type being used."""
177
-
178
- if self.api_key:
179
- return USER_KEY
180
-
181
- if constants.OLLAMA_MODEL:
182
- return "ollama"
183
-
184
- return MITO_SERVER_KEY
185
105
 
186
106
  def _build_openai_client(self) -> Optional[Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]]:
187
107
  base_url = None
@@ -201,12 +121,12 @@ This attribute is observed by the websocket provider to push the error to the cl
201
121
  timeout=self.timeout,
202
122
  )
203
123
 
204
- elif constants.OLLAMA_MODEL and not self.api_key:
124
+ elif constants.OLLAMA_MODEL:
205
125
  base_url = constants.OLLAMA_BASE_URL
206
126
  llm_api_key = "ollama"
207
127
  self.log.debug(f"Using Ollama with model: {constants.OLLAMA_MODEL}")
208
- elif self.api_key:
209
- llm_api_key = self.api_key
128
+ elif constants.OPENAI_API_KEY:
129
+ llm_api_key = constants.OPENAI_API_KEY
210
130
  self.log.debug("Using OpenAI with user-provided API key")
211
131
  else:
212
132
  self.log.warning("No valid API key or model configuration provided")
@@ -262,7 +182,7 @@ This attribute is observed by the websocket provider to push the error to the cl
262
182
 
263
183
  # Handle other providers as before
264
184
  completion_function_params = get_open_ai_completion_function_params(
265
- message_type, model, messages, False, response_format_info
185
+ model, messages, False, response_format_info
266
186
  )
267
187
 
268
188
  # If they have set an Azure OpenAI or Ollama model, then we use it
@@ -313,7 +233,7 @@ This attribute is observed by the websocket provider to push the error to the cl
313
233
 
314
234
  # Handle other providers as before
315
235
  completion_function_params = get_open_ai_completion_function_params(
316
- message_type, model, messages, True, response_format_info
236
+ model, messages, True, response_format_info
317
237
  )
318
238
 
319
239
  completion_function_params["model"] = self._adjust_model_for_azure_or_ollama(completion_function_params["model"])
@@ -6,7 +6,7 @@ import asyncio
6
6
  from typing import Any, Callable, Dict, List, Optional, Union, cast
7
7
  from mito_ai import constants
8
8
  from openai.types.chat import ChatCompletionMessageParam
9
- from traitlets import Instance, Unicode, default, validate
9
+ from traitlets import Instance
10
10
  from traitlets.config import LoggingConfigurable
11
11
  from openai.types.chat import ChatCompletionMessageParam
12
12
 
@@ -24,32 +24,23 @@ from mito_ai.completions.models import (
24
24
  CompletionReply,
25
25
  CompletionStreamChunk,
26
26
  MessageType,
27
- ResponseFormatInfo, CompletionItemError,
27
+ ResponseFormatInfo,
28
28
  )
29
+ from mito_ai.utils.litellm_utils import is_litellm_configured
29
30
  from mito_ai.utils.telemetry_utils import (
30
- KEY_TYPE_PARAM,
31
- MITO_AI_COMPLETION_ERROR,
32
- MITO_AI_COMPLETION_RETRY,
33
31
  MITO_SERVER_KEY,
34
32
  USER_KEY,
35
- log,
36
33
  log_ai_completion_error,
37
34
  log_ai_completion_retry,
38
35
  log_ai_completion_success,
39
36
  )
40
37
  from mito_ai.utils.provider_utils import get_model_provider
41
- from mito_ai.utils.mito_server_utils import ProviderCompletionException
38
+ from mito_ai.utils.model_utils import get_available_models, get_fast_model_for_selected_model, get_smartest_model_for_selected_model
42
39
 
43
- __all__ = ["OpenAIProvider"]
40
+ __all__ = ["ProviderManager"]
44
41
 
45
- class OpenAIProvider(LoggingConfigurable):
46
- """Provide AI feature through OpenAI services."""
47
-
48
- api_key = Unicode(
49
- config=True,
50
- allow_none=True,
51
- help="OpenAI API key. Default value is read from the OPENAI_API_KEY environment variable.",
52
- )
42
+ class ProviderManager(LoggingConfigurable):
43
+ """Manage AI providers (Claude, Gemini, OpenAI) and route requests to the appropriate client."""
53
44
 
54
45
  last_error = Instance(
55
46
  CompletionError,
@@ -61,29 +52,57 @@ This attribute is observed by the websocket provider to push the error to the cl
61
52
 
62
53
  def __init__(self, **kwargs: Dict[str, Any]) -> None:
63
54
  config = kwargs.get('config', {})
64
- if 'api_key' in kwargs:
65
- config['OpenAIClient'] = {'api_key': kwargs['api_key']}
66
55
  kwargs['config'] = config
67
56
 
68
57
  super().__init__(log=get_logger(), **kwargs)
69
58
  self.last_error = None
70
59
  self._openai_client: Optional[OpenAIClient] = OpenAIClient(**config)
60
+ # Initialize with the first available model to ensure it's always valid
61
+ # This respects LiteLLM configuration: if LiteLLM is configured, uses first LiteLLM model
62
+ # Otherwise, uses first standard model
63
+ available_models = get_available_models()
64
+ self._selected_model: str = available_models[0] if available_models else "gpt-4.1"
65
+
66
+ def get_selected_model(self) -> str:
67
+ """Get the currently selected model."""
68
+ return self._selected_model
69
+
70
+ def set_selected_model(self, model: str) -> None:
71
+ """Set the selected model."""
72
+ self._selected_model = model
71
73
 
72
74
  @property
73
75
  def capabilities(self) -> AICapabilities:
74
76
  """
75
77
  Returns the capabilities of the AI provider.
76
78
  """
77
- if constants.CLAUDE_API_KEY and not self.api_key:
79
+ # TODO: We should validate that these keys are actually valid for the provider
80
+ # otherwise it will look like we are using the user_key when actually falling back
81
+ # to the mito server because the key is invalid.
82
+ if is_litellm_configured():
83
+ return AICapabilities(
84
+ configuration={"model": "<dynamic>"},
85
+ provider="LiteLLM",
86
+ )
87
+
88
+ if constants.OPENAI_API_KEY:
89
+ return AICapabilities(
90
+ configuration={"model": "<dynamic>"},
91
+ provider="OpenAI",
92
+ )
93
+
94
+ if constants.ANTHROPIC_API_KEY:
78
95
  return AICapabilities(
79
96
  configuration={"model": "<dynamic>"},
80
97
  provider="Claude",
81
98
  )
82
- if constants.GEMINI_API_KEY and not self.api_key:
99
+
100
+ if constants.GEMINI_API_KEY:
83
101
  return AICapabilities(
84
102
  configuration={"model": "<dynamic>"},
85
103
  provider="Gemini",
86
104
  )
105
+
87
106
  if self._openai_client:
88
107
  return self._openai_client.capabilities
89
108
 
@@ -94,65 +113,106 @@ This attribute is observed by the websocket provider to push the error to the cl
94
113
 
95
114
  @property
96
115
  def key_type(self) -> str:
97
- if constants.CLAUDE_API_KEY and not self.api_key:
98
- return "claude"
99
- if constants.GEMINI_API_KEY and not self.api_key:
100
- return "gemini"
101
- if self._openai_client:
102
- return self._openai_client.key_type
116
+ # TODO: We should validate that these keys are actually valid for the provider
117
+ # otherwise it will look like we are using the user_key when actually falling back
118
+ # to the mito server because the key is invalid.
119
+ if is_litellm_configured():
120
+ return USER_KEY
121
+
122
+ if constants.ANTHROPIC_API_KEY or constants.GEMINI_API_KEY or constants.OPENAI_API_KEY or constants.OLLAMA_MODEL:
123
+ return USER_KEY
124
+
103
125
  return MITO_SERVER_KEY
104
126
 
105
127
  async def request_completions(
106
128
  self,
107
129
  message_type: MessageType,
108
130
  messages: List[ChatCompletionMessageParam],
109
- model: str,
110
131
  response_format_info: Optional[ResponseFormatInfo] = None,
111
132
  user_input: Optional[str] = None,
112
133
  thread_id: Optional[str] = None,
113
- max_retries: int = 3
134
+ max_retries: int = 3,
135
+ use_fast_model: bool = False,
136
+ use_smartest_model: bool = False
114
137
  ) -> str:
115
138
  """
116
139
  Request completions from the AI provider.
140
+
141
+ Args:
142
+ message_type: Type of message
143
+ messages: List of chat messages
144
+ response_format_info: Optional response format specification
145
+ user_input: Optional user input for logging
146
+ thread_id: Optional thread ID for logging
147
+ max_retries: Maximum number of retries
148
+ use_fast_model: If True, use the fastest model from the selected provider
149
+ use_smartest_model: If True, use the smartest model from the selected provider
117
150
  """
118
151
  self.last_error = None
119
152
  completion = None
120
153
  last_message_content = str(messages[-1].get('content', '')) if messages else ""
121
- model_type = get_model_provider(model)
154
+
155
+ # Get the model to use (selected model, fast model, or smartest model if requested)
156
+ selected_model = self.get_selected_model()
157
+ if use_smartest_model:
158
+ resolved_model = get_smartest_model_for_selected_model(selected_model)
159
+ elif use_fast_model:
160
+ resolved_model = get_fast_model_for_selected_model(selected_model)
161
+ else:
162
+ resolved_model = selected_model
163
+
164
+ # Validate model is in allowed list (uses same function as endpoint)
165
+ available_models = get_available_models()
166
+ if resolved_model not in available_models:
167
+ raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
168
+
169
+ # Get model provider type
170
+ model_type = get_model_provider(resolved_model)
122
171
 
123
172
  # Retry loop
124
173
  for attempt in range(max_retries + 1):
125
174
  try:
126
- if model_type == "claude":
127
- api_key = constants.CLAUDE_API_KEY
175
+ if model_type == "litellm":
176
+ from mito_ai.enterprise.litellm_client import LiteLLMClient
177
+ if not constants.LITELLM_BASE_URL:
178
+ raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
179
+ litellm_client = LiteLLMClient(api_key=constants.LITELLM_API_KEY, base_url=constants.LITELLM_BASE_URL)
180
+ completion = await litellm_client.request_completions(
181
+ messages=messages,
182
+ model=resolved_model,
183
+ response_format_info=response_format_info,
184
+ message_type=message_type
185
+ )
186
+ elif model_type == "claude":
187
+ api_key = constants.ANTHROPIC_API_KEY
128
188
  anthropic_client = AnthropicClient(api_key=api_key)
129
- completion = await anthropic_client.request_completions(messages, model, response_format_info, message_type)
189
+ completion = await anthropic_client.request_completions(messages, resolved_model, response_format_info, message_type)
130
190
  elif model_type == "gemini":
131
191
  api_key = constants.GEMINI_API_KEY
132
192
  gemini_client = GeminiClient(api_key=api_key)
133
193
  messages_for_gemini = [dict(m) for m in messages]
134
- completion = await gemini_client.request_completions(messages_for_gemini, model, response_format_info, message_type)
194
+ completion = await gemini_client.request_completions(messages_for_gemini, resolved_model, response_format_info, message_type)
135
195
  elif model_type == "openai":
136
196
  if not self._openai_client:
137
197
  raise RuntimeError("OpenAI client is not initialized.")
138
198
  completion = await self._openai_client.request_completions(
139
199
  message_type=message_type,
140
200
  messages=messages,
141
- model=model,
201
+ model=resolved_model,
142
202
  response_format_info=response_format_info
143
203
  )
144
204
  else:
145
- raise ValueError(f"No AI provider configured for model: {model}")
205
+ raise ValueError(f"No AI provider configured for model: {resolved_model}")
146
206
 
147
207
  # Success! Log and return
148
208
  log_ai_completion_success(
149
- key_type=USER_KEY if self.key_type == "user" else MITO_SERVER_KEY,
209
+ key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
150
210
  message_type=message_type,
151
211
  last_message_content=last_message_content,
152
212
  response={"completion": completion},
153
213
  user_input=user_input or "",
154
214
  thread_id=thread_id or "",
155
- model=model
215
+ model=resolved_model
156
216
  )
157
217
  return completion # type: ignore
158
218
 
@@ -160,7 +220,7 @@ This attribute is observed by the websocket provider to push the error to the cl
160
220
  # If we hit a free tier limit, then raise an exception right away without retrying.
161
221
  self.log.exception(f"Error during request_completions: {e}")
162
222
  self.last_error = CompletionError.from_exception(e)
163
- log_ai_completion_error('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id or "", message_type, e)
223
+ log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
164
224
  raise
165
225
 
166
226
  except BaseException as e:
@@ -169,14 +229,14 @@ This attribute is observed by the websocket provider to push the error to the cl
169
229
  # Exponential backoff: wait 2^attempt seconds
170
230
  wait_time = 2 ** attempt
171
231
  self.log.info(f"Retrying request_completions after {wait_time}s (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
172
- log_ai_completion_retry('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id or "", message_type, e)
232
+ log_ai_completion_retry(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
173
233
  await asyncio.sleep(wait_time)
174
234
  continue
175
235
  else:
176
236
  # Final failure after all retries - set error state and raise
177
237
  self.log.exception(f"Error during request_completions after {attempt + 1} attempts: {e}")
178
238
  self.last_error = CompletionError.from_exception(e)
179
- log_ai_completion_error('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id or "", message_type, e)
239
+ log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
180
240
  raise
181
241
 
182
242
  # This should never be reached due to the raise in the except block,
@@ -187,21 +247,50 @@ This attribute is observed by the websocket provider to push the error to the cl
187
247
  self,
188
248
  message_type: MessageType,
189
249
  messages: List[ChatCompletionMessageParam],
190
- model: str,
191
250
  message_id: str,
192
251
  thread_id: str,
193
252
  reply_fn: Callable[[Union[CompletionReply, CompletionStreamChunk]], None],
194
253
  user_input: Optional[str] = None,
195
- response_format_info: Optional[ResponseFormatInfo] = None
254
+ response_format_info: Optional[ResponseFormatInfo] = None,
255
+ use_fast_model: bool = False,
256
+ use_smartest_model: bool = False
196
257
  ) -> str:
197
258
  """
198
259
  Stream completions from the AI provider and return the accumulated response.
260
+
261
+ Args:
262
+ message_type: Type of message
263
+ messages: List of chat messages
264
+ message_id: ID of the message being processed
265
+ thread_id: Thread ID for logging
266
+ reply_fn: Function to call with each chunk for streaming replies
267
+ user_input: Optional user input for logging
268
+ response_format_info: Optional response format specification
269
+ use_fast_model: If True, use the fastest model from the selected provider
270
+ use_smartest_model: If True, use the smartest model from the selected provider
271
+
199
272
  Returns: The accumulated response string.
200
273
  """
201
274
  self.last_error = None
202
275
  accumulated_response = ""
203
276
  last_message_content = str(messages[-1].get('content', '')) if messages else ""
204
- model_type = get_model_provider(model)
277
+
278
+ # Get the model to use (selected model, fast model, or smartest model if requested)
279
+ selected_model = self.get_selected_model()
280
+ if use_smartest_model:
281
+ resolved_model = get_smartest_model_for_selected_model(selected_model)
282
+ elif use_fast_model:
283
+ resolved_model = get_fast_model_for_selected_model(selected_model)
284
+ else:
285
+ resolved_model = selected_model
286
+
287
+ # Validate model is in allowed list (uses same function as endpoint)
288
+ available_models = get_available_models()
289
+ if resolved_model not in available_models:
290
+ raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
291
+
292
+ # Get model provider type
293
+ model_type = get_model_provider(resolved_model)
205
294
  reply_fn(CompletionReply(
206
295
  items=[
207
296
  CompletionItem(content="", isIncomplete=True, token=message_id)
@@ -210,12 +299,28 @@ This attribute is observed by the websocket provider to push the error to the cl
210
299
  ))
211
300
 
212
301
  try:
213
- if model_type == "claude":
214
- api_key = constants.CLAUDE_API_KEY
302
+ if model_type == "litellm":
303
+ from mito_ai.enterprise.litellm_client import LiteLLMClient
304
+ if not constants.LITELLM_BASE_URL:
305
+ raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
306
+ litellm_client = LiteLLMClient(
307
+ api_key=constants.LITELLM_API_KEY,
308
+ base_url=constants.LITELLM_BASE_URL
309
+ )
310
+ accumulated_response = await litellm_client.stream_completions(
311
+ messages=messages,
312
+ model=resolved_model,
313
+ message_type=message_type,
314
+ message_id=message_id,
315
+ reply_fn=reply_fn,
316
+ response_format_info=response_format_info
317
+ )
318
+ elif model_type == "claude":
319
+ api_key = constants.ANTHROPIC_API_KEY
215
320
  anthropic_client = AnthropicClient(api_key=api_key)
216
321
  accumulated_response = await anthropic_client.stream_completions(
217
322
  messages=messages,
218
- model=model,
323
+ model=resolved_model,
219
324
  message_type=message_type,
220
325
  message_id=message_id,
221
326
  reply_fn=reply_fn
@@ -228,7 +333,7 @@ This attribute is observed by the websocket provider to push the error to the cl
228
333
  messages_for_gemini = [dict(m) for m in messages]
229
334
  accumulated_response = await gemini_client.stream_completions(
230
335
  messages=messages_for_gemini,
231
- model=model,
336
+ model=resolved_model,
232
337
  message_id=message_id,
233
338
  reply_fn=reply_fn,
234
339
  message_type=message_type
@@ -239,7 +344,7 @@ This attribute is observed by the websocket provider to push the error to the cl
239
344
  accumulated_response = await self._openai_client.stream_completions(
240
345
  message_type=message_type,
241
346
  messages=messages,
242
- model=model,
347
+ model=resolved_model,
243
348
  message_id=message_id,
244
349
  thread_id=thread_id,
245
350
  reply_fn=reply_fn,
@@ -247,24 +352,24 @@ This attribute is observed by the websocket provider to push the error to the cl
247
352
  response_format_info=response_format_info
248
353
  )
249
354
  else:
250
- raise ValueError(f"No AI provider configured for model: {model}")
355
+ raise ValueError(f"No AI provider configured for model: {resolved_model}")
251
356
 
252
357
  # Log the successful completion
253
358
  log_ai_completion_success(
254
- key_type=USER_KEY if self.key_type == "user" else MITO_SERVER_KEY,
359
+ key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
255
360
  message_type=message_type,
256
361
  last_message_content=last_message_content,
257
362
  response={"completion": accumulated_response},
258
363
  user_input=user_input or "",
259
364
  thread_id=thread_id,
260
- model=model
365
+ model=resolved_model
261
366
  )
262
367
  return accumulated_response
263
368
 
264
369
  except BaseException as e:
265
370
  self.log.exception(f"Error during stream_completions: {e}")
266
371
  self.last_error = CompletionError.from_exception(e)
267
- log_ai_completion_error('user_key' if self.key_type != MITO_SERVER_KEY else 'mito_server_key', thread_id, message_type, e)
372
+ log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id, message_type, e)
268
373
 
269
374
  # Send error message to client before raising
270
375
  reply_fn(CompletionStreamChunk(
@@ -281,4 +386,3 @@ This attribute is observed by the websocket provider to push the error to the cl
281
386
  error=CompletionError.from_exception(e),
282
387
  ))
283
388
  raise
284
-
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ import json
5
+ import tornado
6
+ from jupyter_server.base.handlers import APIHandler
7
+ from mito_ai.utils.model_utils import get_available_models
8
+
9
+
10
+ class AvailableModelsHandler(APIHandler):
11
+ """REST handler for returning available models to the frontend."""
12
+
13
+ @tornado.web.authenticated
14
+ async def get(self) -> None:
15
+ """GET endpoint that returns the list of available models."""
16
+ try:
17
+ available_models = get_available_models()
18
+
19
+ self.write({
20
+ "models": available_models
21
+ })
22
+ self.finish()
23
+ except Exception as e:
24
+ self.set_status(500)
25
+ self.write({"error": str(e)})
26
+ self.finish()
mito_ai/settings/urls.py CHANGED
@@ -4,6 +4,7 @@
4
4
  from typing import Any, List, Tuple
5
5
  from jupyter_server.utils import url_path_join
6
6
  from mito_ai.settings.handlers import SettingsHandler
7
+ from mito_ai.settings.enterprise_handler import AvailableModelsHandler
7
8
 
8
9
  def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
9
10
  """Get all settings related URL patterns.
@@ -17,4 +18,5 @@ def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
17
18
  BASE_URL = base_url + "/mito-ai"
18
19
  return [
19
20
  (url_path_join(BASE_URL, "settings/(.*)"), SettingsHandler, {}),
21
+ (url_path_join(BASE_URL, "available-models"), AvailableModelsHandler, {}),
20
22
  ]
@@ -1,37 +1,9 @@
1
1
  # Copyright (c) Saga Inc.
2
2
  # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
3
 
4
- from typing import List, Tuple
5
- import re
6
- from anthropic.types import MessageParam
7
- from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
8
- from mito_ai.utils.anthropic_utils import stream_anthropic_completion_from_mito_server
4
+ from typing import List
9
5
  from mito_ai.streamlit_conversion.prompts.prompt_constants import MITO_TODO_PLACEHOLDER
10
- from mito_ai.completions.models import MessageType
11
-
12
- STREAMLIT_AI_MODEL = "claude-sonnet-4-5-20250929"
13
6
 
14
7
  def extract_todo_placeholders(agent_response: str) -> List[str]:
15
8
  """Extract TODO placeholders from the agent's response"""
16
- return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
17
-
18
- async def get_response_from_agent(message_to_agent: List[MessageParam]) -> str:
19
- """Gets the streaming response from the agent using the mito server"""
20
- model = STREAMLIT_AI_MODEL
21
- max_tokens = 64000 # TODO: If we move to haiku, we must reset this to 8192
22
- temperature = 0.2
23
-
24
- accumulated_response = ""
25
- async for stream_chunk in stream_anthropic_completion_from_mito_server(
26
- model = model,
27
- max_tokens = max_tokens,
28
- temperature = temperature,
29
- system = streamlit_system_prompt,
30
- messages = message_to_agent,
31
- stream=True,
32
- message_type=MessageType.STREAMLIT_CONVERSION,
33
- reply_fn=None,
34
- message_id=""
35
- ):
36
- accumulated_response += stream_chunk
37
- return accumulated_response
9
+ return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]