mito-ai 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mito-ai might be problematic. Click here for more details.

Files changed (58) hide show
  1. mito_ai/_version.py +1 -1
  2. mito_ai/anthropic_client.py +52 -54
  3. mito_ai/app_builder/handlers.py +2 -4
  4. mito_ai/completions/models.py +15 -1
  5. mito_ai/completions/prompt_builders/agent_system_message.py +10 -2
  6. mito_ai/completions/providers.py +79 -39
  7. mito_ai/constants.py +11 -24
  8. mito_ai/gemini_client.py +44 -48
  9. mito_ai/openai_client.py +30 -44
  10. mito_ai/tests/message_history/test_generate_short_chat_name.py +0 -4
  11. mito_ai/tests/open_ai_utils_test.py +18 -22
  12. mito_ai/tests/{test_anthropic_client.py → providers/test_anthropic_client.py} +37 -32
  13. mito_ai/tests/providers/test_azure.py +2 -6
  14. mito_ai/tests/providers/test_capabilities.py +120 -0
  15. mito_ai/tests/{test_gemini_client.py → providers/test_gemini_client.py} +40 -36
  16. mito_ai/tests/providers/test_mito_server_utils.py +448 -0
  17. mito_ai/tests/providers/test_model_resolution.py +130 -0
  18. mito_ai/tests/providers/test_openai_client.py +57 -0
  19. mito_ai/tests/providers/test_provider_completion_exception.py +66 -0
  20. mito_ai/tests/providers/test_provider_limits.py +42 -0
  21. mito_ai/tests/providers/test_providers.py +382 -0
  22. mito_ai/tests/providers/test_retry_logic.py +389 -0
  23. mito_ai/tests/providers/utils.py +85 -0
  24. mito_ai/tests/test_constants.py +15 -2
  25. mito_ai/tests/test_telemetry.py +12 -0
  26. mito_ai/utils/anthropic_utils.py +21 -29
  27. mito_ai/utils/gemini_utils.py +18 -22
  28. mito_ai/utils/mito_server_utils.py +92 -0
  29. mito_ai/utils/open_ai_utils.py +22 -46
  30. mito_ai/utils/provider_utils.py +49 -0
  31. mito_ai/utils/telemetry_utils.py +11 -1
  32. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/build_log.json +1 -1
  33. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/package.json +2 -2
  34. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/package.json.orig +1 -1
  35. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.281f4b9af60d620c6fb1.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js +737 -319
  36. mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.a20772bc113422d0f505.js.map +1 -0
  37. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.4f1d00fd0c58fcc05d8d.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.d2eea6519fa332d79efb.js +13 -16
  38. mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.d2eea6519fa332d79efb.js.map +1 -0
  39. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js → mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js +6 -2
  40. mito_ai-0.1.35.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.76efcc5c3be4056457ee.js.map +1 -0
  41. {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/METADATA +1 -1
  42. {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/RECORD +52 -43
  43. mito_ai/tests/providers_test.py +0 -438
  44. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/lib_index_js.281f4b9af60d620c6fb1.js.map +0 -1
  45. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/remoteEntry.4f1d00fd0c58fcc05d8d.js.map +0 -1
  46. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/style_index_js.06083e515de4862df010.js.map +0 -1
  47. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js +0 -7842
  48. mito_ai-0.1.33.data/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_html2canvas_dist_html2canvas_js.ea47e8c8c906197f8d19.js.map +0 -1
  49. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/etc/jupyter/jupyter_server_config.d/mito_ai.json +0 -0
  50. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/schemas/mito_ai/toolbar-buttons.json +0 -0
  51. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/style.js +0 -0
  52. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js +0 -0
  53. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_semver_index_js.9795f79265ddb416864b.js.map +0 -0
  54. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js +0 -0
  55. {mito_ai-0.1.33.data → mito_ai-0.1.35.data}/data/share/jupyter/labextensions/mito_ai/static/vendors-node_modules_vscode-diff_dist_index_js.ea55f1f9346638aafbcf.js.map +0 -0
  56. {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/WHEEL +0 -0
  57. {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/entry_points.txt +0 -0
  58. {mito_ai-0.1.33.dist-info → mito_ai-0.1.35.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,130 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ """
5
+ These tests ensure that the correct model is chosen for each message type, for each provider.
6
+ """
7
+
8
+ import pytest
9
+ from mito_ai.utils.provider_utils import does_message_require_fast_model
10
+ from mito_ai.completions.models import MessageType
11
+ from unittest.mock import AsyncMock, MagicMock, patch
12
+ from mito_ai.completions.providers import OpenAIProvider
13
+ from mito_ai.completions.models import MessageType
14
+ from mito_ai.utils.provider_utils import does_message_require_fast_model
15
+ from traitlets.config import Config
16
+
17
+ @pytest.fixture
18
+ def provider_config() -> Config:
19
+ """Create a proper Config object for the OpenAIProvider."""
20
+ config = Config()
21
+ config.OpenAIProvider = Config()
22
+ config.OpenAIClient = Config()
23
+ return config
24
+
25
+ @pytest.fixture
26
+ def mock_messages():
27
+ """Sample messages for testing."""
28
+ return [{"role": "user", "content": "Test message"}]
29
+
30
+ # Test cases for different message types and their expected fast model requirement
31
+ MESSAGE_TYPE_TEST_CASES = [
32
+ (MessageType.CHAT, False),
33
+ (MessageType.SMART_DEBUG, False),
34
+ (MessageType.CODE_EXPLAIN, False),
35
+ (MessageType.AGENT_EXECUTION, False),
36
+ (MessageType.AGENT_AUTO_ERROR_FIXUP, False),
37
+ (MessageType.INLINE_COMPLETION, True),
38
+ (MessageType.CHAT_NAME_GENERATION, True),
39
+ ]
40
+ @pytest.mark.parametrize("message_type,expected_result", MESSAGE_TYPE_TEST_CASES)
41
+ def test_does_message_require_fast_model(message_type: MessageType, expected_result: bool) -> None:
42
+ """Test that does_message_require_fast_model returns the correct boolean for each message type."""
43
+ assert does_message_require_fast_model(message_type) == expected_result
44
+
45
+ def test_does_message_require_fast_model_raises_error_for_unknown_message_type():
46
+ """Test that does_message_require_fast_model raises an error for an unknown message type."""
47
+ with pytest.raises(ValueError):
48
+ does_message_require_fast_model('unknown_message_type') # type: ignore
49
+
50
+ @pytest.mark.asyncio
51
+ async def test_request_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
52
+ """Test that request_completions calls does_message_require_fast_model and uses the correct model."""
53
+ # Set up environment variables to ensure OpenAI provider is used
54
+ monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
55
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
56
+
57
+ with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
58
+ # Mock the OpenAI API call instead of the entire client
59
+ mock_response = MagicMock()
60
+ mock_response.choices = [MagicMock()]
61
+ mock_response.choices[0].message.content = "Test Completion"
62
+
63
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
64
+ mock_openai_client = MagicMock()
65
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_response)
66
+ mock_openai_client.is_closed.return_value = False
67
+ mock_openai_class.return_value = mock_openai_client
68
+
69
+ # Mock the validation that happens in OpenAIClient constructor
70
+ with patch('openai.OpenAI') as mock_sync_openai_class:
71
+ mock_sync_client = MagicMock()
72
+ mock_sync_client.models.list.return_value = MagicMock()
73
+ mock_sync_openai_class.return_value = mock_sync_client
74
+
75
+ provider = OpenAIProvider(config=provider_config)
76
+ await provider.request_completions(
77
+ message_type=MessageType.CHAT,
78
+ messages=mock_messages,
79
+ model="gpt-3.5",
80
+ )
81
+
82
+ mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
83
+ # Verify the model passed to the API call
84
+ call_args = mock_openai_client.chat.completions.create.call_args
85
+ assert call_args[1]['model'] == "gpt-3.5"
86
+
87
+ @pytest.mark.asyncio
88
+ async def test_stream_completions_calls_does_message_require_fast_model(provider_config: Config, mock_messages, monkeypatch: pytest.MonkeyPatch):
89
+ """Test that stream_completions calls does_message_require_fast_model and uses the correct model."""
90
+ # Set up environment variables to ensure OpenAI provider is used
91
+ monkeypatch.setenv("OPENAI_API_KEY", "fake-key")
92
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "fake-key")
93
+
94
+ with patch('mito_ai.utils.open_ai_utils.does_message_require_fast_model', wraps=does_message_require_fast_model) as mock_does_message_require_fast_model:
95
+ # Mock the OpenAI API call instead of the entire client
96
+ mock_response = MagicMock()
97
+ mock_response.choices = [MagicMock()]
98
+ mock_response.choices[0].delta.content = "Test Stream Completion"
99
+ mock_response.choices[0].finish_reason = "stop"
100
+
101
+ with patch('openai.AsyncOpenAI') as mock_openai_class:
102
+ mock_openai_client = MagicMock()
103
+ # Create an async generator for streaming
104
+ async def mock_stream():
105
+ yield mock_response
106
+
107
+ mock_openai_client.chat.completions.create = AsyncMock(return_value=mock_stream())
108
+ mock_openai_client.is_closed.return_value = False
109
+ mock_openai_class.return_value = mock_openai_client
110
+
111
+ # Mock the validation that happens in OpenAIClient constructor
112
+ with patch('openai.OpenAI') as mock_sync_openai_class:
113
+ mock_sync_client = MagicMock()
114
+ mock_sync_client.models.list.return_value = MagicMock()
115
+ mock_sync_openai_class.return_value = mock_sync_client
116
+
117
+ provider = OpenAIProvider(config=provider_config)
118
+ await provider.stream_completions(
119
+ message_type=MessageType.CHAT,
120
+ messages=mock_messages,
121
+ model="gpt-3.5",
122
+ message_id="test_id",
123
+ thread_id="test_thread",
124
+ reply_fn=lambda x: None
125
+ )
126
+
127
+ mock_does_message_require_fast_model.assert_called_once_with(MessageType.CHAT)
128
+ # Verify the model passed to the API call
129
+ call_args = mock_openai_client.chat.completions.create.call_args
130
+ assert call_args[1]['model'] == "gpt-3.5"
@@ -0,0 +1,57 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ import pytest
5
+ from mito_ai.openai_client import OpenAIClient
6
+ from mito_ai.utils.open_ai_utils import FAST_OPENAI_MODEL
7
+ from mito_ai.completions.models import MessageType
8
+ from unittest.mock import MagicMock, patch, AsyncMock
9
+ from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
10
+
11
+ CUSTOM_MODEL = "smart-openai-model"
12
+ @pytest.mark.parametrize("message_type, expected_model", [
13
+ (MessageType.CHAT, CUSTOM_MODEL), #
14
+ (MessageType.SMART_DEBUG, CUSTOM_MODEL), #
15
+ (MessageType.CODE_EXPLAIN, CUSTOM_MODEL), #
16
+ (MessageType.AGENT_EXECUTION, CUSTOM_MODEL), #
17
+ (MessageType.AGENT_AUTO_ERROR_FIXUP, CUSTOM_MODEL), #
18
+ (MessageType.INLINE_COMPLETION, FAST_OPENAI_MODEL), #
19
+ (MessageType.CHAT_NAME_GENERATION, FAST_OPENAI_MODEL), #
20
+ ])
21
+ @pytest.mark.asyncio
22
+ async def test_model_selection_based_on_message_type(message_type, expected_model):
23
+ """
24
+ Tests that the correct model is selected based on the message type.
25
+ """
26
+ client = OpenAIClient(api_key="test_key") # type: ignore
27
+
28
+ # Mock the _build_openai_client method to return our mock client
29
+ with patch.object(client, '_build_openai_client') as mock_build_client, \
30
+ patch('openai.AsyncOpenAI') as mock_openai_class:
31
+
32
+ mock_client = MagicMock()
33
+ mock_chat = MagicMock()
34
+ mock_completions = MagicMock()
35
+ mock_client.chat = mock_chat
36
+ mock_chat.completions = mock_completions
37
+ mock_openai_class.return_value = mock_client
38
+ mock_build_client.return_value = mock_client
39
+
40
+ # Create an async mock for the create method
41
+ mock_create = AsyncMock()
42
+ mock_create.return_value = MagicMock(
43
+ choices=[MagicMock(message=MagicMock(content="test"))]
44
+ )
45
+ mock_completions.create = mock_create
46
+
47
+ await client.request_completions(
48
+ message_type=message_type,
49
+ messages=[{"role": "user", "content": "Test message"}],
50
+ model=CUSTOM_MODEL,
51
+ response_format_info=None
52
+ )
53
+
54
+ # Verify that create was called with the expected model
55
+ mock_create.assert_called_once()
56
+ call_args = mock_create.call_args
57
+ assert call_args[1]['model'] == expected_model
@@ -0,0 +1,66 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ from mito_ai.utils.mito_server_utils import ProviderCompletionException
5
+ import pytest
6
+
7
+
8
+ class TestProviderCompletionException:
9
+ """Test the ProviderCompletionException class."""
10
+
11
+ @pytest.mark.parametrize("error_message,provider_name,error_type,expected_title,expected_hint_contains", [
12
+ (
13
+ "Something went wrong",
14
+ "LLM Provider",
15
+ "LLMProviderError",
16
+ "LLM Provider Error: Something went wrong",
17
+ "LLM Provider"
18
+ ),
19
+ (
20
+ "API key is invalid",
21
+ "OpenAI",
22
+ "AuthenticationError",
23
+ "OpenAI Error: API key is invalid",
24
+ "OpenAI"
25
+ ),
26
+ (
27
+ "There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
28
+ "Anthropic",
29
+ "LLMProviderError",
30
+ "Anthropic Error: There was an error accessing the Anthropic API: Error code: 529 - {'type': 'error', 'error': {'type': 'overloaded_error', 'message': 'Overloaded'}}",
31
+ "Anthropic"
32
+ ),
33
+ ])
34
+ def test_exception_initialization(
35
+ self,
36
+ error_message: str,
37
+ provider_name: str,
38
+ error_type: str,
39
+ expected_title: str,
40
+ expected_hint_contains: str
41
+ ):
42
+ """Test exception initialization with various parameter combinations."""
43
+ exception = ProviderCompletionException(
44
+ error_message,
45
+ provider_name=provider_name,
46
+ error_type=error_type
47
+ )
48
+
49
+ assert exception.error_message == error_message
50
+ assert exception.provider_name == provider_name
51
+ assert exception.error_type == error_type
52
+ assert exception.user_friendly_title == expected_title
53
+ assert expected_hint_contains in exception.user_friendly_hint
54
+ assert str(exception) == expected_title
55
+ assert exception.args[0] == expected_title
56
+
57
+ def test_default_initialization(self):
58
+ """Test exception initialization with default values."""
59
+ error_msg = "Something went wrong"
60
+ exception = ProviderCompletionException(error_msg)
61
+
62
+ assert exception.error_message == error_msg
63
+ assert exception.provider_name == "LLM Provider"
64
+ assert exception.error_type == "LLMProviderError"
65
+ assert exception.user_friendly_title == "LLM Provider Error: Something went wrong"
66
+ assert "LLM Provider" in exception.user_friendly_hint
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ import pytest
5
+ from mito_ai.completions.providers import OpenAIProvider
6
+ from mito_ai.tests.providers.utils import mock_openai_client, patch_server_limits
7
+ from mito_ai.utils.server_limits import OS_MONTHLY_AI_COMPLETIONS_LIMIT
8
+ from traitlets.config import Config
9
+
10
+ FAKE_API_KEY = "sk-1234567890"
11
+
12
+ @pytest.fixture
13
+ def provider_config() -> Config:
14
+ """Create a proper Config object for the OpenAIProvider."""
15
+ config = Config()
16
+ config.OpenAIProvider = Config()
17
+ config.OpenAIClient = Config()
18
+ return config
19
+
20
+ @pytest.mark.parametrize("is_pro,completion_count", [
21
+ (False, 1), # OS user below limit
22
+ (False, OS_MONTHLY_AI_COMPLETIONS_LIMIT + 1), # OS user above limit
23
+ (True, 1), # Pro user below limit
24
+ (True, OS_MONTHLY_AI_COMPLETIONS_LIMIT + 1), # Pro user above limit
25
+ ])
26
+ def test_openai_provider_with_limits(
27
+ is_pro: bool,
28
+ completion_count: int,
29
+ monkeypatch: pytest.MonkeyPatch,
30
+ provider_config: Config) -> None:
31
+ """Test OpenAI provider behavior with different user types and usage limits."""
32
+ monkeypatch.setenv("OPENAI_API_KEY", FAKE_API_KEY)
33
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", FAKE_API_KEY)
34
+
35
+ with (
36
+ patch_server_limits(is_pro=is_pro, completion_count=completion_count),
37
+ mock_openai_client()
38
+ ):
39
+ llm = OpenAIProvider(config=provider_config)
40
+ capabilities = llm.capabilities
41
+ assert "user key" in capabilities.provider
42
+ assert llm.last_error is None
@@ -0,0 +1,382 @@
1
+ # Copyright (c) Saga Inc.
2
+ # Distributed under the terms of the GNU Affero General Public License v3.0 License.
3
+
4
+ from __future__ import annotations
5
+ from datetime import datetime
6
+ from typing import Any, List, Optional
7
+ from unittest.mock import patch, MagicMock, AsyncMock
8
+
9
+ from mito_ai.tests.providers.utils import mock_azure_openai_client, mock_openai_client, patch_server_limits
10
+ import pytest
11
+ from traitlets.config import Config
12
+ from mito_ai.completions.providers import OpenAIProvider
13
+ from mito_ai.completions.models import (
14
+ MessageType,
15
+ AICapabilities,
16
+ CompletionReply
17
+ )
18
+ from mito_ai.utils.server_limits import OS_MONTHLY_AI_COMPLETIONS_LIMIT
19
+ from openai.types.chat import ChatCompletionMessageParam
20
+
21
+ REALLY_OLD_DATE = "2020-01-01"
22
+ TODAY = datetime.now().strftime("%Y-%m-%d")
23
+ FAKE_API_KEY = "sk-1234567890"
24
+
25
+ @pytest.fixture
26
+ def provider_config() -> Config:
27
+ """Create a proper Config object for the OpenAIProvider."""
28
+ config = Config()
29
+ config.OpenAIProvider = Config()
30
+ config.OpenAIClient = Config()
31
+ return config
32
+
33
+ @pytest.fixture(autouse=True)
34
+ def reset_env_vars(monkeypatch: pytest.MonkeyPatch) -> None:
35
+ for var in [
36
+ "OPENAI_API_KEY", "CLAUDE_API_KEY",
37
+ "GEMINI_API_KEY", "OLLAMA_MODEL",
38
+ "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_MODEL"
39
+ ]:
40
+ monkeypatch.delenv(var, raising=False)
41
+
42
+
43
+ # ====================
44
+ # TESTS
45
+ # ====================
46
+
47
+ @pytest.mark.parametrize("provider_config_data", [
48
+ {
49
+ "name": "openai",
50
+ "env_vars": {"OPENAI_API_KEY": FAKE_API_KEY},
51
+ "constants": {"OPENAI_API_KEY": FAKE_API_KEY},
52
+ "model": "gpt-4o-mini",
53
+ "mock_patch": "mito_ai.completions.providers.OpenAIClient",
54
+ "mock_method": "request_completions",
55
+ "provider_name": "OpenAI with user key",
56
+ "key_type": "user"
57
+ },
58
+ {
59
+ "name": "claude",
60
+ "env_vars": {"CLAUDE_API_KEY": "claude-key"},
61
+ "constants": {"CLAUDE_API_KEY": "claude-key", "OPENAI_API_KEY": None},
62
+ "model": "claude-3-opus-20240229",
63
+ "mock_patch": "mito_ai.completions.providers.AnthropicClient",
64
+ "mock_method": "request_completions",
65
+ "provider_name": "Claude",
66
+ "key_type": "claude"
67
+ },
68
+ {
69
+ "name": "gemini",
70
+ "env_vars": {"GEMINI_API_KEY": "gemini-key"},
71
+ "constants": {"GEMINI_API_KEY": "gemini-key", "OPENAI_API_KEY": None},
72
+ "model": "gemini-2.0-flash",
73
+ "mock_patch": "mito_ai.completions.providers.GeminiClient",
74
+ "mock_method": "request_completions",
75
+ "provider_name": "Gemini",
76
+ "key_type": "gemini"
77
+ },
78
+ {
79
+ "name": "azure",
80
+ "env_vars": {"AZURE_OPENAI_API_KEY": "azure-key"},
81
+ "constants": {"AZURE_OPENAI_API_KEY": "azure-key", "OPENAI_API_KEY": None},
82
+ "model": "gpt-4o",
83
+ "mock_patch": "mito_ai.completions.providers.OpenAIClient",
84
+ "mock_method": "request_completions",
85
+ "provider_name": "Azure OpenAI",
86
+ "key_type": "azure"
87
+ }
88
+ ])
89
+ @pytest.mark.asyncio
90
+ async def test_completion_request(
91
+ provider_config_data: dict,
92
+ monkeypatch: pytest.MonkeyPatch,
93
+ provider_config: Config
94
+ ) -> None:
95
+ """Test completion requests for different providers."""
96
+ # Set up environment variables
97
+ for env_var, value in provider_config_data["env_vars"].items():
98
+ monkeypatch.setenv(env_var, value)
99
+
100
+ # Set up constants
101
+ for constant, value in provider_config_data["constants"].items():
102
+ monkeypatch.setattr(f"mito_ai.constants.{constant}", value)
103
+
104
+ # Create mock client
105
+ mock_client = MagicMock()
106
+ mock_client.capabilities = AICapabilities(
107
+ configuration={"model": provider_config_data["model"]},
108
+ provider=provider_config_data["provider_name"],
109
+ type="ai_capabilities"
110
+ )
111
+ mock_client.key_type = provider_config_data["key_type"]
112
+ mock_client.request_completions = AsyncMock(return_value="Test completion")
113
+ mock_client.stream_completions = AsyncMock(return_value="Test completion")
114
+
115
+ with patch(provider_config_data["mock_patch"], return_value=mock_client):
116
+ llm = OpenAIProvider(config=provider_config)
117
+ messages: List[ChatCompletionMessageParam] = [
118
+ {"role": "user", "content": "Test message"}
119
+ ]
120
+
121
+ completion = await llm.request_completions(
122
+ message_type=MessageType.CHAT,
123
+ messages=messages,
124
+ model=provider_config_data["model"]
125
+ )
126
+
127
+ assert completion == "Test completion"
128
+ getattr(mock_client, provider_config_data["mock_method"]).assert_called_once()
129
+
130
+
131
+ @pytest.mark.parametrize("provider_config_data", [
132
+ {
133
+ "name": "openai",
134
+ "env_vars": {"OPENAI_API_KEY": FAKE_API_KEY},
135
+ "constants": {"OPENAI_API_KEY": FAKE_API_KEY},
136
+ "model": "gpt-4o-mini",
137
+ "mock_patch": "mito_ai.completions.providers.OpenAIClient",
138
+ "mock_method": "stream_completions",
139
+ "provider_name": "OpenAI with user key",
140
+ "key_type": "user"
141
+ },
142
+ {
143
+ "name": "claude",
144
+ "env_vars": {"CLAUDE_API_KEY": "claude-key"},
145
+ "constants": {"CLAUDE_API_KEY": "claude-key", "OPENAI_API_KEY": None},
146
+ "model": "claude-3-opus-20240229",
147
+ "mock_patch": "mito_ai.completions.providers.AnthropicClient",
148
+ "mock_method": "stream_completions",
149
+ "provider_name": "Claude",
150
+ "key_type": "claude"
151
+ },
152
+ {
153
+ "name": "gemini",
154
+ "env_vars": {"GEMINI_API_KEY": "gemini-key"},
155
+ "constants": {"GEMINI_API_KEY": "gemini-key", "OPENAI_API_KEY": None},
156
+ "model": "gemini-2.0-flash",
157
+ "mock_patch": "mito_ai.completions.providers.GeminiClient",
158
+ "mock_method": "stream_completions",
159
+ "provider_name": "Gemini",
160
+ "key_type": "gemini"
161
+ },
162
+ ])
163
+ @pytest.mark.asyncio
164
+ async def test_stream_completion_parameterized(
165
+ provider_config_data: dict,
166
+ monkeypatch: pytest.MonkeyPatch,
167
+ provider_config: Config
168
+ ) -> None:
169
+ """Test stream completions for different providers."""
170
+ # Set up environment variables
171
+ for env_var, value in provider_config_data["env_vars"].items():
172
+ monkeypatch.setenv(env_var, value)
173
+
174
+ # Set up constants
175
+ for constant, value in provider_config_data["constants"].items():
176
+ monkeypatch.setattr(f"mito_ai.constants.{constant}", value)
177
+
178
+ # Create mock client
179
+ mock_client = MagicMock()
180
+ mock_client.capabilities = AICapabilities(
181
+ configuration={"model": provider_config_data["model"]},
182
+ provider=provider_config_data["provider_name"],
183
+ type="ai_capabilities"
184
+ )
185
+ mock_client.key_type = provider_config_data["key_type"]
186
+ mock_client.request_completions = AsyncMock(return_value="Test completion")
187
+ mock_client.stream_completions = AsyncMock(return_value="Test completion")
188
+ mock_client.stream_response = AsyncMock(return_value="Test completion") # For Claude
189
+
190
+ with patch(provider_config_data["mock_patch"], return_value=mock_client):
191
+ llm = OpenAIProvider(config=provider_config)
192
+ messages: List[ChatCompletionMessageParam] = [
193
+ {"role": "user", "content": "Test message"}
194
+ ]
195
+
196
+ reply_chunks = []
197
+ def mock_reply(chunk):
198
+ reply_chunks.append(chunk)
199
+
200
+ completion = await llm.stream_completions(
201
+ message_type=MessageType.CHAT,
202
+ messages=messages,
203
+ model=provider_config_data["model"],
204
+ message_id="test-id",
205
+ thread_id="test-thread",
206
+ reply_fn=mock_reply
207
+ )
208
+
209
+ assert completion == "Test completion"
210
+ getattr(mock_client, provider_config_data["mock_method"]).assert_called_once()
211
+ assert len(reply_chunks) > 0
212
+ assert isinstance(reply_chunks[0], CompletionReply)
213
+
214
+
215
+ def test_error_handling(monkeypatch: pytest.MonkeyPatch, provider_config: Config) -> None:
216
+ monkeypatch.setenv("OPENAI_API_KEY", "invalid-key")
217
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", "invalid-key")
218
+ mock_client = MagicMock()
219
+ mock_client.capabilities = AICapabilities(
220
+ configuration={"model": "gpt-4o-mini"},
221
+ provider="OpenAI with user key",
222
+ type="ai_capabilities"
223
+ )
224
+ mock_client.key_type = "user"
225
+ mock_client.request_completions.side_effect = Exception("API error")
226
+
227
+ with patch("mito_ai.completions.providers.OpenAIClient", return_value=mock_client):
228
+ llm = OpenAIProvider(config=provider_config)
229
+ assert llm.last_error is None # Error should be None until a request is made
230
+
231
+ def test_claude_error_handling(monkeypatch: pytest.MonkeyPatch, provider_config: Config) -> None:
232
+ monkeypatch.setenv("CLAUDE_API_KEY", "invalid-key")
233
+ monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", "invalid-key")
234
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
235
+
236
+ mock_client = MagicMock()
237
+ mock_client.capabilities = AICapabilities(
238
+ configuration={"model": "claude-3-opus-20240229"},
239
+ provider="Claude",
240
+ type="ai_capabilities"
241
+ )
242
+ mock_client.key_type = "claude"
243
+ mock_client.request_completions.side_effect = Exception("API error")
244
+
245
+ with patch("mito_ai.completions.providers.AnthropicClient", return_value=mock_client):
246
+ llm = OpenAIProvider(config=provider_config)
247
+ assert llm.last_error is None # Error should be None until a request is made
248
+
249
+
250
+ # Mito Server Fallback Tests
251
+ @pytest.mark.parametrize("mito_server_config", [
252
+ {
253
+ "name": "openai_fallback",
254
+ "model": "gpt-4o-mini",
255
+ "mock_function": "mito_ai.openai_client.get_ai_completion_from_mito_server",
256
+ "provider_name": "Mito server",
257
+ "key_type": "mito_server"
258
+ },
259
+ {
260
+ "name": "claude_fallback",
261
+ "model": "claude-3-opus-20240229",
262
+ "mock_function": "mito_ai.anthropic_client.get_anthropic_completion_from_mito_server",
263
+ "provider_name": "Claude",
264
+ "key_type": "claude"
265
+ },
266
+ {
267
+ "name": "gemini_fallback",
268
+ "model": "gemini-2.0-flash",
269
+ "mock_function": "mito_ai.gemini_client.get_gemini_completion_from_mito_server",
270
+ "provider_name": "Gemini",
271
+ "key_type": "gemini"
272
+ },
273
+ ])
274
+ @pytest.mark.asyncio
275
+ async def test_mito_server_fallback_completion_request(
276
+ mito_server_config: dict,
277
+ monkeypatch: pytest.MonkeyPatch,
278
+ provider_config: Config
279
+ ) -> None:
280
+ """Test that completion requests fallback to Mito server when no API keys are set."""
281
+ # Clear all API keys to force Mito server fallback
282
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
283
+ monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", None)
284
+ monkeypatch.setattr("mito_ai.constants.GEMINI_API_KEY", None)
285
+ monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
286
+ provider_config.OpenAIProvider.api_key = None
287
+
288
+ # Mock the appropriate Mito server function
289
+ with patch(mito_server_config["mock_function"], new_callable=AsyncMock) as mock_mito_function:
290
+ mock_mito_function.return_value = "Mito server response"
291
+
292
+ messages: List[ChatCompletionMessageParam] = [
293
+ {"role": "user", "content": "Test message"}
294
+ ]
295
+
296
+ with patch_server_limits():
297
+ llm = OpenAIProvider(config=provider_config)
298
+
299
+ completion = await llm.request_completions(
300
+ message_type=MessageType.CHAT,
301
+ messages=messages,
302
+ model=mito_server_config["model"]
303
+ )
304
+
305
+ assert completion == "Mito server response"
306
+ mock_mito_function.assert_called_once()
307
+
308
+
309
+ @pytest.mark.parametrize("mito_server_config", [
310
+ {
311
+ "name": "openai_fallback",
312
+ "model": "gpt-4o-mini",
313
+ "mock_function": "mito_ai.openai_client.stream_ai_completion_from_mito_server",
314
+ "provider_name": "Mito server",
315
+ "key_type": "mito_server"
316
+ },
317
+ {
318
+ "name": "claude_fallback",
319
+ "model": "claude-3-opus-20240229",
320
+ "mock_function": "mito_ai.anthropic_client.stream_anthropic_completion_from_mito_server",
321
+ "provider_name": "Claude",
322
+ "key_type": "claude"
323
+ },
324
+ {
325
+ "name": "gemini_fallback",
326
+ "model": "gemini-2.0-flash",
327
+ "mock_function": "mito_ai.gemini_client.stream_gemini_completion_from_mito_server",
328
+ "provider_name": "Gemini",
329
+ "key_type": "gemini"
330
+ },
331
+ ])
332
+ @pytest.mark.asyncio
333
+ async def test_mito_server_fallback_stream_completion(
334
+ mito_server_config: dict,
335
+ monkeypatch: pytest.MonkeyPatch,
336
+ provider_config: Config
337
+ ) -> None:
338
+ """Test that stream completions fallback to Mito server when no API keys are set."""
339
+ # Clear all API keys to force Mito server fallback
340
+ monkeypatch.setattr("mito_ai.constants.OPENAI_API_KEY", None)
341
+ monkeypatch.setattr("mito_ai.constants.CLAUDE_API_KEY", None)
342
+ monkeypatch.setattr("mito_ai.constants.GEMINI_API_KEY", None)
343
+ monkeypatch.setattr("mito_ai.enterprise.utils.is_azure_openai_configured", lambda: False)
344
+ provider_config.OpenAIProvider.api_key = None
345
+
346
+ # Create an async generator that yields chunks for streaming
347
+ async def mock_stream_generator():
348
+ yield "Chunk 1"
349
+ yield "Chunk 2"
350
+ yield "Chunk 3"
351
+
352
+ # Mock the appropriate Mito server streaming function
353
+ with patch(mito_server_config["mock_function"]) as mock_mito_stream:
354
+ mock_mito_stream.return_value = mock_stream_generator()
355
+
356
+ messages: List[ChatCompletionMessageParam] = [
357
+ {"role": "user", "content": "Test message"}
358
+ ]
359
+
360
+ reply_chunks = []
361
+ def mock_reply(chunk):
362
+ reply_chunks.append(chunk)
363
+
364
+ # Apply patch_server_limits for all cases, not just openai_fallback
365
+ # Also patch update_mito_server_quota where it's actually used in openai_client
366
+ with patch_server_limits(), patch("mito_ai.openai_client.update_mito_server_quota", MagicMock(return_value=None)):
367
+ llm = OpenAIProvider(config=provider_config)
368
+
369
+ completion = await llm.stream_completions(
370
+ message_type=MessageType.CHAT,
371
+ messages=messages,
372
+ model=mito_server_config["model"],
373
+ message_id="test-id",
374
+ thread_id="test-thread",
375
+ reply_fn=mock_reply
376
+ )
377
+
378
+ # Verify that the Mito server function was called
379
+ mock_mito_stream.assert_called_once()
380
+ # Verify that reply chunks were generated
381
+ assert len(reply_chunks) > 0
382
+ assert isinstance(reply_chunks[0], CompletionReply)