terminal-sherpa 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ask/__init__.py +0 -0
- ask/config.py +94 -0
- ask/exceptions.py +25 -0
- ask/main.py +115 -0
- ask/providers/__init__.py +35 -0
- ask/providers/anthropic.py +73 -0
- ask/providers/base.py +28 -0
- ask/providers/openai.py +106 -0
- terminal_sherpa-0.1.0.dist-info/METADATA +242 -0
- terminal_sherpa-0.1.0.dist-info/RECORD +20 -0
- terminal_sherpa-0.1.0.dist-info/WHEEL +5 -0
- terminal_sherpa-0.1.0.dist-info/entry_points.txt +2 -0
- terminal_sherpa-0.1.0.dist-info/top_level.txt +2 -0
- test/conftest.py +58 -0
- test/test_anthropic.py +173 -0
- test/test_config.py +164 -0
- test/test_exceptions.py +55 -0
- test/test_main.py +206 -0
- test/test_openai.py +269 -0
- test/test_providers.py +77 -0
test/test_exceptions.py
ADDED
@@ -0,0 +1,55 @@
|
|
1
|
+
"""Tests for custom exception classes."""
|
2
|
+
|
3
|
+
from ask.exceptions import (
|
4
|
+
APIError,
|
5
|
+
AuthenticationError,
|
6
|
+
ConfigurationError,
|
7
|
+
RateLimitError,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
def test_configuration_error():
|
12
|
+
"""Test ConfigurationError creation."""
|
13
|
+
error = ConfigurationError("Config problem")
|
14
|
+
assert str(error) == "Config problem"
|
15
|
+
assert isinstance(error, Exception)
|
16
|
+
|
17
|
+
|
18
|
+
def test_authentication_error():
|
19
|
+
"""Test AuthenticationError creation."""
|
20
|
+
error = AuthenticationError("Auth failed")
|
21
|
+
assert str(error) == "Auth failed"
|
22
|
+
assert isinstance(error, Exception)
|
23
|
+
|
24
|
+
|
25
|
+
def test_api_error():
|
26
|
+
"""Test APIError creation."""
|
27
|
+
error = APIError("API failed")
|
28
|
+
assert str(error) == "API failed"
|
29
|
+
assert isinstance(error, Exception)
|
30
|
+
|
31
|
+
|
32
|
+
def test_rate_limit_error():
|
33
|
+
"""Test RateLimitError inheritance."""
|
34
|
+
error = RateLimitError("Rate limit exceeded")
|
35
|
+
assert str(error) == "Rate limit exceeded"
|
36
|
+
assert isinstance(error, APIError)
|
37
|
+
assert isinstance(error, Exception)
|
38
|
+
|
39
|
+
|
40
|
+
def test_exception_hierarchy():
|
41
|
+
"""Test exception inheritance hierarchy."""
|
42
|
+
# Test that RateLimitError is a subclass of APIError
|
43
|
+
assert issubclass(RateLimitError, APIError)
|
44
|
+
|
45
|
+
# Test that all exceptions are subclasses of Exception
|
46
|
+
assert issubclass(ConfigurationError, Exception)
|
47
|
+
assert issubclass(AuthenticationError, Exception)
|
48
|
+
assert issubclass(APIError, Exception)
|
49
|
+
assert issubclass(RateLimitError, Exception)
|
50
|
+
|
51
|
+
# Test that custom exceptions are not subclasses of each other
|
52
|
+
assert not issubclass(ConfigurationError, AuthenticationError)
|
53
|
+
assert not issubclass(AuthenticationError, ConfigurationError)
|
54
|
+
assert not issubclass(APIError, ConfigurationError)
|
55
|
+
assert not issubclass(APIError, AuthenticationError)
|
test/test_main.py
ADDED
@@ -0,0 +1,206 @@
|
|
1
|
+
"""Tests for the CLI interface."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
from unittest.mock import MagicMock, patch
|
5
|
+
|
6
|
+
import pytest
|
7
|
+
|
8
|
+
from ask.exceptions import APIError, AuthenticationError, ConfigurationError
|
9
|
+
from ask.main import (
|
10
|
+
configure_logging,
|
11
|
+
load_configuration,
|
12
|
+
main,
|
13
|
+
parse_arguments,
|
14
|
+
resolve_provider,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
def test_parse_arguments_basic():
|
19
|
+
"""Test basic argument parsing."""
|
20
|
+
with patch("sys.argv", ["ask", "list files"]):
|
21
|
+
args = parse_arguments()
|
22
|
+
assert args.prompt == "list files"
|
23
|
+
assert args.model is None
|
24
|
+
assert args.verbose is False
|
25
|
+
|
26
|
+
|
27
|
+
def test_parse_arguments_with_model():
|
28
|
+
"""Test --model argument."""
|
29
|
+
with patch("sys.argv", ["ask", "list files", "--model", "anthropic:sonnet"]):
|
30
|
+
args = parse_arguments()
|
31
|
+
assert args.prompt == "list files"
|
32
|
+
assert args.model == "anthropic:sonnet"
|
33
|
+
|
34
|
+
|
35
|
+
def test_parse_arguments_with_verbose():
|
36
|
+
"""Test --verbose flag."""
|
37
|
+
with patch("sys.argv", ["ask", "list files", "--verbose"]):
|
38
|
+
args = parse_arguments()
|
39
|
+
assert args.prompt == "list files"
|
40
|
+
assert args.verbose is True
|
41
|
+
|
42
|
+
|
43
|
+
def test_configure_logging_verbose():
|
44
|
+
"""Test verbose logging configuration."""
|
45
|
+
with patch("ask.main.logger") as mock_logger:
|
46
|
+
configure_logging(verbose=True)
|
47
|
+
mock_logger.remove.assert_called_once()
|
48
|
+
mock_logger.add.assert_called_once()
|
49
|
+
|
50
|
+
# Check that DEBUG level was set
|
51
|
+
call_args = mock_logger.add.call_args
|
52
|
+
assert "DEBUG" in str(call_args)
|
53
|
+
|
54
|
+
|
55
|
+
def test_configure_logging_normal():
|
56
|
+
"""Test normal logging configuration."""
|
57
|
+
with patch("ask.main.logger") as mock_logger:
|
58
|
+
configure_logging(verbose=False)
|
59
|
+
mock_logger.remove.assert_called_once()
|
60
|
+
mock_logger.add.assert_called_once()
|
61
|
+
|
62
|
+
# Check that ERROR level was set
|
63
|
+
call_args = mock_logger.add.call_args
|
64
|
+
assert "ERROR" in str(call_args)
|
65
|
+
|
66
|
+
|
67
|
+
def test_load_configuration_success():
|
68
|
+
"""Test successful config loading."""
|
69
|
+
mock_config = {"ask": {"default_model": "anthropic"}}
|
70
|
+
|
71
|
+
with patch("ask.config.load_config", return_value=mock_config):
|
72
|
+
config = load_configuration()
|
73
|
+
assert config == mock_config
|
74
|
+
|
75
|
+
|
76
|
+
def test_load_configuration_error():
|
77
|
+
"""Test configuration error handling."""
|
78
|
+
with patch(
|
79
|
+
"ask.config.load_config", side_effect=ConfigurationError("Config error")
|
80
|
+
):
|
81
|
+
with patch("ask.main.logger") as mock_logger:
|
82
|
+
with pytest.raises(SystemExit):
|
83
|
+
load_configuration()
|
84
|
+
mock_logger.error.assert_called_once_with(
|
85
|
+
"Configuration error: Config error"
|
86
|
+
)
|
87
|
+
|
88
|
+
|
89
|
+
def test_resolve_provider_with_model_arg():
|
90
|
+
"""Test provider resolution with --model."""
|
91
|
+
args = argparse.Namespace(model="anthropic:sonnet")
|
92
|
+
config_data = {}
|
93
|
+
mock_provider = MagicMock()
|
94
|
+
|
95
|
+
with patch("ask.config.get_provider_config", return_value=("anthropic", {})):
|
96
|
+
with patch("ask.providers.get_provider", return_value=mock_provider):
|
97
|
+
with patch("ask.main.logger"):
|
98
|
+
result = resolve_provider(args, config_data)
|
99
|
+
assert result == mock_provider
|
100
|
+
|
101
|
+
|
102
|
+
def test_resolve_provider_with_default_model():
|
103
|
+
"""Test default model from config."""
|
104
|
+
args = argparse.Namespace(model=None)
|
105
|
+
config_data = {"ask": {"default_model": "anthropic"}}
|
106
|
+
mock_provider = MagicMock()
|
107
|
+
|
108
|
+
with patch("ask.config.get_default_model", return_value="anthropic"):
|
109
|
+
with patch("ask.config.get_provider_config", return_value=("anthropic", {})):
|
110
|
+
with patch("ask.providers.get_provider", return_value=mock_provider):
|
111
|
+
with patch("ask.main.logger"):
|
112
|
+
result = resolve_provider(args, config_data)
|
113
|
+
assert result == mock_provider
|
114
|
+
|
115
|
+
|
116
|
+
def test_resolve_provider_with_env_fallback():
|
117
|
+
"""Test environment variable fallback."""
|
118
|
+
args = argparse.Namespace(model=None)
|
119
|
+
config_data = {}
|
120
|
+
mock_provider = MagicMock()
|
121
|
+
|
122
|
+
with patch("ask.config.get_default_model", return_value=None):
|
123
|
+
with patch("ask.config.get_default_provider", return_value="anthropic"):
|
124
|
+
with patch(
|
125
|
+
"ask.config.get_provider_config", return_value=("anthropic", {})
|
126
|
+
):
|
127
|
+
with patch("ask.providers.get_provider", return_value=mock_provider):
|
128
|
+
with patch("ask.main.logger"):
|
129
|
+
result = resolve_provider(args, config_data)
|
130
|
+
assert result == mock_provider
|
131
|
+
|
132
|
+
|
133
|
+
def test_resolve_provider_no_keys():
|
134
|
+
"""Test when no API keys available."""
|
135
|
+
args = argparse.Namespace(model=None)
|
136
|
+
config_data = {}
|
137
|
+
|
138
|
+
with patch("ask.config.get_default_model", return_value=None):
|
139
|
+
with patch("ask.config.get_default_provider", return_value=None):
|
140
|
+
with patch("ask.main.logger") as mock_logger:
|
141
|
+
with pytest.raises(SystemExit):
|
142
|
+
resolve_provider(args, config_data)
|
143
|
+
mock_logger.error.assert_called()
|
144
|
+
|
145
|
+
|
146
|
+
def test_main_success():
|
147
|
+
"""Test successful main function execution."""
|
148
|
+
mock_provider = MagicMock()
|
149
|
+
mock_provider.get_bash_command.return_value = "ls -la"
|
150
|
+
|
151
|
+
with patch("ask.main.parse_arguments") as mock_parse:
|
152
|
+
with patch("ask.main.configure_logging"):
|
153
|
+
with patch("ask.main.load_configuration", return_value={}):
|
154
|
+
with patch("ask.main.resolve_provider", return_value=mock_provider):
|
155
|
+
with patch("builtins.print") as mock_print:
|
156
|
+
mock_parse.return_value = argparse.Namespace(
|
157
|
+
prompt="list files", model=None, verbose=False
|
158
|
+
)
|
159
|
+
|
160
|
+
main()
|
161
|
+
|
162
|
+
mock_provider.validate_config.assert_called_once()
|
163
|
+
mock_provider.get_bash_command.assert_called_once_with(
|
164
|
+
"list files"
|
165
|
+
)
|
166
|
+
mock_print.assert_called_once_with("ls -la")
|
167
|
+
|
168
|
+
|
169
|
+
def test_main_authentication_error():
|
170
|
+
"""Test authentication error handling."""
|
171
|
+
mock_provider = MagicMock()
|
172
|
+
mock_provider.validate_config.side_effect = AuthenticationError("Invalid API key")
|
173
|
+
|
174
|
+
with patch("ask.main.parse_arguments") as mock_parse:
|
175
|
+
with patch("ask.main.configure_logging"):
|
176
|
+
with patch("ask.main.load_configuration", return_value={}):
|
177
|
+
with patch("ask.main.resolve_provider", return_value=mock_provider):
|
178
|
+
with patch("ask.main.logger") as mock_logger:
|
179
|
+
mock_parse.return_value = argparse.Namespace(
|
180
|
+
prompt="list files", model=None, verbose=False
|
181
|
+
)
|
182
|
+
|
183
|
+
with pytest.raises(SystemExit):
|
184
|
+
main()
|
185
|
+
|
186
|
+
mock_logger.error.assert_called_once_with("Invalid API key")
|
187
|
+
|
188
|
+
|
189
|
+
def test_main_api_error():
|
190
|
+
"""Test API error handling."""
|
191
|
+
mock_provider = MagicMock()
|
192
|
+
mock_provider.get_bash_command.side_effect = APIError("API request failed")
|
193
|
+
|
194
|
+
with patch("ask.main.parse_arguments") as mock_parse:
|
195
|
+
with patch("ask.main.configure_logging"):
|
196
|
+
with patch("ask.main.load_configuration", return_value={}):
|
197
|
+
with patch("ask.main.resolve_provider", return_value=mock_provider):
|
198
|
+
with patch("ask.main.logger") as mock_logger:
|
199
|
+
mock_parse.return_value = argparse.Namespace(
|
200
|
+
prompt="list files", model=None, verbose=False
|
201
|
+
)
|
202
|
+
|
203
|
+
with pytest.raises(SystemExit):
|
204
|
+
main()
|
205
|
+
|
206
|
+
mock_logger.error.assert_called_once_with("API request failed")
|
test/test_openai.py
ADDED
@@ -0,0 +1,269 @@
|
|
1
|
+
"""Tests for OpenAI provider."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
from unittest.mock import MagicMock, patch
|
6
|
+
|
7
|
+
import pytest
|
8
|
+
|
9
|
+
from ask.config import SYSTEM_PROMPT
|
10
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
11
|
+
from ask.providers.openai import OpenAIProvider
|
12
|
+
|
13
|
+
|
14
|
+
def test_openai_provider_init():
|
15
|
+
"""Test provider initialization."""
|
16
|
+
config = {"model_name": "gpt-4o-mini"}
|
17
|
+
provider = OpenAIProvider(config)
|
18
|
+
|
19
|
+
assert provider.config == config
|
20
|
+
assert provider.client is None
|
21
|
+
|
22
|
+
|
23
|
+
def test_validate_config_success(mock_openai_key):
|
24
|
+
"""Test successful config validation."""
|
25
|
+
config = {"api_key_env": "OPENAI_API_KEY"}
|
26
|
+
provider = OpenAIProvider(config)
|
27
|
+
|
28
|
+
with patch("openai.OpenAI") as mock_openai:
|
29
|
+
mock_client = MagicMock()
|
30
|
+
mock_openai.return_value = mock_client
|
31
|
+
|
32
|
+
provider.validate_config()
|
33
|
+
|
34
|
+
assert provider.client == mock_client
|
35
|
+
mock_openai.assert_called_once_with(api_key="test-openai-key")
|
36
|
+
|
37
|
+
|
38
|
+
def test_validate_config_missing_key(mock_env_vars):
|
39
|
+
"""Test missing API key error."""
|
40
|
+
config = {"api_key_env": "OPENAI_API_KEY"}
|
41
|
+
provider = OpenAIProvider(config)
|
42
|
+
|
43
|
+
with pytest.raises(
|
44
|
+
AuthenticationError, match="OPENAI_API_KEY environment variable is required"
|
45
|
+
):
|
46
|
+
provider.validate_config()
|
47
|
+
|
48
|
+
|
49
|
+
def test_validate_config_custom_env():
|
50
|
+
"""Test custom environment variable."""
|
51
|
+
config = {"api_key_env": "CUSTOM_OPENAI_KEY"}
|
52
|
+
provider = OpenAIProvider(config)
|
53
|
+
|
54
|
+
with patch.dict(os.environ, {"CUSTOM_OPENAI_KEY": "custom-key"}):
|
55
|
+
with patch("openai.OpenAI") as mock_openai:
|
56
|
+
mock_client = MagicMock()
|
57
|
+
mock_openai.return_value = mock_client
|
58
|
+
|
59
|
+
provider.validate_config()
|
60
|
+
|
61
|
+
mock_openai.assert_called_once_with(api_key="custom-key")
|
62
|
+
|
63
|
+
|
64
|
+
def test_get_default_config():
|
65
|
+
"""Test default configuration values."""
|
66
|
+
default_config = OpenAIProvider.get_default_config()
|
67
|
+
|
68
|
+
assert default_config["model_name"] == "gpt-4o-mini"
|
69
|
+
assert default_config["max_tokens"] == 150
|
70
|
+
assert default_config["api_key_env"] == "OPENAI_API_KEY"
|
71
|
+
assert default_config["temperature"] == 0.0
|
72
|
+
assert default_config["system_prompt"] == SYSTEM_PROMPT
|
73
|
+
|
74
|
+
|
75
|
+
def test_get_bash_command_success(mock_openai_key):
|
76
|
+
"""Test successful command generation."""
|
77
|
+
config = {"model_name": "gpt-4o-mini", "max_tokens": 150}
|
78
|
+
provider = OpenAIProvider(config)
|
79
|
+
|
80
|
+
mock_response = MagicMock()
|
81
|
+
mock_response.choices = [MagicMock()]
|
82
|
+
mock_response.choices[0].message.content = "ls -la"
|
83
|
+
|
84
|
+
with patch("openai.OpenAI") as mock_openai:
|
85
|
+
mock_client = MagicMock()
|
86
|
+
mock_client.chat.completions.create.return_value = mock_response
|
87
|
+
mock_openai.return_value = mock_client
|
88
|
+
|
89
|
+
result = provider.get_bash_command("list files")
|
90
|
+
|
91
|
+
assert result == "ls -la"
|
92
|
+
mock_client.chat.completions.create.assert_called_once_with(
|
93
|
+
model="gpt-4o-mini",
|
94
|
+
max_completion_tokens=150,
|
95
|
+
temperature=0.0,
|
96
|
+
messages=[
|
97
|
+
{"role": "system", "content": SYSTEM_PROMPT},
|
98
|
+
{"role": "user", "content": "list files"},
|
99
|
+
],
|
100
|
+
)
|
101
|
+
|
102
|
+
|
103
|
+
def test_get_bash_command_with_code_block(mock_openai_key):
|
104
|
+
"""Test bash code block extraction."""
|
105
|
+
config = {}
|
106
|
+
provider = OpenAIProvider(config)
|
107
|
+
|
108
|
+
mock_response = MagicMock()
|
109
|
+
mock_response.choices = [MagicMock()]
|
110
|
+
mock_response.choices[0].message.content = "```bash\nls -la\n```"
|
111
|
+
|
112
|
+
with patch("openai.OpenAI") as mock_openai:
|
113
|
+
mock_client = MagicMock()
|
114
|
+
mock_client.chat.completions.create.return_value = mock_response
|
115
|
+
mock_openai.return_value = mock_client
|
116
|
+
|
117
|
+
result = provider.get_bash_command("list files")
|
118
|
+
|
119
|
+
assert result == "ls -la"
|
120
|
+
|
121
|
+
|
122
|
+
def test_get_bash_command_without_code_block(mock_openai_key):
|
123
|
+
"""Test plain text response."""
|
124
|
+
config = {}
|
125
|
+
provider = OpenAIProvider(config)
|
126
|
+
|
127
|
+
mock_response = MagicMock()
|
128
|
+
mock_response.choices = [MagicMock()]
|
129
|
+
mock_response.choices[0].message.content = "ls -la"
|
130
|
+
|
131
|
+
with patch("openai.OpenAI") as mock_openai:
|
132
|
+
mock_client = MagicMock()
|
133
|
+
mock_client.chat.completions.create.return_value = mock_response
|
134
|
+
mock_openai.return_value = mock_client
|
135
|
+
|
136
|
+
result = provider.get_bash_command("list files")
|
137
|
+
|
138
|
+
assert result == "ls -la"
|
139
|
+
|
140
|
+
|
141
|
+
def test_get_bash_command_empty_response(mock_openai_key):
|
142
|
+
"""Test empty API response handling."""
|
143
|
+
config = {}
|
144
|
+
provider = OpenAIProvider(config)
|
145
|
+
|
146
|
+
mock_response = MagicMock()
|
147
|
+
mock_response.choices = [MagicMock()]
|
148
|
+
mock_response.choices[0].message.content = None
|
149
|
+
|
150
|
+
with patch("openai.OpenAI") as mock_openai:
|
151
|
+
mock_client = MagicMock()
|
152
|
+
mock_client.chat.completions.create.return_value = mock_response
|
153
|
+
mock_openai.return_value = mock_client
|
154
|
+
|
155
|
+
with pytest.raises(APIError, match="API returned empty response"):
|
156
|
+
provider.get_bash_command("list files")
|
157
|
+
|
158
|
+
|
159
|
+
def test_get_bash_command_auto_validate(mock_openai_key):
|
160
|
+
"""Test auto-validation behavior."""
|
161
|
+
config = {}
|
162
|
+
provider = OpenAIProvider(config)
|
163
|
+
|
164
|
+
mock_response = MagicMock()
|
165
|
+
mock_response.choices = [MagicMock()]
|
166
|
+
mock_response.choices[0].message.content = "ls -la"
|
167
|
+
|
168
|
+
with patch("openai.OpenAI") as mock_openai:
|
169
|
+
mock_client = MagicMock()
|
170
|
+
mock_client.chat.completions.create.return_value = mock_response
|
171
|
+
mock_openai.return_value = mock_client
|
172
|
+
|
173
|
+
# Client should be None initially
|
174
|
+
assert provider.client is None
|
175
|
+
|
176
|
+
result = provider.get_bash_command("list files")
|
177
|
+
|
178
|
+
# Client should be set after auto-validation
|
179
|
+
assert provider.client is not None
|
180
|
+
assert result == "ls -la"
|
181
|
+
|
182
|
+
|
183
|
+
def test_handle_api_error_auth():
|
184
|
+
"""Test authentication error mapping."""
|
185
|
+
provider = OpenAIProvider({})
|
186
|
+
|
187
|
+
with pytest.raises(AuthenticationError, match="Invalid API key"):
|
188
|
+
provider._handle_api_error(Exception("authentication failed"))
|
189
|
+
|
190
|
+
|
191
|
+
def test_handle_api_error_rate_limit():
|
192
|
+
"""Test rate limit error mapping."""
|
193
|
+
provider = OpenAIProvider({})
|
194
|
+
|
195
|
+
with pytest.raises(RateLimitError, match="API rate limit exceeded"):
|
196
|
+
provider._handle_api_error(Exception("rate limit exceeded"))
|
197
|
+
|
198
|
+
|
199
|
+
def test_handle_api_error_quota():
|
200
|
+
"""Test quota error mapping."""
|
201
|
+
provider = OpenAIProvider({})
|
202
|
+
|
203
|
+
with pytest.raises(RateLimitError, match="API rate limit exceeded"):
|
204
|
+
provider._handle_api_error(Exception("quota exceeded"))
|
205
|
+
|
206
|
+
|
207
|
+
def test_handle_api_error_generic():
|
208
|
+
"""Test generic API error mapping."""
|
209
|
+
provider = OpenAIProvider({})
|
210
|
+
|
211
|
+
with pytest.raises(APIError, match="API request failed"):
|
212
|
+
provider._handle_api_error(Exception("unexpected error"))
|
213
|
+
|
214
|
+
|
215
|
+
def test_config_parameter_usage(mock_openai_key):
|
216
|
+
"""Test configuration parameter usage."""
|
217
|
+
config = {
|
218
|
+
"model_name": "gpt-4o",
|
219
|
+
"max_tokens": 1024,
|
220
|
+
"temperature": 0.5,
|
221
|
+
"system_prompt": "Custom system prompt",
|
222
|
+
}
|
223
|
+
provider = OpenAIProvider(config)
|
224
|
+
|
225
|
+
mock_response = MagicMock()
|
226
|
+
mock_response.choices = [MagicMock()]
|
227
|
+
mock_response.choices[0].message.content = "custom response"
|
228
|
+
|
229
|
+
with patch("openai.OpenAI") as mock_openai:
|
230
|
+
mock_client = MagicMock()
|
231
|
+
mock_client.chat.completions.create.return_value = mock_response
|
232
|
+
mock_openai.return_value = mock_client
|
233
|
+
|
234
|
+
result = provider.get_bash_command("test prompt")
|
235
|
+
|
236
|
+
assert result == "custom response"
|
237
|
+
mock_client.chat.completions.create.assert_called_once_with(
|
238
|
+
model="gpt-4o",
|
239
|
+
max_completion_tokens=1024,
|
240
|
+
temperature=0.5,
|
241
|
+
messages=[
|
242
|
+
{"role": "system", "content": "Custom system prompt"},
|
243
|
+
{"role": "user", "content": "test prompt"},
|
244
|
+
],
|
245
|
+
)
|
246
|
+
|
247
|
+
|
248
|
+
def test_regex_bash_extraction():
|
249
|
+
"""Test regex pattern for bash code extraction."""
|
250
|
+
_ = OpenAIProvider({})
|
251
|
+
|
252
|
+
# Test various bash code block formats
|
253
|
+
test_cases = [
|
254
|
+
("```bash\nls -la\n```", "ls -la"),
|
255
|
+
("```bash\nfind . -name '*.py'\n```", "find . -name '*.py'"),
|
256
|
+
("Here is the command:\n```bash\necho 'hello'\n```", "echo 'hello'"),
|
257
|
+
("plain text command", "plain text command"),
|
258
|
+
("```python\nprint('hello')\n```", "```python\nprint('hello')\n```"),
|
259
|
+
]
|
260
|
+
|
261
|
+
for input_text, expected in test_cases:
|
262
|
+
# Test the regex pattern used in the provider
|
263
|
+
re_match = re.search(r"```bash\n(.*)\n```", input_text)
|
264
|
+
if re_match:
|
265
|
+
result = re_match.group(1)
|
266
|
+
else:
|
267
|
+
result = input_text
|
268
|
+
|
269
|
+
assert result == expected
|
test/test_providers.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
"""Tests for the provider registry."""
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
|
5
|
+
from ask.exceptions import ConfigurationError
|
6
|
+
from ask.providers import get_provider, list_providers, register_provider
|
7
|
+
from ask.providers.base import ProviderInterface
|
8
|
+
|
9
|
+
|
10
|
+
class MockProvider(ProviderInterface):
|
11
|
+
"""Mock provider for testing."""
|
12
|
+
|
13
|
+
def get_bash_command(self, prompt: str) -> str:
|
14
|
+
return f"mock command for: {prompt}"
|
15
|
+
|
16
|
+
def validate_config(self) -> None:
|
17
|
+
pass
|
18
|
+
|
19
|
+
@classmethod
|
20
|
+
def get_default_config(cls) -> dict:
|
21
|
+
return {"mock": "config"}
|
22
|
+
|
23
|
+
|
24
|
+
def test_register_provider():
|
25
|
+
"""Test provider registration."""
|
26
|
+
# Clean up any existing registration
|
27
|
+
from ask.providers import _PROVIDER_REGISTRY
|
28
|
+
|
29
|
+
if "test_provider" in _PROVIDER_REGISTRY:
|
30
|
+
del _PROVIDER_REGISTRY["test_provider"]
|
31
|
+
|
32
|
+
register_provider("test_provider", MockProvider)
|
33
|
+
|
34
|
+
assert "test_provider" in _PROVIDER_REGISTRY
|
35
|
+
assert _PROVIDER_REGISTRY["test_provider"] == MockProvider
|
36
|
+
|
37
|
+
|
38
|
+
def test_get_provider_success():
|
39
|
+
"""Test successful provider retrieval."""
|
40
|
+
register_provider("test_provider", MockProvider)
|
41
|
+
|
42
|
+
provider = get_provider("test_provider", {"test": "config"})
|
43
|
+
|
44
|
+
assert isinstance(provider, MockProvider)
|
45
|
+
assert provider.config == {"test": "config"}
|
46
|
+
|
47
|
+
|
48
|
+
def test_get_provider_not_found():
|
49
|
+
"""Test unknown provider error."""
|
50
|
+
with pytest.raises(
|
51
|
+
ConfigurationError, match="Provider 'unknown_provider' not found"
|
52
|
+
):
|
53
|
+
get_provider("unknown_provider", {})
|
54
|
+
|
55
|
+
|
56
|
+
def test_list_providers():
|
57
|
+
"""Test provider listing."""
|
58
|
+
register_provider("test_provider", MockProvider)
|
59
|
+
|
60
|
+
providers = list_providers()
|
61
|
+
|
62
|
+
assert isinstance(providers, list)
|
63
|
+
assert "anthropic" in providers
|
64
|
+
assert "openai" in providers
|
65
|
+
|
66
|
+
|
67
|
+
def test_provider_registry_isolation():
|
68
|
+
"""Test registry isolation between tests."""
|
69
|
+
from ask.providers import _PROVIDER_REGISTRY
|
70
|
+
|
71
|
+
# Register a temporary provider
|
72
|
+
register_provider("temp_provider", MockProvider)
|
73
|
+
assert "temp_provider" in _PROVIDER_REGISTRY
|
74
|
+
|
75
|
+
# Clean up
|
76
|
+
del _PROVIDER_REGISTRY["temp_provider"]
|
77
|
+
assert "temp_provider" not in _PROVIDER_REGISTRY
|