terminal-sherpa 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ask/providers/__init__.py +2 -0
- ask/providers/grok.py +118 -0
- {terminal_sherpa-0.2.0.dist-info → terminal_sherpa-0.4.0.dist-info}/METADATA +50 -45
- {terminal_sherpa-0.2.0.dist-info → terminal_sherpa-0.4.0.dist-info}/RECORD +10 -8
- test/conftest.py +7 -0
- test/test_grok.py +247 -0
- {terminal_sherpa-0.2.0.dist-info → terminal_sherpa-0.4.0.dist-info}/WHEEL +0 -0
- {terminal_sherpa-0.2.0.dist-info → terminal_sherpa-0.4.0.dist-info}/entry_points.txt +0 -0
- {terminal_sherpa-0.2.0.dist-info → terminal_sherpa-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {terminal_sherpa-0.2.0.dist-info → terminal_sherpa-0.4.0.dist-info}/top_level.txt +0 -0
ask/providers/__init__.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from .anthropic import AnthropicProvider
|
4
4
|
from .base import ProviderInterface
|
5
5
|
from .gemini import GeminiProvider
|
6
|
+
from .grok import GrokProvider
|
6
7
|
from .openai import OpenAIProvider
|
7
8
|
|
8
9
|
# Provider registry - maps provider names to their classes
|
@@ -35,3 +36,4 @@ def list_providers() -> list[str]:
|
|
35
36
|
register_provider("anthropic", AnthropicProvider)
|
36
37
|
register_provider("openai", OpenAIProvider)
|
37
38
|
register_provider("gemini", GeminiProvider)
|
39
|
+
register_provider("grok", GrokProvider)
|
ask/providers/grok.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
"""Grok provider implementation using official xAI SDK."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
from typing import Any, NoReturn
|
6
|
+
|
7
|
+
from xai_sdk import Client
|
8
|
+
from xai_sdk.chat import system, user
|
9
|
+
|
10
|
+
from ask.config import SYSTEM_PROMPT
|
11
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
12
|
+
from ask.providers.base import ProviderInterface
|
13
|
+
|
14
|
+
|
15
|
+
class GrokProvider(ProviderInterface):
|
16
|
+
"""Grok provider implementation using official xAI SDK."""
|
17
|
+
|
18
|
+
def __init__(self, config: dict[str, Any]):
|
19
|
+
"""Initialize Grok provider with configuration.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: The configuration for the Grok provider
|
23
|
+
"""
|
24
|
+
super().__init__(config)
|
25
|
+
self.client: Client | None = None
|
26
|
+
|
27
|
+
def get_bash_command(self, prompt: str) -> str:
|
28
|
+
"""Generate bash command from natural language prompt.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
prompt: The natural language prompt to generate a bash command for
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The generated bash command
|
35
|
+
"""
|
36
|
+
if self.client is None:
|
37
|
+
self.validate_config()
|
38
|
+
|
39
|
+
# After validate_config(), client should be set
|
40
|
+
assert self.client is not None, "Client should be initialized after validation"
|
41
|
+
|
42
|
+
try:
|
43
|
+
model_name = self.config.get("model_name", "grok-3-fast")
|
44
|
+
system_prompt = self.config.get("system_prompt", SYSTEM_PROMPT)
|
45
|
+
|
46
|
+
# Create chat using xAI SDK workflow
|
47
|
+
chat = self.client.chat.create(model=model_name)
|
48
|
+
chat.append(system(system_prompt))
|
49
|
+
chat.append(user(prompt))
|
50
|
+
|
51
|
+
# Get response
|
52
|
+
response = chat.sample()
|
53
|
+
content = response.content
|
54
|
+
|
55
|
+
if content is None:
|
56
|
+
raise APIError("Error: API returned empty response")
|
57
|
+
|
58
|
+
# Remove ```bash and ``` from the content if present
|
59
|
+
re_match = re.search(r"```bash\n(.*)\n```", content, re.DOTALL)
|
60
|
+
if re_match is None:
|
61
|
+
return content.strip()
|
62
|
+
else:
|
63
|
+
return re_match.group(1).strip()
|
64
|
+
|
65
|
+
except Exception as e:
|
66
|
+
self._handle_api_error(e)
|
67
|
+
return ""
|
68
|
+
|
69
|
+
def validate_config(self) -> None:
|
70
|
+
"""Validate provider configuration and API key."""
|
71
|
+
api_key_env = self.config.get("api_key_env", "XAI_API_KEY")
|
72
|
+
api_key = os.environ.get(api_key_env)
|
73
|
+
|
74
|
+
if not api_key:
|
75
|
+
raise AuthenticationError(
|
76
|
+
f"Error: {api_key_env} environment variable is required"
|
77
|
+
)
|
78
|
+
|
79
|
+
# Initialize xAI SDK client
|
80
|
+
self.client = Client(api_key=api_key)
|
81
|
+
|
82
|
+
def _handle_api_error(self, error: Exception) -> NoReturn:
|
83
|
+
"""Handle API errors and map them to standard exceptions.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
error: The exception to handle
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
AuthenticationError: If the API key is invalid
|
90
|
+
RateLimitError: If the API rate limit is exceeded
|
91
|
+
"""
|
92
|
+
error_str = str(error).lower()
|
93
|
+
|
94
|
+
if (
|
95
|
+
"authentication" in error_str
|
96
|
+
or "unauthorized" in error_str
|
97
|
+
or "invalid api key" in error_str
|
98
|
+
):
|
99
|
+
raise AuthenticationError("Error: Invalid API key")
|
100
|
+
elif (
|
101
|
+
"rate limit" in error_str
|
102
|
+
or "quota" in error_str
|
103
|
+
or "too many requests" in error_str
|
104
|
+
):
|
105
|
+
raise RateLimitError("Error: API rate limit exceeded")
|
106
|
+
else:
|
107
|
+
raise APIError(f"Error: API request failed - {error}")
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def get_default_config(cls) -> dict[str, Any]:
|
111
|
+
"""Return default configuration for Grok provider."""
|
112
|
+
return {
|
113
|
+
"model_name": "grok-3-fast",
|
114
|
+
"max_tokens": 150,
|
115
|
+
"api_key_env": "XAI_API_KEY",
|
116
|
+
"temperature": 0.0,
|
117
|
+
"system_prompt": SYSTEM_PROMPT,
|
118
|
+
}
|
@@ -1,24 +1,34 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: terminal-sherpa
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: AI-powered bash command generator
|
5
|
-
|
5
|
+
Project-URL: Homepage, https://github.com/lcford2/terminal-sherpa
|
6
|
+
Project-URL: Issues, https://github.com/lcford2/terminal-sherpa/issues
|
7
|
+
Classifier: Development Status :: 4 - Beta
|
8
|
+
Classifier: Environment :: Console
|
9
|
+
Classifier: Intended Audience :: Developers
|
10
|
+
Classifier: Operating System :: OS Independent
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python
|
13
|
+
Classifier: Programming Language :: Python :: 3.8
|
14
|
+
Classifier: Programming Language :: Python :: 3.9
|
15
|
+
Classifier: Programming Language :: Python :: 3.10
|
16
|
+
Classifier: Programming Language :: Python :: 3.11
|
17
|
+
Classifier: Programming Language :: Python :: 3.12
|
18
|
+
Classifier: Programming Language :: Python :: 3.13
|
19
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
20
|
+
Classifier: Topic :: Utilities
|
21
|
+
Classifier: Topic :: Software Development :: Libraries
|
22
|
+
Requires-Python: >=3.10
|
6
23
|
Description-Content-Type: text/markdown
|
7
24
|
License-File: LICENSE
|
8
25
|
Requires-Dist: anthropic>=0.7.0
|
9
|
-
Requires-Dist: black>=25.1.0
|
10
|
-
Requires-Dist: cosmic-ray>=8.3.4
|
11
26
|
Requires-Dist: google-genai>=1.26.0
|
12
27
|
Requires-Dist: loguru>=0.7.0
|
13
|
-
Requires-Dist: mutatest>=3.1.0
|
14
28
|
Requires-Dist: openai>=1.0.0
|
15
|
-
Requires-Dist: pytest>=8.0.0
|
16
|
-
Requires-Dist: pytest-cov>=4.0.0
|
17
|
-
Requires-Dist: pytest-mock>=3.12.0
|
18
|
-
Requires-Dist: ruff>=0.12.3
|
19
29
|
Requires-Dist: setuptools>=80.9.0
|
20
|
-
Requires-Dist: taskipy>=1.14.1
|
21
30
|
Requires-Dist: toml>=0.10.0
|
31
|
+
Requires-Dist: xai-sdk
|
22
32
|
Dynamic: license-file
|
23
33
|
|
24
34
|
# terminal-sherpa
|
@@ -28,6 +38,10 @@ A lightweight AI chat interface for fellow terminal dwellers.
|
|
28
38
|
Turn natural language into bash commands instantly.
|
29
39
|
Stop googling syntax and start asking.
|
30
40
|
|
41
|
+
[](https://pypi.python.org/pypi/terminal-sherpa)
|
42
|
+
[](https://github.com/lcford2/terminal-sherpa/blob/main/LICENSE)
|
43
|
+
[](https://pypi.python.org/pypi/terminal-sherpa)
|
44
|
+
[](https://github.com/lcford2/terminal-sherpa/actions)
|
31
45
|
[](https://codecov.io/github/lcford2/terminal-sherpa)
|
32
46
|
|
33
47
|
## 🚀 Getting Started
|
@@ -54,7 +68,7 @@ find . -name "*.py" -mtime -7
|
|
54
68
|
## ✨ Features
|
55
69
|
|
56
70
|
- **Natural language to bash conversion** - Describe what you want, get the command
|
57
|
-
- **Multiple AI provider support** - Choose between Anthropic (Claude), OpenAI (GPT),
|
71
|
+
- **Multiple AI provider support** - Choose between Anthropic (Claude), OpenAI (GPT), Google (Gemini), and xAI (Grok) models
|
58
72
|
- **Flexible configuration system** - Set defaults, customize models, and manage API keys
|
59
73
|
- **XDG-compliant config files** - Follows standard configuration file locations
|
60
74
|
- **Verbose logging support** - Debug and understand what's happening under the hood
|
@@ -63,8 +77,8 @@ find . -name "*.py" -mtime -7
|
|
63
77
|
|
64
78
|
### Requirements
|
65
79
|
|
66
|
-
- Python 3.
|
67
|
-
- API key for Anthropic or
|
80
|
+
- Python 3.9+
|
81
|
+
- API key for Anthropic, OpenAI, Google, or xAI
|
68
82
|
|
69
83
|
### Install Methods
|
70
84
|
|
@@ -99,11 +113,15 @@ ask "your natural language prompt"
|
|
99
113
|
|
100
114
|
### Command Options
|
101
115
|
|
102
|
-
| Option | Description | Example
|
103
|
-
| ------------------------ | -------------------------- |
|
104
|
-
| `--model provider:model` | Specify provider and model | `ask --model anthropic
|
105
|
-
| | | `ask --model
|
106
|
-
|
|
116
|
+
| Option | Description | Example |
|
117
|
+
| ------------------------ | -------------------------- | ------------------------------------------- |
|
118
|
+
| `--model provider:model` | Specify provider and model | `ask --model anthropic "list files"` |
|
119
|
+
| | | `ask --model anthropic:sonnet "list files"` |
|
120
|
+
| | | `ask --model openai "list files"` |
|
121
|
+
| | | `ask --model gemini "list files"` |
|
122
|
+
| | | `ask --model gemini:pro "list files"` |
|
123
|
+
| | | `ask --model grok "list files"` |
|
124
|
+
| `--verbose` | Enable verbose logging | `ask --verbose "compress this folder"` |
|
107
125
|
|
108
126
|
### Practical Examples
|
109
127
|
|
@@ -173,6 +191,7 @@ Ask follows XDG Base Directory Specification:
|
|
173
191
|
export ANTHROPIC_API_KEY="your-anthropic-key"
|
174
192
|
export OPENAI_API_KEY="your-openai-key"
|
175
193
|
export GEMINI_API_KEY="your-gemini-key"
|
194
|
+
export XAI_API_KEY="your-xai-key"
|
176
195
|
```
|
177
196
|
|
178
197
|
### Example Configuration File
|
@@ -188,7 +207,7 @@ model = "claude-3-haiku-20240307"
|
|
188
207
|
max_tokens = 512
|
189
208
|
|
190
209
|
[anthropic.sonnet]
|
191
|
-
model = "claude-
|
210
|
+
model = "claude-sonnet-4-20250514"
|
192
211
|
max_tokens = 1024
|
193
212
|
|
194
213
|
[openai]
|
@@ -196,12 +215,17 @@ model = "gpt-4o"
|
|
196
215
|
max_tokens = 1024
|
197
216
|
|
198
217
|
[gemini]
|
199
|
-
model = "gemini-2.5-flash"
|
218
|
+
model = "gemini-2.5-flash-lite-preview-06-17"
|
200
219
|
max_tokens = 150
|
201
220
|
|
202
221
|
[gemini.pro]
|
203
222
|
model = "gemini-2.5-pro"
|
204
223
|
max_tokens = 1024
|
224
|
+
|
225
|
+
[grok]
|
226
|
+
model = "grok-3-fast"
|
227
|
+
max_tokens = 150
|
228
|
+
temperature = 0.0
|
205
229
|
```
|
206
230
|
|
207
231
|
## 🤖 Supported Providers
|
@@ -209,31 +233,15 @@ max_tokens = 1024
|
|
209
233
|
- Anthropic (Claude)
|
210
234
|
- OpenAI (GPT)
|
211
235
|
- Google (Gemini)
|
236
|
+
- xAI (Grok)
|
212
237
|
|
213
|
-
> **Note:** Get API keys from [Anthropic Console](https://console.anthropic.com/), [OpenAI Platform](https://platform.openai.com/),
|
238
|
+
> **Note:** Get API keys from [Anthropic Console](https://console.anthropic.com/), [OpenAI Platform](https://platform.openai.com/), [Google AI Studio](https://aistudio.google.com/), or [xAI Console](https://x.ai/console)
|
214
239
|
|
215
240
|
## 🛣️ Roadmap
|
216
241
|
|
217
|
-
### Near-term
|
218
|
-
|
219
242
|
- [ ] Shell integration and auto-completion
|
220
|
-
- [ ]
|
221
|
-
- [ ] Safety features (command preview/confirmation)
|
222
|
-
- [ ] Output formatting options
|
223
|
-
|
224
|
-
### Medium-term
|
225
|
-
|
226
|
-
- [ ] Additional providers (Google, Cohere, Mistral)
|
227
|
-
- [ ] Interactive mode for complex tasks
|
228
|
-
- [ ] Plugin system for custom providers
|
229
|
-
- [ ] Command validation and testing
|
230
|
-
|
231
|
-
### Long-term
|
232
|
-
|
243
|
+
- [ ] Additional providers (Cohere, Mistral)
|
233
244
|
- [ ] Local model support (Ollama, llama.cpp)
|
234
|
-
- [ ] Learning from user preferences
|
235
|
-
- [ ] Advanced safety and sandboxing
|
236
|
-
- [ ] GUI and web interface options
|
237
245
|
|
238
246
|
## 🔧 Development
|
239
247
|
|
@@ -242,7 +250,7 @@ max_tokens = 1024
|
|
242
250
|
```bash
|
243
251
|
git clone https://github.com/lcford2/terminal-sherpa.git
|
244
252
|
cd ask
|
245
|
-
uv sync
|
253
|
+
uv sync --all-groups
|
246
254
|
uv run pre-commit install
|
247
255
|
```
|
248
256
|
|
@@ -258,16 +266,13 @@ uv run python -m pytest
|
|
258
266
|
2. Create a feature branch
|
259
267
|
3. Make your changes
|
260
268
|
4. Run pre-commit checks: `uv run pre-commit run --all-files`
|
261
|
-
5.
|
269
|
+
5. Run tests: `uv run task test`
|
270
|
+
6. Submit a pull request
|
262
271
|
|
263
272
|
## License
|
264
273
|
|
265
274
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
266
275
|
|
267
|
-
## Contributing
|
268
|
-
|
269
|
-
Contributions are welcome! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
|
270
|
-
|
271
276
|
## Issues
|
272
277
|
|
273
278
|
Found a bug or have a feature request? Please open an issue on [GitHub Issues](https://github.com/lcford2/ask/issues).
|
@@ -2,22 +2,24 @@ ask/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
ask/config.py,sha256=iHIiMKePia80Sno_XlARqa7pEyW3eZm_Bf5SlUMidiQ,2880
|
3
3
|
ask/exceptions.py,sha256=0RLMSbw6j49BEhJN7C8MYaKpuhVeitsBhTGjZmaiHis,434
|
4
4
|
ask/main.py,sha256=9mVXwncU2P4OQxE7Oxcqi376A06xluC76kiIoCCqNSc,3936
|
5
|
-
ask/providers/__init__.py,sha256=
|
5
|
+
ask/providers/__init__.py,sha256=WRcLAqAK6dEbRM9aF5hYWw60_sjCtLlSBiyf8VFb4tA,1255
|
6
6
|
ask/providers/anthropic.py,sha256=3bB335ZxRu4a_j7U0tox_6sQxTf4X287GhU2-gr5ctU,2725
|
7
7
|
ask/providers/base.py,sha256=91ZbVORYWckSHNwNPiTmgfqQN0FLO9AgV6mptuAkIU0,769
|
8
8
|
ask/providers/gemini.py,sha256=c1uQg6ShNkw2AkwDJhjfki7hTBl4vOkT9v1O4ewRfbg,3408
|
9
|
+
ask/providers/grok.py,sha256=Lbue8dgcGoaSAtQ-Jme5e5EQ9hp608W78QEPjBP_vBk,3824
|
9
10
|
ask/providers/openai.py,sha256=9PS1AgMr6Nb-OcYYiNYNrG384wNUS8m2lVT04K3hFV8,3683
|
10
|
-
terminal_sherpa-0.
|
11
|
-
test/conftest.py,sha256=
|
11
|
+
terminal_sherpa-0.4.0.dist-info/licenses/LICENSE,sha256=xLe81eIrf0X6CnEDDJXmoXuDzkdMYM3Eq1BgHUpG1JQ,1067
|
12
|
+
test/conftest.py,sha256=aRJL41pj0ITemxvXO1FqANhF8zi4T1T8ehKNbZKBSao,1724
|
12
13
|
test/test_anthropic.py,sha256=S5OQ67qIZ4VO38eJwAAwJa4JBylJhKCtmcGjCWA8WLY,5687
|
13
14
|
test/test_config.py,sha256=FrJ6bsZ6mK46e-8fQfkFGx9GgwHrNfnoI8211R0V9K8,5565
|
14
15
|
test/test_exceptions.py,sha256=tw-spMitAdYj9uW_8TjnlyVKKXFC06FR3610WGR-494,1754
|
15
16
|
test/test_gemini.py,sha256=sV8FkaU5rLfuu3lGeQdxsa-ZmNnLUYpODaKhayrasSo,8000
|
17
|
+
test/test_grok.py,sha256=udlLJCCZ9NUYz7Szpy4bifTPtOR9-QMQ_S4ZDLNWhw8,8372
|
16
18
|
test/test_main.py,sha256=3gZ83nVHMSEmgHSF2UJoELfK028a4vgxLpIk2P1cH1Y,7745
|
17
19
|
test/test_openai.py,sha256=KAGQWFrXeu4P9umij7XDoxnKQ2cApv6ImuL8EiG_5W8,8388
|
18
20
|
test/test_providers.py,sha256=SejQvCZSEQ5RAfVTCtPZ-39fXnfV17n4gaSxjiHA5UM,2140
|
19
|
-
terminal_sherpa-0.
|
20
|
-
terminal_sherpa-0.
|
21
|
-
terminal_sherpa-0.
|
22
|
-
terminal_sherpa-0.
|
23
|
-
terminal_sherpa-0.
|
21
|
+
terminal_sherpa-0.4.0.dist-info/METADATA,sha256=_WxCjbuTNgns8yq5VFYkFs0LD_2dB50aaoBuUsIzGKk,7637
|
22
|
+
terminal_sherpa-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
terminal_sherpa-0.4.0.dist-info/entry_points.txt,sha256=LxG9-J__nMGmeEIi47WVGYC1LLJo1GaADH21hfxEK70,38
|
24
|
+
terminal_sherpa-0.4.0.dist-info/top_level.txt,sha256=Y7k5b2NSCkKiA_XPU-4fT_GYangD6JVDug5xwfXvmuQ,9
|
25
|
+
terminal_sherpa-0.4.0.dist-info/RECORD,,
|
test/conftest.py
CHANGED
@@ -52,6 +52,13 @@ def mock_openai_key():
|
|
52
52
|
yield
|
53
53
|
|
54
54
|
|
55
|
+
@pytest.fixture
|
56
|
+
def mock_grok_key():
|
57
|
+
"""Mock Grok API key in environment."""
|
58
|
+
with patch.dict(os.environ, {"XAI_API_KEY": "test-grok-key"}, clear=True):
|
59
|
+
yield
|
60
|
+
|
61
|
+
|
55
62
|
@pytest.fixture
|
56
63
|
def mock_both_keys():
|
57
64
|
"""Mock both API keys in environment."""
|
test/test_grok.py
ADDED
@@ -0,0 +1,247 @@
|
|
1
|
+
"""Tests for Grok provider using xAI SDK."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from unittest.mock import MagicMock, patch
|
5
|
+
|
6
|
+
import pytest
|
7
|
+
|
8
|
+
from ask.config import SYSTEM_PROMPT
|
9
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
10
|
+
from ask.providers.grok import GrokProvider
|
11
|
+
|
12
|
+
|
13
|
+
def test_grok_provider_init():
|
14
|
+
"""Test provider initialization."""
|
15
|
+
config = {"model_name": "grok-3-fast"}
|
16
|
+
provider = GrokProvider(config)
|
17
|
+
|
18
|
+
assert provider.config == config
|
19
|
+
assert provider.client is None
|
20
|
+
|
21
|
+
|
22
|
+
def test_validate_config_success(mock_grok_key):
|
23
|
+
"""Test successful config validation."""
|
24
|
+
config = {"api_key_env": "XAI_API_KEY"}
|
25
|
+
provider = GrokProvider(config)
|
26
|
+
|
27
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
28
|
+
mock_client = MagicMock()
|
29
|
+
mock_client_class.return_value = mock_client
|
30
|
+
|
31
|
+
provider.validate_config()
|
32
|
+
|
33
|
+
assert provider.client == mock_client
|
34
|
+
mock_client_class.assert_called_once_with(api_key="test-grok-key")
|
35
|
+
|
36
|
+
|
37
|
+
def test_validate_config_missing_key(mock_env_vars):
|
38
|
+
"""Test missing API key error."""
|
39
|
+
config = {"api_key_env": "XAI_API_KEY"}
|
40
|
+
provider = GrokProvider(config)
|
41
|
+
|
42
|
+
with pytest.raises(
|
43
|
+
AuthenticationError, match="XAI_API_KEY environment variable is required"
|
44
|
+
):
|
45
|
+
provider.validate_config()
|
46
|
+
|
47
|
+
|
48
|
+
def test_validate_config_custom_env():
|
49
|
+
"""Test custom environment variable."""
|
50
|
+
config = {"api_key_env": "CUSTOM_XAI_KEY"}
|
51
|
+
provider = GrokProvider(config)
|
52
|
+
|
53
|
+
with patch.dict(os.environ, {"CUSTOM_XAI_KEY": "custom-xai-key"}):
|
54
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
55
|
+
mock_client = MagicMock()
|
56
|
+
mock_client_class.return_value = mock_client
|
57
|
+
|
58
|
+
provider.validate_config()
|
59
|
+
|
60
|
+
assert provider.client == mock_client
|
61
|
+
mock_client_class.assert_called_once_with(api_key="custom-xai-key")
|
62
|
+
|
63
|
+
|
64
|
+
@patch("ask.providers.grok.system")
|
65
|
+
@patch("ask.providers.grok.user")
|
66
|
+
def test_get_bash_command_success(mock_user, mock_system, mock_grok_key):
|
67
|
+
"""Test successful bash command generation."""
|
68
|
+
config = {"model_name": "grok-beta", "max_tokens": 150}
|
69
|
+
provider = GrokProvider(config)
|
70
|
+
|
71
|
+
# Mock the xAI SDK workflow components
|
72
|
+
mock_chat = MagicMock()
|
73
|
+
mock_response = MagicMock()
|
74
|
+
mock_response.content = "ls -la"
|
75
|
+
mock_chat.sample.return_value = mock_response
|
76
|
+
|
77
|
+
mock_system_msg = MagicMock()
|
78
|
+
mock_user_msg = MagicMock()
|
79
|
+
mock_system.return_value = mock_system_msg
|
80
|
+
mock_user.return_value = mock_user_msg
|
81
|
+
|
82
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
83
|
+
mock_client = MagicMock()
|
84
|
+
mock_client.chat.create.return_value = mock_chat
|
85
|
+
mock_client_class.return_value = mock_client
|
86
|
+
provider.client = mock_client
|
87
|
+
|
88
|
+
result = provider.get_bash_command("list files")
|
89
|
+
|
90
|
+
assert result == "ls -la"
|
91
|
+
mock_client.chat.create.assert_called_once_with(model="grok-beta")
|
92
|
+
mock_system.assert_called_once_with(SYSTEM_PROMPT)
|
93
|
+
mock_user.assert_called_once_with("list files")
|
94
|
+
assert mock_chat.append.call_count == 2
|
95
|
+
mock_chat.append.assert_any_call(mock_system_msg)
|
96
|
+
mock_chat.append.assert_any_call(mock_user_msg)
|
97
|
+
mock_chat.sample.assert_called_once()
|
98
|
+
|
99
|
+
|
100
|
+
@patch("ask.providers.grok.system")
|
101
|
+
@patch("ask.providers.grok.user")
|
102
|
+
def test_get_bash_command_with_code_block(mock_user, mock_system, mock_grok_key):
|
103
|
+
"""Test bash command extraction from code block."""
|
104
|
+
config = {}
|
105
|
+
provider = GrokProvider(config)
|
106
|
+
|
107
|
+
# Mock the xAI SDK workflow components
|
108
|
+
mock_chat = MagicMock()
|
109
|
+
mock_response = MagicMock()
|
110
|
+
mock_response.content = "```bash\nls -la\n```"
|
111
|
+
mock_chat.sample.return_value = mock_response
|
112
|
+
|
113
|
+
mock_system_msg = MagicMock()
|
114
|
+
mock_user_msg = MagicMock()
|
115
|
+
mock_system.return_value = mock_system_msg
|
116
|
+
mock_user.return_value = mock_user_msg
|
117
|
+
|
118
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
119
|
+
mock_client = MagicMock()
|
120
|
+
mock_client.chat.create.return_value = mock_chat
|
121
|
+
mock_client_class.return_value = mock_client
|
122
|
+
provider.client = mock_client
|
123
|
+
|
124
|
+
result = provider.get_bash_command("list files")
|
125
|
+
|
126
|
+
assert result == "ls -la"
|
127
|
+
|
128
|
+
|
129
|
+
@patch("ask.providers.grok.system")
|
130
|
+
@patch("ask.providers.grok.user")
|
131
|
+
def test_get_bash_command_empty_response(mock_user, mock_system, mock_grok_key):
|
132
|
+
"""Test handling of empty API response."""
|
133
|
+
config = {}
|
134
|
+
provider = GrokProvider(config)
|
135
|
+
|
136
|
+
# Mock the xAI SDK workflow components
|
137
|
+
mock_chat = MagicMock()
|
138
|
+
mock_response = MagicMock()
|
139
|
+
mock_response.content = None
|
140
|
+
mock_chat.sample.return_value = mock_response
|
141
|
+
|
142
|
+
mock_system_msg = MagicMock()
|
143
|
+
mock_user_msg = MagicMock()
|
144
|
+
mock_system.return_value = mock_system_msg
|
145
|
+
mock_user.return_value = mock_user_msg
|
146
|
+
|
147
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
148
|
+
mock_client = MagicMock()
|
149
|
+
mock_client.chat.create.return_value = mock_chat
|
150
|
+
mock_client_class.return_value = mock_client
|
151
|
+
provider.client = mock_client
|
152
|
+
|
153
|
+
with pytest.raises(APIError, match="API returned empty response"):
|
154
|
+
provider.get_bash_command("list files")
|
155
|
+
|
156
|
+
|
157
|
+
def test_handle_authentication_error(mock_grok_key):
|
158
|
+
"""Test authentication error handling."""
|
159
|
+
config = {}
|
160
|
+
provider = GrokProvider(config)
|
161
|
+
|
162
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
163
|
+
mock_client = MagicMock()
|
164
|
+
mock_client.chat.create.side_effect = Exception("Authentication failed")
|
165
|
+
mock_client_class.return_value = mock_client
|
166
|
+
provider.client = mock_client
|
167
|
+
|
168
|
+
with pytest.raises(AuthenticationError, match="Invalid API key"):
|
169
|
+
provider.get_bash_command("test")
|
170
|
+
|
171
|
+
|
172
|
+
def test_handle_rate_limit_error(mock_grok_key):
|
173
|
+
"""Test rate limit error handling."""
|
174
|
+
config = {}
|
175
|
+
provider = GrokProvider(config)
|
176
|
+
|
177
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
178
|
+
mock_client = MagicMock()
|
179
|
+
mock_client.chat.create.side_effect = Exception("Rate limit exceeded")
|
180
|
+
mock_client_class.return_value = mock_client
|
181
|
+
provider.client = mock_client
|
182
|
+
|
183
|
+
with pytest.raises(RateLimitError, match="rate limit exceeded"):
|
184
|
+
provider.get_bash_command("test")
|
185
|
+
|
186
|
+
|
187
|
+
def test_handle_general_api_error(mock_grok_key):
|
188
|
+
"""Test general API error handling."""
|
189
|
+
config = {}
|
190
|
+
provider = GrokProvider(config)
|
191
|
+
|
192
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
193
|
+
mock_client = MagicMock()
|
194
|
+
mock_client.chat.create.side_effect = Exception("Unknown error")
|
195
|
+
mock_client_class.return_value = mock_client
|
196
|
+
provider.client = mock_client
|
197
|
+
|
198
|
+
with pytest.raises(APIError, match="API request failed - Unknown error"):
|
199
|
+
provider.get_bash_command("test")
|
200
|
+
|
201
|
+
|
202
|
+
def test_get_default_config():
|
203
|
+
"""Test default configuration."""
|
204
|
+
default_config = GrokProvider.get_default_config()
|
205
|
+
|
206
|
+
expected_config = {
|
207
|
+
"model_name": "grok-3-fast",
|
208
|
+
"max_tokens": 150,
|
209
|
+
"api_key_env": "XAI_API_KEY",
|
210
|
+
"temperature": 0.0,
|
211
|
+
"system_prompt": SYSTEM_PROMPT,
|
212
|
+
}
|
213
|
+
|
214
|
+
assert default_config == expected_config
|
215
|
+
|
216
|
+
|
217
|
+
@patch("ask.providers.grok.system")
|
218
|
+
@patch("ask.providers.grok.user")
|
219
|
+
def test_get_bash_command_multiline_code_block(mock_user, mock_system, mock_grok_key):
|
220
|
+
"""Test bash command extraction from multiline code block."""
|
221
|
+
config = {}
|
222
|
+
provider = GrokProvider(config)
|
223
|
+
|
224
|
+
# Mock the xAI SDK workflow components
|
225
|
+
mock_chat = MagicMock()
|
226
|
+
mock_response = MagicMock()
|
227
|
+
mock_response.content = (
|
228
|
+
"```bash\nfind . -name '*.py' \\\n "
|
229
|
+
"-type f \\\n -exec grep -l 'test' {} \\;\n```"
|
230
|
+
)
|
231
|
+
mock_chat.sample.return_value = mock_response
|
232
|
+
|
233
|
+
mock_system_msg = MagicMock()
|
234
|
+
mock_user_msg = MagicMock()
|
235
|
+
mock_system.return_value = mock_system_msg
|
236
|
+
mock_user.return_value = mock_user_msg
|
237
|
+
|
238
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
239
|
+
mock_client = MagicMock()
|
240
|
+
mock_client.chat.create.return_value = mock_chat
|
241
|
+
mock_client_class.return_value = mock_client
|
242
|
+
provider.client = mock_client
|
243
|
+
|
244
|
+
result = provider.get_bash_command("find Python files with test")
|
245
|
+
|
246
|
+
expected = "find . -name '*.py' \\\n -type f \\\n -exec grep -l 'test' {} \\;"
|
247
|
+
assert result == expected
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|