mito-ai 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mito-ai might be problematic. Click here for more details.
- mito_ai/_version.py +1 -1
- mito_ai/anthropic_client.py +52 -54
- mito_ai/app_builder/handlers.py +2 -4
- mito_ai/completions/models.py +15 -1
- mito_ai/completions/prompt_builders/agent_system_message.py +10 -2
- mito_ai/completions/providers.py +79 -39
- mito_ai/constants.py +11 -24
- mito_ai/gemini_client.py +44 -48
- mito_ai/openai_client.py +30 -44
- mito_ai/tests/message_history/test_generate_short_chat_name.py +0 -4
- mito_ai/tests/open_ai_utils_test.py +18 -22
- mito_ai/tests/{test_anthropic_client.py → providers/test_anthropic_client.py} +37 -32
- mito_ai/tests/providers/test_azure.py +2 -6
- mito_ai/tests/providers/test_capabilities.py +120 -0
- mito_ai/tests/{test_gemini_client.py → providers/test_gemini_client.py} +40 -36
- mito_ai/tests/providers/test_mito_server_utils.py +448 -0
- mito_ai/tests/providers/test_model_resolution.py +130 -0
- mito_ai/tests/providers/test_openai_client.py +57 -0
- mito_ai/tests/providers/test_provider_completion_exception.py +66 -0
- mito_ai/tests/providers/test_provider_limits.py +42 -0
- mito_ai/tests/providers/test_providers.py +382 -0
- mito_ai/tests/providers/test_retry_logic.py +389 -0
- mito_ai/tests/providers/utils.py +85 -0
- mito_ai/tests/test_constants.py +15 -2
- mito_ai/tests/test_telemetry.py +12 -0
- mito_ai/utils/anthropic_utils.py +21 -29
- mito_ai/utils/gemini_utils.py +18 -22
- mito_ai/utils/mito_server_utils.py +92 -0
- mito_ai/utils/open_ai_utils.py +22 -46
- mito_ai/utils/provider_utils.py +49 -0
- mito_ai/utils/telemetry_utils.py +11 -1
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +1 -1
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.281f4b9af60d620c6fb1.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js +737 -319
- mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js.map +1 -0
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.4f1d00fd0c58fcc05d8d.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.d2eea6519fa332d79efb.js +13 -16
- mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.d2eea6519fa332d79efb.js.map +1 -0
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js +6 -2
- mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js.map +1 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/METADATA +1 -1
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/RECORD +52 -43
- mito_ai/tests/providers_test.py +0 -438
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.281f4b9af60d620c6fb1.js.map +0 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.4f1d00fd0c58fcc05d8d.js.map +0 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js.map +0 -1
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js +0 -7842
- mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js.map +0 -1
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js.map +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
- {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/WHEEL +0 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/entry_points.txt +0 -0
- {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
These tests ensure that the correct model is chosen for each message type, for each provider.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
from mito_ai.utils.provider_utils import does_message_require_fast_model
|
|
10
|
+
from mito_ai.completions.models import MessageType
|
|
11
|
+
from unittest.mock import AsyncMock, MagicMock, patch
|
|
12
|
+
from mito_ai.completions.providers import OpenAIProvider
|
|
13
|
+
from mito_ai.completions.models import MessageType
|
|
14
|
+
from mito_ai.utils.provider_utils import does_message_require_fast_model
|
|
15
|
+
from traitlets.config import Config
|
|
16
|
+
|
|
17
|
+
@pytest.fixture
|
|
18
|
+
def provider_config() -> Config:
|
|
19
|
+
"""Create a proper Config object for the OpenAIProvider."""
|
|
20
|
+
config = Config()
|
|
21
|
+
config.OpenAIProvider = Config()
|
|
22
|
+
config.OpenAIClient = Config()
|
|
23
|
+
return config
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def mock_messages():
|
|
27
|
+
"""Sample messages for testing."""
|
|
28
|
+
return [{"role": "user", "content": "Test message"}]
|
|
29
|
+
|
|
30
|
+
# Test cases for different message types and their expected fast model requirement
|
|
31
|
+
MESSAGE_TYPE_TEST_CASES = [
|
|
32
|
+
(MessageType.CHAT, False),
|
|
33
|
+
(MessageType.SMART_DEBUG, False),
|
|
34
|
+
(MessageType.CODE_EXPLAIN, False),
|
|
35
|
+
(MessageType.AGENT_EXECUTION, False),
|
|
36
|
+
(MessageType.AGENT_AUTO_ERROR_FIXUP, False),
|
|
37
|
+
(MessageType.INLINE_COMPLETION, True),
|
|
38
|
+
(MessageType.CHAT_NAME_GENERATION, True),
|
|
39
|
+
]
|
|
40
|
+
@pytest.mark.parametrize("message_type,expected_result", MESSAGE_TYPE_TEST_CASES)
|
|
41
|
+
def test_does_message_require_fast_model(message_type: MessageType, expected_result: bool) -> None:
|
|
42
|
+
"""Test that does_message_require_fast_model returns the correct boolean for each message type."""
|
|
43
|
+
assert does_message_require_fast_model(message_type) == expected_result
|
|
44
|
+
|
|
45
|
+
def test_does_message_require_fast_model_raises_error_for_unknown_message_type():
|
|
46
|
+
"""Test that does_message_require_fast_model raises an error for an unknown message type."""
|
|
47
|
+
with pytest.raises(ValueError):
|
|
48
|
+
does_message_require_fast_model('unknown_message_type') # type: ignore
|
|
49
|
+
|
|
50
|
+
@pytest.mark.asyncio
|
|
51
|
+
async def test_request_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
52
|
+
"""Test that request_completions calls does_message_require_fast_model and uses the correct model."""
|
|
53
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
54
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
55
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
56
|
+
|
|
57
|
+
with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
|
|
58
|
+
# Mock the OpenAI API call instead of the entire client
|
|
59
|
+
mock_response = MagicMock()
|
|
60
|
+
mock_response.choices = [MagicMock()]
|
|
61
|
+
mock_response.choices[0].message.content = "Test Completion"
|
|
62
|
+
|
|
63
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
64
|
+
mock_openai_client = MagicMock()
|
|
65
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
|
66
|
+
mock_openai_client.is_closed.return_value = False
|
|
67
|
+
mock_openai_class.return_value = mock_openai_client
|
|
68
|
+
|
|
69
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
70
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
71
|
+
mock_sync_client = MagicMock()
|
|
72
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
73
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
74
|
+
|
|
75
|
+
provider = OpenAIProvider(config=provider_config)
|
|
76
|
+
await provider.request_completions(
|
|
77
|
+
message_type=MessageType.CHAT,
|
|
78
|
+
messages=mock_messages,
|
|
79
|
+
model="gpt-3.5",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
83
|
+
# Verify the model passed to the API call
|
|
84
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
85
|
+
assert call_args[1]['model'] == "gpt-3.5"
|
|
86
|
+
|
|
87
|
+
@pytest.mark.asyncio
|
|
88
|
+
async def test_stream_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
|
|
89
|
+
"""Test that stream_completions calls does_message_require_fast_model and uses the correct model."""
|
|
90
|
+
# Set up environment variables to ensure OpenAI provider is used
|
|
91
|
+
monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
|
|
92
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
|
|
93
|
+
|
|
94
|
+
with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
|
|
95
|
+
# Mock the OpenAI API call instead of the entire client
|
|
96
|
+
mock_response = MagicMock()
|
|
97
|
+
mock_response.choices = [MagicMock()]
|
|
98
|
+
mock_response.choices[0].delta.content = "Test Stream Completion"
|
|
99
|
+
mock_response.choices[0].finish_reason = "stop"
|
|
100
|
+
|
|
101
|
+
with patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
102
|
+
mock_openai_client = MagicMock()
|
|
103
|
+
# Create an async generator for streaming
|
|
104
|
+
async def mock_stream():
|
|
105
|
+
yield mock_response
|
|
106
|
+
|
|
107
|
+
mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
|
|
108
|
+
mock_openai_client.is_closed.return_value = False
|
|
109
|
+
mock_openai_class.return_value = mock_openai_client
|
|
110
|
+
|
|
111
|
+
# Mock the validation that happens in OpenAIClient constructor
|
|
112
|
+
with patch('openai.OpenAI') as mock_sync_openai_class:
|
|
113
|
+
mock_sync_client = MagicMock()
|
|
114
|
+
mock_sync_client.models.list.return_value = MagicMock()
|
|
115
|
+
mock_sync_openai_class.return_value = mock_sync_client
|
|
116
|
+
|
|
117
|
+
provider = OpenAIProvider(config=provider_config)
|
|
118
|
+
await provider.stream_completions(
|
|
119
|
+
message_type=MessageType.CHAT,
|
|
120
|
+
messages=mock_messages,
|
|
121
|
+
model="gpt-3.5",
|
|
122
|
+
message_id="test_id",
|
|
123
|
+
thread_id="test_thread",
|
|
124
|
+
reply_fn=lambda x: None
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
|
|
128
|
+
# Verify the model passed to the API call
|
|
129
|
+
call_args = mock_openai_client.chat.completions.create.call_args
|
|
130
|
+
assert call_args[1]['model'] == "gpt-3.5"
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from mito_ai.openai_client import OpenAIClient
|
|
6
|
+
from mito_ai.utils.open_ai_utils import FAST_OPENAI_MODEL
|
|
7
|
+
from mito_ai.completions.models import MessageType
|
|
8
|
+
from unittest.mock import MagicMock, patch, AsyncMock
|
|
9
|
+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
|
|
10
|
+
|
|
11
|
+
CUSTOM_MODEL = "smart-openai-model"
|
|
12
|
+
@pytest.mark.parametrize("message_type, expected_model", [
|
|
13
|
+
(MessageType.CHAT, CUSTOM_MODEL), #
|
|
14
|
+
(MessageType.SMART_DEBUG, CUSTOM_MODEL), #
|
|
15
|
+
(MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
|
|
16
|
+
(MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
|
|
17
|
+
(MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
|
|
18
|
+
(MessageType.INLINE_COMPLETION, FAST_OPENAI_MODEL), #
|
|
19
|
+
(MessageType.CHAT_NAME_GENERATION, FAST_OPENAI_MODEL), #
|
|
20
|
+
])
|
|
21
|
+
@pytest.mark.asyncio
|
|
22
|
+
async def test_model_selection_based_on_message_type(message_type, expected_model):
|
|
23
|
+
"""
|
|
24
|
+
Tests that the correct model is selected based on the message type.
|
|
25
|
+
"""
|
|
26
|
+
client = OpenAIClient(api_key="test_key") # type: ignore
|
|
27
|
+
|
|
28
|
+
# Mock the _build_openai_client method to return our mock client
|
|
29
|
+
with patch.object(client, '_build_openai_client') as mock_build_client, \
|
|
30
|
+
patch('openai.AsyncOpenAI') as mock_openai_class:
|
|
31
|
+
|
|
32
|
+
mock_client = MagicMock()
|
|
33
|
+
mock_chat = MagicMock()
|
|
34
|
+
mock_completions = MagicMock()
|
|
35
|
+
mock_client.chat = mock_chat
|
|
36
|
+
mock_chat.completions = mock_completions
|
|
37
|
+
mock_openai_class.return_value = mock_client
|
|
38
|
+
mock_build_client.return_value = mock_client
|
|
39
|
+
|
|
40
|
+
# Create an async mock for the create method
|
|
41
|
+
mock_create = AsyncMock()
|
|
42
|
+
mock_create.return_value = MagicMock(
|
|
43
|
+
choices=[MagicMock(message=MagicMock(content="test"))]
|
|
44
|
+
)
|
|
45
|
+
mock_completions.create = mock_create
|
|
46
|
+
|
|
47
|
+
await client.request_completions(
|
|
48
|
+
message_type=message_type,
|
|
49
|
+
messages=[{"role": "user", "content": "Test message"}],
|
|
50
|
+
model=CUSTOM_MODEL,
|
|
51
|
+
response_format_info=None
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Verify that create was called with the expected model
|
|
55
|
+
mock_create.assert_called_once()
|
|
56
|
+
call_args = mock_create.call_args
|
|
57
|
+
assert call_args[1]['model'] == expected_model
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
from mito_ai.utils.mito_server_utils import ProviderCompletionException
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestProviderCompletionException:
|
|
9
|
+
"""Test the ProviderCompletionException class."""
|
|
10
|
+
|
|
11
|
+
@pytest.mark.parametrize("error_message,provider_name,error_type,expected_title,expected_hint_contains", [
|
|
12
|
+
(
|
|
13
|
+
"Something went wrong",
|
|
14
|
+
"LLM Provider",
|
|
15
|
+
"LLMProviderError",
|
|
16
|
+
"LLM Provider Error: Something went wrong",
|
|
17
|
+
"LLM Provider"
|
|
18
|
+
),
|
|
19
|
+
(
|
|
20
|
+
"API key is invalid",
|
|
21
|
+
"OpenAI",
|
|
22
|
+
"AuthenticationError",
|
|
23
|
+
"OpenAI Error: API key is invalid",
|
|
24
|
+
"OpenAI"
|
|
25
|
+
),
|
|
26
|
+
(
|
|
27
|
+
"There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
|
|
28
|
+
"Anthropic",
|
|
29
|
+
"LLMProviderError",
|
|
30
|
+
"Anthropic Error: There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
|
|
31
|
+
"Anthropic"
|
|
32
|
+
),
|
|
33
|
+
])
|
|
34
|
+
def test_exception_initialization(
|
|
35
|
+
self,
|
|
36
|
+
error_message: str,
|
|
37
|
+
provider_name: str,
|
|
38
|
+
error_type: str,
|
|
39
|
+
expected_title: str,
|
|
40
|
+
expected_hint_contains: str
|
|
41
|
+
):
|
|
42
|
+
"""Test exception initialization with various parameter combinations."""
|
|
43
|
+
exception = ProviderCompletionException(
|
|
44
|
+
error_message,
|
|
45
|
+
provider_name=provider_name,
|
|
46
|
+
error_type=error_type
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
assert exception.error_message == error_message
|
|
50
|
+
assert exception.provider_name == provider_name
|
|
51
|
+
assert exception.error_type == error_type
|
|
52
|
+
assert exception.user_friendly_title == expected_title
|
|
53
|
+
assert expected_hint_contains in exception.user_friendly_hint
|
|
54
|
+
assert str(exception) == expected_title
|
|
55
|
+
assert exception.args[0] == expected_title
|
|
56
|
+
|
|
57
|
+
def test_default_initialization(self):
|
|
58
|
+
"""Test exception initialization with default values."""
|
|
59
|
+
error_msg = "Something went wrong"
|
|
60
|
+
exception = ProviderCompletionException(error_msg)
|
|
61
|
+
|
|
62
|
+
assert exception.error_message == error_msg
|
|
63
|
+
assert exception.provider_name == "LLM Provider"
|
|
64
|
+
assert exception.error_type == "LLMProviderError"
|
|
65
|
+
assert exception.user_friendly_title == "LLM Provider Error: Something went wrong"
|
|
66
|
+
assert "LLM Provider" in exception.user_friendly_hint
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from mito_ai.completions.providers import OpenAIProvider
|
|
6
|
+
from mito_ai.tests.providers.utils import mock_openai_client, patch_server_limits
|
|
7
|
+
from mito_ai.utils.server_limits import OS_MONTHLY_AI_COMPLETIONS_LIMIT
|
|
8
|
+
from traitlets.config import Config
|
|
9
|
+
|
|
10
|
+
FAKE_API_KEY = "sk-1234567890"
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def provider_config() -> Config:
|
|
14
|
+
"""Create a proper Config object for the OpenAIProvider."""
|
|
15
|
+
config = Config()
|
|
16
|
+
config.OpenAIProvider = Config()
|
|
17
|
+
config.OpenAIClient = Config()
|
|
18
|
+
return config
|
|
19
|
+
|
|
20
|
+
@pytest.mark.parametrize("is_pro,completion_count", [
|
|
21
|
+
(False, 1), # OS user below limit
|
|
22
|
+
(False, OS_MONTHLY_AI_COMPLETIONS_LIMIT + 1), # OS user above limit
|
|
23
|
+
(True, 1), # Pro user below limit
|
|
24
|
+
(True, OS_MONTHLY_AI_COMPLETIONS_LIMIT + 1), # Pro user above limit
|
|
25
|
+
])
|
|
26
|
+
def test_openai_provider_with_limits(
|
|
27
|
+
is_pro: bool,
|
|
28
|
+
completion_count: int,
|
|
29
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
30
|
+
provider_config: Config) -> None:
|
|
31
|
+
"""Test OpenAI provider behavior with different user types and usage limits."""
|
|
32
|
+
monkeypatch.setenv("OPENAI_API_KEY", FAKE_API_KEY)
|
|
33
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", FAKE_API_KEY)
|
|
34
|
+
|
|
35
|
+
with (
|
|
36
|
+
patch_server_limits(is_pro=is_pro, completion_count=completion_count),
|
|
37
|
+
mock_openai_client()
|
|
38
|
+
):
|
|
39
|
+
llm = OpenAIProvider(config=provider_config)
|
|
40
|
+
capabilities = llm.capabilities
|
|
41
|
+
assert "user key" in capabilities.provider
|
|
42
|
+
assert llm.last_error is None
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
# Copyright (c) Saga Inc.
|
|
2
|
+
# Distributed under the terms of the GNU Affero General Public License v3.0 License.
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from typing import Any, List, Optional
|
|
7
|
+
from unittest.mock import patch, MagicMock, AsyncMock
|
|
8
|
+
|
|
9
|
+
from mito_ai.tests.providers.utils import mock_azure_openai_client, mock_openai_client, patch_server_limits
|
|
10
|
+
import pytest
|
|
11
|
+
from traitlets.config import Config
|
|
12
|
+
from mito_ai.completions.providers import OpenAIProvider
|
|
13
|
+
from mito_ai.completions.models import (
|
|
14
|
+
MessageType,
|
|
15
|
+
AICapabilities,
|
|
16
|
+
CompletionReply
|
|
17
|
+
)
|
|
18
|
+
from mito_ai.utils.server_limits import OS_MONTHLY_AI_COMPLETIONS_LIMIT
|
|
19
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
20
|
+
|
|
21
|
+
REALLY_OLD_DATE = "2020-01-01"
|
|
22
|
+
TODAY = datetime.now().strftime("%Y-%m-%d")
|
|
23
|
+
FAKE_API_KEY = "sk-1234567890"
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def provider_config() -> Config:
|
|
27
|
+
"""Create a proper Config object for the OpenAIProvider."""
|
|
28
|
+
config = Config()
|
|
29
|
+
config.OpenAIProvider = Config()
|
|
30
|
+
config.OpenAIClient = Config()
|
|
31
|
+
return config
|
|
32
|
+
|
|
33
|
+
@pytest.fixture(autouse=True)
|
|
34
|
+
def reset_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
35
|
+
for var in [
|
|
36
|
+
"OPENAI_API_KEY", "CLAUDE_API_KEY",
|
|
37
|
+
"GEMINI_API_KEY", "OLLAMA_MODEL",
|
|
38
|
+
"AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_MODEL"
|
|
39
|
+
]:
|
|
40
|
+
monkeypatch.delenv(var, raising=False)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ====================
|
|
44
|
+
# TESTS
|
|
45
|
+
# ====================
|
|
46
|
+
|
|
47
|
+
@pytest.mark.parametrize("provider_config_data", [
|
|
48
|
+
{
|
|
49
|
+
"name": "openai",
|
|
50
|
+
"env_vars": {"OPENAI_API_KEY": FAKE_API_KEY},
|
|
51
|
+
"constants": {"OPENAI_API_KEY": FAKE_API_KEY},
|
|
52
|
+
"model": "gpt-4o-mini",
|
|
53
|
+
"mock_patch": "mito_ai.completions.providers.OpenAIClient",
|
|
54
|
+
"mock_method": "request_completions",
|
|
55
|
+
"provider_name": "OpenAI with user key",
|
|
56
|
+
"key_type": "user"
|
|
57
|
+
},
|
|
58
|
+
{
|
|
59
|
+
"name": "claude",
|
|
60
|
+
"env_vars": {"CLAUDE_API_KEY": "claude-key"},
|
|
61
|
+
"constants": {"CLAUDE_API_KEY": "claude-key", "OPENAI_API_KEY": None},
|
|
62
|
+
"model": "claude-3-opus-20240229",
|
|
63
|
+
"mock_patch": "mito_ai.completions.providers.AnthropicClient",
|
|
64
|
+
"mock_method": "request_completions",
|
|
65
|
+
"provider_name": "Claude",
|
|
66
|
+
"key_type": "claude"
|
|
67
|
+
},
|
|
68
|
+
{
|
|
69
|
+
"name": "gemini",
|
|
70
|
+
"env_vars": {"GEMINI_API_KEY": "gemini-key"},
|
|
71
|
+
"constants": {"GEMINI_API_KEY": "gemini-key", "OPENAI_API_KEY": None},
|
|
72
|
+
"model": "gemini-2.0-flash",
|
|
73
|
+
"mock_patch": "mito_ai.completions.providers.GeminiClient",
|
|
74
|
+
"mock_method": "request_completions",
|
|
75
|
+
"provider_name": "Gemini",
|
|
76
|
+
"key_type": "gemini"
|
|
77
|
+
},
|
|
78
|
+
{
|
|
79
|
+
"name": "azure",
|
|
80
|
+
"env_vars": {"AZURE_OPENAI_API_KEY": "azure-key"},
|
|
81
|
+
"constants": {"AZURE_OPENAI_API_KEY": "azure-key", "OPENAI_API_KEY": None},
|
|
82
|
+
"model": "gpt-4o",
|
|
83
|
+
"mock_patch": "mito_ai.completions.providers.OpenAIClient",
|
|
84
|
+
"mock_method": "request_completions",
|
|
85
|
+
"provider_name": "Azure OpenAI",
|
|
86
|
+
"key_type": "azure"
|
|
87
|
+
}
|
|
88
|
+
])
|
|
89
|
+
@pytest.mark.asyncio
|
|
90
|
+
async def test_completion_request(
|
|
91
|
+
provider_config_data: dict,
|
|
92
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
93
|
+
provider_config: Config
|
|
94
|
+
) -> None:
|
|
95
|
+
"""Test completion requests for different providers."""
|
|
96
|
+
# Set up environment variables
|
|
97
|
+
for env_var, value in provider_config_data["env_vars"].items():
|
|
98
|
+
monkeypatch.setenv(env_var, value)
|
|
99
|
+
|
|
100
|
+
# Set up constants
|
|
101
|
+
for constant, value in provider_config_data["constants"].items():
|
|
102
|
+
monkeypatch.setattr(f"mito_ai.constants.{constant}", value)
|
|
103
|
+
|
|
104
|
+
# Create mock client
|
|
105
|
+
mock_client = MagicMock()
|
|
106
|
+
mock_client.capabilities = AICapabilities(
|
|
107
|
+
configuration={"model": provider_config_data["model"]},
|
|
108
|
+
provider=provider_config_data["provider_name"],
|
|
109
|
+
type="ai_capabilities"
|
|
110
|
+
)
|
|
111
|
+
mock_client.key_type = provider_config_data["key_type"]
|
|
112
|
+
mock_client.request_completions = AsyncMock(return_value="Test completion")
|
|
113
|
+
mock_client.stream_completions = AsyncMock(return_value="Test completion")
|
|
114
|
+
|
|
115
|
+
with patch(provider_config_data["mock_patch"], return_value=mock_client):
|
|
116
|
+
llm = OpenAIProvider(config=provider_config)
|
|
117
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
118
|
+
{"role": "user", "content": "Test message"}
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
completion = await llm.request_completions(
|
|
122
|
+
message_type=MessageType.CHAT,
|
|
123
|
+
messages=messages,
|
|
124
|
+
model=provider_config_data["model"]
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
assert completion == "Test completion"
|
|
128
|
+
getattr(mock_client, provider_config_data["mock_method"]).assert_called_once()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.mark.parametrize("provider_config_data", [
|
|
132
|
+
{
|
|
133
|
+
"name": "openai",
|
|
134
|
+
"env_vars": {"OPENAI_API_KEY": FAKE_API_KEY},
|
|
135
|
+
"constants": {"OPENAI_API_KEY": FAKE_API_KEY},
|
|
136
|
+
"model": "gpt-4o-mini",
|
|
137
|
+
"mock_patch": "mito_ai.completions.providers.OpenAIClient",
|
|
138
|
+
"mock_method": "stream_completions",
|
|
139
|
+
"provider_name": "OpenAI with user key",
|
|
140
|
+
"key_type": "user"
|
|
141
|
+
},
|
|
142
|
+
{
|
|
143
|
+
"name": "claude",
|
|
144
|
+
"env_vars": {"CLAUDE_API_KEY": "claude-key"},
|
|
145
|
+
"constants": {"CLAUDE_API_KEY": "claude-key", "OPENAI_API_KEY": None},
|
|
146
|
+
"model": "claude-3-opus-20240229",
|
|
147
|
+
"mock_patch": "mito_ai.completions.providers.AnthropicClient",
|
|
148
|
+
"mock_method": "stream_completions",
|
|
149
|
+
"provider_name": "Claude",
|
|
150
|
+
"key_type": "claude"
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"name": "gemini",
|
|
154
|
+
"env_vars": {"GEMINI_API_KEY": "gemini-key"},
|
|
155
|
+
"constants": {"GEMINI_API_KEY": "gemini-key", "OPENAI_API_KEY": None},
|
|
156
|
+
"model": "gemini-2.0-flash",
|
|
157
|
+
"mock_patch": "mito_ai.completions.providers.GeminiClient",
|
|
158
|
+
"mock_method": "stream_completions",
|
|
159
|
+
"provider_name": "Gemini",
|
|
160
|
+
"key_type": "gemini"
|
|
161
|
+
},
|
|
162
|
+
])
|
|
163
|
+
@pytest.mark.asyncio
|
|
164
|
+
async def test_stream_completion_parameterized(
|
|
165
|
+
provider_config_data: dict,
|
|
166
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
167
|
+
provider_config: Config
|
|
168
|
+
) -> None:
|
|
169
|
+
"""Test stream completions for different providers."""
|
|
170
|
+
# Set up environment variables
|
|
171
|
+
for env_var, value in provider_config_data["env_vars"].items():
|
|
172
|
+
monkeypatch.setenv(env_var, value)
|
|
173
|
+
|
|
174
|
+
# Set up constants
|
|
175
|
+
for constant, value in provider_config_data["constants"].items():
|
|
176
|
+
monkeypatch.setattr(f"mito_ai.constants.{constant}", value)
|
|
177
|
+
|
|
178
|
+
# Create mock client
|
|
179
|
+
mock_client = MagicMock()
|
|
180
|
+
mock_client.capabilities = AICapabilities(
|
|
181
|
+
configuration={"model": provider_config_data["model"]},
|
|
182
|
+
provider=provider_config_data["provider_name"],
|
|
183
|
+
type="ai_capabilities"
|
|
184
|
+
)
|
|
185
|
+
mock_client.key_type = provider_config_data["key_type"]
|
|
186
|
+
mock_client.request_completions = AsyncMock(return_value="Test completion")
|
|
187
|
+
mock_client.stream_completions = AsyncMock(return_value="Test completion")
|
|
188
|
+
mock_client.stream_response = AsyncMock(return_value="Test completion") # For Claude
|
|
189
|
+
|
|
190
|
+
with patch(provider_config_data["mock_patch"], return_value=mock_client):
|
|
191
|
+
llm = OpenAIProvider(config=provider_config)
|
|
192
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
193
|
+
{"role": "user", "content": "Test message"}
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
reply_chunks = []
|
|
197
|
+
def mock_reply(chunk):
|
|
198
|
+
reply_chunks.append(chunk)
|
|
199
|
+
|
|
200
|
+
completion = await llm.stream_completions(
|
|
201
|
+
message_type=MessageType.CHAT,
|
|
202
|
+
messages=messages,
|
|
203
|
+
model=provider_config_data["model"],
|
|
204
|
+
message_id="test-id",
|
|
205
|
+
thread_id="test-thread",
|
|
206
|
+
reply_fn=mock_reply
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
assert completion == "Test completion"
|
|
210
|
+
getattr(mock_client, provider_config_data["mock_method"]).assert_called_once()
|
|
211
|
+
assert len(reply_chunks) > 0
|
|
212
|
+
assert isinstance(reply_chunks[0], CompletionReply)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def test_error_handling(monkeypatch: pytest.MonkeyPatch, provider_config: Config) -> None:
|
|
216
|
+
monkeypatch.setenv("OPENAI_API_KEY", "invalid-key")
|
|
217
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "invalid-key")
|
|
218
|
+
mock_client = MagicMock()
|
|
219
|
+
mock_client.capabilities = AICapabilities(
|
|
220
|
+
configuration={"model": "gpt-4o-mini"},
|
|
221
|
+
provider="OpenAI with user key",
|
|
222
|
+
type="ai_capabilities"
|
|
223
|
+
)
|
|
224
|
+
mock_client.key_type = "user"
|
|
225
|
+
mock_client.request_completions.side_effect = Exception("API error")
|
|
226
|
+
|
|
227
|
+
with patch("mito_ai.completions.providers.OpenAIClient", return_value=mock_client):
|
|
228
|
+
llm = OpenAIProvider(config=provider_config)
|
|
229
|
+
assert llm.last_error is None # Error should be None until a request is made
|
|
230
|
+
|
|
231
|
+
def test_claude_error_handling(monkeypatch: pytest.MonkeyPatch, provider_config: Config) -> None:
|
|
232
|
+
monkeypatch.setenv("CLAUDE_API_KEY", "invalid-key")
|
|
233
|
+
monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", "invalid-key")
|
|
234
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
|
|
235
|
+
|
|
236
|
+
mock_client = MagicMock()
|
|
237
|
+
mock_client.capabilities = AICapabilities(
|
|
238
|
+
configuration={"model": "claude-3-opus-20240229"},
|
|
239
|
+
provider="Claude",
|
|
240
|
+
type="ai_capabilities"
|
|
241
|
+
)
|
|
242
|
+
mock_client.key_type = "claude"
|
|
243
|
+
mock_client.request_completions.side_effect = Exception("API error")
|
|
244
|
+
|
|
245
|
+
with patch("mito_ai.completions.providers.AnthropicClient", return_value=mock_client):
|
|
246
|
+
llm = OpenAIProvider(config=provider_config)
|
|
247
|
+
assert llm.last_error is None # Error should be None until a request is made
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# Mito Server Fallback Tests
|
|
251
|
+
@pytest.mark.parametrize("mito_server_config", [
|
|
252
|
+
{
|
|
253
|
+
"name": "openai_fallback",
|
|
254
|
+
"model": "gpt-4o-mini",
|
|
255
|
+
"mock_function": "mito_ai.openai_client.get_ai_completion_from_mito_server",
|
|
256
|
+
"provider_name": "Mito server",
|
|
257
|
+
"key_type": "mito_server"
|
|
258
|
+
},
|
|
259
|
+
{
|
|
260
|
+
"name": "claude_fallback",
|
|
261
|
+
"model": "claude-3-opus-20240229",
|
|
262
|
+
"mock_function": "mito_ai.anthropic_client.get_anthropic_completion_from_mito_server",
|
|
263
|
+
"provider_name": "Claude",
|
|
264
|
+
"key_type": "claude"
|
|
265
|
+
},
|
|
266
|
+
{
|
|
267
|
+
"name": "gemini_fallback",
|
|
268
|
+
"model": "gemini-2.0-flash",
|
|
269
|
+
"mock_function": "mito_ai.gemini_client.get_gemini_completion_from_mito_server",
|
|
270
|
+
"provider_name": "Gemini",
|
|
271
|
+
"key_type": "gemini"
|
|
272
|
+
},
|
|
273
|
+
])
|
|
274
|
+
@pytest.mark.asyncio
|
|
275
|
+
async def test_mito_server_fallback_completion_request(
|
|
276
|
+
mito_server_config: dict,
|
|
277
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
278
|
+
provider_config: Config
|
|
279
|
+
) -> None:
|
|
280
|
+
"""Test that completion requests fallback to Mito server when no API keys are set."""
|
|
281
|
+
# Clear all API keys to force Mito server fallback
|
|
282
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
|
|
283
|
+
monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", None)
|
|
284
|
+
monkeypatch.setattr("mito_ai.constants.GEMINI_API_KEY", None)
|
|
285
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
|
|
286
|
+
provider_config.OpenAIProvider.api_key = None
|
|
287
|
+
|
|
288
|
+
# Mock the appropriate Mito server function
|
|
289
|
+
with patch(mito_server_config["mock_function"], new_callable=AsyncMock) as mock_mito_function:
|
|
290
|
+
mock_mito_function.return_value = "Mito server response"
|
|
291
|
+
|
|
292
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
293
|
+
{"role": "user", "content": "Test message"}
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
with patch_server_limits():
|
|
297
|
+
llm = OpenAIProvider(config=provider_config)
|
|
298
|
+
|
|
299
|
+
completion = await llm.request_completions(
|
|
300
|
+
message_type=MessageType.CHAT,
|
|
301
|
+
messages=messages,
|
|
302
|
+
model=mito_server_config["model"]
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
assert completion == "Mito server response"
|
|
306
|
+
mock_mito_function.assert_called_once()
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@pytest.mark.parametrize("mito_server_config", [
|
|
310
|
+
{
|
|
311
|
+
"name": "openai_fallback",
|
|
312
|
+
"model": "gpt-4o-mini",
|
|
313
|
+
"mock_function": "mito_ai.openai_client.stream_ai_completion_from_mito_server",
|
|
314
|
+
"provider_name": "Mito server",
|
|
315
|
+
"key_type": "mito_server"
|
|
316
|
+
},
|
|
317
|
+
{
|
|
318
|
+
"name": "claude_fallback",
|
|
319
|
+
"model": "claude-3-opus-20240229",
|
|
320
|
+
"mock_function": "mito_ai.anthropic_client.stream_anthropic_completion_from_mito_server",
|
|
321
|
+
"provider_name": "Claude",
|
|
322
|
+
"key_type": "claude"
|
|
323
|
+
},
|
|
324
|
+
{
|
|
325
|
+
"name": "gemini_fallback",
|
|
326
|
+
"model": "gemini-2.0-flash",
|
|
327
|
+
"mock_function": "mito_ai.gemini_client.stream_gemini_completion_from_mito_server",
|
|
328
|
+
"provider_name": "Gemini",
|
|
329
|
+
"key_type": "gemini"
|
|
330
|
+
},
|
|
331
|
+
])
|
|
332
|
+
@pytest.mark.asyncio
|
|
333
|
+
async def test_mito_server_fallback_stream_completion(
|
|
334
|
+
mito_server_config: dict,
|
|
335
|
+
monkeypatch: pytest.MonkeyPatch,
|
|
336
|
+
provider_config: Config
|
|
337
|
+
) -> None:
|
|
338
|
+
"""Test that stream completions fallback to Mito server when no API keys are set."""
|
|
339
|
+
# Clear all API keys to force Mito server fallback
|
|
340
|
+
monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
|
|
341
|
+
monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", None)
|
|
342
|
+
monkeypatch.setattr("mito_ai.constants.GEMINI_API_KEY", None)
|
|
343
|
+
monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
|
|
344
|
+
provider_config.OpenAIProvider.api_key = None
|
|
345
|
+
|
|
346
|
+
# Create an async generator that yields chunks for streaming
|
|
347
|
+
async def mock_stream_generator():
|
|
348
|
+
yield "Chunk 1"
|
|
349
|
+
yield "Chunk 2"
|
|
350
|
+
yield "Chunk 3"
|
|
351
|
+
|
|
352
|
+
# Mock the appropriate Mito server streaming function
|
|
353
|
+
with patch(mito_server_config["mock_function"]) as mock_mito_stream:
|
|
354
|
+
mock_mito_stream.return_value = mock_stream_generator()
|
|
355
|
+
|
|
356
|
+
messages: List[ChatCompletionMessageParam] = [
|
|
357
|
+
{"role": "user", "content": "Test message"}
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
reply_chunks = []
|
|
361
|
+
def mock_reply(chunk):
|
|
362
|
+
reply_chunks.append(chunk)
|
|
363
|
+
|
|
364
|
+
# Apply patch_server_limits for all cases, not just openai_fallback
|
|
365
|
+
# Also patch update_mito_server_quota where it's actually used in openai_client
|
|
366
|
+
with patch_server_limits(), patch("mito_ai.openai_client.update_mito_server_quota", MagicMock(return_value=None)):
|
|
367
|
+
llm = OpenAIProvider(config=provider_config)
|
|
368
|
+
|
|
369
|
+
completion = await llm.stream_completions(
|
|
370
|
+
message_type=MessageType.CHAT,
|
|
371
|
+
messages=messages,
|
|
372
|
+
model=mito_server_config["model"],
|
|
373
|
+
message_id="test-id",
|
|
374
|
+
thread_id="test-thread",
|
|
375
|
+
reply_fn=mock_reply
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Verify that the Mito server function was called
|
|
379
|
+
mock_mito_stream.assert_called_once()
|
|
380
|
+
# Verify that reply chunks were generated
|
|
381
|
+
assert len(reply_chunks) > 0
|
|
382
|
+
assert isinstance(reply_chunks[0], CompletionReply)
|