terminal-sherpa 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ask/__init__.py +0 -0
- ask/config.py +94 -0
- ask/exceptions.py +25 -0
- ask/main.py +115 -0
- ask/providers/__init__.py +35 -0
- ask/providers/anthropic.py +73 -0
- ask/providers/base.py +28 -0
- ask/providers/openai.py +106 -0
- terminal_sherpa-0.1.0.dist-info/METADATA +242 -0
- terminal_sherpa-0.1.0.dist-info/RECORD +20 -0
- terminal_sherpa-0.1.0.dist-info/WHEEL +5 -0
- terminal_sherpa-0.1.0.dist-info/entry_points.txt +2 -0
- terminal_sherpa-0.1.0.dist-info/top_level.txt +2 -0
- test/conftest.py +58 -0
- test/test_anthropic.py +173 -0
- test/test_config.py +164 -0
- test/test_exceptions.py +55 -0
- test/test_main.py +206 -0
- test/test_openai.py +269 -0
- test/test_providers.py +77 -0
ask/__init__.py
ADDED
File without changes
|
ask/config.py
ADDED
@@ -0,0 +1,94 @@
|
|
1
|
+
"""Configuration loading and management module."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import toml
|
8
|
+
|
9
|
+
from ask.exceptions import ConfigurationError
|
10
|
+
|
11
|
+
SYSTEM_PROMPT = (
|
12
|
+
"You are a bash command generator. Given a user request, "
|
13
|
+
"respond with ONLY the bash command that accomplishes the task. "
|
14
|
+
"Do not include explanations, comments, or any other text. "
|
15
|
+
"Just the command.",
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
def get_config_path() -> Path | None:
|
20
|
+
"""Find config file using XDG standard."""
|
21
|
+
# Primary location: $XDG_CONFIG_HOME/ask/config.toml
|
22
|
+
xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
|
23
|
+
if xdg_config_home:
|
24
|
+
primary_path = Path(xdg_config_home) / "ask" / "config.toml"
|
25
|
+
else:
|
26
|
+
primary_path = Path.home() / ".config" / "ask" / "config.toml"
|
27
|
+
|
28
|
+
if primary_path.exists():
|
29
|
+
return primary_path
|
30
|
+
|
31
|
+
# Fallback location: ~/.ask/config.toml
|
32
|
+
fallback_path = Path.home() / ".ask" / "config.toml"
|
33
|
+
if fallback_path.exists():
|
34
|
+
return fallback_path
|
35
|
+
|
36
|
+
return None
|
37
|
+
|
38
|
+
|
39
|
+
def load_config() -> dict[str, Any]:
|
40
|
+
"""Load configuration from TOML file."""
|
41
|
+
config_path = get_config_path()
|
42
|
+
|
43
|
+
if config_path is None:
|
44
|
+
return {}
|
45
|
+
|
46
|
+
try:
|
47
|
+
with open(config_path) as f:
|
48
|
+
return toml.load(f)
|
49
|
+
except Exception as e:
|
50
|
+
raise ConfigurationError(f"Failed to load config file {config_path}: {e}")
|
51
|
+
|
52
|
+
|
53
|
+
def get_provider_config(
|
54
|
+
config: dict[str, Any], provider_spec: str
|
55
|
+
) -> tuple[str, dict[str, Any]]:
|
56
|
+
"""Parse provider:model syntax and return provider name and config."""
|
57
|
+
if ":" in provider_spec:
|
58
|
+
provider_name, model_name = provider_spec.split(":", 1)
|
59
|
+
|
60
|
+
# First try to get nested config (e.g., anthropic.haiku)
|
61
|
+
provider_section = config.get(provider_name, {})
|
62
|
+
if isinstance(provider_section, dict) and model_name in provider_section:
|
63
|
+
provider_config = provider_section[model_name]
|
64
|
+
else:
|
65
|
+
# Fall back to base provider config
|
66
|
+
provider_config = provider_section
|
67
|
+
else:
|
68
|
+
provider_name = provider_spec
|
69
|
+
provider_config = config.get(provider_name, {})
|
70
|
+
|
71
|
+
# Add global settings
|
72
|
+
global_config = config.get("ask", {})
|
73
|
+
|
74
|
+
# Merge global and provider-specific config
|
75
|
+
merged_config = {**global_config, **provider_config}
|
76
|
+
|
77
|
+
return provider_name, merged_config
|
78
|
+
|
79
|
+
|
80
|
+
def get_default_model(config: dict[str, Any]) -> str | None:
|
81
|
+
"""Get default model from configuration."""
|
82
|
+
global_config = config.get("ask", {})
|
83
|
+
return global_config.get("default_model")
|
84
|
+
|
85
|
+
|
86
|
+
def get_default_provider() -> str | None:
|
87
|
+
"""Determine fallback provider from environment variables."""
|
88
|
+
# Check for API keys in order of preference: claude -> openai
|
89
|
+
if os.environ.get("ANTHROPIC_API_KEY"):
|
90
|
+
return "anthropic"
|
91
|
+
elif os.environ.get("OPENAI_API_KEY"):
|
92
|
+
return "openai"
|
93
|
+
|
94
|
+
return None
|
ask/exceptions.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
"""Custom exception classes for the ask CLI tool."""
|
2
|
+
|
3
|
+
|
4
|
+
class ConfigurationError(Exception):
|
5
|
+
"""Raised when there are configuration-related errors."""
|
6
|
+
|
7
|
+
pass
|
8
|
+
|
9
|
+
|
10
|
+
class AuthenticationError(Exception):
|
11
|
+
"""Raised when authentication fails."""
|
12
|
+
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class APIError(Exception):
|
17
|
+
"""Raised when API requests fail."""
|
18
|
+
|
19
|
+
pass
|
20
|
+
|
21
|
+
|
22
|
+
class RateLimitError(APIError):
|
23
|
+
"""Raised when API rate limits are exceeded."""
|
24
|
+
|
25
|
+
pass
|
ask/main.py
ADDED
@@ -0,0 +1,115 @@
|
|
1
|
+
import argparse
|
2
|
+
import sys
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from loguru import logger
|
6
|
+
|
7
|
+
import ask.config as config
|
8
|
+
import ask.providers as providers
|
9
|
+
from ask.exceptions import APIError, AuthenticationError, ConfigurationError
|
10
|
+
from ask.providers.base import ProviderInterface
|
11
|
+
|
12
|
+
|
13
|
+
def configure_logging(verbose: bool) -> None:
|
14
|
+
"""Configure loguru logging with appropriate level and format."""
|
15
|
+
logger.remove() # Remove default handler
|
16
|
+
|
17
|
+
level = "DEBUG" if verbose else "ERROR"
|
18
|
+
|
19
|
+
logger.add(
|
20
|
+
sys.stderr,
|
21
|
+
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | <level>{level}</level> | {message}",
|
22
|
+
level=level,
|
23
|
+
colorize=True,
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def parse_arguments() -> argparse.Namespace:
|
28
|
+
"""Parse command line arguments."""
|
29
|
+
parser = argparse.ArgumentParser(description="AI-powered bash command generator")
|
30
|
+
parser.add_argument("prompt", help="Natural language description of the task")
|
31
|
+
parser.add_argument(
|
32
|
+
"--model", help="Provider and model to use (format: provider[:model])"
|
33
|
+
)
|
34
|
+
parser.add_argument(
|
35
|
+
"--verbose", "-v", action="store_true", help="Enable verbose logging"
|
36
|
+
)
|
37
|
+
return parser.parse_args()
|
38
|
+
|
39
|
+
|
40
|
+
def load_configuration() -> dict[str, Any]:
|
41
|
+
"""Load configuration from file and environment."""
|
42
|
+
try:
|
43
|
+
return config.load_config()
|
44
|
+
except ConfigurationError as e:
|
45
|
+
logger.error(f"Configuration error: {e}")
|
46
|
+
sys.exit(1)
|
47
|
+
|
48
|
+
|
49
|
+
def resolve_provider(args, config_data) -> ProviderInterface:
|
50
|
+
"""Determine which provider to use based on arguments and configuration.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
args: Parsed command line arguments
|
54
|
+
config_data: Configuration data loaded from the config file
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Provider instance
|
58
|
+
"""
|
59
|
+
if args.model:
|
60
|
+
logger.debug(f"Using model specified via --model argument: {args.model}")
|
61
|
+
provider_name, provider_config = config.get_provider_config(
|
62
|
+
config_data, args.model
|
63
|
+
)
|
64
|
+
else:
|
65
|
+
# Check for default model in config first
|
66
|
+
default_model = config.get_default_model(config_data)
|
67
|
+
if default_model:
|
68
|
+
logger.debug(f"Using default model from config: {default_model}")
|
69
|
+
provider_name, provider_config = config.get_provider_config(
|
70
|
+
config_data, default_model
|
71
|
+
)
|
72
|
+
else:
|
73
|
+
logger.warning(
|
74
|
+
"No default model configured, falling back to environment variables"
|
75
|
+
)
|
76
|
+
# Use default provider from environment variables
|
77
|
+
default_provider = config.get_default_provider()
|
78
|
+
if not default_provider:
|
79
|
+
logger.error(
|
80
|
+
"No default model configured and no API keys found. "
|
81
|
+
"Please set ANTHROPIC_API_KEY or OPENAI_API_KEY environment "
|
82
|
+
"variable, or set a default_provider in your config file."
|
83
|
+
)
|
84
|
+
sys.exit(1)
|
85
|
+
logger.debug(f"Using default provider from environment: {default_provider}")
|
86
|
+
provider_name, provider_config = config.get_provider_config(
|
87
|
+
config_data, default_provider
|
88
|
+
)
|
89
|
+
|
90
|
+
try:
|
91
|
+
logger.debug(f"Initializing provider: {provider_name}")
|
92
|
+
return providers.get_provider(provider_name, provider_config)
|
93
|
+
except ConfigurationError as e:
|
94
|
+
logger.error(f"Error: {e}")
|
95
|
+
sys.exit(1)
|
96
|
+
|
97
|
+
|
98
|
+
def main() -> None:
|
99
|
+
"""Main entry point for the CLI application."""
|
100
|
+
args = parse_arguments()
|
101
|
+
configure_logging(args.verbose)
|
102
|
+
config_data = load_configuration()
|
103
|
+
|
104
|
+
try:
|
105
|
+
provider = resolve_provider(args, config_data)
|
106
|
+
provider.validate_config()
|
107
|
+
bash_command = provider.get_bash_command(args.prompt)
|
108
|
+
print(bash_command)
|
109
|
+
except (AuthenticationError, APIError) as e:
|
110
|
+
logger.error(str(e))
|
111
|
+
sys.exit(1)
|
112
|
+
|
113
|
+
|
114
|
+
if __name__ == "__main__":
|
115
|
+
main()
|
@@ -0,0 +1,35 @@
|
|
1
|
+
"""Provider registry and initialization module."""
|
2
|
+
|
3
|
+
from .anthropic import AnthropicProvider
|
4
|
+
from .base import ProviderInterface
|
5
|
+
from .openai import OpenAIProvider
|
6
|
+
|
7
|
+
# Provider registry - maps provider names to their classes
|
8
|
+
_PROVIDER_REGISTRY: dict[str, type[ProviderInterface]] = {}
|
9
|
+
|
10
|
+
|
11
|
+
def register_provider(name: str, provider_class: type[ProviderInterface]) -> None:
|
12
|
+
"""Register a provider class with the given name."""
|
13
|
+
_PROVIDER_REGISTRY[name] = provider_class
|
14
|
+
|
15
|
+
|
16
|
+
def get_provider(name: str, config: dict) -> ProviderInterface:
|
17
|
+
"""Get a provider instance by name."""
|
18
|
+
if name not in _PROVIDER_REGISTRY:
|
19
|
+
from ask.exceptions import ConfigurationError
|
20
|
+
|
21
|
+
raise ConfigurationError(
|
22
|
+
f"Provider '{name}' not found. Available providers: {list_providers()}"
|
23
|
+
)
|
24
|
+
|
25
|
+
provider_class = _PROVIDER_REGISTRY[name]
|
26
|
+
return provider_class(config)
|
27
|
+
|
28
|
+
|
29
|
+
def list_providers() -> list[str]:
|
30
|
+
"""List all available provider names."""
|
31
|
+
return list(_PROVIDER_REGISTRY.keys())
|
32
|
+
|
33
|
+
|
34
|
+
register_provider("anthropic", AnthropicProvider)
|
35
|
+
register_provider("openai", OpenAIProvider)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
"""Anthropic provider implementation."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import anthropic
|
7
|
+
|
8
|
+
from ask.config import SYSTEM_PROMPT
|
9
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
10
|
+
from ask.providers.base import ProviderInterface
|
11
|
+
|
12
|
+
|
13
|
+
class AnthropicProvider(ProviderInterface):
|
14
|
+
"""Anthropic AI provider implementation."""
|
15
|
+
|
16
|
+
def __init__(self, config: dict[str, Any]):
|
17
|
+
"""Initialize Anthropic provider with configuration."""
|
18
|
+
super().__init__(config)
|
19
|
+
self.client: anthropic.Anthropic | None = None
|
20
|
+
|
21
|
+
def get_bash_command(self, prompt: str) -> str:
|
22
|
+
"""Generate bash command from natural language prompt."""
|
23
|
+
if self.client is None:
|
24
|
+
self.validate_config()
|
25
|
+
|
26
|
+
# After validate_config(), client should be set
|
27
|
+
assert self.client is not None, "Client should be initialized after validation"
|
28
|
+
|
29
|
+
try:
|
30
|
+
response = self.client.messages.create(
|
31
|
+
model=self.config.get("model_name", "claude-3-haiku-20240307"),
|
32
|
+
max_tokens=self.config.get("max_tokens", 150),
|
33
|
+
temperature=self.config.get("temperature", 0.0),
|
34
|
+
system=self.config.get("system_prompt", SYSTEM_PROMPT),
|
35
|
+
messages=[{"role": "user", "content": prompt}],
|
36
|
+
)
|
37
|
+
return response.content[0].text
|
38
|
+
except Exception as e:
|
39
|
+
self._handle_api_error(e)
|
40
|
+
|
41
|
+
def validate_config(self) -> None:
|
42
|
+
"""Validate provider configuration and API key."""
|
43
|
+
api_key_env = self.config.get("api_key_env", "ANTHROPIC_API_KEY")
|
44
|
+
api_key = os.environ.get(api_key_env)
|
45
|
+
|
46
|
+
if not api_key:
|
47
|
+
raise AuthenticationError(
|
48
|
+
f"Error: {api_key_env} environment variable is required"
|
49
|
+
)
|
50
|
+
|
51
|
+
self.client = anthropic.Anthropic(api_key=api_key)
|
52
|
+
|
53
|
+
def _handle_api_error(self, error: Exception):
|
54
|
+
"""Handle API errors and map them to standard exceptions."""
|
55
|
+
error_str = str(error).lower()
|
56
|
+
|
57
|
+
if "authentication" in error_str or "unauthorized" in error_str:
|
58
|
+
raise AuthenticationError("Error: Invalid API key")
|
59
|
+
elif "rate limit" in error_str:
|
60
|
+
raise RateLimitError("Error: API rate limit exceeded")
|
61
|
+
else:
|
62
|
+
raise APIError(f"Error: API request failed - {error}")
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def get_default_config(cls) -> dict[str, Any]:
|
66
|
+
"""Return default configuration for Anthropic provider."""
|
67
|
+
return {
|
68
|
+
"model_name": "claude-3-haiku-20240307",
|
69
|
+
"max_tokens": 150,
|
70
|
+
"api_key_env": "ANTHROPIC_API_KEY",
|
71
|
+
"temperature": 0.0,
|
72
|
+
"system_prompt": SYSTEM_PROMPT,
|
73
|
+
}
|
ask/providers/base.py
ADDED
@@ -0,0 +1,28 @@
|
|
1
|
+
"""Abstract base class for all providers."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
|
7
|
+
class ProviderInterface(ABC):
|
8
|
+
"""Abstract base class for all AI providers."""
|
9
|
+
|
10
|
+
def __init__(self, config: dict[str, Any]):
|
11
|
+
"""Initialize provider with configuration."""
|
12
|
+
self.config = config
|
13
|
+
|
14
|
+
@abstractmethod
|
15
|
+
def get_bash_command(self, prompt: str) -> str:
|
16
|
+
"""Generate bash command from natural language prompt."""
|
17
|
+
pass
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def validate_config(self) -> None:
|
21
|
+
"""Validate provider configuration and API key."""
|
22
|
+
pass
|
23
|
+
|
24
|
+
@classmethod
|
25
|
+
@abstractmethod
|
26
|
+
def get_default_config(cls) -> dict[str, Any]:
|
27
|
+
"""Return default configuration for this provider."""
|
28
|
+
pass
|
ask/providers/openai.py
ADDED
@@ -0,0 +1,106 @@
|
|
1
|
+
"""OpenAI provider implementation."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
from typing import Any, NoReturn
|
6
|
+
|
7
|
+
import openai
|
8
|
+
|
9
|
+
from ask.config import SYSTEM_PROMPT
|
10
|
+
from ask.exceptions import APIError, AuthenticationError, RateLimitError
|
11
|
+
from ask.providers.base import ProviderInterface
|
12
|
+
|
13
|
+
|
14
|
+
class OpenAIProvider(ProviderInterface):
|
15
|
+
"""OpenAI provider implementation."""
|
16
|
+
|
17
|
+
def __init__(self, config: dict[str, Any]):
|
18
|
+
"""Initialize OpenAI provider with configuration.
|
19
|
+
|
20
|
+
Args:
|
21
|
+
config: The configuration for the OpenAI provider
|
22
|
+
"""
|
23
|
+
super().__init__(config)
|
24
|
+
self.client: openai.OpenAI | None = None
|
25
|
+
|
26
|
+
def get_bash_command(self, prompt: str) -> str:
|
27
|
+
"""Generate bash command from natural language prompt.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
prompt: The natural language prompt to generate a bash command for
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
The generated bash command
|
34
|
+
"""
|
35
|
+
if self.client is None:
|
36
|
+
self.validate_config()
|
37
|
+
|
38
|
+
# After validate_config(), client should be set
|
39
|
+
assert self.client is not None, "Client should be initialized after validation"
|
40
|
+
|
41
|
+
try:
|
42
|
+
response = self.client.chat.completions.create(
|
43
|
+
model=self.config.get("model_name", "gpt-4o-mini"),
|
44
|
+
max_completion_tokens=self.config.get("max_tokens", 150),
|
45
|
+
temperature=self.config.get("temperature", 0.0),
|
46
|
+
messages=[
|
47
|
+
{
|
48
|
+
"role": "system",
|
49
|
+
"content": self.config.get("system_prompt", SYSTEM_PROMPT),
|
50
|
+
},
|
51
|
+
{"role": "user", "content": prompt},
|
52
|
+
],
|
53
|
+
)
|
54
|
+
content = response.choices[0].message.content
|
55
|
+
if content is None:
|
56
|
+
raise APIError("Error: API returned empty response")
|
57
|
+
# Remove ```bash and ``` from the content
|
58
|
+
re_match = re.search(r"```bash\n(.*)\n```", content)
|
59
|
+
if re_match is None:
|
60
|
+
return content
|
61
|
+
else:
|
62
|
+
return re_match.group(1)
|
63
|
+
except Exception as e:
|
64
|
+
self._handle_api_error(e)
|
65
|
+
|
66
|
+
def validate_config(self) -> None:
|
67
|
+
"""Validate provider configuration and API key."""
|
68
|
+
api_key_env = self.config.get("api_key_env", "OPENAI_API_KEY")
|
69
|
+
api_key = os.environ.get(api_key_env)
|
70
|
+
|
71
|
+
if not api_key:
|
72
|
+
raise AuthenticationError(
|
73
|
+
f"Error: {api_key_env} environment variable is required"
|
74
|
+
)
|
75
|
+
|
76
|
+
self.client = openai.OpenAI(api_key=api_key)
|
77
|
+
|
78
|
+
def _handle_api_error(self, error: Exception) -> NoReturn:
|
79
|
+
"""Handle API errors and map them to standard exceptions.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
error: The exception to handle
|
83
|
+
|
84
|
+
Raises:
|
85
|
+
AuthenticationError: If the API key is invalid
|
86
|
+
RateLimitError: If the API rate limit is exceeded
|
87
|
+
"""
|
88
|
+
error_str = str(error).lower()
|
89
|
+
|
90
|
+
if "authentication" in error_str or "unauthorized" in error_str:
|
91
|
+
raise AuthenticationError("Error: Invalid API key")
|
92
|
+
elif "rate limit" in error_str or "quota" in error_str:
|
93
|
+
raise RateLimitError("Error: API rate limit exceeded")
|
94
|
+
else:
|
95
|
+
raise APIError(f"Error: API request failed - {error}")
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def get_default_config(cls) -> dict[str, Any]:
|
99
|
+
"""Return default configuration for OpenAI provider."""
|
100
|
+
return {
|
101
|
+
"model_name": "gpt-4o-mini",
|
102
|
+
"max_tokens": 150,
|
103
|
+
"api_key_env": "OPENAI_API_KEY",
|
104
|
+
"temperature": 0.0,
|
105
|
+
"system_prompt": SYSTEM_PROMPT,
|
106
|
+
}
|