shotgun-sh 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/__init__.py +5 -0
- shotgun/agents/__init__.py +1 -0
- shotgun/agents/agent_manager.py +651 -0
- shotgun/agents/common.py +549 -0
- shotgun/agents/config/__init__.py +13 -0
- shotgun/agents/config/constants.py +17 -0
- shotgun/agents/config/manager.py +294 -0
- shotgun/agents/config/models.py +185 -0
- shotgun/agents/config/provider.py +206 -0
- shotgun/agents/conversation_history.py +106 -0
- shotgun/agents/conversation_manager.py +105 -0
- shotgun/agents/export.py +96 -0
- shotgun/agents/history/__init__.py +5 -0
- shotgun/agents/history/compaction.py +85 -0
- shotgun/agents/history/constants.py +19 -0
- shotgun/agents/history/context_extraction.py +108 -0
- shotgun/agents/history/history_building.py +104 -0
- shotgun/agents/history/history_processors.py +426 -0
- shotgun/agents/history/message_utils.py +84 -0
- shotgun/agents/history/token_counting.py +429 -0
- shotgun/agents/history/token_estimation.py +138 -0
- shotgun/agents/messages.py +35 -0
- shotgun/agents/models.py +275 -0
- shotgun/agents/plan.py +98 -0
- shotgun/agents/research.py +108 -0
- shotgun/agents/specify.py +98 -0
- shotgun/agents/tasks.py +96 -0
- shotgun/agents/tools/__init__.py +34 -0
- shotgun/agents/tools/codebase/__init__.py +28 -0
- shotgun/agents/tools/codebase/codebase_shell.py +256 -0
- shotgun/agents/tools/codebase/directory_lister.py +141 -0
- shotgun/agents/tools/codebase/file_read.py +144 -0
- shotgun/agents/tools/codebase/models.py +252 -0
- shotgun/agents/tools/codebase/query_graph.py +67 -0
- shotgun/agents/tools/codebase/retrieve_code.py +81 -0
- shotgun/agents/tools/file_management.py +218 -0
- shotgun/agents/tools/user_interaction.py +37 -0
- shotgun/agents/tools/web_search/__init__.py +60 -0
- shotgun/agents/tools/web_search/anthropic.py +144 -0
- shotgun/agents/tools/web_search/gemini.py +85 -0
- shotgun/agents/tools/web_search/openai.py +98 -0
- shotgun/agents/tools/web_search/utils.py +20 -0
- shotgun/build_constants.py +20 -0
- shotgun/cli/__init__.py +1 -0
- shotgun/cli/codebase/__init__.py +5 -0
- shotgun/cli/codebase/commands.py +202 -0
- shotgun/cli/codebase/models.py +21 -0
- shotgun/cli/config.py +275 -0
- shotgun/cli/export.py +81 -0
- shotgun/cli/models.py +10 -0
- shotgun/cli/plan.py +73 -0
- shotgun/cli/research.py +85 -0
- shotgun/cli/specify.py +69 -0
- shotgun/cli/tasks.py +78 -0
- shotgun/cli/update.py +152 -0
- shotgun/cli/utils.py +25 -0
- shotgun/codebase/__init__.py +12 -0
- shotgun/codebase/core/__init__.py +46 -0
- shotgun/codebase/core/change_detector.py +358 -0
- shotgun/codebase/core/code_retrieval.py +243 -0
- shotgun/codebase/core/ingestor.py +1497 -0
- shotgun/codebase/core/language_config.py +297 -0
- shotgun/codebase/core/manager.py +1662 -0
- shotgun/codebase/core/nl_query.py +331 -0
- shotgun/codebase/core/parser_loader.py +128 -0
- shotgun/codebase/models.py +111 -0
- shotgun/codebase/service.py +206 -0
- shotgun/logging_config.py +227 -0
- shotgun/main.py +167 -0
- shotgun/posthog_telemetry.py +158 -0
- shotgun/prompts/__init__.py +5 -0
- shotgun/prompts/agents/__init__.py +1 -0
- shotgun/prompts/agents/export.j2 +350 -0
- shotgun/prompts/agents/partials/codebase_understanding.j2 +87 -0
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +37 -0
- shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +26 -0
- shotgun/prompts/agents/plan.j2 +144 -0
- shotgun/prompts/agents/research.j2 +69 -0
- shotgun/prompts/agents/specify.j2 +51 -0
- shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +19 -0
- shotgun/prompts/agents/state/system_state.j2 +31 -0
- shotgun/prompts/agents/tasks.j2 +143 -0
- shotgun/prompts/codebase/__init__.py +1 -0
- shotgun/prompts/codebase/cypher_query_patterns.j2 +223 -0
- shotgun/prompts/codebase/cypher_system.j2 +28 -0
- shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
- shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
- shotgun/prompts/codebase/partials/graph_schema.j2 +30 -0
- shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
- shotgun/prompts/history/__init__.py +1 -0
- shotgun/prompts/history/incremental_summarization.j2 +53 -0
- shotgun/prompts/history/summarization.j2 +46 -0
- shotgun/prompts/loader.py +140 -0
- shotgun/py.typed +0 -0
- shotgun/sdk/__init__.py +13 -0
- shotgun/sdk/codebase.py +219 -0
- shotgun/sdk/exceptions.py +17 -0
- shotgun/sdk/models.py +189 -0
- shotgun/sdk/services.py +23 -0
- shotgun/sentry_telemetry.py +87 -0
- shotgun/telemetry.py +93 -0
- shotgun/tui/__init__.py +0 -0
- shotgun/tui/app.py +116 -0
- shotgun/tui/commands/__init__.py +76 -0
- shotgun/tui/components/prompt_input.py +69 -0
- shotgun/tui/components/spinner.py +86 -0
- shotgun/tui/components/splash.py +25 -0
- shotgun/tui/components/vertical_tail.py +13 -0
- shotgun/tui/screens/chat.py +782 -0
- shotgun/tui/screens/chat.tcss +43 -0
- shotgun/tui/screens/chat_screen/__init__.py +0 -0
- shotgun/tui/screens/chat_screen/command_providers.py +219 -0
- shotgun/tui/screens/chat_screen/hint_message.py +40 -0
- shotgun/tui/screens/chat_screen/history.py +221 -0
- shotgun/tui/screens/directory_setup.py +113 -0
- shotgun/tui/screens/provider_config.py +221 -0
- shotgun/tui/screens/splash.py +31 -0
- shotgun/tui/styles.tcss +10 -0
- shotgun/tui/utils/__init__.py +5 -0
- shotgun/tui/utils/mode_progress.py +257 -0
- shotgun/utils/__init__.py +5 -0
- shotgun/utils/env_utils.py +35 -0
- shotgun/utils/file_system_utils.py +36 -0
- shotgun/utils/update_checker.py +375 -0
- shotgun_sh-0.1.0.dist-info/METADATA +466 -0
- shotgun_sh-0.1.0.dist-info/RECORD +130 -0
- shotgun_sh-0.1.0.dist-info/WHEEL +4 -0
- shotgun_sh-0.1.0.dist-info/entry_points.txt +2 -0
- shotgun_sh-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""Configuration manager for Shotgun CLI."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import uuid
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from pydantic import SecretStr
|
|
10
|
+
|
|
11
|
+
from shotgun.logging_config import get_logger
|
|
12
|
+
from shotgun.utils import get_shotgun_home
|
|
13
|
+
|
|
14
|
+
from .constants import (
|
|
15
|
+
ANTHROPIC_API_KEY_ENV,
|
|
16
|
+
ANTHROPIC_PROVIDER,
|
|
17
|
+
API_KEY_FIELD,
|
|
18
|
+
GEMINI_API_KEY_ENV,
|
|
19
|
+
GOOGLE_PROVIDER,
|
|
20
|
+
OPENAI_API_KEY_ENV,
|
|
21
|
+
OPENAI_PROVIDER,
|
|
22
|
+
)
|
|
23
|
+
from .models import ProviderType, ShotgunConfig
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ConfigManager:
|
|
29
|
+
"""Manager for Shotgun configuration."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, config_path: Path | None = None):
|
|
32
|
+
"""Initialize ConfigManager.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config_path: Path to config file. If None, uses default ~/.shotgun-sh/config.json
|
|
36
|
+
"""
|
|
37
|
+
if config_path is None:
|
|
38
|
+
self.config_path = get_shotgun_home() / "config.json"
|
|
39
|
+
else:
|
|
40
|
+
self.config_path = config_path
|
|
41
|
+
|
|
42
|
+
self._config: ShotgunConfig | None = None
|
|
43
|
+
|
|
44
|
+
def load(self) -> ShotgunConfig:
|
|
45
|
+
"""Load configuration from file.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
ShotgunConfig: Loaded configuration or default config if file doesn't exist
|
|
49
|
+
"""
|
|
50
|
+
if self._config is not None:
|
|
51
|
+
return self._config
|
|
52
|
+
|
|
53
|
+
if not self.config_path.exists():
|
|
54
|
+
logger.info(
|
|
55
|
+
"Configuration file not found, creating new config with user_id: %s",
|
|
56
|
+
self.config_path,
|
|
57
|
+
)
|
|
58
|
+
# Create new config with generated user_id
|
|
59
|
+
self._config = self.initialize()
|
|
60
|
+
return self._config
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
with open(self.config_path, encoding="utf-8") as f:
|
|
64
|
+
data = json.load(f)
|
|
65
|
+
|
|
66
|
+
# Convert plain text secrets to SecretStr objects
|
|
67
|
+
self._convert_secrets_to_secretstr(data)
|
|
68
|
+
|
|
69
|
+
self._config = ShotgunConfig.model_validate(data)
|
|
70
|
+
logger.debug("Configuration loaded successfully from %s", self.config_path)
|
|
71
|
+
|
|
72
|
+
# Check if the default provider has a key, if not find one that does
|
|
73
|
+
if not self.has_provider_key(self._config.default_provider):
|
|
74
|
+
original_default = self._config.default_provider
|
|
75
|
+
# Find first provider with a configured key
|
|
76
|
+
for provider in ProviderType:
|
|
77
|
+
if self.has_provider_key(provider):
|
|
78
|
+
logger.info(
|
|
79
|
+
"Default provider %s has no API key, updating to %s",
|
|
80
|
+
original_default.value,
|
|
81
|
+
provider.value,
|
|
82
|
+
)
|
|
83
|
+
self._config.default_provider = provider
|
|
84
|
+
self.save(self._config)
|
|
85
|
+
break
|
|
86
|
+
|
|
87
|
+
return self._config
|
|
88
|
+
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.error(
|
|
91
|
+
"Failed to load configuration from %s: %s", self.config_path, e
|
|
92
|
+
)
|
|
93
|
+
logger.info("Creating new configuration with generated user_id")
|
|
94
|
+
self._config = self.initialize()
|
|
95
|
+
return self._config
|
|
96
|
+
|
|
97
|
+
def save(self, config: ShotgunConfig | None = None) -> None:
|
|
98
|
+
"""Save configuration to file.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
config: Configuration to save. If None, saves current loaded config
|
|
102
|
+
"""
|
|
103
|
+
if config is None:
|
|
104
|
+
if self._config:
|
|
105
|
+
config = self._config
|
|
106
|
+
else:
|
|
107
|
+
# Create a new config with generated user_id
|
|
108
|
+
config = ShotgunConfig(
|
|
109
|
+
user_id=str(uuid.uuid4()),
|
|
110
|
+
config_version=1,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Ensure directory exists
|
|
114
|
+
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
# Convert SecretStr to plain text for JSON serialization
|
|
118
|
+
data = config.model_dump()
|
|
119
|
+
self._convert_secretstr_to_plain(data)
|
|
120
|
+
|
|
121
|
+
with open(self.config_path, "w", encoding="utf-8") as f:
|
|
122
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
123
|
+
|
|
124
|
+
logger.debug("Configuration saved to %s", self.config_path)
|
|
125
|
+
self._config = config
|
|
126
|
+
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.error("Failed to save configuration to %s: %s", self.config_path, e)
|
|
129
|
+
raise
|
|
130
|
+
|
|
131
|
+
def update_provider(self, provider: ProviderType | str, **kwargs: Any) -> None:
|
|
132
|
+
"""Update provider configuration.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
provider: Provider to update
|
|
136
|
+
**kwargs: Configuration fields to update (only api_key supported)
|
|
137
|
+
"""
|
|
138
|
+
config = self.load()
|
|
139
|
+
provider_enum = self._ensure_provider_enum(provider)
|
|
140
|
+
provider_config = self._get_provider_config(config, provider_enum)
|
|
141
|
+
|
|
142
|
+
# Only support api_key updates
|
|
143
|
+
if API_KEY_FIELD in kwargs:
|
|
144
|
+
api_key_value = kwargs[API_KEY_FIELD]
|
|
145
|
+
provider_config.api_key = (
|
|
146
|
+
SecretStr(api_key_value) if api_key_value is not None else None
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Reject other fields
|
|
150
|
+
unsupported_fields = set(kwargs.keys()) - {API_KEY_FIELD}
|
|
151
|
+
if unsupported_fields:
|
|
152
|
+
raise ValueError(f"Unsupported configuration fields: {unsupported_fields}")
|
|
153
|
+
|
|
154
|
+
# If no other providers have keys configured and we just added one,
|
|
155
|
+
# set this provider as the default
|
|
156
|
+
if API_KEY_FIELD in kwargs and api_key_value is not None:
|
|
157
|
+
other_providers = [p for p in ProviderType if p != provider_enum]
|
|
158
|
+
has_other_keys = any(self.has_provider_key(p) for p in other_providers)
|
|
159
|
+
if not has_other_keys:
|
|
160
|
+
config.default_provider = provider_enum
|
|
161
|
+
|
|
162
|
+
self.save(config)
|
|
163
|
+
|
|
164
|
+
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
165
|
+
"""Remove the API key for the given provider."""
|
|
166
|
+
config = self.load()
|
|
167
|
+
provider_enum = self._ensure_provider_enum(provider)
|
|
168
|
+
provider_config = self._get_provider_config(config, provider_enum)
|
|
169
|
+
provider_config.api_key = None
|
|
170
|
+
self.save(config)
|
|
171
|
+
|
|
172
|
+
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
173
|
+
"""Check if the given provider has a non-empty API key configured.
|
|
174
|
+
|
|
175
|
+
This checks both the configuration file and environment variables.
|
|
176
|
+
"""
|
|
177
|
+
config = self.load()
|
|
178
|
+
provider_enum = self._ensure_provider_enum(provider)
|
|
179
|
+
provider_config = self._get_provider_config(config, provider_enum)
|
|
180
|
+
|
|
181
|
+
# Check config first
|
|
182
|
+
if self._provider_has_api_key(provider_config):
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
# Check environment variable
|
|
186
|
+
if provider_enum == ProviderType.OPENAI:
|
|
187
|
+
return bool(os.getenv(OPENAI_API_KEY_ENV))
|
|
188
|
+
elif provider_enum == ProviderType.ANTHROPIC:
|
|
189
|
+
return bool(os.getenv(ANTHROPIC_API_KEY_ENV))
|
|
190
|
+
elif provider_enum == ProviderType.GOOGLE:
|
|
191
|
+
return bool(os.getenv(GEMINI_API_KEY_ENV))
|
|
192
|
+
|
|
193
|
+
return False
|
|
194
|
+
|
|
195
|
+
def has_any_provider_key(self) -> bool:
|
|
196
|
+
"""Determine whether any provider has a configured API key."""
|
|
197
|
+
config = self.load()
|
|
198
|
+
return any(
|
|
199
|
+
self._provider_has_api_key(self._get_provider_config(config, provider))
|
|
200
|
+
for provider in (
|
|
201
|
+
ProviderType.OPENAI,
|
|
202
|
+
ProviderType.ANTHROPIC,
|
|
203
|
+
ProviderType.GOOGLE,
|
|
204
|
+
)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def initialize(self) -> ShotgunConfig:
|
|
208
|
+
"""Initialize configuration with defaults and save to file.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Default ShotgunConfig
|
|
212
|
+
"""
|
|
213
|
+
# Generate unique user ID for new config
|
|
214
|
+
config = ShotgunConfig(
|
|
215
|
+
user_id=str(uuid.uuid4()),
|
|
216
|
+
config_version=1,
|
|
217
|
+
)
|
|
218
|
+
self.save(config)
|
|
219
|
+
logger.info(
|
|
220
|
+
"Configuration initialized at %s with user_id: %s",
|
|
221
|
+
self.config_path,
|
|
222
|
+
config.user_id,
|
|
223
|
+
)
|
|
224
|
+
return config
|
|
225
|
+
|
|
226
|
+
def _convert_secrets_to_secretstr(self, data: dict[str, Any]) -> None:
|
|
227
|
+
"""Convert plain text secrets in data to SecretStr objects."""
|
|
228
|
+
for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
|
|
229
|
+
if provider in data and isinstance(data[provider], dict):
|
|
230
|
+
if (
|
|
231
|
+
API_KEY_FIELD in data[provider]
|
|
232
|
+
and data[provider][API_KEY_FIELD] is not None
|
|
233
|
+
):
|
|
234
|
+
data[provider][API_KEY_FIELD] = SecretStr(
|
|
235
|
+
data[provider][API_KEY_FIELD]
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _convert_secretstr_to_plain(self, data: dict[str, Any]) -> None:
|
|
239
|
+
"""Convert SecretStr objects in data to plain text for JSON serialization."""
|
|
240
|
+
for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
|
|
241
|
+
if provider in data and isinstance(data[provider], dict):
|
|
242
|
+
if (
|
|
243
|
+
API_KEY_FIELD in data[provider]
|
|
244
|
+
and data[provider][API_KEY_FIELD] is not None
|
|
245
|
+
):
|
|
246
|
+
if hasattr(data[provider][API_KEY_FIELD], "get_secret_value"):
|
|
247
|
+
data[provider][API_KEY_FIELD] = data[provider][
|
|
248
|
+
API_KEY_FIELD
|
|
249
|
+
].get_secret_value()
|
|
250
|
+
|
|
251
|
+
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
252
|
+
"""Normalize provider values to ProviderType enum."""
|
|
253
|
+
return (
|
|
254
|
+
provider if isinstance(provider, ProviderType) else ProviderType(provider)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _get_provider_config(
|
|
258
|
+
self, config: ShotgunConfig, provider: ProviderType
|
|
259
|
+
) -> Any:
|
|
260
|
+
"""Retrieve the provider-specific configuration section."""
|
|
261
|
+
if provider == ProviderType.OPENAI:
|
|
262
|
+
return config.openai
|
|
263
|
+
if provider == ProviderType.ANTHROPIC:
|
|
264
|
+
return config.anthropic
|
|
265
|
+
if provider == ProviderType.GOOGLE:
|
|
266
|
+
return config.google
|
|
267
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
268
|
+
|
|
269
|
+
def _provider_has_api_key(self, provider_config: Any) -> bool:
|
|
270
|
+
"""Return True if the provider config contains a usable API key."""
|
|
271
|
+
api_key = getattr(provider_config, API_KEY_FIELD, None)
|
|
272
|
+
if api_key is None:
|
|
273
|
+
return False
|
|
274
|
+
|
|
275
|
+
if isinstance(api_key, SecretStr):
|
|
276
|
+
value = api_key.get_secret_value()
|
|
277
|
+
else:
|
|
278
|
+
value = str(api_key)
|
|
279
|
+
|
|
280
|
+
return bool(value.strip())
|
|
281
|
+
|
|
282
|
+
def get_user_id(self) -> str:
|
|
283
|
+
"""Get the user ID from configuration.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
The unique user ID string
|
|
287
|
+
"""
|
|
288
|
+
config = self.load()
|
|
289
|
+
return config.user_id
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def get_config_manager() -> ConfigManager:
|
|
293
|
+
"""Get the global ConfigManager instance."""
|
|
294
|
+
return ConfigManager()
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Pydantic models for configuration."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, PrivateAttr, SecretStr
|
|
7
|
+
from pydantic_ai.direct import model_request
|
|
8
|
+
from pydantic_ai.messages import ModelMessage, ModelResponse
|
|
9
|
+
from pydantic_ai.models import Model
|
|
10
|
+
from pydantic_ai.settings import ModelSettings
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProviderType(str, Enum):
|
|
14
|
+
"""Provider types for AI services."""
|
|
15
|
+
|
|
16
|
+
OPENAI = "openai"
|
|
17
|
+
ANTHROPIC = "anthropic"
|
|
18
|
+
GOOGLE = "google"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelSpec(BaseModel):
|
|
22
|
+
"""Static specification for a model - just metadata."""
|
|
23
|
+
|
|
24
|
+
name: str # Model identifier (e.g., "gpt-5", "claude-opus-4-1")
|
|
25
|
+
provider: ProviderType
|
|
26
|
+
max_input_tokens: int
|
|
27
|
+
max_output_tokens: int
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class ModelConfig(BaseModel):
|
|
31
|
+
"""A fully configured model with API key and settings."""
|
|
32
|
+
|
|
33
|
+
name: str # Model identifier (e.g., "gpt-5", "claude-opus-4-1")
|
|
34
|
+
provider: ProviderType
|
|
35
|
+
max_input_tokens: int
|
|
36
|
+
max_output_tokens: int
|
|
37
|
+
api_key: str
|
|
38
|
+
_model_instance: Model | None = PrivateAttr(default=None)
|
|
39
|
+
|
|
40
|
+
class Config:
|
|
41
|
+
arbitrary_types_allowed = True
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def model_instance(self) -> Model:
|
|
45
|
+
"""Lazy load the Model instance."""
|
|
46
|
+
if self._model_instance is None:
|
|
47
|
+
from .provider import get_or_create_model
|
|
48
|
+
|
|
49
|
+
self._model_instance = get_or_create_model(
|
|
50
|
+
self.provider, self.name, self.api_key
|
|
51
|
+
)
|
|
52
|
+
return self._model_instance
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def pydantic_model_name(self) -> str:
|
|
56
|
+
"""Compute the full Pydantic AI model identifier. For backward compatibility."""
|
|
57
|
+
provider_prefix = {
|
|
58
|
+
ProviderType.OPENAI: "openai",
|
|
59
|
+
ProviderType.ANTHROPIC: "anthropic",
|
|
60
|
+
ProviderType.GOOGLE: "google-gla",
|
|
61
|
+
}
|
|
62
|
+
return f"{provider_prefix[self.provider]}:{self.name}"
|
|
63
|
+
|
|
64
|
+
def get_model_settings(self, max_tokens: int | None = None) -> ModelSettings:
|
|
65
|
+
"""Get ModelSettings with optional token override.
|
|
66
|
+
|
|
67
|
+
This provides flexibility for specific use cases that need different
|
|
68
|
+
token limits while defaulting to maximum utilization.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
max_tokens: Optional override for max_tokens. If None, uses max_output_tokens
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
ModelSettings configured with specified or maximum tokens
|
|
75
|
+
"""
|
|
76
|
+
return ModelSettings(
|
|
77
|
+
max_tokens=max_tokens if max_tokens is not None else self.max_output_tokens
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Model specifications registry (static metadata)
|
|
82
|
+
MODEL_SPECS: dict[str, ModelSpec] = {
|
|
83
|
+
"gpt-5": ModelSpec(
|
|
84
|
+
name="gpt-5",
|
|
85
|
+
provider=ProviderType.OPENAI,
|
|
86
|
+
max_input_tokens=400_000,
|
|
87
|
+
max_output_tokens=128_000,
|
|
88
|
+
),
|
|
89
|
+
"gpt-4o": ModelSpec(
|
|
90
|
+
name="gpt-4o",
|
|
91
|
+
provider=ProviderType.OPENAI,
|
|
92
|
+
max_input_tokens=128_000,
|
|
93
|
+
max_output_tokens=16_000,
|
|
94
|
+
),
|
|
95
|
+
"claude-opus-4-1": ModelSpec(
|
|
96
|
+
name="claude-opus-4-1",
|
|
97
|
+
provider=ProviderType.ANTHROPIC,
|
|
98
|
+
max_input_tokens=200_000,
|
|
99
|
+
max_output_tokens=32_000,
|
|
100
|
+
),
|
|
101
|
+
"claude-3-5-sonnet-latest": ModelSpec(
|
|
102
|
+
name="claude-3-5-sonnet-latest",
|
|
103
|
+
provider=ProviderType.ANTHROPIC,
|
|
104
|
+
max_input_tokens=200_000,
|
|
105
|
+
max_output_tokens=8_192,
|
|
106
|
+
),
|
|
107
|
+
"gemini-2.5-pro": ModelSpec(
|
|
108
|
+
name="gemini-2.5-pro",
|
|
109
|
+
provider=ProviderType.GOOGLE,
|
|
110
|
+
max_input_tokens=1_000_000,
|
|
111
|
+
max_output_tokens=64_000,
|
|
112
|
+
),
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class OpenAIConfig(BaseModel):
|
|
117
|
+
"""Configuration for OpenAI provider."""
|
|
118
|
+
|
|
119
|
+
api_key: SecretStr | None = None
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AnthropicConfig(BaseModel):
|
|
123
|
+
"""Configuration for Anthropic provider."""
|
|
124
|
+
|
|
125
|
+
api_key: SecretStr | None = None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class GoogleConfig(BaseModel):
|
|
129
|
+
"""Configuration for Google provider."""
|
|
130
|
+
|
|
131
|
+
api_key: SecretStr | None = None
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class ShotgunConfig(BaseModel):
|
|
135
|
+
"""Main configuration for Shotgun CLI."""
|
|
136
|
+
|
|
137
|
+
openai: OpenAIConfig = Field(default_factory=OpenAIConfig)
|
|
138
|
+
anthropic: AnthropicConfig = Field(default_factory=AnthropicConfig)
|
|
139
|
+
google: GoogleConfig = Field(default_factory=GoogleConfig)
|
|
140
|
+
default_provider: ProviderType = Field(
|
|
141
|
+
default=ProviderType.OPENAI, description="Default AI provider to use"
|
|
142
|
+
)
|
|
143
|
+
user_id: str = Field(description="Unique anonymous user identifier")
|
|
144
|
+
config_version: int = Field(default=1, description="Configuration schema version")
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
async def shotgun_model_request(
|
|
148
|
+
model_config: ModelConfig,
|
|
149
|
+
messages: list[ModelMessage],
|
|
150
|
+
max_tokens: int | None = None,
|
|
151
|
+
**kwargs: Any,
|
|
152
|
+
) -> ModelResponse:
|
|
153
|
+
"""Model request wrapper that uses full token capacity by default.
|
|
154
|
+
|
|
155
|
+
This wrapper ensures all LLM calls in Shotgun use the maximum available
|
|
156
|
+
token capacity of each model, improving response quality and completeness.
|
|
157
|
+
The most common issue this fixes is truncated summaries that were cut off
|
|
158
|
+
at default token limits (e.g., 4096 for Claude models).
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
model_config: ModelConfig instance with model settings and API key
|
|
162
|
+
messages: Messages to send to the model
|
|
163
|
+
max_tokens: Optional override for max_tokens. If None, uses model's max_output_tokens
|
|
164
|
+
**kwargs: Additional arguments passed to model_request
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
ModelResponse from the model
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
# Uses full token capacity (e.g., 4096 for Claude, 128k for GPT-5)
|
|
171
|
+
response = await shotgun_model_request(model_config, messages)
|
|
172
|
+
|
|
173
|
+
# Override for specific use case
|
|
174
|
+
response = await shotgun_model_request(model_config, messages, max_tokens=1000)
|
|
175
|
+
"""
|
|
176
|
+
# Get properly configured ModelSettings with maximum or overridden token limit
|
|
177
|
+
model_settings = model_config.get_model_settings(max_tokens)
|
|
178
|
+
|
|
179
|
+
# Make the model request with full token utilization
|
|
180
|
+
return await model_request(
|
|
181
|
+
model=model_config.model_instance,
|
|
182
|
+
messages=messages,
|
|
183
|
+
model_settings=model_settings,
|
|
184
|
+
**kwargs,
|
|
185
|
+
)
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""Provider management for LLM configuration."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from pydantic import SecretStr
|
|
6
|
+
from pydantic_ai.models import Model
|
|
7
|
+
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
|
|
8
|
+
from pydantic_ai.models.google import GoogleModel
|
|
9
|
+
from pydantic_ai.models.openai import OpenAIChatModel
|
|
10
|
+
from pydantic_ai.providers.anthropic import AnthropicProvider
|
|
11
|
+
from pydantic_ai.providers.google import GoogleProvider
|
|
12
|
+
from pydantic_ai.providers.openai import OpenAIProvider
|
|
13
|
+
from pydantic_ai.settings import ModelSettings
|
|
14
|
+
|
|
15
|
+
from shotgun.logging_config import get_logger
|
|
16
|
+
|
|
17
|
+
from .constants import (
|
|
18
|
+
ANTHROPIC_API_KEY_ENV,
|
|
19
|
+
GEMINI_API_KEY_ENV,
|
|
20
|
+
OPENAI_API_KEY_ENV,
|
|
21
|
+
)
|
|
22
|
+
from .manager import get_config_manager
|
|
23
|
+
from .models import MODEL_SPECS, ModelConfig, ProviderType
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
# Global cache for Model instances (singleton pattern)
|
|
28
|
+
_model_cache: dict[tuple[ProviderType, str, str], Model] = {}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -> Model:
|
|
32
|
+
"""Get or create a singleton Model instance.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
provider: Provider type
|
|
36
|
+
model_name: Name of the model
|
|
37
|
+
api_key: API key for the provider
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Cached or newly created Model instance
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If provider is not supported
|
|
44
|
+
"""
|
|
45
|
+
cache_key = (provider, model_name, api_key)
|
|
46
|
+
|
|
47
|
+
if cache_key not in _model_cache:
|
|
48
|
+
logger.debug("Creating new %s model instance: %s", provider.value, model_name)
|
|
49
|
+
|
|
50
|
+
if provider == ProviderType.OPENAI:
|
|
51
|
+
# Get max_tokens from MODEL_SPECS to use full capacity
|
|
52
|
+
if model_name in MODEL_SPECS:
|
|
53
|
+
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
54
|
+
else:
|
|
55
|
+
max_tokens = 16_000 # Default for GPT models
|
|
56
|
+
|
|
57
|
+
openai_provider = OpenAIProvider(api_key=api_key)
|
|
58
|
+
_model_cache[cache_key] = OpenAIChatModel(
|
|
59
|
+
model_name,
|
|
60
|
+
provider=openai_provider,
|
|
61
|
+
settings=ModelSettings(max_tokens=max_tokens),
|
|
62
|
+
)
|
|
63
|
+
elif provider == ProviderType.ANTHROPIC:
|
|
64
|
+
# Get max_tokens from MODEL_SPECS to use full capacity
|
|
65
|
+
if model_name in MODEL_SPECS:
|
|
66
|
+
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
67
|
+
else:
|
|
68
|
+
max_tokens = 32_000 # Default for Claude models
|
|
69
|
+
|
|
70
|
+
anthropic_provider = AnthropicProvider(api_key=api_key)
|
|
71
|
+
_model_cache[cache_key] = AnthropicModel(
|
|
72
|
+
model_name,
|
|
73
|
+
provider=anthropic_provider,
|
|
74
|
+
settings=AnthropicModelSettings(
|
|
75
|
+
max_tokens=max_tokens,
|
|
76
|
+
timeout=600, # 10 minutes timeout for large responses
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
elif provider == ProviderType.GOOGLE:
|
|
80
|
+
# Get max_tokens from MODEL_SPECS to use full capacity
|
|
81
|
+
if model_name in MODEL_SPECS:
|
|
82
|
+
max_tokens = MODEL_SPECS[model_name].max_output_tokens
|
|
83
|
+
else:
|
|
84
|
+
max_tokens = 64_000 # Default for Gemini models
|
|
85
|
+
|
|
86
|
+
google_provider = GoogleProvider(api_key=api_key)
|
|
87
|
+
_model_cache[cache_key] = GoogleModel(
|
|
88
|
+
model_name,
|
|
89
|
+
provider=google_provider,
|
|
90
|
+
settings=ModelSettings(max_tokens=max_tokens),
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
94
|
+
else:
|
|
95
|
+
logger.debug("Reusing cached %s model instance: %s", provider.value, model_name)
|
|
96
|
+
|
|
97
|
+
return _model_cache[cache_key]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
101
|
+
"""Get a fully configured ModelConfig with API key and Model instance.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
provider: Provider to get model for. If None, uses default provider
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
ModelConfig with API key configured and lazy Model instance
|
|
108
|
+
|
|
109
|
+
Raises:
|
|
110
|
+
ValueError: If provider is not configured properly or model not found
|
|
111
|
+
"""
|
|
112
|
+
config_manager = get_config_manager()
|
|
113
|
+
config = config_manager.load()
|
|
114
|
+
# Convert string to ProviderType enum if needed
|
|
115
|
+
provider_enum = (
|
|
116
|
+
provider
|
|
117
|
+
if isinstance(provider, ProviderType)
|
|
118
|
+
else ProviderType(provider)
|
|
119
|
+
if provider
|
|
120
|
+
else config.default_provider
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if provider_enum == ProviderType.OPENAI:
|
|
124
|
+
api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
|
|
125
|
+
if not api_key:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Get model spec - hardcoded to gpt-5
|
|
131
|
+
model_name = "gpt-5"
|
|
132
|
+
if model_name not in MODEL_SPECS:
|
|
133
|
+
raise ValueError(f"Model '{model_name}' not found")
|
|
134
|
+
spec = MODEL_SPECS[model_name]
|
|
135
|
+
|
|
136
|
+
# Create fully configured ModelConfig
|
|
137
|
+
return ModelConfig(
|
|
138
|
+
name=spec.name,
|
|
139
|
+
provider=spec.provider,
|
|
140
|
+
max_input_tokens=spec.max_input_tokens,
|
|
141
|
+
max_output_tokens=spec.max_output_tokens,
|
|
142
|
+
api_key=api_key,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
elif provider_enum == ProviderType.ANTHROPIC:
|
|
146
|
+
api_key = _get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV)
|
|
147
|
+
if not api_key:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Get model spec - hardcoded to claude-opus-4-1
|
|
153
|
+
model_name = "claude-opus-4-1"
|
|
154
|
+
if model_name not in MODEL_SPECS:
|
|
155
|
+
raise ValueError(f"Model '{model_name}' not found")
|
|
156
|
+
spec = MODEL_SPECS[model_name]
|
|
157
|
+
|
|
158
|
+
# Create fully configured ModelConfig
|
|
159
|
+
return ModelConfig(
|
|
160
|
+
name=spec.name,
|
|
161
|
+
provider=spec.provider,
|
|
162
|
+
max_input_tokens=spec.max_input_tokens,
|
|
163
|
+
max_output_tokens=spec.max_output_tokens,
|
|
164
|
+
api_key=api_key,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
elif provider_enum == ProviderType.GOOGLE:
|
|
168
|
+
api_key = _get_api_key(config.google.api_key, GEMINI_API_KEY_ENV)
|
|
169
|
+
if not api_key:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Get model spec - hardcoded to gemini-2.5-pro
|
|
175
|
+
model_name = "gemini-2.5-pro"
|
|
176
|
+
if model_name not in MODEL_SPECS:
|
|
177
|
+
raise ValueError(f"Model '{model_name}' not found")
|
|
178
|
+
spec = MODEL_SPECS[model_name]
|
|
179
|
+
|
|
180
|
+
# Create fully configured ModelConfig
|
|
181
|
+
return ModelConfig(
|
|
182
|
+
name=spec.name,
|
|
183
|
+
provider=spec.provider,
|
|
184
|
+
max_input_tokens=spec.max_input_tokens,
|
|
185
|
+
max_output_tokens=spec.max_output_tokens,
|
|
186
|
+
api_key=api_key,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
else:
|
|
190
|
+
raise ValueError(f"Unsupported provider: {provider_enum}")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
|
|
194
|
+
"""Get API key from config or environment variable.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
config_key: API key from configuration
|
|
198
|
+
env_var: Environment variable name to check
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
API key string or None
|
|
202
|
+
"""
|
|
203
|
+
if config_key is not None:
|
|
204
|
+
return config_key.get_secret_value()
|
|
205
|
+
|
|
206
|
+
return os.getenv(env_var)
|