mito-ai 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mito_ai/__init__.py +19 -22
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +24 -14
- mito_ai/chart_wizard/handlers.py +78 -17
- mito_ai/chart_wizard/urls.py +8 -5
- mito_ai/completions/completion_handlers/agent_auto_error_fixup_handler.py +6 -8
- mito_ai/completions/completion_handlers/agent_execution_handler.py +6 -8
- mito_ai/completions/completion_handlers/chat_completion_handler.py +13 -17
- mito_ai/completions/completion_handlers/code_explain_handler.py +13 -17
- mito_ai/completions/completion_handlers/completion_handler.py +3 -5
- mito_ai/completions/completion_handlers/inline_completer_handler.py +5 -6
- mito_ai/completions/completion_handlers/scratchpad_result_handler.py +6 -8
- mito_ai/completions/completion_handlers/smart_debug_handler.py +13 -17
- mito_ai/completions/completion_handlers/utils.py +3 -7
- mito_ai/completions/handlers.py +32 -22
- mito_ai/completions/message_history.py +8 -10
- mito_ai/completions/prompt_builders/chart_add_field_prompt.py +35 -0
- mito_ai/completions/prompt_builders/prompt_constants.py +2 -0
- mito_ai/constants.py +31 -2
- mito_ai/enterprise/__init__.py +1 -1
- mito_ai/enterprise/litellm_client.py +144 -0
- mito_ai/enterprise/utils.py +16 -2
- mito_ai/log/handlers.py +1 -1
- mito_ai/openai_client.py +36 -96
- mito_ai/provider_manager.py +420 -0
- mito_ai/settings/enterprise_handler.py +26 -0
- mito_ai/settings/urls.py +2 -0
- mito_ai/streamlit_conversion/agent_utils.py +2 -30
- mito_ai/streamlit_conversion/streamlit_agent_handler.py +48 -46
- mito_ai/streamlit_preview/handlers.py +6 -3
- mito_ai/streamlit_preview/urls.py +5 -3
- mito_ai/tests/message_history/test_generate_short_chat_name.py +103 -28
- mito_ai/tests/open_ai_utils_test.py +34 -36
- mito_ai/tests/providers/test_anthropic_client.py +174 -16
- mito_ai/tests/providers/test_azure.py +15 -15
- mito_ai/tests/providers/test_capabilities.py +14 -17
- mito_ai/tests/providers/test_gemini_client.py +14 -13
- mito_ai/tests/providers/test_model_resolution.py +145 -89
- mito_ai/tests/providers/test_openai_client.py +209 -13
- mito_ai/tests/providers/test_provider_limits.py +5 -5
- mito_ai/tests/providers/test_providers.py +229 -51
- mito_ai/tests/providers/test_retry_logic.py +13 -22
- mito_ai/tests/providers/utils.py +4 -4
- mito_ai/tests/streamlit_conversion/test_streamlit_agent_handler.py +57 -85
- mito_ai/tests/streamlit_preview/test_streamlit_preview_handler.py +4 -1
- mito_ai/tests/test_constants.py +90 -0
- mito_ai/tests/test_enterprise_mode.py +217 -0
- mito_ai/tests/test_model_utils.py +362 -0
- mito_ai/utils/anthropic_utils.py +8 -6
- mito_ai/utils/gemini_utils.py +0 -3
- mito_ai/utils/litellm_utils.py +84 -0
- mito_ai/utils/model_utils.py +257 -0
- mito_ai/utils/open_ai_utils.py +29 -41
- mito_ai/utils/provider_utils.py +13 -29
- mito_ai/utils/telemetry_utils.py +14 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +102 -102
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.44c109c7be36fb884d25.js +1059 -144
- mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.44c109c7be36fb884d25.js.map +1 -0
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.f7decebaf69618541e0f.js +17 -17
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.79c1ea8a3cda73a4cb6f.js.map → mito_ai-0.1.59.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.f7decebaf69618541e0f.js.map +1 -1
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.css +78 -78
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/METADATA +2 -1
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/RECORD +90 -83
- mito_ai/completions/providers.py +0 -284
- mito_ai-0.1.57.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.9d26322f3e78beb2b666.js.map +0 -1
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/node_modules_process_browser_js.4b128e94d31a81ebd209.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.f5d476ac514294615881.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_apis_signOut_mjs-node_module-75790d.688c25857e7b81b1740f.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_auth_dist_esm_providers_cognito_tokenProvider_tokenProvider_-72f1c8.a917210f057fcfe224ad.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_dist_esm_index_mjs.6bac1a8c4cc93f15f6b7.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_aws-amplify_ui-react_dist_esm_index_mjs.4fcecd65bef9e9847609.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_react-dom_client_js-node_modules_aws-amplify_ui-react_dist_styles_css.b43d4249e4d3dac9ad7b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.3f6754ac5116d47de76b.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.57.data → mito_ai-0.1.59.data}/data/share/jupyter/labextensions/mito_ai/themes/mito_ai/index.js +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.57.dist-info → mito_ai-0.1.59.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
|
7
|
+
from mito_ai import constants
|
|
8
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
9
|
+
from traitlets import Instance
|
|
10
|
+
from traitlets.config import LoggingConfigurable
|
|
11
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
12
|
+
|
|
13
|
+
from mito_ai import constants
|
|
14
|
+
from mito_ai.enterprise.utils import is_azure_openai_configured
|
|
15
|
+
from mito_ai.gemini_client import GeminiClient
|
|
16
|
+
from mito_ai.openai_client import OpenAIClient
|
|
17
|
+
from mito_ai.anthropic_client import AnthropicClient
|
|
18
|
+
from mito_ai.logger import get_logger
|
|
19
|
+
from mito_ai.completions.models import (
|
|
20
|
+
AICapabilities,
|
|
21
|
+
CompletionError,
|
|
22
|
+
CompletionItem,
|
|
23
|
+
CompletionItemError,
|
|
24
|
+
CompletionReply,
|
|
25
|
+
CompletionStreamChunk,
|
|
26
|
+
MessageType,
|
|
27
|
+
ResponseFormatInfo,
|
|
28
|
+
)
|
|
29
|
+
from mito_ai.utils.litellm_utils import is_litellm_configured
|
|
30
|
+
from mito_ai.enterprise.utils import is_abacus_configured
|
|
31
|
+
from mito_ai.utils.telemetry_utils import (
|
|
32
|
+
MITO_SERVER_KEY,
|
|
33
|
+
USER_KEY,
|
|
34
|
+
log_ai_completion_error,
|
|
35
|
+
log_ai_completion_retry,
|
|
36
|
+
log_ai_completion_success,
|
|
37
|
+
)
|
|
38
|
+
from mito_ai.utils.provider_utils import get_model_provider
|
|
39
|
+
from mito_ai.utils.model_utils import get_available_models, get_fast_model_for_selected_model, get_smartest_model_for_selected_model
|
|
40
|
+
|
|
41
|
+
__all__ = ["ProviderManager"]
|
|
42
|
+
|
|
43
|
+
class ProviderManager(LoggingConfigurable):
|
|
44
|
+
"""Manage AI providers (Claude, Gemini, OpenAI) and route requests to the appropriate client."""
|
|
45
|
+
|
|
46
|
+
last_error = Instance(
|
|
47
|
+
CompletionError,
|
|
48
|
+
allow_none=True,
|
|
49
|
+
help="""Last error encountered when using the OpenAI provider.
|
|
50
|
+
|
|
51
|
+
This attribute is observed by the websocket provider to push the error to the client.""",
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def __init__(self, **kwargs: Dict[str, Any]) -> None:
|
|
55
|
+
config = kwargs.get('config', {})
|
|
56
|
+
kwargs['config'] = config
|
|
57
|
+
|
|
58
|
+
super().__init__(log=get_logger(), **kwargs)
|
|
59
|
+
self.last_error = None
|
|
60
|
+
self._openai_client: Optional[OpenAIClient] = OpenAIClient(**config)
|
|
61
|
+
# Initialize with the first available model to ensure it's always valid
|
|
62
|
+
# This respects LiteLLM configuration: if LiteLLM is configured, uses first LiteLLM model
|
|
63
|
+
# Otherwise, uses first standard model
|
|
64
|
+
available_models = get_available_models()
|
|
65
|
+
self._selected_model: str = available_models[0] if available_models else "gpt-4.1"
|
|
66
|
+
|
|
67
|
+
def get_selected_model(self) -> str:
|
|
68
|
+
"""Get the currently selected model."""
|
|
69
|
+
return self._selected_model
|
|
70
|
+
|
|
71
|
+
def set_selected_model(self, model: str) -> None:
|
|
72
|
+
"""Set the selected model."""
|
|
73
|
+
self._selected_model = model
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def capabilities(self) -> AICapabilities:
|
|
77
|
+
"""
|
|
78
|
+
Returns the capabilities of the AI provider.
|
|
79
|
+
"""
|
|
80
|
+
# TODO: We should validate that these keys are actually valid for the provider
|
|
81
|
+
# otherwise it will look like we are using the user_key when actually falling back
|
|
82
|
+
# to the mito server because the key is invalid.
|
|
83
|
+
if is_abacus_configured():
|
|
84
|
+
return AICapabilities(
|
|
85
|
+
configuration={"model": "<dynamic>"},
|
|
86
|
+
provider="Abacus AI",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if is_litellm_configured():
|
|
90
|
+
return AICapabilities(
|
|
91
|
+
configuration={"model": "<dynamic>"},
|
|
92
|
+
provider="LiteLLM",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if constants.OPENAI_API_KEY:
|
|
96
|
+
return AICapabilities(
|
|
97
|
+
configuration={"model": "<dynamic>"},
|
|
98
|
+
provider="OpenAI",
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if constants.ANTHROPIC_API_KEY:
|
|
102
|
+
return AICapabilities(
|
|
103
|
+
configuration={"model": "<dynamic>"},
|
|
104
|
+
provider="Claude",
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if constants.GEMINI_API_KEY:
|
|
108
|
+
return AICapabilities(
|
|
109
|
+
configuration={"model": "<dynamic>"},
|
|
110
|
+
provider="Gemini",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if self._openai_client:
|
|
114
|
+
return self._openai_client.capabilities
|
|
115
|
+
|
|
116
|
+
return AICapabilities(
|
|
117
|
+
configuration={"model": "<dynamic>"},
|
|
118
|
+
provider="Mito server",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def key_type(self) -> str:
|
|
123
|
+
# TODO: We should validate that these keys are actually valid for the provider
|
|
124
|
+
# otherwise it will look like we are using the user_key when actually falling back
|
|
125
|
+
# to the mito server because the key is invalid.
|
|
126
|
+
if is_abacus_configured():
|
|
127
|
+
return USER_KEY
|
|
128
|
+
|
|
129
|
+
if is_litellm_configured():
|
|
130
|
+
return USER_KEY
|
|
131
|
+
|
|
132
|
+
if constants.ANTHROPIC_API_KEY or constants.GEMINI_API_KEY or constants.OPENAI_API_KEY or constants.OLLAMA_MODEL:
|
|
133
|
+
return USER_KEY
|
|
134
|
+
|
|
135
|
+
return MITO_SERVER_KEY
|
|
136
|
+
|
|
137
|
+
async def request_completions(
|
|
138
|
+
self,
|
|
139
|
+
message_type: MessageType,
|
|
140
|
+
messages: List[ChatCompletionMessageParam],
|
|
141
|
+
response_format_info: Optional[ResponseFormatInfo] = None,
|
|
142
|
+
user_input: Optional[str] = None,
|
|
143
|
+
thread_id: Optional[str] = None,
|
|
144
|
+
max_retries: int = 3,
|
|
145
|
+
use_fast_model: bool = False,
|
|
146
|
+
use_smartest_model: bool = False
|
|
147
|
+
) -> str:
|
|
148
|
+
"""
|
|
149
|
+
Request completions from the AI provider.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
message_type: Type of message
|
|
153
|
+
messages: List of chat messages
|
|
154
|
+
response_format_info: Optional response format specification
|
|
155
|
+
user_input: Optional user input for logging
|
|
156
|
+
thread_id: Optional thread ID for logging
|
|
157
|
+
max_retries: Maximum number of retries
|
|
158
|
+
use_fast_model: If True, use the fastest model from the selected provider
|
|
159
|
+
use_smartest_model: If True, use the smartest model from the selected provider
|
|
160
|
+
"""
|
|
161
|
+
self.last_error = None
|
|
162
|
+
completion = None
|
|
163
|
+
last_message_content = str(messages[-1].get('content', '')) if messages else ""
|
|
164
|
+
|
|
165
|
+
# Get the model to use (selected model, fast model, or smartest model if requested)
|
|
166
|
+
selected_model = self.get_selected_model()
|
|
167
|
+
if use_smartest_model:
|
|
168
|
+
resolved_model = get_smartest_model_for_selected_model(selected_model)
|
|
169
|
+
elif use_fast_model:
|
|
170
|
+
resolved_model = get_fast_model_for_selected_model(selected_model)
|
|
171
|
+
else:
|
|
172
|
+
resolved_model = selected_model
|
|
173
|
+
|
|
174
|
+
# Validate model is in allowed list (uses same function as endpoint)
|
|
175
|
+
available_models = get_available_models()
|
|
176
|
+
if resolved_model not in available_models:
|
|
177
|
+
raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
|
|
178
|
+
|
|
179
|
+
# Get model provider type
|
|
180
|
+
model_type = get_model_provider(resolved_model)
|
|
181
|
+
|
|
182
|
+
# Retry loop
|
|
183
|
+
for attempt in range(max_retries + 1):
|
|
184
|
+
try:
|
|
185
|
+
if model_type == "abacus":
|
|
186
|
+
if not self._openai_client:
|
|
187
|
+
raise RuntimeError("OpenAI client is not initialized.")
|
|
188
|
+
completion = await self._openai_client.request_completions(
|
|
189
|
+
message_type=message_type,
|
|
190
|
+
messages=messages,
|
|
191
|
+
model=resolved_model,
|
|
192
|
+
response_format_info=response_format_info
|
|
193
|
+
)
|
|
194
|
+
elif model_type == "litellm":
|
|
195
|
+
from mito_ai.enterprise.litellm_client import LiteLLMClient
|
|
196
|
+
if not constants.LITELLM_BASE_URL:
|
|
197
|
+
raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
|
|
198
|
+
litellm_client = LiteLLMClient(api_key=constants.LITELLM_API_KEY, base_url=constants.LITELLM_BASE_URL)
|
|
199
|
+
completion = await litellm_client.request_completions(
|
|
200
|
+
messages=messages,
|
|
201
|
+
model=resolved_model,
|
|
202
|
+
response_format_info=response_format_info,
|
|
203
|
+
message_type=message_type
|
|
204
|
+
)
|
|
205
|
+
elif model_type == "claude":
|
|
206
|
+
api_key = constants.ANTHROPIC_API_KEY
|
|
207
|
+
anthropic_client = AnthropicClient(api_key=api_key)
|
|
208
|
+
completion = await anthropic_client.request_completions(messages, resolved_model, response_format_info, message_type)
|
|
209
|
+
elif model_type == "gemini":
|
|
210
|
+
api_key = constants.GEMINI_API_KEY
|
|
211
|
+
gemini_client = GeminiClient(api_key=api_key)
|
|
212
|
+
messages_for_gemini = [dict(m) for m in messages]
|
|
213
|
+
completion = await gemini_client.request_completions(messages_for_gemini, resolved_model, response_format_info, message_type)
|
|
214
|
+
elif model_type == "openai":
|
|
215
|
+
if not self._openai_client:
|
|
216
|
+
raise RuntimeError("OpenAI client is not initialized.")
|
|
217
|
+
completion = await self._openai_client.request_completions(
|
|
218
|
+
message_type=message_type,
|
|
219
|
+
messages=messages,
|
|
220
|
+
model=resolved_model,
|
|
221
|
+
response_format_info=response_format_info
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
raise ValueError(f"No AI provider configured for model: {resolved_model}")
|
|
225
|
+
|
|
226
|
+
# Success! Log and return
|
|
227
|
+
log_ai_completion_success(
|
|
228
|
+
key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
|
|
229
|
+
message_type=message_type,
|
|
230
|
+
last_message_content=last_message_content,
|
|
231
|
+
response={"completion": completion},
|
|
232
|
+
user_input=user_input or "",
|
|
233
|
+
thread_id=thread_id or "",
|
|
234
|
+
model=resolved_model
|
|
235
|
+
)
|
|
236
|
+
return completion # type: ignore
|
|
237
|
+
|
|
238
|
+
except PermissionError as e:
|
|
239
|
+
# If we hit a free tier limit, then raise an exception right away without retrying.
|
|
240
|
+
self.log.exception(f"Error during request_completions: {e}")
|
|
241
|
+
self.last_error = CompletionError.from_exception(e)
|
|
242
|
+
log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
|
|
243
|
+
raise
|
|
244
|
+
|
|
245
|
+
except BaseException as e:
|
|
246
|
+
# Check if we should retry (not on the last attempt)
|
|
247
|
+
if attempt < max_retries:
|
|
248
|
+
# Exponential backoff: wait 2^attempt seconds
|
|
249
|
+
wait_time = 2 ** attempt
|
|
250
|
+
self.log.info(f"Retrying request_completions after {wait_time}s (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
|
|
251
|
+
log_ai_completion_retry(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
|
|
252
|
+
await asyncio.sleep(wait_time)
|
|
253
|
+
continue
|
|
254
|
+
else:
|
|
255
|
+
# Final failure after all retries - set error state and raise
|
|
256
|
+
self.log.exception(f"Error during request_completions after {attempt + 1} attempts: {e}")
|
|
257
|
+
self.last_error = CompletionError.from_exception(e)
|
|
258
|
+
log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id or "", message_type, e)
|
|
259
|
+
raise
|
|
260
|
+
|
|
261
|
+
# This should never be reached due to the raise in the except block,
|
|
262
|
+
# but added to satisfy the linter
|
|
263
|
+
raise RuntimeError("Unexpected code path in request_completions")
|
|
264
|
+
|
|
265
|
+
async def stream_completions(
|
|
266
|
+
self,
|
|
267
|
+
message_type: MessageType,
|
|
268
|
+
messages: List[ChatCompletionMessageParam],
|
|
269
|
+
message_id: str,
|
|
270
|
+
thread_id: str,
|
|
271
|
+
reply_fn: Callable[[Union[CompletionReply, CompletionStreamChunk]], None],
|
|
272
|
+
user_input: Optional[str] = None,
|
|
273
|
+
response_format_info: Optional[ResponseFormatInfo] = None,
|
|
274
|
+
use_fast_model: bool = False,
|
|
275
|
+
use_smartest_model: bool = False
|
|
276
|
+
) -> str:
|
|
277
|
+
"""
|
|
278
|
+
Stream completions from the AI provider and return the accumulated response.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
message_type: Type of message
|
|
282
|
+
messages: List of chat messages
|
|
283
|
+
message_id: ID of the message being processed
|
|
284
|
+
thread_id: Thread ID for logging
|
|
285
|
+
reply_fn: Function to call with each chunk for streaming replies
|
|
286
|
+
user_input: Optional user input for logging
|
|
287
|
+
response_format_info: Optional response format specification
|
|
288
|
+
use_fast_model: If True, use the fastest model from the selected provider
|
|
289
|
+
use_smartest_model: If True, use the smartest model from the selected provider
|
|
290
|
+
|
|
291
|
+
Returns: The accumulated response string.
|
|
292
|
+
"""
|
|
293
|
+
self.last_error = None
|
|
294
|
+
accumulated_response = ""
|
|
295
|
+
last_message_content = str(messages[-1].get('content', '')) if messages else ""
|
|
296
|
+
|
|
297
|
+
# Get the model to use (selected model, fast model, or smartest model if requested)
|
|
298
|
+
selected_model = self.get_selected_model()
|
|
299
|
+
if use_smartest_model:
|
|
300
|
+
resolved_model = get_smartest_model_for_selected_model(selected_model)
|
|
301
|
+
elif use_fast_model:
|
|
302
|
+
resolved_model = get_fast_model_for_selected_model(selected_model)
|
|
303
|
+
else:
|
|
304
|
+
resolved_model = selected_model
|
|
305
|
+
|
|
306
|
+
# Validate model is in allowed list (uses same function as endpoint)
|
|
307
|
+
available_models = get_available_models()
|
|
308
|
+
if resolved_model not in available_models:
|
|
309
|
+
raise ValueError(f"Model {resolved_model} is not in the allowed model list: {available_models}")
|
|
310
|
+
|
|
311
|
+
# Get model provider type
|
|
312
|
+
model_type = get_model_provider(resolved_model)
|
|
313
|
+
reply_fn(CompletionReply(
|
|
314
|
+
items=[
|
|
315
|
+
CompletionItem(content="", isIncomplete=True, token=message_id)
|
|
316
|
+
],
|
|
317
|
+
parent_id=message_id,
|
|
318
|
+
))
|
|
319
|
+
|
|
320
|
+
try:
|
|
321
|
+
if model_type == "abacus":
|
|
322
|
+
if not self._openai_client:
|
|
323
|
+
raise RuntimeError("OpenAI client is not initialized.")
|
|
324
|
+
accumulated_response = await self._openai_client.stream_completions(
|
|
325
|
+
message_type=message_type,
|
|
326
|
+
messages=messages,
|
|
327
|
+
model=resolved_model,
|
|
328
|
+
message_id=message_id,
|
|
329
|
+
thread_id=thread_id,
|
|
330
|
+
reply_fn=reply_fn,
|
|
331
|
+
user_input=user_input,
|
|
332
|
+
response_format_info=response_format_info
|
|
333
|
+
)
|
|
334
|
+
elif model_type == "litellm":
|
|
335
|
+
from mito_ai.enterprise.litellm_client import LiteLLMClient
|
|
336
|
+
if not constants.LITELLM_BASE_URL:
|
|
337
|
+
raise ValueError("LITELLM_BASE_URL is required for LiteLLM models")
|
|
338
|
+
litellm_client = LiteLLMClient(
|
|
339
|
+
api_key=constants.LITELLM_API_KEY,
|
|
340
|
+
base_url=constants.LITELLM_BASE_URL
|
|
341
|
+
)
|
|
342
|
+
accumulated_response = await litellm_client.stream_completions(
|
|
343
|
+
messages=messages,
|
|
344
|
+
model=resolved_model,
|
|
345
|
+
message_type=message_type,
|
|
346
|
+
message_id=message_id,
|
|
347
|
+
reply_fn=reply_fn,
|
|
348
|
+
response_format_info=response_format_info
|
|
349
|
+
)
|
|
350
|
+
elif model_type == "claude":
|
|
351
|
+
api_key = constants.ANTHROPIC_API_KEY
|
|
352
|
+
anthropic_client = AnthropicClient(api_key=api_key)
|
|
353
|
+
accumulated_response = await anthropic_client.stream_completions(
|
|
354
|
+
messages=messages,
|
|
355
|
+
model=resolved_model,
|
|
356
|
+
message_type=message_type,
|
|
357
|
+
message_id=message_id,
|
|
358
|
+
reply_fn=reply_fn
|
|
359
|
+
)
|
|
360
|
+
elif model_type == "gemini":
|
|
361
|
+
api_key = constants.GEMINI_API_KEY
|
|
362
|
+
gemini_client = GeminiClient(api_key=api_key)
|
|
363
|
+
# TODO: We shouldn't need to do this because the messages should already be dictionaries...
|
|
364
|
+
# but if we do have to do some pre-processing, we should do it in the gemini_client instead.
|
|
365
|
+
messages_for_gemini = [dict(m) for m in messages]
|
|
366
|
+
accumulated_response = await gemini_client.stream_completions(
|
|
367
|
+
messages=messages_for_gemini,
|
|
368
|
+
model=resolved_model,
|
|
369
|
+
message_id=message_id,
|
|
370
|
+
reply_fn=reply_fn,
|
|
371
|
+
message_type=message_type
|
|
372
|
+
)
|
|
373
|
+
elif model_type == "openai":
|
|
374
|
+
if not self._openai_client:
|
|
375
|
+
raise RuntimeError("OpenAI client is not initialized.")
|
|
376
|
+
accumulated_response = await self._openai_client.stream_completions(
|
|
377
|
+
message_type=message_type,
|
|
378
|
+
messages=messages,
|
|
379
|
+
model=resolved_model,
|
|
380
|
+
message_id=message_id,
|
|
381
|
+
thread_id=thread_id,
|
|
382
|
+
reply_fn=reply_fn,
|
|
383
|
+
user_input=user_input,
|
|
384
|
+
response_format_info=response_format_info
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
raise ValueError(f"No AI provider configured for model: {resolved_model}")
|
|
388
|
+
|
|
389
|
+
# Log the successful completion
|
|
390
|
+
log_ai_completion_success(
|
|
391
|
+
key_type=USER_KEY if self.key_type == USER_KEY else MITO_SERVER_KEY,
|
|
392
|
+
message_type=message_type,
|
|
393
|
+
last_message_content=last_message_content,
|
|
394
|
+
response={"completion": accumulated_response},
|
|
395
|
+
user_input=user_input or "",
|
|
396
|
+
thread_id=thread_id,
|
|
397
|
+
model=resolved_model
|
|
398
|
+
)
|
|
399
|
+
return accumulated_response
|
|
400
|
+
|
|
401
|
+
except BaseException as e:
|
|
402
|
+
self.log.exception(f"Error during stream_completions: {e}")
|
|
403
|
+
self.last_error = CompletionError.from_exception(e)
|
|
404
|
+
log_ai_completion_error(USER_KEY if self.key_type != MITO_SERVER_KEY else MITO_SERVER_KEY, thread_id, message_type, e)
|
|
405
|
+
|
|
406
|
+
# Send error message to client before raising
|
|
407
|
+
reply_fn(CompletionStreamChunk(
|
|
408
|
+
parent_id=message_id,
|
|
409
|
+
chunk=CompletionItem(
|
|
410
|
+
content="",
|
|
411
|
+
isIncomplete=True,
|
|
412
|
+
error=CompletionItemError(
|
|
413
|
+
message=f"Failed to process completion: {e!r}"
|
|
414
|
+
),
|
|
415
|
+
token=message_id,
|
|
416
|
+
),
|
|
417
|
+
done=True,
|
|
418
|
+
error=CompletionError.from_exception(e),
|
|
419
|
+
))
|
|
420
|
+
raise
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import tornado
|
|
6
|
+
from jupyter_server.base.handlers import APIHandler
|
|
7
|
+
from mito_ai.utils.model_utils import get_available_models
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AvailableModelsHandler(APIHandler):
|
|
11
|
+
"""REST handler for returning available models to the frontend."""
|
|
12
|
+
|
|
13
|
+
@tornado.web.authenticated
|
|
14
|
+
async def get(self) -> None:
|
|
15
|
+
"""GET endpoint that returns the list of available models."""
|
|
16
|
+
try:
|
|
17
|
+
available_models = get_available_models()
|
|
18
|
+
|
|
19
|
+
self.write({
|
|
20
|
+
"models": available_models
|
|
21
|
+
})
|
|
22
|
+
self.finish()
|
|
23
|
+
except Exception as e:
|
|
24
|
+
self.set_status(500)
|
|
25
|
+
self.write({"error": str(e)})
|
|
26
|
+
self.finish()
|
mito_ai/settings/urls.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
from typing import Any, List, Tuple
|
|
5
5
|
from jupyter_server.utils import url_path_join
|
|
6
6
|
from mito_ai.settings.handlers import SettingsHandler
|
|
7
|
+
from mito_ai.settings.enterprise_handler import AvailableModelsHandler
|
|
7
8
|
|
|
8
9
|
def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
|
|
9
10
|
"""Get all settings related URL patterns.
|
|
@@ -17,4 +18,5 @@ def get_settings_urls(base_url: str) -> List[Tuple[str, Any, dict]]:
|
|
|
17
18
|
BASE_URL = base_url + "/mito-ai"
|
|
18
19
|
return [
|
|
19
20
|
(url_path_join(BASE_URL, "settings/(.*)"), SettingsHandler, {}),
|
|
21
|
+
(url_path_join(BASE_URL, "available-models"), AvailableModelsHandler, {}),
|
|
20
22
|
]
|
|
@@ -1,37 +1,9 @@
|
|
|
1
1
|
# Copyright (c) Saga Inc.
|
|
2
2
|
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
3
|
|
|
4
|
-
from typing import List
|
|
5
|
-
import re
|
|
6
|
-
from anthropic.types import MessageParam
|
|
7
|
-
from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
|
|
8
|
-
from mito_ai.utils.anthropic_utils import stream_anthropic_completion_from_mito_server
|
|
4
|
+
from typing import List
|
|
9
5
|
from mito_ai.streamlit_conversion.prompts.prompt_constants import MITO_TODO_PLACEHOLDER
|
|
10
|
-
from mito_ai.completions.models import MessageType
|
|
11
|
-
|
|
12
|
-
STREAMLIT_AI_MODEL = "claude-sonnet-4-5-20250929"
|
|
13
6
|
|
|
14
7
|
def extract_todo_placeholders(agent_response: str) -> List[str]:
|
|
15
8
|
"""Extract TODO placeholders from the agent's response"""
|
|
16
|
-
return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
|
|
17
|
-
|
|
18
|
-
async def get_response_from_agent(message_to_agent: List[MessageParam]) -> str:
|
|
19
|
-
"""Gets the streaming response from the agent using the mito server"""
|
|
20
|
-
model = STREAMLIT_AI_MODEL
|
|
21
|
-
max_tokens = 64000 # TODO: If we move to haiku, we must reset this to 8192
|
|
22
|
-
temperature = 0.2
|
|
23
|
-
|
|
24
|
-
accumulated_response = ""
|
|
25
|
-
async for stream_chunk in stream_anthropic_completion_from_mito_server(
|
|
26
|
-
model = model,
|
|
27
|
-
max_tokens = max_tokens,
|
|
28
|
-
temperature = temperature,
|
|
29
|
-
system = streamlit_system_prompt,
|
|
30
|
-
messages = message_to_agent,
|
|
31
|
-
stream=True,
|
|
32
|
-
message_type=MessageType.STREAMLIT_CONVERSION,
|
|
33
|
-
reply_fn=None,
|
|
34
|
-
message_id=""
|
|
35
|
-
):
|
|
36
|
-
accumulated_response += stream_chunk
|
|
37
|
-
return accumulated_response
|
|
9
|
+
return [line.strip() for line in agent_response.split('\n') if MITO_TODO_PLACEHOLDER in line]
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
# Copyright (c) Saga Inc.
|
|
2
2
|
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
3
|
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
from mito_ai.streamlit_conversion.agent_utils import extract_todo_placeholders
|
|
4
|
+
from typing import List
|
|
5
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
6
|
+
from mito_ai.streamlit_conversion.agent_utils import extract_todo_placeholders
|
|
7
|
+
from mito_ai.provider_manager import ProviderManager
|
|
7
8
|
from mito_ai.streamlit_conversion.prompts.streamlit_app_creation_prompt import get_streamlit_app_creation_prompt
|
|
8
9
|
from mito_ai.streamlit_conversion.prompts.streamlit_error_correction_prompt import get_streamlit_error_correction_prompt
|
|
9
10
|
from mito_ai.streamlit_conversion.prompts.streamlit_finish_todo_prompt import get_finish_todo_prompt
|
|
@@ -15,22 +16,23 @@ from mito_ai.completions.models import MessageType
|
|
|
15
16
|
from mito_ai.utils.error_classes import StreamlitConversionError
|
|
16
17
|
from mito_ai.utils.telemetry_utils import log_streamlit_app_validation_retry, log_streamlit_app_conversion_success
|
|
17
18
|
from mito_ai.path_utils import AbsoluteNotebookPath, AppFileName, get_absolute_notebook_dir_path, get_absolute_app_path, get_app_file_name
|
|
19
|
+
from mito_ai.streamlit_conversion.prompts.streamlit_system_prompt import streamlit_system_prompt
|
|
18
20
|
|
|
19
|
-
async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt: str) -> str:
|
|
21
|
+
async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt: str, provider: ProviderManager) -> str:
|
|
20
22
|
"""Send a query to the agent, get its response and parse the code"""
|
|
21
23
|
|
|
22
24
|
prompt_text = get_streamlit_app_creation_prompt(notebook, streamlit_app_prompt)
|
|
23
25
|
|
|
24
|
-
messages: List[
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
"content": [{
|
|
28
|
-
"type": "text",
|
|
29
|
-
"text": prompt_text
|
|
30
|
-
}]
|
|
31
|
-
})
|
|
26
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
27
|
+
{"role": "system", "content": streamlit_system_prompt},
|
|
28
|
+
{"role": "user", "content": prompt_text}
|
|
32
29
|
]
|
|
33
|
-
agent_response = await
|
|
30
|
+
agent_response = await provider.request_completions(
|
|
31
|
+
message_type=MessageType.STREAMLIT_CONVERSION,
|
|
32
|
+
messages=messages,
|
|
33
|
+
use_smartest_model=True,
|
|
34
|
+
thread_id=None
|
|
35
|
+
)
|
|
34
36
|
converted_code = extract_code_blocks(agent_response)
|
|
35
37
|
|
|
36
38
|
# Extract the TODOs from the agent's response
|
|
@@ -39,16 +41,16 @@ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt
|
|
|
39
41
|
for todo_placeholder in todo_placeholders:
|
|
40
42
|
print(f"Processing AI TODO: {todo_placeholder}")
|
|
41
43
|
todo_prompt = get_finish_todo_prompt(notebook, converted_code, todo_placeholder)
|
|
42
|
-
todo_messages: List[
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
"content": [{
|
|
46
|
-
"type": "text",
|
|
47
|
-
"text": todo_prompt
|
|
48
|
-
}]
|
|
49
|
-
})
|
|
44
|
+
todo_messages: List[ChatCompletionMessageParam] = [
|
|
45
|
+
{"role": "system", "content": streamlit_system_prompt},
|
|
46
|
+
{"role": "user", "content": todo_prompt}
|
|
50
47
|
]
|
|
51
|
-
todo_response = await
|
|
48
|
+
todo_response = await provider.request_completions(
|
|
49
|
+
message_type=MessageType.STREAMLIT_CONVERSION,
|
|
50
|
+
messages=todo_messages,
|
|
51
|
+
use_smartest_model=True,
|
|
52
|
+
thread_id=None
|
|
53
|
+
)
|
|
52
54
|
|
|
53
55
|
# Apply the search/replace to the streamlit app
|
|
54
56
|
search_replace_pairs = extract_search_replace_blocks(todo_response)
|
|
@@ -57,21 +59,21 @@ async def generate_new_streamlit_code(notebook: List[dict], streamlit_app_prompt
|
|
|
57
59
|
return converted_code
|
|
58
60
|
|
|
59
61
|
|
|
60
|
-
async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_code: str, edit_prompt: str) -> str:
|
|
62
|
+
async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_code: str, edit_prompt: str, provider: ProviderManager) -> str:
|
|
61
63
|
"""Send a query to the agent, get its response and parse the code"""
|
|
62
64
|
prompt_text = get_update_existing_app_prompt(notebook, streamlit_app_code, edit_prompt)
|
|
63
65
|
|
|
64
|
-
messages: List[
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
"content": [{
|
|
68
|
-
"type": "text",
|
|
69
|
-
"text": prompt_text
|
|
70
|
-
}]
|
|
71
|
-
})
|
|
66
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
67
|
+
{"role": "system", "content": streamlit_system_prompt},
|
|
68
|
+
{"role": "user", "content": prompt_text}
|
|
72
69
|
]
|
|
73
70
|
|
|
74
|
-
agent_response = await
|
|
71
|
+
agent_response = await provider.request_completions(
|
|
72
|
+
message_type=MessageType.STREAMLIT_CONVERSION,
|
|
73
|
+
messages=messages,
|
|
74
|
+
use_smartest_model=True,
|
|
75
|
+
thread_id=None
|
|
76
|
+
)
|
|
75
77
|
print(f"[Mito AI Search/Replace Tool]:\n {agent_response}")
|
|
76
78
|
|
|
77
79
|
# Apply the search/replace to the streamlit app
|
|
@@ -81,18 +83,18 @@ async def update_existing_streamlit_code(notebook: List[dict], streamlit_app_cod
|
|
|
81
83
|
return converted_code
|
|
82
84
|
|
|
83
85
|
|
|
84
|
-
async def correct_error_in_generation(error: str, streamlit_app_code: str) -> str:
|
|
86
|
+
async def correct_error_in_generation(error: str, streamlit_app_code: str, provider: ProviderManager) -> str:
|
|
85
87
|
"""If errors are present, send it back to the agent to get corrections in code"""
|
|
86
|
-
messages: List[
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
"content": [{
|
|
90
|
-
"type": "text",
|
|
91
|
-
"text": get_streamlit_error_correction_prompt(error, streamlit_app_code)
|
|
92
|
-
}]
|
|
93
|
-
})
|
|
88
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
89
|
+
{"role": "system", "content": streamlit_system_prompt},
|
|
90
|
+
{"role": "user", "content": get_streamlit_error_correction_prompt(error, streamlit_app_code)}
|
|
94
91
|
]
|
|
95
|
-
agent_response = await
|
|
92
|
+
agent_response = await provider.request_completions(
|
|
93
|
+
message_type=MessageType.STREAMLIT_CONVERSION,
|
|
94
|
+
messages=messages,
|
|
95
|
+
use_smartest_model=True,
|
|
96
|
+
thread_id=None
|
|
97
|
+
)
|
|
96
98
|
|
|
97
99
|
# Apply the search/replace to the streamlit app
|
|
98
100
|
search_replace_pairs = extract_search_replace_blocks(agent_response)
|
|
@@ -100,7 +102,7 @@ async def correct_error_in_generation(error: str, streamlit_app_code: str) -> st
|
|
|
100
102
|
|
|
101
103
|
return streamlit_app_code
|
|
102
104
|
|
|
103
|
-
async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNotebookPath, app_file_name: AppFileName, streamlit_app_prompt: str
|
|
105
|
+
async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNotebookPath, app_file_name: AppFileName, streamlit_app_prompt: str, provider: ProviderManager) -> None:
|
|
104
106
|
"""Handler function for streamlit code generation and validation"""
|
|
105
107
|
|
|
106
108
|
# Convert to absolute path for consistent handling
|
|
@@ -110,7 +112,7 @@ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNoteboo
|
|
|
110
112
|
|
|
111
113
|
if create_new_app:
|
|
112
114
|
# Otherwise generate a new streamlit app
|
|
113
|
-
streamlit_code = await generate_new_streamlit_code(notebook_code, streamlit_app_prompt)
|
|
115
|
+
streamlit_code = await generate_new_streamlit_code(notebook_code, streamlit_app_prompt, provider)
|
|
114
116
|
else:
|
|
115
117
|
# If the user is editing an existing streamlit app, use the update function
|
|
116
118
|
existing_streamlit_code = get_app_code_from_file(app_path)
|
|
@@ -118,14 +120,14 @@ async def streamlit_handler(create_new_app: bool, notebook_path: AbsoluteNoteboo
|
|
|
118
120
|
if existing_streamlit_code is None:
|
|
119
121
|
raise StreamlitConversionError("Error updating existing streamlit app because app.py file was not found.", 404)
|
|
120
122
|
|
|
121
|
-
streamlit_code = await update_existing_streamlit_code(notebook_code, existing_streamlit_code, streamlit_app_prompt)
|
|
123
|
+
streamlit_code = await update_existing_streamlit_code(notebook_code, existing_streamlit_code, streamlit_app_prompt, provider)
|
|
122
124
|
|
|
123
125
|
# Then, after creating/updating the app, validate that the new code runs
|
|
124
126
|
errors = validate_app(streamlit_code, notebook_path)
|
|
125
127
|
tries = 0
|
|
126
128
|
while len(errors) > 0 and tries < 5:
|
|
127
129
|
for error in errors:
|
|
128
|
-
streamlit_code = await correct_error_in_generation(error, streamlit_code)
|
|
130
|
+
streamlit_code = await correct_error_in_generation(error, streamlit_code, provider)
|
|
129
131
|
|
|
130
132
|
errors = validate_app(streamlit_code, notebook_path)
|
|
131
133
|
|