mito-ai 0.1.57__py3-none-any.whl → 0.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mito_ai/__init__.py +16 -22
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +24 -14
- mito_ai/chart_wizard/handlers.py +78 -17
- mito_ai/chart_wizard/urls.py +8 -5
- mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
- mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
- mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
- mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
- mito_ai/completions/completion_handlers/completion_handler.py +3 -5
- mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
- mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
- mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
- mito_ai/completions/completion_handlers/utils.py +3 -7
- mito_ai/completions/handlers.py +32 -22
- mito_ai/completions/message_history.py +8 -10
- mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
- mito_ai/constants.py +8 -1
- mito_ai/enterprise/__init__.py +1 -1
- mito_ai/enterprise/litellm_client.py +137 -0
- mito_ai/log/handlers.py +1 -1
- mito_ai/openai_client.py +10 -90
- mito_ai/{completions/providers.py → provider_manager.py} +157 -53
- mito_ai/settings/enterprise_handler.py +26 -0
- mito_ai/settings/urls.py +2 -0
- mito_ai/streamlit_conversion/agent_utils.py +2 -30
- mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
- mito_ai/streamlit_preview/handlers.py +6 -3
- mito_ai/streamlit_preview/urls.py +5 -3
- mito_ai/tests/message_history/test_generate_short_chat_name.py +72 -28
- mito_ai/tests/providers/test_anthropic_client.py +174 -16
- mito_ai/tests/providers/test_azure.py +13 -13
- mito_ai/tests/providers/test_capabilities.py +14 -17
- mito_ai/tests/providers/test_gemini_client.py +14 -13
- mito_ai/tests/providers/test_model_resolution.py +145 -89
- mito_ai/tests/providers/test_openai_client.py +209 -13
- mito_ai/tests/providers/test_provider_limits.py +5 -5
- mito_ai/tests/providers/test_providers.py +229 -51
- mito_ai/tests/providers/test_retry_logic.py +13 -22
- mito_ai/tests/providers/utils.py +4 -4
- mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
- mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
- mito_ai/tests/test_enterprise_mode.py +162 -0
- mito_ai/tests/test_model_utils.py +271 -0
- mito_ai/utils/anthropic_utils.py +8 -6
- mito_ai/utils/gemini_utils.py +0 -3
- mito_ai/utils/litellm_utils.py +84 -0
- mito_ai/utils/model_utils.py +178 -0
- mito_ai/utils/open_ai_utils.py +0 -8
- mito_ai/utils/provider_utils.py +6 -28
- mito_ai/utils/telemetry_utils.py +14 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js +671 -75
- mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.03302cc521d72eb56b00.js.map +1 -0
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js +17 -17
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.58.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.570df809a692f53a7ab7.js.map +1 -1
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/METADATA +2 -1
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/RECORD +86 -79
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.58.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.58.dist-info}/licenses/LICENSE +0 -0
mito_ai/openai_client.py
CHANGED
|
@@ -7,7 +7,7 @@ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
|
|
7
7
|
from mito_ai.utils.mito_server_utils import ProviderCompletionException
|
|
8
8
|
import openai
|
|
9
9
|
from openai.types.chat import ChatCompletionMessageParam
|
|
10
|
-
from traitlets import Instance,
|
|
10
|
+
from traitlets import Instance, default, validate
|
|
11
11
|
from traitlets.config import LoggingConfigurable
|
|
12
12
|
|
|
13
13
|
from mito_ai import constants
|
|
@@ -30,22 +30,12 @@ from mito_ai.utils.open_ai_utils import (
|
|
|
30
30
|
stream_ai_completion_from_mito_server,
|
|
31
31
|
)
|
|
32
32
|
from mito_ai.utils.server_limits import update_mito_server_quota
|
|
33
|
-
from mito_ai.utils.telemetry_utils import (
|
|
34
|
-
MITO_SERVER_KEY,
|
|
35
|
-
USER_KEY,
|
|
36
|
-
)
|
|
37
33
|
|
|
38
34
|
OPENAI_MODEL_FALLBACK = "gpt-4.1"
|
|
39
35
|
|
|
40
36
|
class OpenAIClient(LoggingConfigurable):
|
|
41
37
|
"""Provide AI feature through OpenAI services."""
|
|
42
38
|
|
|
43
|
-
api_key = Unicode(
|
|
44
|
-
config=True,
|
|
45
|
-
allow_none=True,
|
|
46
|
-
help="OpenAI API key. Default value is read from the OPENAI_API_KEY environment variable.",
|
|
47
|
-
)
|
|
48
|
-
|
|
49
39
|
last_error = Instance(
|
|
50
40
|
CompletionError,
|
|
51
41
|
allow_none=True,
|
|
@@ -65,61 +55,6 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
65
55
|
super().__init__(log=get_logger(), **kwargs)
|
|
66
56
|
self.last_error = None
|
|
67
57
|
self._async_client: Optional[openai.AsyncOpenAI] = None
|
|
68
|
-
|
|
69
|
-
@default("api_key")
|
|
70
|
-
def _api_key_default(self) -> Optional[str]:
|
|
71
|
-
default_key = constants.OPENAI_API_KEY
|
|
72
|
-
return self._validate_api_key(default_key)
|
|
73
|
-
|
|
74
|
-
@validate("api_key")
|
|
75
|
-
def _validate_api_key(self, api_key: Optional[str]) -> Optional[str]:
|
|
76
|
-
if not api_key:
|
|
77
|
-
self.log.debug(
|
|
78
|
-
"No OpenAI API key provided; following back to Mito server API."
|
|
79
|
-
)
|
|
80
|
-
return None
|
|
81
|
-
|
|
82
|
-
client = openai.OpenAI(api_key=api_key)
|
|
83
|
-
try:
|
|
84
|
-
# Make an http request to OpenAI to make sure it works
|
|
85
|
-
client.models.list()
|
|
86
|
-
except openai.AuthenticationError as e:
|
|
87
|
-
self.log.warning(
|
|
88
|
-
"Invalid OpenAI API key provided.",
|
|
89
|
-
exc_info=e,
|
|
90
|
-
)
|
|
91
|
-
self.last_error = CompletionError.from_exception(
|
|
92
|
-
e,
|
|
93
|
-
hint="You're missing the OPENAI_API_KEY environment variable. Run the following code in your terminal to set the environment variable and then relaunch the jupyter server `export OPENAI_API_KEY=<your-api-key>`",
|
|
94
|
-
)
|
|
95
|
-
return None
|
|
96
|
-
except openai.PermissionDeniedError as e:
|
|
97
|
-
self.log.warning(
|
|
98
|
-
"Invalid OpenAI API key provided.",
|
|
99
|
-
exc_info=e,
|
|
100
|
-
)
|
|
101
|
-
self.last_error = CompletionError.from_exception(e)
|
|
102
|
-
return None
|
|
103
|
-
except openai.InternalServerError as e:
|
|
104
|
-
self.log.debug(
|
|
105
|
-
"Unable to get OpenAI models due to OpenAI error.", exc_info=e
|
|
106
|
-
)
|
|
107
|
-
return api_key
|
|
108
|
-
except openai.RateLimitError as e:
|
|
109
|
-
self.log.debug(
|
|
110
|
-
"Unable to get OpenAI models due to rate limit error.", exc_info=e
|
|
111
|
-
)
|
|
112
|
-
return api_key
|
|
113
|
-
except openai.APIConnectionError as e:
|
|
114
|
-
self.log.warning(
|
|
115
|
-
"Unable to connect to OpenAI API.",
|
|
116
|
-
exec_info=e,
|
|
117
|
-
)
|
|
118
|
-
self.last_error = CompletionError.from_exception(e)
|
|
119
|
-
return None
|
|
120
|
-
else:
|
|
121
|
-
self.log.debug("User OpenAI API key validated.")
|
|
122
|
-
return api_key
|
|
123
58
|
|
|
124
59
|
@property
|
|
125
60
|
def capabilities(self) -> AICapabilities:
|
|
@@ -133,7 +68,7 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
133
68
|
provider="Azure OpenAI",
|
|
134
69
|
)
|
|
135
70
|
|
|
136
|
-
if constants.OLLAMA_MODEL
|
|
71
|
+
if constants.OLLAMA_MODEL:
|
|
137
72
|
return AICapabilities(
|
|
138
73
|
configuration={
|
|
139
74
|
"model": constants.OLLAMA_MODEL
|
|
@@ -141,14 +76,12 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
141
76
|
provider="Ollama",
|
|
142
77
|
)
|
|
143
78
|
|
|
144
|
-
if
|
|
145
|
-
self._validate_api_key(self.api_key)
|
|
146
|
-
|
|
79
|
+
if constants.OPENAI_API_KEY:
|
|
147
80
|
return AICapabilities(
|
|
148
81
|
configuration={
|
|
149
|
-
"model":
|
|
82
|
+
"model": "<dynamic>"
|
|
150
83
|
},
|
|
151
|
-
provider="OpenAI
|
|
84
|
+
provider="OpenAI",
|
|
152
85
|
)
|
|
153
86
|
|
|
154
87
|
try:
|
|
@@ -169,19 +102,6 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
169
102
|
if not self._async_client or self._async_client.is_closed():
|
|
170
103
|
self._async_client = self._build_openai_client()
|
|
171
104
|
return self._async_client
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
@property
|
|
175
|
-
def key_type(self) -> str:
|
|
176
|
-
"""Returns the authentication key type being used."""
|
|
177
|
-
|
|
178
|
-
if self.api_key:
|
|
179
|
-
return USER_KEY
|
|
180
|
-
|
|
181
|
-
if constants.OLLAMA_MODEL:
|
|
182
|
-
return "ollama"
|
|
183
|
-
|
|
184
|
-
return MITO_SERVER_KEY
|
|
185
105
|
|
|
186
106
|
def _build_openai_client(self) -> Optional[Union[openai.AsyncOpenAI, openai.AsyncAzureOpenAI]]:
|
|
187
107
|
base_url = None
|
|
@@ -201,12 +121,12 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
201
121
|
timeout=self.timeout,
|
|
202
122
|
)
|
|
203
123
|
|
|
204
|
-
elif constants.OLLAMA_MODEL
|
|
124
|
+
elif constants.OLLAMA_MODEL:
|
|
205
125
|
base_url = constants.OLLAMA_BASE_URL
|
|
206
126
|
llm_api_key = "ollama"
|
|
207
127
|
self.log.debug(f"Using Ollama with model: {constants.OLLAMA_MODEL}")
|
|
208
|
-
elif
|
|
209
|
-
llm_api_key =
|
|
128
|
+
elif constants.OPENAI_API_KEY:
|
|
129
|
+
llm_api_key = constants.OPENAI_API_KEY
|
|
210
130
|
self.log.debug("Using OpenAI with user-provided API key")
|
|
211
131
|
else:
|
|
212
132
|
self.log.warning("No valid API key or model configuration provided")
|
|
@@ -262,7 +182,7 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
262
182
|
|
|
263
183
|
# Handle other providers as before
|
|
264
184
|
completion_function_params = get_open_ai_completion_function_params(
|
|
265
|
-
|
|
185
|
+
model, messages, False, response_format_info
|
|
266
186
|
)
|
|
267
187
|
|
|
268
188
|
# If they have set an Azure OpenAI or Ollama model, then we use it
|
|
@@ -313,7 +233,7 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
313
233
|
|
|
314
234
|
# Handle other providers as before
|
|
315
235
|
completion_function_params = get_open_ai_completion_function_params(
|
|
316
|
-
|
|
236
|
+
model, messages, True, response_format_info
|
|
317
237
|
)
|
|
318
238
|
|
|
319
239
|
completion_function_params["model"] = self._adjust_model_for_azure_or_ollama(completion_function_params["model"])
|
|
@@ -6,7 +6,7 @@ import asyncio
|
|
|
6
6
|
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
|
7
7
|
from mito_ai import constants
|
|
8
8
|
from openai.types.chat import ChatCompletionMessageParam
|
|
9
|
-
from traitlets import Instance
|
|
9
|
+
from traitlets import Instance
|
|
10
10
|
from traitlets.config import LoggingConfigurable
|
|
11
11
|
from openai.types.chat import ChatCompletionMessageParam
|
|
12
12
|
|
|
@@ -24,32 +24,23 @@ from mito_ai.completions.models import (
|
|
|
24
24
|
CompletionReply,
|
|
25
25
|
CompletionStreamChunk,
|
|
26
26
|
MessageType,
|
|
27
|
-
ResponseFormatInfo,
|
|
27
|
+
ResponseFormatInfo,
|
|
28
28
|
)
|
|
29
|
+
from mito_ai.utils.litellm_utils import is_litellm_configured
|
|
29
30
|
from mito_ai.utils.telemetry_utils import (
|
|
30
|
-
KEY_TYPE_PARAM,
|
|
31
|
-
MITO_AI_COMPLETION_ERROR,
|
|
32
|
-
MITO_AI_COMPLETION_RETRY,
|
|
33
31
|
MITO_SERVER_KEY,
|
|
34
32
|
USER_KEY,
|
|
35
|
-
log,
|
|
36
33
|
log_ai_completion_error,
|
|
37
34
|
log_ai_completion_retry,
|
|
38
35
|
log_ai_completion_success,
|
|
39
36
|
)
|
|
40
37
|
from mito_ai.utils.provider_utils import get_model_provider
|
|
41
|
-
from mito_ai.utils.
|
|
38
|
+
from mito_ai.utils.model_utils import get_available_models, get_fast_model_for_selected_model, get_smartest_model_for_selected_model
|
|
42
39
|
|
|
43
|
-
__all__ = ["
|
|
40
|
+
__all__ = ["ProviderManager"]
|
|
44
41
|
|
|
45
|
-
class
|
|
46
|
-
"""
|
|
47
|
-
|
|
48
|
-
api_key = Unicode(
|
|
49
|
-
config=True,
|
|
50
|
-
allow_none=True,
|
|
51
|
-
help="OpenAI API key. Default value is read from the OPENAI_API_KEY environment variable.",
|
|
52
|
-
)
|
|
42
|
+
class ProviderManager(LoggingConfigurable):
|
|
43
|
+
"""Manage AI providers (Claude, Gemini, OpenAI) and route requests to the appropriate client."""
|
|
53
44
|
|
|
54
45
|
last_error = Instance(
|
|
55
46
|
CompletionError,
|
|
@@ -61,29 +52,57 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
61
52
|
|
|
62
53
|
def __init__(self, **kwargs: Dict[str, Any]) -> None:
|
|
63
54
|
config = kwargs.get('config', {})
|
|
64
|
-
if 'api_key' in kwargs:
|
|
65
|
-
config['OpenAIClient'] = {'api_key': kwargs['api_key']}
|
|
66
55
|
kwargs['config'] = config
|
|
67
56
|
|
|
68
57
|
super().__init__(log=get_logger(), **kwargs)
|
|
69
58
|
self.last_error = None
|
|
70
59
|
self._openai_client: Optional[OpenAIClient] = OpenAIClient(**config)
|
|
60
|
+
# Initialize with the first available model to ensure it's always valid
|
|
61
|
+
# This respects LiteLLM configuration: if LiteLLM is configured, uses first LiteLLM model
|
|
62
|
+
# Otherwise, uses first standard model
|
|
63
|
+
available_models = get_available_models()
|
|
64
|
+
self._selected_model: str = available_models[0] if available_models else "gpt-4.1"
|
|
65
|
+
|
|
66
|
+
def get_selected_model(self) -> str:
|
|
67
|
+
"""Get the currently selected model."""
|
|
68
|
+
return self._selected_model
|
|
69
|
+
|
|
70
|
+
def set_selected_model(self, model: str) -> None:
|
|
71
|
+
"""Set the selected model."""
|
|
72
|
+
self._selected_model = model
|
|
71
73
|
|
|
72
74
|
@property
|
|
73
75
|
def capabilities(self) -> AICapabilities:
|
|
74
76
|
"""
|
|
75
77
|
Returns the capabilities of the AI provider.
|
|
76
78
|
"""
|
|
77
|
-
|
|
79
|
+
# TODO: We should validate that these keys are actually valid for the provider
|
|
80
|
+
# otherwise it will look like we are using the user_key when actually falling back
|
|
81
|
+
# to the mito server because the key is invalid.
|
|
82
|
+
if is_litellm_configured():
|
|
83
|
+
return AICapabilities(
|
|
84
|
+
configuration={"model": "<dynamic>"},
|
|
85
|
+
provider="LiteLLM",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if constants.OPENAI_API_KEY:
|
|
89
|
+
return AICapabilities(
|
|
90
|
+
configuration={"model": "<dynamic>"},
|
|
91
|
+
provider="OpenAI",
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
if constants.ANTHROPIC_API_KEY:
|
|
78
95
|
return AICapabilities(
|
|
79
96
|
configuration={"model": "<dynamic>"},
|
|
80
97
|
provider="Claude",
|
|
81
98
|
)
|
|
82
|
-
|
|
99
|
+
|
|
100
|
+
if constants.GEMINI_API_KEY:
|
|
83
101
|
return AICapabilities(
|
|
84
102
|
configuration={"model": "<dynamic>"},
|
|
85
103
|
provider="Gemini",
|
|
86
104
|
)
|
|
105
|
+
|
|
87
106
|
if self._openai_client:
|
|
88
107
|
return self._openai_client.capabilities
|
|
89
108
|
|
|
@@ -94,65 +113,106 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
94
113
|
|
|
95
114
|
@property
|
|
96
115
|
def key_type(self) -> str:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
116
|
+
# TODO: We should validate that these keys are actually valid for the provider
|
|
117
|
+
# otherwise it will look like we are using the user_key when actually falling back
|
|
118
|
+
# to the mito server because the key is invalid.
|
|
119
|
+
if is_litellm_configured():
|
|
120
|
+
return USER_KEY
|
|
121
|
+
|
|
122
|
+
if constants.ANTHROPIC_API_KEY or constants.GEMINI_API_KEY or constants.OPENAI_API_KEY or constants.OLLAMA_MODEL:
|
|
123
|
+
return USER_KEY
|
|
124
|
+
|
|
103
125
|
return MITO_SERVER_KEY
|
|
104
126
|
|
|
105
127
|
async def request_completions(
|
|
106
128
|
self,
|
|
107
129
|
message_type: MessageType,
|
|
108
130
|
messages: List[ChatCompletionMessageParam],
|
|
109
|
-
model: str,
|
|
110
131
|
response_format_info: Optional[ResponseFormatInfo] = None,
|
|
111
132
|
user_input: Optional[str] = None,
|
|
112
133
|
thread_id: Optional[str] = None,
|
|
113
|
-
max_retries: int = 3
|
|
134
|
+
max_retries: int = 3,
|
|
135
|
+
use_fast_model: bool = False,
|
|
136
|
+
use_smartest_model: bool = False
|
|
114
137
|
) -> str:
|
|
115
138
|
"""
|
|
116
139
|
Request completions from the AI provider.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
message_type: Type of message
|
|
143
|
+
messages: List of chat messages
|
|
144
|
+
response_format_info: Optional response format specification
|
|
145
|
+
user_input: Optional user input for logging
|
|
146
|
+
thread_id: Optional thread ID for logging
|
|
147
|
+
max_retries: Maximum number of retries
|
|
148
|
+
use_fast_model: If True, use the fastest model from the selected provider
|
|
149
|
+
use_smartest_model: If True, use the smartest model from the selected provider
|
|
117
150
|
"""
|
|
118
151
|
self.last_error = None
|
|
119
152
|
completion = None
|
|
120
153
|
last_message_content = str(messages[-1].get('content', '')) if messages else ""
|
|
121
|
-
|
|
154
|
+
|
|
155
|
+
# Get the model to use (selected model, fast model, or smartest model if requested)
|
|
156
|
+
selected_model = self.get_selected_model()
|
|
157
|
+
if use_smartest_model:
|
|
158
|
+
resolved_model = get_smartest_model_for_selected_model(selected_model)
|
|
159
|
+
elif use_fast_model:
|
|
160
|
+
resolved_model = get_fast_model_for_selected_model(selected_model)
|
|
161
|
+
else:
|
|
162
|
+
resolved_model = selected_model
|
|
163
|
+
|
|
164
|
+
# Validate model is in allowed list (uses same function as endpoint)
|
|
165
|
+
available_models = get_available_models()
|
|
166
|
+
if resolved_model not in available_models:
|
|
167
|
+
raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
|
|
168
|
+
|
|
169
|
+
# Get model provider type
|
|
170
|
+
model_type = get_model_provider(resolved_model)
|
|
122
171
|
|
|
123
172
|
# Retry loop
|
|
124
173
|
for attempt in range(max_retries + 1):
|
|
125
174
|
try:
|
|
126
|
-
if model_type == "
|
|
127
|
-
|
|
175
|
+
if model_type == "litellm":
|
|
176
|
+
from mito_ai.enterprise.litellm_client import LiteLLMClient
|
|
177
|
+
if not constants.LITELLM_BASE_URL:
|
|
178
|
+
raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
|
|
179
|
+
litellm_client = LiteLLMClient(api_key=constants.LITELLM_API_KEY, base_url=constants.LITELLM_BASE_URL)
|
|
180
|
+
completion = await litellm_client.request_completions(
|
|
181
|
+
messages=messages,
|
|
182
|
+
model=resolved_model,
|
|
183
|
+
response_format_info=response_format_info,
|
|
184
|
+
message_type=message_type
|
|
185
|
+
)
|
|
186
|
+
elif model_type == "claude":
|
|
187
|
+
api_key = constants.ANTHROPIC_API_KEY
|
|
128
188
|
anthropic_client = AnthropicClient(api_key=api_key)
|
|
129
|
-
completion = await anthropic_client.request_completions(messages,
|
|
189
|
+
completion = await anthropic_client.request_completions(messages, resolved_model, response_format_info, message_type)
|
|
130
190
|
elif model_type == "gemini":
|
|
131
191
|
api_key = constants.GEMINI_API_KEY
|
|
132
192
|
gemini_client = GeminiClient(api_key=api_key)
|
|
133
193
|
messages_for_gemini = [dict(m) for m in messages]
|
|
134
|
-
completion = await gemini_client.request_completions(messages_for_gemini,
|
|
194
|
+
completion = await gemini_client.request_completions(messages_for_gemini, resolved_model, response_format_info, message_type)
|
|
135
195
|
elif model_type == "openai":
|
|
136
196
|
if not self._openai_client:
|
|
137
197
|
raise RuntimeError("OpenAI client is not initialized.")
|
|
138
198
|
completion = await self._openai_client.request_completions(
|
|
139
199
|
message_type=message_type,
|
|
140
200
|
messages=messages,
|
|
141
|
-
model=
|
|
201
|
+
model=resolved_model,
|
|
142
202
|
response_format_info=response_format_info
|
|
143
203
|
)
|
|
144
204
|
else:
|
|
145
|
-
raise ValueError(f"No AI provider configured for model: {
|
|
205
|
+
raise ValueError(f"No AI provider configured for model: {resolved_model}")
|
|
146
206
|
|
|
147
207
|
# Success! Log and return
|
|
148
208
|
log_ai_completion_success(
|
|
149
|
-
key_type=USER_KEY if self.key_type ==
|
|
209
|
+
key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
|
|
150
210
|
message_type=message_type,
|
|
151
211
|
last_message_content=last_message_content,
|
|
152
212
|
response={"completion": completion},
|
|
153
213
|
user_input=user_input or "",
|
|
154
214
|
thread_id=thread_id or "",
|
|
155
|
-
model=
|
|
215
|
+
model=resolved_model
|
|
156
216
|
)
|
|
157
217
|
return completion # type: ignore
|
|
158
218
|
|
|
@@ -160,7 +220,7 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
160
220
|
# If we hit a free tier limit, then raise an exception right away without retrying.
|
|
161
221
|
self.log.exception(f"Error during request_completions: {e}")
|
|
162
222
|
self.last_error = CompletionError.from_exception(e)
|
|
163
|
-
log_ai_completion_error(
|
|
223
|
+
log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
|
|
164
224
|
raise
|
|
165
225
|
|
|
166
226
|
except BaseException as e:
|
|
@@ -169,14 +229,14 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
169
229
|
# Exponential backoff: wait 2^attempt seconds
|
|
170
230
|
wait_time = 2 ** attempt
|
|
171
231
|
self.log.info(f"Retrying request_completions after {wait_time}s (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
|
|
172
|
-
log_ai_completion_retry(
|
|
232
|
+
log_ai_completion_retry(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
|
|
173
233
|
await asyncio.sleep(wait_time)
|
|
174
234
|
continue
|
|
175
235
|
else:
|
|
176
236
|
# Final failure after all retries - set error state and raise
|
|
177
237
|
self.log.exception(f"Error during request_completions after {attempt + 1} attempts: {e}")
|
|
178
238
|
self.last_error = CompletionError.from_exception(e)
|
|
179
|
-
log_ai_completion_error(
|
|
239
|
+
log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
|
|
180
240
|
raise
|
|
181
241
|
|
|
182
242
|
# This should never be reached due to the raise in the except block,
|
|
@@ -187,21 +247,50 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
187
247
|
self,
|
|
188
248
|
message_type: MessageType,
|
|
189
249
|
messages: List[ChatCompletionMessageParam],
|
|
190
|
-
model: str,
|
|
191
250
|
message_id: str,
|
|
192
251
|
thread_id: str,
|
|
193
252
|
reply_fn: Callable[[Union[CompletionReply, CompletionStreamChunk]], None],
|
|
194
253
|
user_input: Optional[str] = None,
|
|
195
|
-
response_format_info: Optional[ResponseFormatInfo] = None
|
|
254
|
+
response_format_info: Optional[ResponseFormatInfo] = None,
|
|
255
|
+
use_fast_model: bool = False,
|
|
256
|
+
use_smartest_model: bool = False
|
|
196
257
|
) -> str:
|
|
197
258
|
"""
|
|
198
259
|
Stream completions from the AI provider and return the accumulated response.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
message_type: Type of message
|
|
263
|
+
messages: List of chat messages
|
|
264
|
+
message_id: ID of the message being processed
|
|
265
|
+
thread_id: Thread ID for logging
|
|
266
|
+
reply_fn: Function to call with each chunk for streaming replies
|
|
267
|
+
user_input: Optional user input for logging
|
|
268
|
+
response_format_info: Optional response format specification
|
|
269
|
+
use_fast_model: If True, use the fastest model from the selected provider
|
|
270
|
+
use_smartest_model: If True, use the smartest model from the selected provider
|
|
271
|
+
|
|
199
272
|
Returns: The accumulated response string.
|
|
200
273
|
"""
|
|
201
274
|
self.last_error = None
|
|
202
275
|
accumulated_response = ""
|
|
203
276
|
last_message_content = str(messages[-1].get('content', '')) if messages else ""
|
|
204
|
-
|
|
277
|
+
|
|
278
|
+
# Get the model to use (selected model, fast model, or smartest model if requested)
|
|
279
|
+
selected_model = self.get_selected_model()
|
|
280
|
+
if use_smartest_model:
|
|
281
|
+
resolved_model = get_smartest_model_for_selected_model(selected_model)
|
|
282
|
+
elif use_fast_model:
|
|
283
|
+
resolved_model = get_fast_model_for_selected_model(selected_model)
|
|
284
|
+
else:
|
|
285
|
+
resolved_model = selected_model
|
|
286
|
+
|
|
287
|
+
# Validate model is in allowed list (uses same function as endpoint)
|
|
288
|
+
available_models = get_available_models()
|
|
289
|
+
if resolved_model not in available_models:
|
|
290
|
+
raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
|
|
291
|
+
|
|
292
|
+
# Get model provider type
|
|
293
|
+
model_type = get_model_provider(resolved_model)
|
|
205
294
|
reply_fn(CompletionReply(
|
|
206
295
|
items=[
|
|
207
296
|
CompletionItem(content="", isIncomplete=True, token=message_id)
|
|
@@ -210,12 +299,28 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
210
299
|
))
|
|
211
300
|
|
|
212
301
|
try:
|
|
213
|
-
if model_type == "
|
|
214
|
-
|
|
302
|
+
if model_type == "litellm":
|
|
303
|
+
from mito_ai.enterprise.litellm_client import LiteLLMClient
|
|
304
|
+
if not constants.LITELLM_BASE_URL:
|
|
305
|
+
raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
|
|
306
|
+
litellm_client = LiteLLMClient(
|
|
307
|
+
api_key=constants.LITELLM_API_KEY,
|
|
308
|
+
base_url=constants.LITELLM_BASE_URL
|
|
309
|
+
)
|
|
310
|
+
accumulated_response = await litellm_client.stream_completions(
|
|
311
|
+
messages=messages,
|
|
312
|
+
model=resolved_model,
|
|
313
|
+
message_type=message_type,
|
|
314
|
+
message_id=message_id,
|
|
315
|
+
reply_fn=reply_fn,
|
|
316
|
+
response_format_info=response_format_info
|
|
317
|
+
)
|
|
318
|
+
elif model_type == "claude":
|
|
319
|
+
api_key = constants.ANTHROPIC_API_KEY
|
|
215
320
|
anthropic_client = AnthropicClient(api_key=api_key)
|
|
216
321
|
accumulated_response = await anthropic_client.stream_completions(
|
|
217
322
|
messages=messages,
|
|
218
|
-
model=
|
|
323
|
+
model=resolved_model,
|
|
219
324
|
message_type=message_type,
|
|
220
325
|
message_id=message_id,
|
|
221
326
|
reply_fn=reply_fn
|
|
@@ -228,7 +333,7 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
228
333
|
messages_for_gemini = [dict(m) for m in messages]
|
|
229
334
|
accumulated_response = await gemini_client.stream_completions(
|
|
230
335
|
messages=messages_for_gemini,
|
|
231
|
-
model=
|
|
336
|
+
model=resolved_model,
|
|
232
337
|
message_id=message_id,
|
|
233
338
|
reply_fn=reply_fn,
|
|
234
339
|
message_type=message_type
|
|
@@ -239,7 +344,7 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
239
344
|
accumulated_response = await self._openai_client.stream_completions(
|
|
240
345
|
message_type=message_type,
|
|
241
346
|
messages=messages,
|
|
242
|
-
model=
|
|
347
|
+
model=resolved_model,
|
|
243
348
|
message_id=message_id,
|
|
244
349
|
thread_id=thread_id,
|
|
245
350
|
reply_fn=reply_fn,
|
|
@@ -247,24 +352,24 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
247
352
|
response_format_info=response_format_info
|
|
248
353
|
)
|
|
249
354
|
else:
|
|
250
|
-
raise ValueError(f"No AI provider configured for model: {
|
|
355
|
+
raise ValueError(f"No AI provider configured for model: {resolved_model}")
|
|
251
356
|
|
|
252
357
|
# Log the successful completion
|
|
253
358
|
log_ai_completion_success(
|
|
254
|
-
key_type=USER_KEY if self.key_type ==
|
|
359
|
+
key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
|
|
255
360
|
message_type=message_type,
|
|
256
361
|
last_message_content=last_message_content,
|
|
257
362
|
response={"completion": accumulated_response},
|
|
258
363
|
user_input=user_input or "",
|
|
259
364
|
thread_id=thread_id,
|
|
260
|
-
model=
|
|
365
|
+
model=resolved_model
|
|
261
366
|
)
|
|
262
367
|
return accumulated_response
|
|
263
368
|
|
|
264
369
|
except BaseException as e:
|
|
265
370
|
self.log.exception(f"Error during stream_completions: {e}")
|
|
266
371
|
self.last_error = CompletionError.from_exception(e)
|
|
267
|
-
log_ai_completion_error(
|
|
372
|
+
log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id, message_type, e)
|
|
268
373
|
|
|
269
374
|
# Send error message to client before raising
|
|
270
375
|
reply_fn(CompletionStreamChunk(
|
|
@@ -281,4 +386,3 @@ This attribute is observed by the websocket provider to push the error to the cl
|
|
|
281
386
|
error=CompletionError.from_exception(e),
|
|
282
387
|
))
|
|
283
388
|
raise
|
|
284
|
-
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import tornado
|
|
6
|
+
from jupyter_server.base.handlers import APIHandler
|
|
7
|
+
from mito_ai.utils.model_utils import get_available_models
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AvailableModelsHandler(APIHandler):
|
|
11
|
+
"""REST handler for returning available models to the frontend."""
|
|
12
|
+
|
|
13
|
+
@tornado.web.authenticated
|
|
14
|
+
async def get(self) -> None:
|
|
15
|
+
"""GET endpoint that returns the list of available models."""
|
|
16
|
+
try:
|
|
17
|
+
available_models = get_available_models()
|
|
18
|
+
|
|
19
|
+
self.write({
|
|
20
|
+
"models": available_models
|
|
21
|
+
})
|
|
22
|
+
self.finish()
|
|
23
|
+
except Exception as e:
|
|
24
|
+
self.set_status(500)
|
|
25
|
+
self.write({"error": str(e)})
|
|
26
|
+
self.finish()
|
mito_ai/settings/urls.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
from typing import Any, List, Tuple
|
|
5
5
|
from jupyter_server.utils import url_path_join
|
|
6
6
|
from mito_ai.settings.handlers import SettingsHandler
|
|
7
|
+
from mito_ai.settings.enterprise_handler import AvailableModelsHandler
|
|
7
8
|
|
|
8
9
|
def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
|
|
9
10
|
"""Get all settings related URL patterns.
|
|
@@ -17,4 +18,5 @@ def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
|
|
|
17
18
|
BASE_URL = base_url + "/mito-ai"
|
|
18
19
|
return [
|
|
19
20
|
(url_path_join(BASE_URL, "settings/(.*)"), SettingsHandler, {}),
|
|
21
|
+
(url_path_join(BASE_URL, "available-models"), AvailableModelsHandler, {}),
|
|
20
22
|
]
|
|
@@ -1,37 +1,9 @@
|
|
|
1
1
|
# Copyright (c) Saga Inc.
|
|
2
2
|
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
3
|
|
|
4
|
-
from typing import List
|
|
5
|
-
import re
|
|
6
|
-
from anthropic.types import MessageParam
|
|
7
|
-
from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
|
|
8
|
-
from mito_ai.utils.anthropic_utils import stream_anthropic_completion_from_mito_server
|
|
4
|
+
from typing import List
|
|
9
5
|
from mito_ai.streamlit_conversion.prompts.prompt_constants import MITO_TODO_PLACEHOLDER
|
|
10
|
-
from mito_ai.completions.models import MessageType
|
|
11
|
-
|
|
12
|
-
STREAMLIT_AI_MODEL = "claude-sonnet-4-5-20250929"
|
|
13
6
|
|
|
14
7
|
def extract_todo_placeholders(agent_response: str) -> List[str]:
|
|
15
8
|
"""Extract TODO placeholders from the agent's response"""
|
|
16
|
-
return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
|
|
17
|
-
|
|
18
|
-
async def get_response_from_agent(message_to_agent: List[MessageParam]) -> str:
|
|
19
|
-
"""Gets the streaming response from the agent using the mito server"""
|
|
20
|
-
model = STREAMLIT_AI_MODEL
|
|
21
|
-
max_tokens = 64000 # TODO: If we move to haiku, we must reset this to 8192
|
|
22
|
-
temperature = 0.2
|
|
23
|
-
|
|
24
|
-
accumulated_response = ""
|
|
25
|
-
async for stream_chunk in stream_anthropic_completion_from_mito_server(
|
|
26
|
-
model = model,
|
|
27
|
-
max_tokens = max_tokens,
|
|
28
|
-
temperature = temperature,
|
|
29
|
-
system = streamlit_system_prompt,
|
|
30
|
-
messages = message_to_agent,
|
|
31
|
-
stream=True,
|
|
32
|
-
message_type=MessageType.STREAMLIT_CONVERSION,
|
|
33
|
-
reply_fn=None,
|
|
34
|
-
message_id=""
|
|
35
|
-
):
|
|
36
|
-
accumulated_response += stream_chunk
|
|
37
|
-
return accumulated_response
|
|
9
|
+
return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
|