mito-ai 0.1.28__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mito-ai might be problematic. Click here for more details.
- mito_ai/_version.py +1 -1
- mito_ai/openai_client.py +22 -6
- mito_ai/tests/providers/test_azure.py +635 -0
- mito_ai/utils/anthropic_utils.py +3 -0
- mito_ai/utils/open_ai_utils.py +0 -4
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +1 -1
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.28.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.114d2b34bc18a45df338.js → mito_ai-0.1.29.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.8fc39671fbc9ba62e74b.js +219 -72
- mito_ai-0.1.29.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.8fc39671fbc9ba62e74b.js.map +1 -0
- mito_ai-0.1.28.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.92c6411fdc4075df549b.js → mito_ai-0.1.29.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.6bf957d24d60237bb287.js +3 -3
- mito_ai-0.1.28.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.92c6411fdc4075df549b.js.map → mito_ai-0.1.29.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.6bf957d24d60237bb287.js.map +1 -1
- {mito_ai-0.1.28.dist-info → mito_ai-0.1.29.dist-info}/METADATA +1 -1
- {mito_ai-0.1.28.dist-info → mito_ai-0.1.29.dist-info}/RECORD +28 -27
- mito_ai-0.1.28.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.114d2b34bc18a45df338.js.map +0 -1
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js.map +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js.map +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js.map +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.28.data → mito_ai-0.1.29.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.28.dist-info → mito_ai-0.1.29.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.28.dist-info → mito_ai-0.1.29.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.28.dist-info → mito_ai-0.1.29.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,635 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from typing import Any, List
|
|
6
|
+
from unittest.mock import patch, MagicMock, AsyncMock
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from traitlets.config import Config
|
|
10
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
11
|
+
|
|
12
|
+
from mito_ai.completions.providers import OpenAIProvider
|
|
13
|
+
from mito_ai.completions.models import (
|
|
14
|
+
MessageType,
|
|
15
|
+
AICapabilities,
|
|
16
|
+
CompletionReply,
|
|
17
|
+
CompletionItem,
|
|
18
|
+
ResponseFormatInfo,
|
|
19
|
+
AgentResponse
|
|
20
|
+
)
|
|
21
|
+
from mito_ai.openai_client import OpenAIClient
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
FAKE_API_KEY = "sk-1234567890"
|
|
25
|
+
FAKE_AZURE_ENDPOINT = "https://test-azure-openai.openai.azure.com"
|
|
26
|
+
FAKE_AZURE_MODEL = "gpt-4o-azure"
|
|
27
|
+
FAKE_AZURE_API_VERSION = "2024-12-01-preview"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.fixture
|
|
31
|
+
def provider_config() -> Config:
|
|
32
|
+
"""Create a proper Config object for the OpenAIProvider."""
|
|
33
|
+
config = Config()
|
|
34
|
+
config.OpenAIProvider = Config()
|
|
35
|
+
config.OpenAIClient = Config()
|
|
36
|
+
return config
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.fixture(autouse=True)
|
|
40
|
+
def reset_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
41
|
+
"""Reset all environment variables before each test."""
|
|
42
|
+
for var in [
|
|
43
|
+
"OPENAI_API_KEY", "CLAUDE_API_KEY", "GEMINI_API_KEY", "OLLAMA_MODEL",
|
|
44
|
+
"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_MODEL",
|
|
45
|
+
"AZURE_OPENAI_API_VERSION"
|
|
46
|
+
]:
|
|
47
|
+
monkeypatch.delenv(var, raising=False)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture
|
|
51
|
+
def mock_azure_openai_environment(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
52
|
+
"""Set up Azure OpenAI environment variables and mocks."""
|
|
53
|
+
# Set environment variables
|
|
54
|
+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", FAKE_API_KEY)
|
|
55
|
+
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", FAKE_AZURE_ENDPOINT)
|
|
56
|
+
monkeypatch.setenv("AZURE_OPENAI_MODEL", FAKE_AZURE_MODEL)
|
|
57
|
+
monkeypatch.setenv("AZURE_OPENAI_API_VERSION", FAKE_AZURE_API_VERSION)
|
|
58
|
+
|
|
59
|
+
# Set constants
|
|
60
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_API_KEY", FAKE_API_KEY)
|
|
61
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_ENDPOINT", FAKE_AZURE_ENDPOINT)
|
|
62
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_MODEL", FAKE_AZURE_MODEL)
|
|
63
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_API_VERSION", FAKE_AZURE_API_VERSION)
|
|
64
|
+
|
|
65
|
+
# Mock enterprise/private functions and directly mock is_azure_openai_configured
|
|
66
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_enterprise", lambda: True)
|
|
67
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_mitosheet_private", lambda: False)
|
|
68
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: True)
|
|
69
|
+
# Also patch where it's imported in the OpenAI client
|
|
70
|
+
monkeypatch.setattr("mito_ai.openai_client.is_azure_openai_configured", lambda: True)
|
|
71
|
+
|
|
72
|
+
# Ensure no other OpenAI key is set
|
|
73
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@pytest.fixture
|
|
77
|
+
def mock_azure_openai_client() -> Any:
|
|
78
|
+
"""Mock Azure OpenAI client for testing."""
|
|
79
|
+
mock_client = MagicMock()
|
|
80
|
+
mock_client.chat.completions.create = AsyncMock()
|
|
81
|
+
mock_client.is_closed.return_value = False
|
|
82
|
+
return mock_client
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Test message types that should use Azure OpenAI
|
|
86
|
+
COMPLETION_MESSAGE_TYPES = [
|
|
87
|
+
MessageType.CHAT,
|
|
88
|
+
MessageType.SMART_DEBUG,
|
|
89
|
+
MessageType.CODE_EXPLAIN,
|
|
90
|
+
MessageType.AGENT_EXECUTION,
|
|
91
|
+
MessageType.AGENT_AUTO_ERROR_FIXUP,
|
|
92
|
+
MessageType.INLINE_COMPLETION,
|
|
93
|
+
MessageType.CHAT_NAME_GENERATION,
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
# Common test data
|
|
97
|
+
TEST_MESSAGES: List[ChatCompletionMessageParam] = [
|
|
98
|
+
{"role": "user", "content": "Test message"}
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
# Helper functions for common test patterns
|
|
102
|
+
def create_mock_azure_client_with_response(response_content: str = "Test Azure completion") -> tuple[MagicMock, MagicMock]:
|
|
103
|
+
"""Create a mock Azure OpenAI client with a standard response."""
|
|
104
|
+
mock_response = MagicMock()
|
|
105
|
+
mock_response.choices = [MagicMock()]
|
|
106
|
+
mock_response.choices[0].message.content = response_content
|
|
107
|
+
|
|
108
|
+
mock_azure_client = MagicMock()
|
|
109
|
+
mock_azure_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
110
|
+
mock_azure_client.is_closed.return_value = False
|
|
111
|
+
|
|
112
|
+
return mock_azure_client, mock_response
|
|
113
|
+
|
|
114
|
+
def create_mock_streaming_response(chunks: List[str]) -> Any:
|
|
115
|
+
"""Create a mock streaming response with the given chunks."""
|
|
116
|
+
async def mock_stream():
|
|
117
|
+
for i, content in enumerate(chunks):
|
|
118
|
+
mock_chunk = MagicMock()
|
|
119
|
+
mock_chunk.choices = [MagicMock()]
|
|
120
|
+
mock_chunk.choices[0].delta.content = content
|
|
121
|
+
mock_chunk.choices[0].finish_reason = "stop" if i == len(chunks) - 1 else None
|
|
122
|
+
yield mock_chunk
|
|
123
|
+
return mock_stream
|
|
124
|
+
|
|
125
|
+
def assert_azure_client_called_correctly(mock_azure_client_class: MagicMock, mock_azure_client: MagicMock, expected_model: str = FAKE_AZURE_MODEL, should_stream: bool = False) -> None:
|
|
126
|
+
"""Assert that Azure client was called correctly."""
|
|
127
|
+
# Verify Azure client was created
|
|
128
|
+
mock_azure_client_class.assert_called_once()
|
|
129
|
+
|
|
130
|
+
# Verify request was made through Azure client
|
|
131
|
+
mock_azure_client.chat.completions.create.assert_called_once()
|
|
132
|
+
|
|
133
|
+
# Verify the model used was the Azure model
|
|
134
|
+
call_args = mock_azure_client.chat.completions.create.call_args
|
|
135
|
+
assert call_args[1]["model"] == expected_model
|
|
136
|
+
|
|
137
|
+
if should_stream:
|
|
138
|
+
assert call_args[1]["stream"] == True
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class TestAzureOpenAIClientCreation:
|
|
142
|
+
"""Test that Azure OpenAI client is properly created when configured."""
|
|
143
|
+
|
|
144
|
+
def test_azure_openai_client_capabilities(self, mock_azure_openai_environment: None, provider_config: Config) -> None:
|
|
145
|
+
"""Test that Azure OpenAI capabilities are properly returned."""
|
|
146
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
147
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
148
|
+
capabilities = openai_client.capabilities
|
|
149
|
+
|
|
150
|
+
assert capabilities.provider == "Azure OpenAI"
|
|
151
|
+
assert capabilities.configuration["model"] == FAKE_AZURE_MODEL
|
|
152
|
+
|
|
153
|
+
# Access the client to trigger creation
|
|
154
|
+
# This let's us test that building the client works
|
|
155
|
+
_ = openai_client._active_async_client
|
|
156
|
+
mock_azure_client.assert_called_once()
|
|
157
|
+
|
|
158
|
+
def test_azure_openai_client_creation_parameters(self, mock_azure_openai_environment: None, provider_config: Config) -> None:
|
|
159
|
+
"""Test that Azure OpenAI client is created with correct parameters."""
|
|
160
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
161
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
162
|
+
# Access the client to trigger creation
|
|
163
|
+
_ = openai_client._active_async_client
|
|
164
|
+
|
|
165
|
+
mock_azure_client.assert_called_once_with(
|
|
166
|
+
api_key=FAKE_API_KEY,
|
|
167
|
+
api_version=FAKE_AZURE_API_VERSION,
|
|
168
|
+
azure_endpoint=FAKE_AZURE_ENDPOINT,
|
|
169
|
+
max_retries=1,
|
|
170
|
+
timeout=30,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
def test_azure_openai_model_resolution(self, mock_azure_openai_environment: None, provider_config: Config) -> None:
|
|
174
|
+
"""Test that Azure OpenAI model is used regardless of requested model."""
|
|
175
|
+
with patch("openai.AsyncAzureOpenAI"):
|
|
176
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
177
|
+
|
|
178
|
+
# Test with gpt-4.1 model
|
|
179
|
+
resolved_model = openai_client._resolve_model("gpt-4.1")
|
|
180
|
+
assert resolved_model == FAKE_AZURE_MODEL
|
|
181
|
+
|
|
182
|
+
# Test with any other model
|
|
183
|
+
resolved_model = openai_client._resolve_model("gpt-3.5-turbo")
|
|
184
|
+
assert resolved_model == FAKE_AZURE_MODEL
|
|
185
|
+
|
|
186
|
+
# Test with no model specified
|
|
187
|
+
resolved_model = openai_client._resolve_model()
|
|
188
|
+
assert resolved_model == FAKE_AZURE_MODEL
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class TestAzureOpenAICompletions:
|
|
192
|
+
"""Test Azure OpenAI request_completions method."""
|
|
193
|
+
|
|
194
|
+
@pytest.mark.asyncio
|
|
195
|
+
@pytest.mark.parametrize("message_type", COMPLETION_MESSAGE_TYPES)
|
|
196
|
+
async def test_request_completions_uses_azure_client(
|
|
197
|
+
self,
|
|
198
|
+
mock_azure_openai_environment: None,
|
|
199
|
+
provider_config: Config,
|
|
200
|
+
message_type: MessageType
|
|
201
|
+
) -> None:
|
|
202
|
+
"""Test that request_completions uses Azure OpenAI client for all message types."""
|
|
203
|
+
|
|
204
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
205
|
+
mock_azure_client, mock_response = create_mock_azure_client_with_response()
|
|
206
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
207
|
+
|
|
208
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
209
|
+
|
|
210
|
+
completion = await openai_client.request_completions(
|
|
211
|
+
message_type=message_type,
|
|
212
|
+
messages=TEST_MESSAGES,
|
|
213
|
+
model="gpt-4.1"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Verify the completion was returned
|
|
217
|
+
assert completion == "Test Azure completion"
|
|
218
|
+
|
|
219
|
+
# Verify Azure client was called correctly
|
|
220
|
+
assert_azure_client_called_correctly(mock_azure_client_class, mock_azure_client)
|
|
221
|
+
|
|
222
|
+
@pytest.mark.asyncio
|
|
223
|
+
async def test_request_completions_with_response_format(
|
|
224
|
+
self,
|
|
225
|
+
mock_azure_openai_environment: None,
|
|
226
|
+
provider_config: Config
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Test that request_completions works with response format (agent mode)."""
|
|
229
|
+
|
|
230
|
+
# Mock the response
|
|
231
|
+
mock_response = MagicMock()
|
|
232
|
+
mock_response.choices = [MagicMock()]
|
|
233
|
+
mock_response.choices[0].message.content = '{"type": "finished_task", "message": "Task completed"}'
|
|
234
|
+
|
|
235
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
236
|
+
mock_azure_client = MagicMock()
|
|
237
|
+
mock_azure_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
238
|
+
mock_azure_client.is_closed.return_value = False
|
|
239
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
240
|
+
|
|
241
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
242
|
+
|
|
243
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
244
|
+
{"role": "user", "content": "Test message"}
|
|
245
|
+
]
|
|
246
|
+
|
|
247
|
+
response_format_info = ResponseFormatInfo(
|
|
248
|
+
name="agent_response",
|
|
249
|
+
format=AgentResponse
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
completion = await openai_client.request_completions(
|
|
253
|
+
message_type=MessageType.AGENT_EXECUTION,
|
|
254
|
+
messages=messages,
|
|
255
|
+
model="gpt-4.1",
|
|
256
|
+
response_format_info=response_format_info
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Verify the completion was returned
|
|
260
|
+
assert completion == '{"type": "finished_task", "message": "Task completed"}'
|
|
261
|
+
|
|
262
|
+
# Verify Azure client was used
|
|
263
|
+
mock_azure_client.chat.completions.create.assert_called_once()
|
|
264
|
+
|
|
265
|
+
# Verify the model used was the Azure model
|
|
266
|
+
call_args = mock_azure_client.chat.completions.create.call_args
|
|
267
|
+
assert call_args[1]["model"] == FAKE_AZURE_MODEL
|
|
268
|
+
|
|
269
|
+
@pytest.mark.asyncio
|
|
270
|
+
@pytest.mark.parametrize("requested_model", ["gpt-4.1", "gpt-3.5-turbo", "gpt-4-turbo", "gpt-4o"])
|
|
271
|
+
async def test_request_completions_uses_azure_model_not_requested_model(
|
|
272
|
+
self,
|
|
273
|
+
mock_azure_openai_environment: None,
|
|
274
|
+
provider_config: Config,
|
|
275
|
+
requested_model: str
|
|
276
|
+
) -> None:
|
|
277
|
+
"""Test that Azure model is used regardless of requested model when Azure is configured."""
|
|
278
|
+
|
|
279
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
280
|
+
mock_azure_client, mock_response = create_mock_azure_client_with_response()
|
|
281
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
282
|
+
|
|
283
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
284
|
+
|
|
285
|
+
completion = await openai_client.request_completions(
|
|
286
|
+
message_type=MessageType.CHAT,
|
|
287
|
+
messages=TEST_MESSAGES,
|
|
288
|
+
model=requested_model
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
assert completion == "Test Azure completion"
|
|
292
|
+
|
|
293
|
+
# Verify Azure client was called correctly and used Azure model, not requested model
|
|
294
|
+
assert_azure_client_called_correctly(mock_azure_client_class, mock_azure_client)
|
|
295
|
+
call_args = mock_azure_client.chat.completions.create.call_args
|
|
296
|
+
assert call_args[1]["model"] != requested_model # Explicitly check it's not the requested model
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
class TestAzureOpenAIStreamCompletions:
|
|
300
|
+
"""Test Azure OpenAI stream_completions method."""
|
|
301
|
+
|
|
302
|
+
@pytest.mark.asyncio
|
|
303
|
+
@pytest.mark.parametrize("message_type", COMPLETION_MESSAGE_TYPES)
|
|
304
|
+
async def test_stream_completions_uses_azure_client(
|
|
305
|
+
self,
|
|
306
|
+
mock_azure_openai_environment: None,
|
|
307
|
+
provider_config: Config,
|
|
308
|
+
message_type: MessageType
|
|
309
|
+
) -> None:
|
|
310
|
+
"""Test that stream_completions uses Azure OpenAI client for all message types."""
|
|
311
|
+
|
|
312
|
+
stream_chunks = ["Hello", " World"]
|
|
313
|
+
expected_completion = "".join(stream_chunks)
|
|
314
|
+
|
|
315
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
316
|
+
mock_azure_client = MagicMock()
|
|
317
|
+
mock_azure_client.chat.completions.create = AsyncMock(return_value=create_mock_streaming_response(stream_chunks)())
|
|
318
|
+
mock_azure_client.is_closed.return_value = False
|
|
319
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
320
|
+
|
|
321
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
322
|
+
|
|
323
|
+
reply_chunks = []
|
|
324
|
+
def mock_reply(chunk):
|
|
325
|
+
reply_chunks.append(chunk)
|
|
326
|
+
|
|
327
|
+
completion = await openai_client.stream_completions(
|
|
328
|
+
message_type=message_type,
|
|
329
|
+
messages=TEST_MESSAGES,
|
|
330
|
+
model="gpt-4.1",
|
|
331
|
+
message_id="test-id",
|
|
332
|
+
thread_id="test-thread",
|
|
333
|
+
reply_fn=mock_reply
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Verify the full completion was returned
|
|
337
|
+
assert completion == expected_completion
|
|
338
|
+
|
|
339
|
+
# Verify Azure client was called correctly for streaming
|
|
340
|
+
assert_azure_client_called_correctly(mock_azure_client_class, mock_azure_client, should_stream=True)
|
|
341
|
+
|
|
342
|
+
# Verify reply function was called with chunks
|
|
343
|
+
assert len(reply_chunks) >= 2 # Initial reply + chunks
|
|
344
|
+
assert isinstance(reply_chunks[0], CompletionReply)
|
|
345
|
+
|
|
346
|
+
@pytest.mark.asyncio
|
|
347
|
+
async def test_stream_completions_with_response_format(
|
|
348
|
+
self,
|
|
349
|
+
mock_azure_openai_environment: None,
|
|
350
|
+
provider_config: Config
|
|
351
|
+
) -> None:
|
|
352
|
+
"""Test that stream_completions works with response format (agent mode)."""
|
|
353
|
+
|
|
354
|
+
# Mock the streaming response
|
|
355
|
+
mock_chunk1 = MagicMock()
|
|
356
|
+
mock_chunk1.choices = [MagicMock()]
|
|
357
|
+
mock_chunk1.choices[0].delta.content = '{"type": "finished_task",'
|
|
358
|
+
mock_chunk1.choices[0].finish_reason = None
|
|
359
|
+
|
|
360
|
+
mock_chunk2 = MagicMock()
|
|
361
|
+
mock_chunk2.choices = [MagicMock()]
|
|
362
|
+
mock_chunk2.choices[0].delta.content = ' "message": "Task completed"}'
|
|
363
|
+
mock_chunk2.choices[0].finish_reason = "stop"
|
|
364
|
+
|
|
365
|
+
async def mock_stream():
|
|
366
|
+
yield mock_chunk1
|
|
367
|
+
yield mock_chunk2
|
|
368
|
+
|
|
369
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
370
|
+
mock_azure_client = MagicMock()
|
|
371
|
+
mock_azure_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
372
|
+
mock_azure_client.is_closed.return_value = False
|
|
373
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
374
|
+
|
|
375
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
376
|
+
|
|
377
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
378
|
+
{"role": "user", "content": "Test message"}
|
|
379
|
+
]
|
|
380
|
+
|
|
381
|
+
response_format_info = ResponseFormatInfo(
|
|
382
|
+
name="agent_response",
|
|
383
|
+
format=AgentResponse
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
reply_chunks = []
|
|
387
|
+
def mock_reply(chunk):
|
|
388
|
+
reply_chunks.append(chunk)
|
|
389
|
+
|
|
390
|
+
completion = await openai_client.stream_completions(
|
|
391
|
+
message_type=MessageType.AGENT_EXECUTION,
|
|
392
|
+
messages=messages,
|
|
393
|
+
model="gpt-4.1",
|
|
394
|
+
message_id="test-id",
|
|
395
|
+
thread_id="test-thread",
|
|
396
|
+
reply_fn=mock_reply,
|
|
397
|
+
response_format_info=response_format_info
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Verify the full completion was returned
|
|
401
|
+
assert completion == '{"type": "finished_task", "message": "Task completed"}'
|
|
402
|
+
|
|
403
|
+
# Verify Azure client was used
|
|
404
|
+
mock_azure_client.chat.completions.create.assert_called_once()
|
|
405
|
+
|
|
406
|
+
# Verify the model used was the Azure model
|
|
407
|
+
call_args = mock_azure_client.chat.completions.create.call_args
|
|
408
|
+
assert call_args[1]["model"] == FAKE_AZURE_MODEL
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
class TestAzureOpenAIProviderIntegration:
|
|
412
|
+
"""Test Azure OpenAI integration through the OpenAIProvider."""
|
|
413
|
+
|
|
414
|
+
@pytest.mark.asyncio
|
|
415
|
+
@pytest.mark.parametrize("message_type", COMPLETION_MESSAGE_TYPES)
|
|
416
|
+
async def test_provider_uses_azure_for_gpt_4_1(
|
|
417
|
+
self,
|
|
418
|
+
mock_azure_openai_environment: None,
|
|
419
|
+
provider_config: Config,
|
|
420
|
+
message_type: MessageType
|
|
421
|
+
) -> None:
|
|
422
|
+
"""Test that OpenAIProvider uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
|
|
423
|
+
|
|
424
|
+
# Mock the response
|
|
425
|
+
mock_response = MagicMock()
|
|
426
|
+
mock_response.choices = [MagicMock()]
|
|
427
|
+
mock_response.choices[0].message.content = "Test Azure completion"
|
|
428
|
+
|
|
429
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
430
|
+
mock_azure_client = MagicMock()
|
|
431
|
+
mock_azure_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
432
|
+
mock_azure_client.is_closed.return_value = False
|
|
433
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
434
|
+
|
|
435
|
+
provider = OpenAIProvider(config=provider_config)
|
|
436
|
+
|
|
437
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
438
|
+
{"role": "user", "content": "Test message"}
|
|
439
|
+
]
|
|
440
|
+
|
|
441
|
+
completion = await provider.request_completions(
|
|
442
|
+
message_type=message_type,
|
|
443
|
+
messages=messages,
|
|
444
|
+
model="gpt-4.1"
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
# Verify the completion was returned
|
|
448
|
+
assert completion == "Test Azure completion"
|
|
449
|
+
|
|
450
|
+
# Verify Azure client was created
|
|
451
|
+
mock_azure_client_class.assert_called_once()
|
|
452
|
+
|
|
453
|
+
# Verify request was made through Azure client
|
|
454
|
+
mock_azure_client.chat.completions.create.assert_called_once()
|
|
455
|
+
|
|
456
|
+
# Verify the model used was the Azure model, not the requested model
|
|
457
|
+
call_args = mock_azure_client.chat.completions.create.call_args
|
|
458
|
+
assert call_args[1]["model"] == FAKE_AZURE_MODEL
|
|
459
|
+
|
|
460
|
+
@pytest.mark.asyncio
|
|
461
|
+
@pytest.mark.parametrize("message_type", COMPLETION_MESSAGE_TYPES)
|
|
462
|
+
async def test_provider_stream_uses_azure_for_gpt_4_1(
|
|
463
|
+
self,
|
|
464
|
+
mock_azure_openai_environment: None,
|
|
465
|
+
provider_config: Config,
|
|
466
|
+
message_type: MessageType
|
|
467
|
+
) -> None:
|
|
468
|
+
"""Test that OpenAIProvider stream_completions uses Azure OpenAI when gpt-4.1 is requested and Azure is configured."""
|
|
469
|
+
|
|
470
|
+
# Mock the streaming response
|
|
471
|
+
mock_chunk1 = MagicMock()
|
|
472
|
+
mock_chunk1.choices = [MagicMock()]
|
|
473
|
+
mock_chunk1.choices[0].delta.content = "Hello"
|
|
474
|
+
mock_chunk1.choices[0].finish_reason = None
|
|
475
|
+
|
|
476
|
+
mock_chunk2 = MagicMock()
|
|
477
|
+
mock_chunk2.choices = [MagicMock()]
|
|
478
|
+
mock_chunk2.choices[0].delta.content = " Azure"
|
|
479
|
+
mock_chunk2.choices[0].finish_reason = "stop"
|
|
480
|
+
|
|
481
|
+
async def mock_stream():
|
|
482
|
+
yield mock_chunk1
|
|
483
|
+
yield mock_chunk2
|
|
484
|
+
|
|
485
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client_class:
|
|
486
|
+
mock_azure_client = MagicMock()
|
|
487
|
+
mock_azure_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
488
|
+
mock_azure_client.is_closed.return_value = False
|
|
489
|
+
mock_azure_client_class.return_value = mock_azure_client
|
|
490
|
+
|
|
491
|
+
provider = OpenAIProvider(config=provider_config)
|
|
492
|
+
|
|
493
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
494
|
+
{"role": "user", "content": "Test message"}
|
|
495
|
+
]
|
|
496
|
+
|
|
497
|
+
reply_chunks = []
|
|
498
|
+
def mock_reply(chunk):
|
|
499
|
+
reply_chunks.append(chunk)
|
|
500
|
+
|
|
501
|
+
completion = await provider.stream_completions(
|
|
502
|
+
message_type=message_type,
|
|
503
|
+
messages=messages,
|
|
504
|
+
model="gpt-4.1",
|
|
505
|
+
message_id="test-id",
|
|
506
|
+
thread_id="test-thread",
|
|
507
|
+
reply_fn=mock_reply
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
# Verify the full completion was returned
|
|
511
|
+
assert completion == "Hello Azure"
|
|
512
|
+
|
|
513
|
+
# Verify Azure client was created
|
|
514
|
+
mock_azure_client_class.assert_called_once()
|
|
515
|
+
|
|
516
|
+
# Verify request was made through Azure client
|
|
517
|
+
mock_azure_client.chat.completions.create.assert_called_once()
|
|
518
|
+
|
|
519
|
+
# Verify the model used was the Azure model, not the requested model
|
|
520
|
+
call_args = mock_azure_client.chat.completions.create.call_args
|
|
521
|
+
assert call_args[1]["model"] == FAKE_AZURE_MODEL
|
|
522
|
+
assert call_args[1]["stream"] == True
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
class TestAzureOpenAIConfigurationPriority:
|
|
526
|
+
"""Test that Azure OpenAI is used when configured, regardless of other providers."""
|
|
527
|
+
|
|
528
|
+
def test_azure_openai_priority_over_regular_openai(
|
|
529
|
+
self,
|
|
530
|
+
mock_azure_openai_environment: None,
|
|
531
|
+
provider_config: Config,
|
|
532
|
+
monkeypatch: pytest.MonkeyPatch
|
|
533
|
+
) -> None:
|
|
534
|
+
"""Test that Azure OpenAI is used even when regular OpenAI key is available."""
|
|
535
|
+
|
|
536
|
+
# Set regular OpenAI key (this should be overridden by Azure OpenAI)
|
|
537
|
+
monkeypatch.setenv("OPENAI_API_KEY", "sk-regular-openai-key")
|
|
538
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "sk-regular-openai-key")
|
|
539
|
+
|
|
540
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
541
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
542
|
+
capabilities = openai_client.capabilities
|
|
543
|
+
|
|
544
|
+
# Should still use Azure OpenAI, not regular OpenAI
|
|
545
|
+
assert capabilities.provider == "Azure OpenAI"
|
|
546
|
+
assert capabilities.configuration["model"] == FAKE_AZURE_MODEL
|
|
547
|
+
|
|
548
|
+
# Access the client to trigger creation
|
|
549
|
+
_ = openai_client._active_async_client
|
|
550
|
+
mock_azure_client.assert_called_once()
|
|
551
|
+
|
|
552
|
+
def test_azure_openai_priority_over_claude(
|
|
553
|
+
self,
|
|
554
|
+
mock_azure_openai_environment: None,
|
|
555
|
+
provider_config: Config,
|
|
556
|
+
monkeypatch: pytest.MonkeyPatch
|
|
557
|
+
) -> None:
|
|
558
|
+
"""Test that Azure OpenAI is used even when Claude key is available."""
|
|
559
|
+
|
|
560
|
+
# Set Claude key (this should be overridden by Azure OpenAI)
|
|
561
|
+
monkeypatch.setenv("CLAUDE_API_KEY", "claude-key")
|
|
562
|
+
monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", "claude-key")
|
|
563
|
+
|
|
564
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
565
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
566
|
+
capabilities = openai_client.capabilities
|
|
567
|
+
|
|
568
|
+
# Should still use Azure OpenAI, not Claude
|
|
569
|
+
assert capabilities.provider == "Azure OpenAI"
|
|
570
|
+
assert capabilities.configuration["model"] == FAKE_AZURE_MODEL
|
|
571
|
+
|
|
572
|
+
# Access the client to trigger creation
|
|
573
|
+
_ = openai_client._active_async_client
|
|
574
|
+
mock_azure_client.assert_called_once()
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
class TestAzureOpenAINotConfigured:
|
|
578
|
+
"""Test behavior when Azure OpenAI is not properly configured."""
|
|
579
|
+
|
|
580
|
+
def test_missing_azure_api_key(self, provider_config: Config, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
581
|
+
"""Test that Azure OpenAI is not used when API key is missing."""
|
|
582
|
+
|
|
583
|
+
# Set some but not all Azure OpenAI env vars
|
|
584
|
+
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", FAKE_AZURE_ENDPOINT)
|
|
585
|
+
monkeypatch.setenv("AZURE_OPENAI_MODEL", FAKE_AZURE_MODEL)
|
|
586
|
+
monkeypatch.setenv("AZURE_OPENAI_API_VERSION", FAKE_AZURE_API_VERSION)
|
|
587
|
+
# Missing AZURE_OPENAI_API_KEY
|
|
588
|
+
|
|
589
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_ENDPOINT", FAKE_AZURE_ENDPOINT)
|
|
590
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_MODEL", FAKE_AZURE_MODEL)
|
|
591
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_API_VERSION", FAKE_AZURE_API_VERSION)
|
|
592
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_API_KEY", None)
|
|
593
|
+
|
|
594
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_enterprise", lambda: True)
|
|
595
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_mitosheet_private", lambda: False)
|
|
596
|
+
# This should return False due to missing API key
|
|
597
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
|
|
598
|
+
monkeypatch.setattr("mito_ai.openai_client.is_azure_openai_configured", lambda: False)
|
|
599
|
+
|
|
600
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
601
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
602
|
+
capabilities = openai_client.capabilities
|
|
603
|
+
|
|
604
|
+
# Should not use Azure OpenAI
|
|
605
|
+
assert capabilities.provider != "Azure OpenAI"
|
|
606
|
+
mock_azure_client.assert_not_called()
|
|
607
|
+
|
|
608
|
+
def test_not_enterprise_user(self, provider_config: Config, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
609
|
+
"""Test that Azure OpenAI is not used when user is not enterprise."""
|
|
610
|
+
|
|
611
|
+
# Set all Azure OpenAI env vars
|
|
612
|
+
monkeypatch.setenv("AZURE_OPENAI_API_KEY", FAKE_API_KEY)
|
|
613
|
+
monkeypatch.setenv("AZURE_OPENAI_ENDPOINT", FAKE_AZURE_ENDPOINT)
|
|
614
|
+
monkeypatch.setenv("AZURE_OPENAI_MODEL", FAKE_AZURE_MODEL)
|
|
615
|
+
monkeypatch.setenv("AZURE_OPENAI_API_VERSION", FAKE_AZURE_API_VERSION)
|
|
616
|
+
|
|
617
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_API_KEY", FAKE_API_KEY)
|
|
618
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_ENDPOINT", FAKE_AZURE_ENDPOINT)
|
|
619
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_MODEL", FAKE_AZURE_MODEL)
|
|
620
|
+
monkeypatch.setattr("mito_ai.constants.AZURE_OPENAI_API_VERSION", FAKE_AZURE_API_VERSION)
|
|
621
|
+
|
|
622
|
+
# Not enterprise user
|
|
623
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_enterprise", lambda: False)
|
|
624
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_mitosheet_private", lambda: False)
|
|
625
|
+
# This should return False due to not being enterprise
|
|
626
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
|
|
627
|
+
monkeypatch.setattr("mito_ai.openai_client.is_azure_openai_configured", lambda: False)
|
|
628
|
+
|
|
629
|
+
with patch("openai.AsyncAzureOpenAI") as mock_azure_client:
|
|
630
|
+
openai_client = OpenAIClient(config=provider_config)
|
|
631
|
+
capabilities = openai_client.capabilities
|
|
632
|
+
|
|
633
|
+
# Should not use Azure OpenAI
|
|
634
|
+
assert capabilities.provider != "Azure OpenAI"
|
|
635
|
+
mock_azure_client.assert_not_called()
|
mito_ai/utils/anthropic_utils.py
CHANGED
|
@@ -211,6 +211,9 @@ def get_anthropic_completion_function_params(
|
|
|
211
211
|
"system": system,
|
|
212
212
|
}
|
|
213
213
|
if response_format_info is not None:
|
|
214
|
+
# TODO: This should not be here.. the model is resolved in the anthropic client.
|
|
215
|
+
# This also means that chat is using the fast model...
|
|
216
|
+
# I bet the same bug exists in gemini...
|
|
214
217
|
provider_data["model"] = INLINE_COMPLETION_MODEL
|
|
215
218
|
if tools:
|
|
216
219
|
provider_data["tools"] = tools
|
mito_ai/utils/open_ai_utils.py
CHANGED
|
@@ -27,8 +27,6 @@ from mito_ai.constants import MITO_OPENAI_URL
|
|
|
27
27
|
__user_email: Optional[str] = None
|
|
28
28
|
__user_id: Optional[str] = None
|
|
29
29
|
|
|
30
|
-
INLINE_COMPLETION_MODEL = "gpt-4.1-nano-2025-04-14"
|
|
31
|
-
|
|
32
30
|
def _prepare_request_data_and_headers(
|
|
33
31
|
last_message_content: Union[str, None],
|
|
34
32
|
ai_completion_data: Dict[str, Any],
|
|
@@ -311,8 +309,6 @@ def get_open_ai_completion_function_params(
|
|
|
311
309
|
"strict": True
|
|
312
310
|
}
|
|
313
311
|
}
|
|
314
|
-
else:
|
|
315
|
-
completion_function_params["model"] = INLINE_COMPLETION_MODEL
|
|
316
312
|
|
|
317
313
|
# o3-mini will error if we try setting the temperature
|
|
318
314
|
if model == "gpt-4o-mini":
|