terminal-sherpa 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ask/__init__.py ADDED
File without changes
ask/config.py ADDED
@@ -0,0 +1,94 @@
1
+ """Configuration loading and management module."""
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import toml
8
+
9
+ from ask.exceptions import ConfigurationError
10
+
11
+ SYSTEM_PROMPT = (
12
+ "You are a bash command generator. Given a user request, "
13
+ "respond with ONLY the bash command that accomplishes the task. "
14
+ "Do not include explanations, comments, or any other text. "
15
+ "Just the command.",
16
+ )
17
+
18
+
19
+ def get_config_path() -> Path | None:
20
+ """Find config file using XDG standard."""
21
+ # Primary location: $XDG_CONFIG_HOME/ask/config.toml
22
+ xdg_config_home = os.environ.get("XDG_CONFIG_HOME")
23
+ if xdg_config_home:
24
+ primary_path = Path(xdg_config_home) / "ask" / "config.toml"
25
+ else:
26
+ primary_path = Path.home() / ".config" / "ask" / "config.toml"
27
+
28
+ if primary_path.exists():
29
+ return primary_path
30
+
31
+ # Fallback location: ~/.ask/config.toml
32
+ fallback_path = Path.home() / ".ask" / "config.toml"
33
+ if fallback_path.exists():
34
+ return fallback_path
35
+
36
+ return None
37
+
38
+
39
+ def load_config() -> dict[str, Any]:
40
+ """Load configuration from TOML file."""
41
+ config_path = get_config_path()
42
+
43
+ if config_path is None:
44
+ return {}
45
+
46
+ try:
47
+ with open(config_path) as f:
48
+ return toml.load(f)
49
+ except Exception as e:
50
+ raise ConfigurationError(f"Failed to load config file {config_path}: {e}")
51
+
52
+
53
+ def get_provider_config(
54
+ config: dict[str, Any], provider_spec: str
55
+ ) -> tuple[str, dict[str, Any]]:
56
+ """Parse provider:model syntax and return provider name and config."""
57
+ if ":" in provider_spec:
58
+ provider_name, model_name = provider_spec.split(":", 1)
59
+
60
+ # First try to get nested config (e.g., anthropic.haiku)
61
+ provider_section = config.get(provider_name, {})
62
+ if isinstance(provider_section, dict) and model_name in provider_section:
63
+ provider_config = provider_section[model_name]
64
+ else:
65
+ # Fall back to base provider config
66
+ provider_config = provider_section
67
+ else:
68
+ provider_name = provider_spec
69
+ provider_config = config.get(provider_name, {})
70
+
71
+ # Add global settings
72
+ global_config = config.get("ask", {})
73
+
74
+ # Merge global and provider-specific config
75
+ merged_config = {**global_config, **provider_config}
76
+
77
+ return provider_name, merged_config
78
+
79
+
80
+ def get_default_model(config: dict[str, Any]) -> str | None:
81
+ """Get default model from configuration."""
82
+ global_config = config.get("ask", {})
83
+ return global_config.get("default_model")
84
+
85
+
86
+ def get_default_provider() -> str | None:
87
+ """Determine fallback provider from environment variables."""
88
+ # Check for API keys in order of preference: claude -> openai
89
+ if os.environ.get("ANTHROPIC_API_KEY"):
90
+ return "anthropic"
91
+ elif os.environ.get("OPENAI_API_KEY"):
92
+ return "openai"
93
+
94
+ return None
ask/exceptions.py ADDED
@@ -0,0 +1,25 @@
1
+ """Custom exception classes for the ask CLI tool."""
2
+
3
+
4
+ class ConfigurationError(Exception):
5
+ """Raised when there are configuration-related errors."""
6
+
7
+ pass
8
+
9
+
10
+ class AuthenticationError(Exception):
11
+ """Raised when authentication fails."""
12
+
13
+ pass
14
+
15
+
16
+ class APIError(Exception):
17
+ """Raised when API requests fail."""
18
+
19
+ pass
20
+
21
+
22
+ class RateLimitError(APIError):
23
+ """Raised when API rate limits are exceeded."""
24
+
25
+ pass
ask/main.py ADDED
@@ -0,0 +1,115 @@
1
+ import argparse
2
+ import sys
3
+ from typing import Any
4
+
5
+ from loguru import logger
6
+
7
+ import ask.config as config
8
+ import ask.providers as providers
9
+ from ask.exceptions import APIError, AuthenticationError, ConfigurationError
10
+ from ask.providers.base import ProviderInterface
11
+
12
+
13
+ def configure_logging(verbose: bool) -> None:
14
+ """Configure loguru logging with appropriate level and format."""
15
+ logger.remove() # Remove default handler
16
+
17
+ level = "DEBUG" if verbose else "ERROR"
18
+
19
+ logger.add(
20
+ sys.stderr,
21
+ format="{time:YYYY-MM-DD HH:mm:ss.SSS} | <level>{level}</level> | {message}",
22
+ level=level,
23
+ colorize=True,
24
+ )
25
+
26
+
27
+ def parse_arguments() -> argparse.Namespace:
28
+ """Parse command line arguments."""
29
+ parser = argparse.ArgumentParser(description="AI-powered bash command generator")
30
+ parser.add_argument("prompt", help="Natural language description of the task")
31
+ parser.add_argument(
32
+ "--model", help="Provider and model to use (format: provider[:model])"
33
+ )
34
+ parser.add_argument(
35
+ "--verbose", "-v", action="store_true", help="Enable verbose logging"
36
+ )
37
+ return parser.parse_args()
38
+
39
+
40
+ def load_configuration() -> dict[str, Any]:
41
+ """Load configuration from file and environment."""
42
+ try:
43
+ return config.load_config()
44
+ except ConfigurationError as e:
45
+ logger.error(f"Configuration error: {e}")
46
+ sys.exit(1)
47
+
48
+
49
+ def resolve_provider(args, config_data) -> ProviderInterface:
50
+ """Determine which provider to use based on arguments and configuration.
51
+
52
+ Args:
53
+ args: Parsed command line arguments
54
+ config_data: Configuration data loaded from the config file
55
+
56
+ Returns:
57
+ Provider instance
58
+ """
59
+ if args.model:
60
+ logger.debug(f"Using model specified via --model argument: {args.model}")
61
+ provider_name, provider_config = config.get_provider_config(
62
+ config_data, args.model
63
+ )
64
+ else:
65
+ # Check for default model in config first
66
+ default_model = config.get_default_model(config_data)
67
+ if default_model:
68
+ logger.debug(f"Using default model from config: {default_model}")
69
+ provider_name, provider_config = config.get_provider_config(
70
+ config_data, default_model
71
+ )
72
+ else:
73
+ logger.warning(
74
+ "No default model configured, falling back to environment variables"
75
+ )
76
+ # Use default provider from environment variables
77
+ default_provider = config.get_default_provider()
78
+ if not default_provider:
79
+ logger.error(
80
+ "No default model configured and no API keys found. "
81
+ "Please set ANTHROPIC_API_KEY or OPENAI_API_KEY environment "
82
+ "variable, or set a default_provider in your config file."
83
+ )
84
+ sys.exit(1)
85
+ logger.debug(f"Using default provider from environment: {default_provider}")
86
+ provider_name, provider_config = config.get_provider_config(
87
+ config_data, default_provider
88
+ )
89
+
90
+ try:
91
+ logger.debug(f"Initializing provider: {provider_name}")
92
+ return providers.get_provider(provider_name, provider_config)
93
+ except ConfigurationError as e:
94
+ logger.error(f"Error: {e}")
95
+ sys.exit(1)
96
+
97
+
98
+ def main() -> None:
99
+ """Main entry point for the CLI application."""
100
+ args = parse_arguments()
101
+ configure_logging(args.verbose)
102
+ config_data = load_configuration()
103
+
104
+ try:
105
+ provider = resolve_provider(args, config_data)
106
+ provider.validate_config()
107
+ bash_command = provider.get_bash_command(args.prompt)
108
+ print(bash_command)
109
+ except (AuthenticationError, APIError) as e:
110
+ logger.error(str(e))
111
+ sys.exit(1)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
@@ -0,0 +1,35 @@
1
+ """Provider registry and initialization module."""
2
+
3
+ from .anthropic import AnthropicProvider
4
+ from .base import ProviderInterface
5
+ from .openai import OpenAIProvider
6
+
7
+ # Provider registry - maps provider names to their classes
8
+ _PROVIDER_REGISTRY: dict[str, type[ProviderInterface]] = {}
9
+
10
+
11
+ def register_provider(name: str, provider_class: type[ProviderInterface]) -> None:
12
+ """Register a provider class with the given name."""
13
+ _PROVIDER_REGISTRY[name] = provider_class
14
+
15
+
16
+ def get_provider(name: str, config: dict) -> ProviderInterface:
17
+ """Get a provider instance by name."""
18
+ if name not in _PROVIDER_REGISTRY:
19
+ from ask.exceptions import ConfigurationError
20
+
21
+ raise ConfigurationError(
22
+ f"Provider '{name}' not found. Available providers: {list_providers()}"
23
+ )
24
+
25
+ provider_class = _PROVIDER_REGISTRY[name]
26
+ return provider_class(config)
27
+
28
+
29
+ def list_providers() -> list[str]:
30
+ """List all available provider names."""
31
+ return list(_PROVIDER_REGISTRY.keys())
32
+
33
+
34
+ register_provider("anthropic", AnthropicProvider)
35
+ register_provider("openai", OpenAIProvider)
@@ -0,0 +1,73 @@
1
+ """Anthropic provider implementation."""
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ import anthropic
7
+
8
+ from ask.config import SYSTEM_PROMPT
9
+ from ask.exceptions import APIError, AuthenticationError, RateLimitError
10
+ from ask.providers.base import ProviderInterface
11
+
12
+
13
+ class AnthropicProvider(ProviderInterface):
14
+ """Anthropic AI provider implementation."""
15
+
16
+ def __init__(self, config: dict[str, Any]):
17
+ """Initialize Anthropic provider with configuration."""
18
+ super().__init__(config)
19
+ self.client: anthropic.Anthropic | None = None
20
+
21
+ def get_bash_command(self, prompt: str) -> str:
22
+ """Generate bash command from natural language prompt."""
23
+ if self.client is None:
24
+ self.validate_config()
25
+
26
+ # After validate_config(), client should be set
27
+ assert self.client is not None, "Client should be initialized after validation"
28
+
29
+ try:
30
+ response = self.client.messages.create(
31
+ model=self.config.get("model_name", "claude-3-haiku-20240307"),
32
+ max_tokens=self.config.get("max_tokens", 150),
33
+ temperature=self.config.get("temperature", 0.0),
34
+ system=self.config.get("system_prompt", SYSTEM_PROMPT),
35
+ messages=[{"role": "user", "content": prompt}],
36
+ )
37
+ return response.content[0].text
38
+ except Exception as e:
39
+ self._handle_api_error(e)
40
+
41
+ def validate_config(self) -> None:
42
+ """Validate provider configuration and API key."""
43
+ api_key_env = self.config.get("api_key_env", "ANTHROPIC_API_KEY")
44
+ api_key = os.environ.get(api_key_env)
45
+
46
+ if not api_key:
47
+ raise AuthenticationError(
48
+ f"Error: {api_key_env} environment variable is required"
49
+ )
50
+
51
+ self.client = anthropic.Anthropic(api_key=api_key)
52
+
53
+ def _handle_api_error(self, error: Exception):
54
+ """Handle API errors and map them to standard exceptions."""
55
+ error_str = str(error).lower()
56
+
57
+ if "authentication" in error_str or "unauthorized" in error_str:
58
+ raise AuthenticationError("Error: Invalid API key")
59
+ elif "rate limit" in error_str:
60
+ raise RateLimitError("Error: API rate limit exceeded")
61
+ else:
62
+ raise APIError(f"Error: API request failed - {error}")
63
+
64
+ @classmethod
65
+ def get_default_config(cls) -> dict[str, Any]:
66
+ """Return default configuration for Anthropic provider."""
67
+ return {
68
+ "model_name": "claude-3-haiku-20240307",
69
+ "max_tokens": 150,
70
+ "api_key_env": "ANTHROPIC_API_KEY",
71
+ "temperature": 0.0,
72
+ "system_prompt": SYSTEM_PROMPT,
73
+ }
ask/providers/base.py ADDED
@@ -0,0 +1,28 @@
1
+ """Abstract base class for all providers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any
5
+
6
+
7
+ class ProviderInterface(ABC):
8
+ """Abstract base class for all AI providers."""
9
+
10
+ def __init__(self, config: dict[str, Any]):
11
+ """Initialize provider with configuration."""
12
+ self.config = config
13
+
14
+ @abstractmethod
15
+ def get_bash_command(self, prompt: str) -> str:
16
+ """Generate bash command from natural language prompt."""
17
+ pass
18
+
19
+ @abstractmethod
20
+ def validate_config(self) -> None:
21
+ """Validate provider configuration and API key."""
22
+ pass
23
+
24
+ @classmethod
25
+ @abstractmethod
26
+ def get_default_config(cls) -> dict[str, Any]:
27
+ """Return default configuration for this provider."""
28
+ pass
@@ -0,0 +1,106 @@
1
+ """OpenAI provider implementation."""
2
+
3
+ import os
4
+ import re
5
+ from typing import Any, NoReturn
6
+
7
+ import openai
8
+
9
+ from ask.config import SYSTEM_PROMPT
10
+ from ask.exceptions import APIError, AuthenticationError, RateLimitError
11
+ from ask.providers.base import ProviderInterface
12
+
13
+
14
+ class OpenAIProvider(ProviderInterface):
15
+ """OpenAI provider implementation."""
16
+
17
+ def __init__(self, config: dict[str, Any]):
18
+ """Initialize OpenAI provider with configuration.
19
+
20
+ Args:
21
+ config: The configuration for the OpenAI provider
22
+ """
23
+ super().__init__(config)
24
+ self.client: openai.OpenAI | None = None
25
+
26
+ def get_bash_command(self, prompt: str) -> str:
27
+ """Generate bash command from natural language prompt.
28
+
29
+ Args:
30
+ prompt: The natural language prompt to generate a bash command for
31
+
32
+ Returns:
33
+ The generated bash command
34
+ """
35
+ if self.client is None:
36
+ self.validate_config()
37
+
38
+ # After validate_config(), client should be set
39
+ assert self.client is not None, "Client should be initialized after validation"
40
+
41
+ try:
42
+ response = self.client.chat.completions.create(
43
+ model=self.config.get("model_name", "gpt-4o-mini"),
44
+ max_completion_tokens=self.config.get("max_tokens", 150),
45
+ temperature=self.config.get("temperature", 0.0),
46
+ messages=[
47
+ {
48
+ "role": "system",
49
+ "content": self.config.get("system_prompt", SYSTEM_PROMPT),
50
+ },
51
+ {"role": "user", "content": prompt},
52
+ ],
53
+ )
54
+ content = response.choices[0].message.content
55
+ if content is None:
56
+ raise APIError("Error: API returned empty response")
57
+ # Remove ```bash and ``` from the content
58
+ re_match = re.search(r"```bash\n(.*)\n```", content)
59
+ if re_match is None:
60
+ return content
61
+ else:
62
+ return re_match.group(1)
63
+ except Exception as e:
64
+ self._handle_api_error(e)
65
+
66
+ def validate_config(self) -> None:
67
+ """Validate provider configuration and API key."""
68
+ api_key_env = self.config.get("api_key_env", "OPENAI_API_KEY")
69
+ api_key = os.environ.get(api_key_env)
70
+
71
+ if not api_key:
72
+ raise AuthenticationError(
73
+ f"Error: {api_key_env} environment variable is required"
74
+ )
75
+
76
+ self.client = openai.OpenAI(api_key=api_key)
77
+
78
+ def _handle_api_error(self, error: Exception) -> NoReturn:
79
+ """Handle API errors and map them to standard exceptions.
80
+
81
+ Args:
82
+ error: The exception to handle
83
+
84
+ Raises:
85
+ AuthenticationError: If the API key is invalid
86
+ RateLimitError: If the API rate limit is exceeded
87
+ """
88
+ error_str = str(error).lower()
89
+
90
+ if "authentication" in error_str or "unauthorized" in error_str:
91
+ raise AuthenticationError("Error: Invalid API key")
92
+ elif "rate limit" in error_str or "quota" in error_str:
93
+ raise RateLimitError("Error: API rate limit exceeded")
94
+ else:
95
+ raise APIError(f"Error: API request failed - {error}")
96
+
97
+ @classmethod
98
+ def get_default_config(cls) -> dict[str, Any]:
99
+ """Return default configuration for OpenAI provider."""
100
+ return {
101
+ "model_name": "gpt-4o-mini",
102
+ "max_tokens": 150,
103
+ "api_key_env": "OPENAI_API_KEY",
104
+ "temperature": 0.0,
105
+ "system_prompt": SYSTEM_PROMPT,
106
+ }