terminal-sherpa 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ask/config.py +4 -4
- ask/providers/__init__.py +2 -0
- ask/providers/anthropic.py +2 -2
- ask/providers/gemini.py +2 -2
- ask/providers/grok.py +118 -0
- ask/providers/openai.py +2 -2
- {terminal_sherpa-0.3.0.dist-info → terminal_sherpa-0.4.0.dist-info}/METADATA +28 -36
- terminal_sherpa-0.4.0.dist-info/RECORD +25 -0
- test/conftest.py +7 -0
- test/test_grok.py +247 -0
- terminal_sherpa-0.3.0.dist-info/RECORD +0 -23
- {terminal_sherpa-0.3.0.dist-info → terminal_sherpa-0.4.0.dist-info}/WHEEL +0 -0
- {terminal_sherpa-0.3.0.dist-info → terminal_sherpa-0.4.0.dist-info}/entry_points.txt +0 -0
- {terminal_sherpa-0.3.0.dist-info → terminal_sherpa-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {terminal_sherpa-0.3.0.dist-info → terminal_sherpa-0.4.0.dist-info}/top_level.txt +0 -0
ask/config.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any
|
6
6
|
|
7
7
|
import toml
|
8
8
|
|
@@ -16,7 +16,7 @@ SYSTEM_PROMPT = (
|
|
16
16
|
)
|
17
17
|
|
18
18
|
|
19
|
-
def get_config_path() ->
|
19
|
+
def get_config_path() -> Path | None:
|
20
20
|
"""Find config file using XDG standard."""
|
21
21
|
# Primary location: $XDG_CONFIG_HOME/ask/config.toml
|
22
22
|
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
|
@@ -77,13 +77,13 @@ def get_provider_config(
|
|
77
77
|
return provider_name, merged_config
|
78
78
|
|
79
79
|
|
80
|
-
def get_default_model(config: dict[str, Any]) ->
|
80
|
+
def get_default_model(config: dict[str, Any]) -> str | None:
|
81
81
|
"""Get default model from configuration."""
|
82
82
|
global_config = config.get("ask", {})
|
83
83
|
return global_config.get("default_model")
|
84
84
|
|
85
85
|
|
86
|
-
def get_default_provider() ->
|
86
|
+
def get_default_provider() -> str | None:
|
87
87
|
"""Determine fallback provider from environment variables."""
|
88
88
|
# Check for API keys in order of preference: claude -> openai
|
89
89
|
if os.environ.get("ANTHROPIC_API_KEY"):
|
ask/providers/__init__.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3
3
|
from .anthropic import AnthropicProvider
|
4
4
|
from .base import ProviderInterface
|
5
5
|
from .gemini import GeminiProvider
|
6
|
+
from .grok import GrokProvider
|
6
7
|
from .openai import OpenAIProvider
|
7
8
|
|
8
9
|
# Provider registry - maps provider names to their classes
|
@@ -35,3 +36,4 @@ def list_providers() -> list[str]:
|
|
35
36
|
register_provider("anthropic", AnthropicProvider)
|
36
37
|
register_provider("openai", OpenAIProvider)
|
37
38
|
register_provider("gemini", GeminiProvider)
|
39
|
+
register_provider("grok", GrokProvider)
|
ask/providers/anthropic.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Anthropic provider implementation."""
|
2
2
|
|
3
3
|
import os
|
4
|
-
from typing import Any
|
4
|
+
from typing import Any
|
5
5
|
|
6
6
|
import anthropic
|
7
7
|
|
@@ -16,7 +16,7 @@ class AnthropicProvider(ProviderInterface):
|
|
16
16
|
def __init__(self, config: dict[str, Any]):
|
17
17
|
"""Initialize Anthropic provider with configuration."""
|
18
18
|
super().__init__(config)
|
19
|
-
self.client:
|
19
|
+
self.client: anthropic.Anthropic | None = None
|
20
20
|
|
21
21
|
def get_bash_command(self, prompt: str) -> str:
|
22
22
|
"""Generate bash command from natural language prompt."""
|
ask/providers/gemini.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Anthropic provider implementation."""
|
2
2
|
|
3
3
|
import os
|
4
|
-
from typing import Any
|
4
|
+
from typing import Any
|
5
5
|
|
6
6
|
from google import genai
|
7
7
|
from google.genai.types import GenerateContentConfig, GenerateContentResponse
|
@@ -17,7 +17,7 @@ class GeminiProvider(ProviderInterface):
|
|
17
17
|
def __init__(self, config: dict[str, Any]):
|
18
18
|
"""Initialize Gemini provider with configuration."""
|
19
19
|
super().__init__(config)
|
20
|
-
self.client:
|
20
|
+
self.client: genai.Client | None = None
|
21
21
|
|
22
22
|
def _parse_response(self, response: GenerateContentResponse) -> str:
|
23
23
|
"""Parse response from Gemini API."""
|
ask/providers/grok.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
"""Grok provider implementation using official xAI SDK."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
from typing import Any, NoReturn
|
6
|
+
|
7
|
+
from xai_sdk import Client
|
8
|
+
from xai_sdk.chat import system, user
|
9
|
+
|
10
|
+
from ask.config import SYSTEM_PROMPT
|
11
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
12
|
+
from ask.providers.base import ProviderInterface
|
13
|
+
|
14
|
+
|
15
|
+
class GrokProvider(ProviderInterface):
|
16
|
+
"""Grok provider implementation using official xAI SDK."""
|
17
|
+
|
18
|
+
def __init__(self, config: dict[str, Any]):
|
19
|
+
"""Initialize Grok provider with configuration.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
config: The configuration for the Grok provider
|
23
|
+
"""
|
24
|
+
super().__init__(config)
|
25
|
+
self.client: Client | None = None
|
26
|
+
|
27
|
+
def get_bash_command(self, prompt: str) -> str:
|
28
|
+
"""Generate bash command from natural language prompt.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
prompt: The natural language prompt to generate a bash command for
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The generated bash command
|
35
|
+
"""
|
36
|
+
if self.client is None:
|
37
|
+
self.validate_config()
|
38
|
+
|
39
|
+
# After validate_config(), client should be set
|
40
|
+
assert self.client is not None, "Client should be initialized after validation"
|
41
|
+
|
42
|
+
try:
|
43
|
+
model_name = self.config.get("model_name", "grok-3-fast")
|
44
|
+
system_prompt = self.config.get("system_prompt", SYSTEM_PROMPT)
|
45
|
+
|
46
|
+
# Create chat using xAI SDK workflow
|
47
|
+
chat = self.client.chat.create(model=model_name)
|
48
|
+
chat.append(system(system_prompt))
|
49
|
+
chat.append(user(prompt))
|
50
|
+
|
51
|
+
# Get response
|
52
|
+
response = chat.sample()
|
53
|
+
content = response.content
|
54
|
+
|
55
|
+
if content is None:
|
56
|
+
raise APIError("Error: API returned empty response")
|
57
|
+
|
58
|
+
# Remove ```bash and ``` from the content if present
|
59
|
+
re_match = re.search(r"```bash\n(.*)\n```", content, re.DOTALL)
|
60
|
+
if re_match is None:
|
61
|
+
return content.strip()
|
62
|
+
else:
|
63
|
+
return re_match.group(1).strip()
|
64
|
+
|
65
|
+
except Exception as e:
|
66
|
+
self._handle_api_error(e)
|
67
|
+
return ""
|
68
|
+
|
69
|
+
def validate_config(self) -> None:
|
70
|
+
"""Validate provider configuration and API key."""
|
71
|
+
api_key_env = self.config.get("api_key_env", "XAI_API_KEY")
|
72
|
+
api_key = os.environ.get(api_key_env)
|
73
|
+
|
74
|
+
if not api_key:
|
75
|
+
raise AuthenticationError(
|
76
|
+
f"Error: {api_key_env} environment variable is required"
|
77
|
+
)
|
78
|
+
|
79
|
+
# Initialize xAI SDK client
|
80
|
+
self.client = Client(api_key=api_key)
|
81
|
+
|
82
|
+
def _handle_api_error(self, error: Exception) -> NoReturn:
|
83
|
+
"""Handle API errors and map them to standard exceptions.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
error: The exception to handle
|
87
|
+
|
88
|
+
Raises:
|
89
|
+
AuthenticationError: If the API key is invalid
|
90
|
+
RateLimitError: If the API rate limit is exceeded
|
91
|
+
"""
|
92
|
+
error_str = str(error).lower()
|
93
|
+
|
94
|
+
if (
|
95
|
+
"authentication" in error_str
|
96
|
+
or "unauthorized" in error_str
|
97
|
+
or "invalid api key" in error_str
|
98
|
+
):
|
99
|
+
raise AuthenticationError("Error: Invalid API key")
|
100
|
+
elif (
|
101
|
+
"rate limit" in error_str
|
102
|
+
or "quota" in error_str
|
103
|
+
or "too many requests" in error_str
|
104
|
+
):
|
105
|
+
raise RateLimitError("Error: API rate limit exceeded")
|
106
|
+
else:
|
107
|
+
raise APIError(f"Error: API request failed - {error}")
|
108
|
+
|
109
|
+
@classmethod
|
110
|
+
def get_default_config(cls) -> dict[str, Any]:
|
111
|
+
"""Return default configuration for Grok provider."""
|
112
|
+
return {
|
113
|
+
"model_name": "grok-3-fast",
|
114
|
+
"max_tokens": 150,
|
115
|
+
"api_key_env": "XAI_API_KEY",
|
116
|
+
"temperature": 0.0,
|
117
|
+
"system_prompt": SYSTEM_PROMPT,
|
118
|
+
}
|
ask/providers/openai.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
import re
|
5
|
-
from typing import Any, NoReturn
|
5
|
+
from typing import Any, NoReturn
|
6
6
|
|
7
7
|
import openai
|
8
8
|
|
@@ -21,7 +21,7 @@ class OpenAIProvider(ProviderInterface):
|
|
21
21
|
config: The configuration for the OpenAI provider
|
22
22
|
"""
|
23
23
|
super().__init__(config)
|
24
|
-
self.client:
|
24
|
+
self.client: openai.OpenAI | None = None
|
25
25
|
|
26
26
|
def get_bash_command(self, prompt: str) -> str:
|
27
27
|
"""Generate bash command from natural language prompt.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: terminal-sherpa
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: AI-powered bash command generator
|
5
5
|
Project-URL: Homepage, https://github.com/lcford2/terminal-sherpa
|
6
6
|
Project-URL: Issues, https://github.com/lcford2/terminal-sherpa/issues
|
@@ -19,7 +19,7 @@ Classifier: Programming Language :: Python :: 3.13
|
|
19
19
|
Classifier: Programming Language :: Python :: 3 :: Only
|
20
20
|
Classifier: Topic :: Utilities
|
21
21
|
Classifier: Topic :: Software Development :: Libraries
|
22
|
-
Requires-Python: >=3.
|
22
|
+
Requires-Python: >=3.10
|
23
23
|
Description-Content-Type: text/markdown
|
24
24
|
License-File: LICENSE
|
25
25
|
Requires-Dist: anthropic>=0.7.0
|
@@ -28,6 +28,7 @@ Requires-Dist: loguru>=0.7.0
|
|
28
28
|
Requires-Dist: openai>=1.0.0
|
29
29
|
Requires-Dist: setuptools>=80.9.0
|
30
30
|
Requires-Dist: toml>=0.10.0
|
31
|
+
Requires-Dist: xai-sdk
|
31
32
|
Dynamic: license-file
|
32
33
|
|
33
34
|
# terminal-sherpa
|
@@ -67,7 +68,7 @@ find . -name "*.py" -mtime -7
|
|
67
68
|
## ✨ Features
|
68
69
|
|
69
70
|
- **Natural language to bash conversion** - Describe what you want, get the command
|
70
|
-
- **Multiple AI provider support** - Choose between Anthropic (Claude), OpenAI (GPT),
|
71
|
+
- **Multiple AI provider support** - Choose between Anthropic (Claude), OpenAI (GPT), Google (Gemini), and xAI (Grok) models
|
71
72
|
- **Flexible configuration system** - Set defaults, customize models, and manage API keys
|
72
73
|
- **XDG-compliant config files** - Follows standard configuration file locations
|
73
74
|
- **Verbose logging support** - Debug and understand what's happening under the hood
|
@@ -77,7 +78,7 @@ find . -name "*.py" -mtime -7
|
|
77
78
|
### Requirements
|
78
79
|
|
79
80
|
- Python 3.9+
|
80
|
-
- API key for Anthropic or
|
81
|
+
- API key for Anthropic, OpenAI, Google, or xAI
|
81
82
|
|
82
83
|
### Install Methods
|
83
84
|
|
@@ -112,11 +113,15 @@ ask "your natural language prompt"
|
|
112
113
|
|
113
114
|
### Command Options
|
114
115
|
|
115
|
-
| Option | Description | Example
|
116
|
-
| ------------------------ | -------------------------- |
|
117
|
-
| `--model provider:model` | Specify provider and model | `ask --model anthropic
|
118
|
-
| | | `ask --model
|
119
|
-
|
|
116
|
+
| Option | Description | Example |
|
117
|
+
| ------------------------ | -------------------------- | ------------------------------------------- |
|
118
|
+
| `--model provider:model` | Specify provider and model | `ask --model anthropic "list files"` |
|
119
|
+
| | | `ask --model anthropic:sonnet "list files"` |
|
120
|
+
| | | `ask --model openai "list files"` |
|
121
|
+
| | | `ask --model gemini "list files"` |
|
122
|
+
| | | `ask --model gemini:pro "list files"` |
|
123
|
+
| | | `ask --model grok "list files"` |
|
124
|
+
| `--verbose` | Enable verbose logging | `ask --verbose "compress this folder"` |
|
120
125
|
|
121
126
|
### Practical Examples
|
122
127
|
|
@@ -186,6 +191,7 @@ Ask follows XDG Base Directory Specification:
|
|
186
191
|
export ANTHROPIC_API_KEY="your-anthropic-key"
|
187
192
|
export OPENAI_API_KEY="your-openai-key"
|
188
193
|
export GEMINI_API_KEY="your-gemini-key"
|
194
|
+
export XAI_API_KEY="your-xai-key"
|
189
195
|
```
|
190
196
|
|
191
197
|
### Example Configuration File
|
@@ -201,7 +207,7 @@ model = "claude-3-haiku-20240307"
|
|
201
207
|
max_tokens = 512
|
202
208
|
|
203
209
|
[anthropic.sonnet]
|
204
|
-
model = "claude-
|
210
|
+
model = "claude-sonnet-4-20250514"
|
205
211
|
max_tokens = 1024
|
206
212
|
|
207
213
|
[openai]
|
@@ -209,12 +215,17 @@ model = "gpt-4o"
|
|
209
215
|
max_tokens = 1024
|
210
216
|
|
211
217
|
[gemini]
|
212
|
-
model = "gemini-2.5-flash"
|
218
|
+
model = "gemini-2.5-flash-lite-preview-06-17"
|
213
219
|
max_tokens = 150
|
214
220
|
|
215
221
|
[gemini.pro]
|
216
222
|
model = "gemini-2.5-pro"
|
217
223
|
max_tokens = 1024
|
224
|
+
|
225
|
+
[grok]
|
226
|
+
model = "grok-3-fast"
|
227
|
+
max_tokens = 150
|
228
|
+
temperature = 0.0
|
218
229
|
```
|
219
230
|
|
220
231
|
## 🤖 Supported Providers
|
@@ -222,31 +233,15 @@ max_tokens = 1024
|
|
222
233
|
- Anthropic (Claude)
|
223
234
|
- OpenAI (GPT)
|
224
235
|
- Google (Gemini)
|
236
|
+
- xAI (Grok)
|
225
237
|
|
226
|
-
> **Note:** Get API keys from [Anthropic Console](https://console.anthropic.com/), [OpenAI Platform](https://platform.openai.com/),
|
238
|
+
> **Note:** Get API keys from [Anthropic Console](https://console.anthropic.com/), [OpenAI Platform](https://platform.openai.com/), [Google AI Studio](https://aistudio.google.com/), or [xAI Console](https://x.ai/console)
|
227
239
|
|
228
240
|
## 🛣️ Roadmap
|
229
241
|
|
230
|
-
### Near-term
|
231
|
-
|
232
242
|
- [ ] Shell integration and auto-completion
|
233
|
-
- [ ]
|
234
|
-
- [ ] Safety features (command preview/confirmation)
|
235
|
-
- [ ] Output formatting options
|
236
|
-
|
237
|
-
### Medium-term
|
238
|
-
|
239
|
-
- [ ] Additional providers (Google, Cohere, Mistral)
|
240
|
-
- [ ] Interactive mode for complex tasks
|
241
|
-
- [ ] Plugin system for custom providers
|
242
|
-
- [ ] Command validation and testing
|
243
|
-
|
244
|
-
### Long-term
|
245
|
-
|
243
|
+
- [ ] Additional providers (Cohere, Mistral)
|
246
244
|
- [ ] Local model support (Ollama, llama.cpp)
|
247
|
-
- [ ] Learning from user preferences
|
248
|
-
- [ ] Advanced safety and sandboxing
|
249
|
-
- [ ] GUI and web interface options
|
250
245
|
|
251
246
|
## 🔧 Development
|
252
247
|
|
@@ -255,7 +250,7 @@ max_tokens = 1024
|
|
255
250
|
```bash
|
256
251
|
git clone https://github.com/lcford2/terminal-sherpa.git
|
257
252
|
cd ask
|
258
|
-
uv sync
|
253
|
+
uv sync --all-groups
|
259
254
|
uv run pre-commit install
|
260
255
|
```
|
261
256
|
|
@@ -271,16 +266,13 @@ uv run python -m pytest
|
|
271
266
|
2. Create a feature branch
|
272
267
|
3. Make your changes
|
273
268
|
4. Run pre-commit checks: `uv run pre-commit run --all-files`
|
274
|
-
5.
|
269
|
+
5. Run tests: `uv run task test`
|
270
|
+
6. Submit a pull request
|
275
271
|
|
276
272
|
## License
|
277
273
|
|
278
274
|
This project is licensed under the MIT License - see the LICENSE file for details.
|
279
275
|
|
280
|
-
## Contributing
|
281
|
-
|
282
|
-
Contributions are welcome! Please see our [Contributing Guidelines](CONTRIBUTING.md) for details.
|
283
|
-
|
284
276
|
## Issues
|
285
277
|
|
286
278
|
Found a bug or have a feature request? Please open an issue on [GitHub Issues](https://github.com/lcford2/ask/issues).
|
@@ -0,0 +1,25 @@
|
|
1
|
+
ask/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
ask/config.py,sha256=iHIiMKePia80Sno_XlARqa7pEyW3eZm_Bf5SlUMidiQ,2880
|
3
|
+
ask/exceptions.py,sha256=0RLMSbw6j49BEhJN7C8MYaKpuhVeitsBhTGjZmaiHis,434
|
4
|
+
ask/main.py,sha256=9mVXwncU2P4OQxE7Oxcqi376A06xluC76kiIoCCqNSc,3936
|
5
|
+
ask/providers/__init__.py,sha256=WRcLAqAK6dEbRM9aF5hYWw60_sjCtLlSBiyf8VFb4tA,1255
|
6
|
+
ask/providers/anthropic.py,sha256=3bB335ZxRu4a_j7U0tox_6sQxTf4X287GhU2-gr5ctU,2725
|
7
|
+
ask/providers/base.py,sha256=91ZbVORYWckSHNwNPiTmgfqQN0FLO9AgV6mptuAkIU0,769
|
8
|
+
ask/providers/gemini.py,sha256=c1uQg6ShNkw2AkwDJhjfki7hTBl4vOkT9v1O4ewRfbg,3408
|
9
|
+
ask/providers/grok.py,sha256=Lbue8dgcGoaSAtQ-Jme5e5EQ9hp608W78QEPjBP_vBk,3824
|
10
|
+
ask/providers/openai.py,sha256=9PS1AgMr6Nb-OcYYiNYNrG384wNUS8m2lVT04K3hFV8,3683
|
11
|
+
terminal_sherpa-0.4.0.dist-info/licenses/LICENSE,sha256=xLe81eIrf0X6CnEDDJXmoXuDzkdMYM3Eq1BgHUpG1JQ,1067
|
12
|
+
test/conftest.py,sha256=aRJL41pj0ITemxvXO1FqANhF8zi4T1T8ehKNbZKBSao,1724
|
13
|
+
test/test_anthropic.py,sha256=S5OQ67qIZ4VO38eJwAAwJa4JBylJhKCtmcGjCWA8WLY,5687
|
14
|
+
test/test_config.py,sha256=FrJ6bsZ6mK46e-8fQfkFGx9GgwHrNfnoI8211R0V9K8,5565
|
15
|
+
test/test_exceptions.py,sha256=tw-spMitAdYj9uW_8TjnlyVKKXFC06FR3610WGR-494,1754
|
16
|
+
test/test_gemini.py,sha256=sV8FkaU5rLfuu3lGeQdxsa-ZmNnLUYpODaKhayrasSo,8000
|
17
|
+
test/test_grok.py,sha256=udlLJCCZ9NUYz7Szpy4bifTPtOR9-QMQ_S4ZDLNWhw8,8372
|
18
|
+
test/test_main.py,sha256=3gZ83nVHMSEmgHSF2UJoELfK028a4vgxLpIk2P1cH1Y,7745
|
19
|
+
test/test_openai.py,sha256=KAGQWFrXeu4P9umij7XDoxnKQ2cApv6ImuL8EiG_5W8,8388
|
20
|
+
test/test_providers.py,sha256=SejQvCZSEQ5RAfVTCtPZ-39fXnfV17n4gaSxjiHA5UM,2140
|
21
|
+
terminal_sherpa-0.4.0.dist-info/METADATA,sha256=_WxCjbuTNgns8yq5VFYkFs0LD_2dB50aaoBuUsIzGKk,7637
|
22
|
+
terminal_sherpa-0.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
23
|
+
terminal_sherpa-0.4.0.dist-info/entry_points.txt,sha256=LxG9-J__nMGmeEIi47WVGYC1LLJo1GaADH21hfxEK70,38
|
24
|
+
terminal_sherpa-0.4.0.dist-info/top_level.txt,sha256=Y7k5b2NSCkKiA_XPU-4fT_GYangD6JVDug5xwfXvmuQ,9
|
25
|
+
terminal_sherpa-0.4.0.dist-info/RECORD,,
|
test/conftest.py
CHANGED
@@ -52,6 +52,13 @@ def mock_openai_key():
|
|
52
52
|
yield
|
53
53
|
|
54
54
|
|
55
|
+
@pytest.fixture
|
56
|
+
def mock_grok_key():
|
57
|
+
"""Mock Grok API key in environment."""
|
58
|
+
with patch.dict(os.environ, {"XAI_API_KEY": "test-grok-key"}, clear=True):
|
59
|
+
yield
|
60
|
+
|
61
|
+
|
55
62
|
@pytest.fixture
|
56
63
|
def mock_both_keys():
|
57
64
|
"""Mock both API keys in environment."""
|
test/test_grok.py
ADDED
@@ -0,0 +1,247 @@
|
|
1
|
+
"""Tests for Grok provider using xAI SDK."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from unittest.mock import MagicMock, patch
|
5
|
+
|
6
|
+
import pytest
|
7
|
+
|
8
|
+
from ask.config import SYSTEM_PROMPT
|
9
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
10
|
+
from ask.providers.grok import GrokProvider
|
11
|
+
|
12
|
+
|
13
|
+
def test_grok_provider_init():
|
14
|
+
"""Test provider initialization."""
|
15
|
+
config = {"model_name": "grok-3-fast"}
|
16
|
+
provider = GrokProvider(config)
|
17
|
+
|
18
|
+
assert provider.config == config
|
19
|
+
assert provider.client is None
|
20
|
+
|
21
|
+
|
22
|
+
def test_validate_config_success(mock_grok_key):
|
23
|
+
"""Test successful config validation."""
|
24
|
+
config = {"api_key_env": "XAI_API_KEY"}
|
25
|
+
provider = GrokProvider(config)
|
26
|
+
|
27
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
28
|
+
mock_client = MagicMock()
|
29
|
+
mock_client_class.return_value = mock_client
|
30
|
+
|
31
|
+
provider.validate_config()
|
32
|
+
|
33
|
+
assert provider.client == mock_client
|
34
|
+
mock_client_class.assert_called_once_with(api_key="test-grok-key")
|
35
|
+
|
36
|
+
|
37
|
+
def test_validate_config_missing_key(mock_env_vars):
|
38
|
+
"""Test missing API key error."""
|
39
|
+
config = {"api_key_env": "XAI_API_KEY"}
|
40
|
+
provider = GrokProvider(config)
|
41
|
+
|
42
|
+
with pytest.raises(
|
43
|
+
AuthenticationError, match="XAI_API_KEY environment variable is required"
|
44
|
+
):
|
45
|
+
provider.validate_config()
|
46
|
+
|
47
|
+
|
48
|
+
def test_validate_config_custom_env():
|
49
|
+
"""Test custom environment variable."""
|
50
|
+
config = {"api_key_env": "CUSTOM_XAI_KEY"}
|
51
|
+
provider = GrokProvider(config)
|
52
|
+
|
53
|
+
with patch.dict(os.environ, {"CUSTOM_XAI_KEY": "custom-xai-key"}):
|
54
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
55
|
+
mock_client = MagicMock()
|
56
|
+
mock_client_class.return_value = mock_client
|
57
|
+
|
58
|
+
provider.validate_config()
|
59
|
+
|
60
|
+
assert provider.client == mock_client
|
61
|
+
mock_client_class.assert_called_once_with(api_key="custom-xai-key")
|
62
|
+
|
63
|
+
|
64
|
+
@patch("ask.providers.grok.system")
|
65
|
+
@patch("ask.providers.grok.user")
|
66
|
+
def test_get_bash_command_success(mock_user, mock_system, mock_grok_key):
|
67
|
+
"""Test successful bash command generation."""
|
68
|
+
config = {"model_name": "grok-beta", "max_tokens": 150}
|
69
|
+
provider = GrokProvider(config)
|
70
|
+
|
71
|
+
# Mock the xAI SDK workflow components
|
72
|
+
mock_chat = MagicMock()
|
73
|
+
mock_response = MagicMock()
|
74
|
+
mock_response.content = "ls -la"
|
75
|
+
mock_chat.sample.return_value = mock_response
|
76
|
+
|
77
|
+
mock_system_msg = MagicMock()
|
78
|
+
mock_user_msg = MagicMock()
|
79
|
+
mock_system.return_value = mock_system_msg
|
80
|
+
mock_user.return_value = mock_user_msg
|
81
|
+
|
82
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
83
|
+
mock_client = MagicMock()
|
84
|
+
mock_client.chat.create.return_value = mock_chat
|
85
|
+
mock_client_class.return_value = mock_client
|
86
|
+
provider.client = mock_client
|
87
|
+
|
88
|
+
result = provider.get_bash_command("list files")
|
89
|
+
|
90
|
+
assert result == "ls -la"
|
91
|
+
mock_client.chat.create.assert_called_once_with(model="grok-beta")
|
92
|
+
mock_system.assert_called_once_with(SYSTEM_PROMPT)
|
93
|
+
mock_user.assert_called_once_with("list files")
|
94
|
+
assert mock_chat.append.call_count == 2
|
95
|
+
mock_chat.append.assert_any_call(mock_system_msg)
|
96
|
+
mock_chat.append.assert_any_call(mock_user_msg)
|
97
|
+
mock_chat.sample.assert_called_once()
|
98
|
+
|
99
|
+
|
100
|
+
@patch("ask.providers.grok.system")
|
101
|
+
@patch("ask.providers.grok.user")
|
102
|
+
def test_get_bash_command_with_code_block(mock_user, mock_system, mock_grok_key):
|
103
|
+
"""Test bash command extraction from code block."""
|
104
|
+
config = {}
|
105
|
+
provider = GrokProvider(config)
|
106
|
+
|
107
|
+
# Mock the xAI SDK workflow components
|
108
|
+
mock_chat = MagicMock()
|
109
|
+
mock_response = MagicMock()
|
110
|
+
mock_response.content = "```bash\nls -la\n```"
|
111
|
+
mock_chat.sample.return_value = mock_response
|
112
|
+
|
113
|
+
mock_system_msg = MagicMock()
|
114
|
+
mock_user_msg = MagicMock()
|
115
|
+
mock_system.return_value = mock_system_msg
|
116
|
+
mock_user.return_value = mock_user_msg
|
117
|
+
|
118
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
119
|
+
mock_client = MagicMock()
|
120
|
+
mock_client.chat.create.return_value = mock_chat
|
121
|
+
mock_client_class.return_value = mock_client
|
122
|
+
provider.client = mock_client
|
123
|
+
|
124
|
+
result = provider.get_bash_command("list files")
|
125
|
+
|
126
|
+
assert result == "ls -la"
|
127
|
+
|
128
|
+
|
129
|
+
@patch("ask.providers.grok.system")
|
130
|
+
@patch("ask.providers.grok.user")
|
131
|
+
def test_get_bash_command_empty_response(mock_user, mock_system, mock_grok_key):
|
132
|
+
"""Test handling of empty API response."""
|
133
|
+
config = {}
|
134
|
+
provider = GrokProvider(config)
|
135
|
+
|
136
|
+
# Mock the xAI SDK workflow components
|
137
|
+
mock_chat = MagicMock()
|
138
|
+
mock_response = MagicMock()
|
139
|
+
mock_response.content = None
|
140
|
+
mock_chat.sample.return_value = mock_response
|
141
|
+
|
142
|
+
mock_system_msg = MagicMock()
|
143
|
+
mock_user_msg = MagicMock()
|
144
|
+
mock_system.return_value = mock_system_msg
|
145
|
+
mock_user.return_value = mock_user_msg
|
146
|
+
|
147
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
148
|
+
mock_client = MagicMock()
|
149
|
+
mock_client.chat.create.return_value = mock_chat
|
150
|
+
mock_client_class.return_value = mock_client
|
151
|
+
provider.client = mock_client
|
152
|
+
|
153
|
+
with pytest.raises(APIError, match="API returned empty response"):
|
154
|
+
provider.get_bash_command("list files")
|
155
|
+
|
156
|
+
|
157
|
+
def test_handle_authentication_error(mock_grok_key):
|
158
|
+
"""Test authentication error handling."""
|
159
|
+
config = {}
|
160
|
+
provider = GrokProvider(config)
|
161
|
+
|
162
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
163
|
+
mock_client = MagicMock()
|
164
|
+
mock_client.chat.create.side_effect = Exception("Authentication failed")
|
165
|
+
mock_client_class.return_value = mock_client
|
166
|
+
provider.client = mock_client
|
167
|
+
|
168
|
+
with pytest.raises(AuthenticationError, match="Invalid API key"):
|
169
|
+
provider.get_bash_command("test")
|
170
|
+
|
171
|
+
|
172
|
+
def test_handle_rate_limit_error(mock_grok_key):
|
173
|
+
"""Test rate limit error handling."""
|
174
|
+
config = {}
|
175
|
+
provider = GrokProvider(config)
|
176
|
+
|
177
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
178
|
+
mock_client = MagicMock()
|
179
|
+
mock_client.chat.create.side_effect = Exception("Rate limit exceeded")
|
180
|
+
mock_client_class.return_value = mock_client
|
181
|
+
provider.client = mock_client
|
182
|
+
|
183
|
+
with pytest.raises(RateLimitError, match="rate limit exceeded"):
|
184
|
+
provider.get_bash_command("test")
|
185
|
+
|
186
|
+
|
187
|
+
def test_handle_general_api_error(mock_grok_key):
|
188
|
+
"""Test general API error handling."""
|
189
|
+
config = {}
|
190
|
+
provider = GrokProvider(config)
|
191
|
+
|
192
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
193
|
+
mock_client = MagicMock()
|
194
|
+
mock_client.chat.create.side_effect = Exception("Unknown error")
|
195
|
+
mock_client_class.return_value = mock_client
|
196
|
+
provider.client = mock_client
|
197
|
+
|
198
|
+
with pytest.raises(APIError, match="API request failed - Unknown error"):
|
199
|
+
provider.get_bash_command("test")
|
200
|
+
|
201
|
+
|
202
|
+
def test_get_default_config():
|
203
|
+
"""Test default configuration."""
|
204
|
+
default_config = GrokProvider.get_default_config()
|
205
|
+
|
206
|
+
expected_config = {
|
207
|
+
"model_name": "grok-3-fast",
|
208
|
+
"max_tokens": 150,
|
209
|
+
"api_key_env": "XAI_API_KEY",
|
210
|
+
"temperature": 0.0,
|
211
|
+
"system_prompt": SYSTEM_PROMPT,
|
212
|
+
}
|
213
|
+
|
214
|
+
assert default_config == expected_config
|
215
|
+
|
216
|
+
|
217
|
+
@patch("ask.providers.grok.system")
|
218
|
+
@patch("ask.providers.grok.user")
|
219
|
+
def test_get_bash_command_multiline_code_block(mock_user, mock_system, mock_grok_key):
|
220
|
+
"""Test bash command extraction from multiline code block."""
|
221
|
+
config = {}
|
222
|
+
provider = GrokProvider(config)
|
223
|
+
|
224
|
+
# Mock the xAI SDK workflow components
|
225
|
+
mock_chat = MagicMock()
|
226
|
+
mock_response = MagicMock()
|
227
|
+
mock_response.content = (
|
228
|
+
"```bash\nfind . -name '*.py' \\\n "
|
229
|
+
"-type f \\\n -exec grep -l 'test' {} \\;\n```"
|
230
|
+
)
|
231
|
+
mock_chat.sample.return_value = mock_response
|
232
|
+
|
233
|
+
mock_system_msg = MagicMock()
|
234
|
+
mock_user_msg = MagicMock()
|
235
|
+
mock_system.return_value = mock_system_msg
|
236
|
+
mock_user.return_value = mock_user_msg
|
237
|
+
|
238
|
+
with patch("ask.providers.grok.Client") as mock_client_class:
|
239
|
+
mock_client = MagicMock()
|
240
|
+
mock_client.chat.create.return_value = mock_chat
|
241
|
+
mock_client_class.return_value = mock_client
|
242
|
+
provider.client = mock_client
|
243
|
+
|
244
|
+
result = provider.get_bash_command("find Python files with test")
|
245
|
+
|
246
|
+
expected = "find . -name '*.py' \\\n -type f \\\n -exec grep -l 'test' {} \\;"
|
247
|
+
assert result == expected
|
@@ -1,23 +0,0 @@
|
|
1
|
-
ask/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
ask/config.py,sha256=12YKck6Q9vaSESfuI7kKFbs1hkxzts2FOsxR9eK96MY,2899
|
3
|
-
ask/exceptions.py,sha256=0RLMSbw6j49BEhJN7C8MYaKpuhVeitsBhTGjZmaiHis,434
|
4
|
-
ask/main.py,sha256=9mVXwncU2P4OQxE7Oxcqi376A06xluC76kiIoCCqNSc,3936
|
5
|
-
ask/providers/__init__.py,sha256=Y0NswA6O8PpE_PDWa-GZ1FNmSXrwReZ9-roUoTOksXU,1184
|
6
|
-
ask/providers/anthropic.py,sha256=M6bcOxDbbWQrJ6kWWtYqsVk6NsZOqX7PzxyiI_yji1Q,2738
|
7
|
-
ask/providers/base.py,sha256=91ZbVORYWckSHNwNPiTmgfqQN0FLO9AgV6mptuAkIU0,769
|
8
|
-
ask/providers/gemini.py,sha256=taFZYEygiEf05XCm4AKSKL2F_BQwLrU4Y5Ac-WN5owk,3421
|
9
|
-
ask/providers/openai.py,sha256=jVyRH4FRdF_91iuK5Tga4as9zbyGKPPFW90ewGG5s5k,3696
|
10
|
-
terminal_sherpa-0.3.0.dist-info/licenses/LICENSE,sha256=xLe81eIrf0X6CnEDDJXmoXuDzkdMYM3Eq1BgHUpG1JQ,1067
|
11
|
-
test/conftest.py,sha256=pjDI0SbIhHxDqJW-BdL7s6lTqM2f8hItxWY8EjC-dL8,1548
|
12
|
-
test/test_anthropic.py,sha256=S5OQ67qIZ4VO38eJwAAwJa4JBylJhKCtmcGjCWA8WLY,5687
|
13
|
-
test/test_config.py,sha256=FrJ6bsZ6mK46e-8fQfkFGx9GgwHrNfnoI8211R0V9K8,5565
|
14
|
-
test/test_exceptions.py,sha256=tw-spMitAdYj9uW_8TjnlyVKKXFC06FR3610WGR-494,1754
|
15
|
-
test/test_gemini.py,sha256=sV8FkaU5rLfuu3lGeQdxsa-ZmNnLUYpODaKhayrasSo,8000
|
16
|
-
test/test_main.py,sha256=3gZ83nVHMSEmgHSF2UJoELfK028a4vgxLpIk2P1cH1Y,7745
|
17
|
-
test/test_openai.py,sha256=KAGQWFrXeu4P9umij7XDoxnKQ2cApv6ImuL8EiG_5W8,8388
|
18
|
-
test/test_providers.py,sha256=SejQvCZSEQ5RAfVTCtPZ-39fXnfV17n4gaSxjiHA5UM,2140
|
19
|
-
terminal_sherpa-0.3.0.dist-info/METADATA,sha256=r0ARy4RTO3vnlxn6tizVE-aEDLj79MyXLTl5ZxLp0Yk,7521
|
20
|
-
terminal_sherpa-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
21
|
-
terminal_sherpa-0.3.0.dist-info/entry_points.txt,sha256=LxG9-J__nMGmeEIi47WVGYC1LLJo1GaADH21hfxEK70,38
|
22
|
-
terminal_sherpa-0.3.0.dist-info/top_level.txt,sha256=Y7k5b2NSCkKiA_XPU-4fT_GYangD6JVDug5xwfXvmuQ,9
|
23
|
-
terminal_sherpa-0.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|