shotgun-sh 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (130) hide show
  1. shotgun/__init__.py +5 -0
  2. shotgun/agents/__init__.py +1 -0
  3. shotgun/agents/agent_manager.py +651 -0
  4. shotgun/agents/common.py +549 -0
  5. shotgun/agents/config/__init__.py +13 -0
  6. shotgun/agents/config/constants.py +17 -0
  7. shotgun/agents/config/manager.py +294 -0
  8. shotgun/agents/config/models.py +185 -0
  9. shotgun/agents/config/provider.py +206 -0
  10. shotgun/agents/conversation_history.py +106 -0
  11. shotgun/agents/conversation_manager.py +105 -0
  12. shotgun/agents/export.py +96 -0
  13. shotgun/agents/history/__init__.py +5 -0
  14. shotgun/agents/history/compaction.py +85 -0
  15. shotgun/agents/history/constants.py +19 -0
  16. shotgun/agents/history/context_extraction.py +108 -0
  17. shotgun/agents/history/history_building.py +104 -0
  18. shotgun/agents/history/history_processors.py +426 -0
  19. shotgun/agents/history/message_utils.py +84 -0
  20. shotgun/agents/history/token_counting.py +429 -0
  21. shotgun/agents/history/token_estimation.py +138 -0
  22. shotgun/agents/messages.py +35 -0
  23. shotgun/agents/models.py +275 -0
  24. shotgun/agents/plan.py +98 -0
  25. shotgun/agents/research.py +108 -0
  26. shotgun/agents/specify.py +98 -0
  27. shotgun/agents/tasks.py +96 -0
  28. shotgun/agents/tools/__init__.py +34 -0
  29. shotgun/agents/tools/codebase/__init__.py +28 -0
  30. shotgun/agents/tools/codebase/codebase_shell.py +256 -0
  31. shotgun/agents/tools/codebase/directory_lister.py +141 -0
  32. shotgun/agents/tools/codebase/file_read.py +144 -0
  33. shotgun/agents/tools/codebase/models.py +252 -0
  34. shotgun/agents/tools/codebase/query_graph.py +67 -0
  35. shotgun/agents/tools/codebase/retrieve_code.py +81 -0
  36. shotgun/agents/tools/file_management.py +218 -0
  37. shotgun/agents/tools/user_interaction.py +37 -0
  38. shotgun/agents/tools/web_search/__init__.py +60 -0
  39. shotgun/agents/tools/web_search/anthropic.py +144 -0
  40. shotgun/agents/tools/web_search/gemini.py +85 -0
  41. shotgun/agents/tools/web_search/openai.py +98 -0
  42. shotgun/agents/tools/web_search/utils.py +20 -0
  43. shotgun/build_constants.py +20 -0
  44. shotgun/cli/__init__.py +1 -0
  45. shotgun/cli/codebase/__init__.py +5 -0
  46. shotgun/cli/codebase/commands.py +202 -0
  47. shotgun/cli/codebase/models.py +21 -0
  48. shotgun/cli/config.py +275 -0
  49. shotgun/cli/export.py +81 -0
  50. shotgun/cli/models.py +10 -0
  51. shotgun/cli/plan.py +73 -0
  52. shotgun/cli/research.py +85 -0
  53. shotgun/cli/specify.py +69 -0
  54. shotgun/cli/tasks.py +78 -0
  55. shotgun/cli/update.py +152 -0
  56. shotgun/cli/utils.py +25 -0
  57. shotgun/codebase/__init__.py +12 -0
  58. shotgun/codebase/core/__init__.py +46 -0
  59. shotgun/codebase/core/change_detector.py +358 -0
  60. shotgun/codebase/core/code_retrieval.py +243 -0
  61. shotgun/codebase/core/ingestor.py +1497 -0
  62. shotgun/codebase/core/language_config.py +297 -0
  63. shotgun/codebase/core/manager.py +1662 -0
  64. shotgun/codebase/core/nl_query.py +331 -0
  65. shotgun/codebase/core/parser_loader.py +128 -0
  66. shotgun/codebase/models.py +111 -0
  67. shotgun/codebase/service.py +206 -0
  68. shotgun/logging_config.py +227 -0
  69. shotgun/main.py +167 -0
  70. shotgun/posthog_telemetry.py +158 -0
  71. shotgun/prompts/__init__.py +5 -0
  72. shotgun/prompts/agents/__init__.py +1 -0
  73. shotgun/prompts/agents/export.j2 +350 -0
  74. shotgun/prompts/agents/partials/codebase_understanding.j2 +87 -0
  75. shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +37 -0
  76. shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
  77. shotgun/prompts/agents/partials/interactive_mode.j2 +26 -0
  78. shotgun/prompts/agents/plan.j2 +144 -0
  79. shotgun/prompts/agents/research.j2 +69 -0
  80. shotgun/prompts/agents/specify.j2 +51 -0
  81. shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +19 -0
  82. shotgun/prompts/agents/state/system_state.j2 +31 -0
  83. shotgun/prompts/agents/tasks.j2 +143 -0
  84. shotgun/prompts/codebase/__init__.py +1 -0
  85. shotgun/prompts/codebase/cypher_query_patterns.j2 +223 -0
  86. shotgun/prompts/codebase/cypher_system.j2 +28 -0
  87. shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
  88. shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
  89. shotgun/prompts/codebase/partials/graph_schema.j2 +30 -0
  90. shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
  91. shotgun/prompts/history/__init__.py +1 -0
  92. shotgun/prompts/history/incremental_summarization.j2 +53 -0
  93. shotgun/prompts/history/summarization.j2 +46 -0
  94. shotgun/prompts/loader.py +140 -0
  95. shotgun/py.typed +0 -0
  96. shotgun/sdk/__init__.py +13 -0
  97. shotgun/sdk/codebase.py +219 -0
  98. shotgun/sdk/exceptions.py +17 -0
  99. shotgun/sdk/models.py +189 -0
  100. shotgun/sdk/services.py +23 -0
  101. shotgun/sentry_telemetry.py +87 -0
  102. shotgun/telemetry.py +93 -0
  103. shotgun/tui/__init__.py +0 -0
  104. shotgun/tui/app.py +116 -0
  105. shotgun/tui/commands/__init__.py +76 -0
  106. shotgun/tui/components/prompt_input.py +69 -0
  107. shotgun/tui/components/spinner.py +86 -0
  108. shotgun/tui/components/splash.py +25 -0
  109. shotgun/tui/components/vertical_tail.py +13 -0
  110. shotgun/tui/screens/chat.py +782 -0
  111. shotgun/tui/screens/chat.tcss +43 -0
  112. shotgun/tui/screens/chat_screen/__init__.py +0 -0
  113. shotgun/tui/screens/chat_screen/command_providers.py +219 -0
  114. shotgun/tui/screens/chat_screen/hint_message.py +40 -0
  115. shotgun/tui/screens/chat_screen/history.py +221 -0
  116. shotgun/tui/screens/directory_setup.py +113 -0
  117. shotgun/tui/screens/provider_config.py +221 -0
  118. shotgun/tui/screens/splash.py +31 -0
  119. shotgun/tui/styles.tcss +10 -0
  120. shotgun/tui/utils/__init__.py +5 -0
  121. shotgun/tui/utils/mode_progress.py +257 -0
  122. shotgun/utils/__init__.py +5 -0
  123. shotgun/utils/env_utils.py +35 -0
  124. shotgun/utils/file_system_utils.py +36 -0
  125. shotgun/utils/update_checker.py +375 -0
  126. shotgun_sh-0.1.0.dist-info/METADATA +466 -0
  127. shotgun_sh-0.1.0.dist-info/RECORD +130 -0
  128. shotgun_sh-0.1.0.dist-info/WHEEL +4 -0
  129. shotgun_sh-0.1.0.dist-info/entry_points.txt +2 -0
  130. shotgun_sh-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,294 @@
1
+ """Configuration manager for Shotgun CLI."""
2
+
3
+ import json
4
+ import os
5
+ import uuid
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from pydantic import SecretStr
10
+
11
+ from shotgun.logging_config import get_logger
12
+ from shotgun.utils import get_shotgun_home
13
+
14
+ from .constants import (
15
+ ANTHROPIC_API_KEY_ENV,
16
+ ANTHROPIC_PROVIDER,
17
+ API_KEY_FIELD,
18
+ GEMINI_API_KEY_ENV,
19
+ GOOGLE_PROVIDER,
20
+ OPENAI_API_KEY_ENV,
21
+ OPENAI_PROVIDER,
22
+ )
23
+ from .models import ProviderType, ShotgunConfig
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ class ConfigManager:
29
+ """Manager for Shotgun configuration."""
30
+
31
+ def __init__(self, config_path: Path | None = None):
32
+ """Initialize ConfigManager.
33
+
34
+ Args:
35
+ config_path: Path to config file. If None, uses default ~/.shotgun-sh/config.json
36
+ """
37
+ if config_path is None:
38
+ self.config_path = get_shotgun_home() / "config.json"
39
+ else:
40
+ self.config_path = config_path
41
+
42
+ self._config: ShotgunConfig | None = None
43
+
44
+ def load(self) -> ShotgunConfig:
45
+ """Load configuration from file.
46
+
47
+ Returns:
48
+ ShotgunConfig: Loaded configuration or default config if file doesn't exist
49
+ """
50
+ if self._config is not None:
51
+ return self._config
52
+
53
+ if not self.config_path.exists():
54
+ logger.info(
55
+ "Configuration file not found, creating new config with user_id: %s",
56
+ self.config_path,
57
+ )
58
+ # Create new config with generated user_id
59
+ self._config = self.initialize()
60
+ return self._config
61
+
62
+ try:
63
+ with open(self.config_path, encoding="utf-8") as f:
64
+ data = json.load(f)
65
+
66
+ # Convert plain text secrets to SecretStr objects
67
+ self._convert_secrets_to_secretstr(data)
68
+
69
+ self._config = ShotgunConfig.model_validate(data)
70
+ logger.debug("Configuration loaded successfully from %s", self.config_path)
71
+
72
+ # Check if the default provider has a key, if not find one that does
73
+ if not self.has_provider_key(self._config.default_provider):
74
+ original_default = self._config.default_provider
75
+ # Find first provider with a configured key
76
+ for provider in ProviderType:
77
+ if self.has_provider_key(provider):
78
+ logger.info(
79
+ "Default provider %s has no API key, updating to %s",
80
+ original_default.value,
81
+ provider.value,
82
+ )
83
+ self._config.default_provider = provider
84
+ self.save(self._config)
85
+ break
86
+
87
+ return self._config
88
+
89
+ except Exception as e:
90
+ logger.error(
91
+ "Failed to load configuration from %s: %s", self.config_path, e
92
+ )
93
+ logger.info("Creating new configuration with generated user_id")
94
+ self._config = self.initialize()
95
+ return self._config
96
+
97
+ def save(self, config: ShotgunConfig | None = None) -> None:
98
+ """Save configuration to file.
99
+
100
+ Args:
101
+ config: Configuration to save. If None, saves current loaded config
102
+ """
103
+ if config is None:
104
+ if self._config:
105
+ config = self._config
106
+ else:
107
+ # Create a new config with generated user_id
108
+ config = ShotgunConfig(
109
+ user_id=str(uuid.uuid4()),
110
+ config_version=1,
111
+ )
112
+
113
+ # Ensure directory exists
114
+ self.config_path.parent.mkdir(parents=True, exist_ok=True)
115
+
116
+ try:
117
+ # Convert SecretStr to plain text for JSON serialization
118
+ data = config.model_dump()
119
+ self._convert_secretstr_to_plain(data)
120
+
121
+ with open(self.config_path, "w", encoding="utf-8") as f:
122
+ json.dump(data, f, indent=2, ensure_ascii=False)
123
+
124
+ logger.debug("Configuration saved to %s", self.config_path)
125
+ self._config = config
126
+
127
+ except Exception as e:
128
+ logger.error("Failed to save configuration to %s: %s", self.config_path, e)
129
+ raise
130
+
131
+ def update_provider(self, provider: ProviderType | str, **kwargs: Any) -> None:
132
+ """Update provider configuration.
133
+
134
+ Args:
135
+ provider: Provider to update
136
+ **kwargs: Configuration fields to update (only api_key supported)
137
+ """
138
+ config = self.load()
139
+ provider_enum = self._ensure_provider_enum(provider)
140
+ provider_config = self._get_provider_config(config, provider_enum)
141
+
142
+ # Only support api_key updates
143
+ if API_KEY_FIELD in kwargs:
144
+ api_key_value = kwargs[API_KEY_FIELD]
145
+ provider_config.api_key = (
146
+ SecretStr(api_key_value) if api_key_value is not None else None
147
+ )
148
+
149
+ # Reject other fields
150
+ unsupported_fields = set(kwargs.keys()) - {API_KEY_FIELD}
151
+ if unsupported_fields:
152
+ raise ValueError(f"Unsupported configuration fields: {unsupported_fields}")
153
+
154
+ # If no other providers have keys configured and we just added one,
155
+ # set this provider as the default
156
+ if API_KEY_FIELD in kwargs and api_key_value is not None:
157
+ other_providers = [p for p in ProviderType if p != provider_enum]
158
+ has_other_keys = any(self.has_provider_key(p) for p in other_providers)
159
+ if not has_other_keys:
160
+ config.default_provider = provider_enum
161
+
162
+ self.save(config)
163
+
164
+ def clear_provider_key(self, provider: ProviderType | str) -> None:
165
+ """Remove the API key for the given provider."""
166
+ config = self.load()
167
+ provider_enum = self._ensure_provider_enum(provider)
168
+ provider_config = self._get_provider_config(config, provider_enum)
169
+ provider_config.api_key = None
170
+ self.save(config)
171
+
172
+ def has_provider_key(self, provider: ProviderType | str) -> bool:
173
+ """Check if the given provider has a non-empty API key configured.
174
+
175
+ This checks both the configuration file and environment variables.
176
+ """
177
+ config = self.load()
178
+ provider_enum = self._ensure_provider_enum(provider)
179
+ provider_config = self._get_provider_config(config, provider_enum)
180
+
181
+ # Check config first
182
+ if self._provider_has_api_key(provider_config):
183
+ return True
184
+
185
+ # Check environment variable
186
+ if provider_enum == ProviderType.OPENAI:
187
+ return bool(os.getenv(OPENAI_API_KEY_ENV))
188
+ elif provider_enum == ProviderType.ANTHROPIC:
189
+ return bool(os.getenv(ANTHROPIC_API_KEY_ENV))
190
+ elif provider_enum == ProviderType.GOOGLE:
191
+ return bool(os.getenv(GEMINI_API_KEY_ENV))
192
+
193
+ return False
194
+
195
+ def has_any_provider_key(self) -> bool:
196
+ """Determine whether any provider has a configured API key."""
197
+ config = self.load()
198
+ return any(
199
+ self._provider_has_api_key(self._get_provider_config(config, provider))
200
+ for provider in (
201
+ ProviderType.OPENAI,
202
+ ProviderType.ANTHROPIC,
203
+ ProviderType.GOOGLE,
204
+ )
205
+ )
206
+
207
+ def initialize(self) -> ShotgunConfig:
208
+ """Initialize configuration with defaults and save to file.
209
+
210
+ Returns:
211
+ Default ShotgunConfig
212
+ """
213
+ # Generate unique user ID for new config
214
+ config = ShotgunConfig(
215
+ user_id=str(uuid.uuid4()),
216
+ config_version=1,
217
+ )
218
+ self.save(config)
219
+ logger.info(
220
+ "Configuration initialized at %s with user_id: %s",
221
+ self.config_path,
222
+ config.user_id,
223
+ )
224
+ return config
225
+
226
+ def _convert_secrets_to_secretstr(self, data: dict[str, Any]) -> None:
227
+ """Convert plain text secrets in data to SecretStr objects."""
228
+ for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
229
+ if provider in data and isinstance(data[provider], dict):
230
+ if (
231
+ API_KEY_FIELD in data[provider]
232
+ and data[provider][API_KEY_FIELD] is not None
233
+ ):
234
+ data[provider][API_KEY_FIELD] = SecretStr(
235
+ data[provider][API_KEY_FIELD]
236
+ )
237
+
238
+ def _convert_secretstr_to_plain(self, data: dict[str, Any]) -> None:
239
+ """Convert SecretStr objects in data to plain text for JSON serialization."""
240
+ for provider in [OPENAI_PROVIDER, ANTHROPIC_PROVIDER, GOOGLE_PROVIDER]:
241
+ if provider in data and isinstance(data[provider], dict):
242
+ if (
243
+ API_KEY_FIELD in data[provider]
244
+ and data[provider][API_KEY_FIELD] is not None
245
+ ):
246
+ if hasattr(data[provider][API_KEY_FIELD], "get_secret_value"):
247
+ data[provider][API_KEY_FIELD] = data[provider][
248
+ API_KEY_FIELD
249
+ ].get_secret_value()
250
+
251
+ def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
252
+ """Normalize provider values to ProviderType enum."""
253
+ return (
254
+ provider if isinstance(provider, ProviderType) else ProviderType(provider)
255
+ )
256
+
257
+ def _get_provider_config(
258
+ self, config: ShotgunConfig, provider: ProviderType
259
+ ) -> Any:
260
+ """Retrieve the provider-specific configuration section."""
261
+ if provider == ProviderType.OPENAI:
262
+ return config.openai
263
+ if provider == ProviderType.ANTHROPIC:
264
+ return config.anthropic
265
+ if provider == ProviderType.GOOGLE:
266
+ return config.google
267
+ raise ValueError(f"Unsupported provider: {provider}")
268
+
269
+ def _provider_has_api_key(self, provider_config: Any) -> bool:
270
+ """Return True if the provider config contains a usable API key."""
271
+ api_key = getattr(provider_config, API_KEY_FIELD, None)
272
+ if api_key is None:
273
+ return False
274
+
275
+ if isinstance(api_key, SecretStr):
276
+ value = api_key.get_secret_value()
277
+ else:
278
+ value = str(api_key)
279
+
280
+ return bool(value.strip())
281
+
282
+ def get_user_id(self) -> str:
283
+ """Get the user ID from configuration.
284
+
285
+ Returns:
286
+ The unique user ID string
287
+ """
288
+ config = self.load()
289
+ return config.user_id
290
+
291
+
292
+ def get_config_manager() -> ConfigManager:
293
+ """Get the global ConfigManager instance."""
294
+ return ConfigManager()
@@ -0,0 +1,185 @@
1
+ """Pydantic models for configuration."""
2
+
3
+ from enum import Enum
4
+ from typing import Any
5
+
6
+ from pydantic import BaseModel, Field, PrivateAttr, SecretStr
7
+ from pydantic_ai.direct import model_request
8
+ from pydantic_ai.messages import ModelMessage, ModelResponse
9
+ from pydantic_ai.models import Model
10
+ from pydantic_ai.settings import ModelSettings
11
+
12
+
13
+ class ProviderType(str, Enum):
14
+ """Provider types for AI services."""
15
+
16
+ OPENAI = "openai"
17
+ ANTHROPIC = "anthropic"
18
+ GOOGLE = "google"
19
+
20
+
21
+ class ModelSpec(BaseModel):
22
+ """Static specification for a model - just metadata."""
23
+
24
+ name: str # Model identifier (e.g., "gpt-5", "claude-opus-4-1")
25
+ provider: ProviderType
26
+ max_input_tokens: int
27
+ max_output_tokens: int
28
+
29
+
30
+ class ModelConfig(BaseModel):
31
+ """A fully configured model with API key and settings."""
32
+
33
+ name: str # Model identifier (e.g., "gpt-5", "claude-opus-4-1")
34
+ provider: ProviderType
35
+ max_input_tokens: int
36
+ max_output_tokens: int
37
+ api_key: str
38
+ _model_instance: Model | None = PrivateAttr(default=None)
39
+
40
+ class Config:
41
+ arbitrary_types_allowed = True
42
+
43
+ @property
44
+ def model_instance(self) -> Model:
45
+ """Lazy load the Model instance."""
46
+ if self._model_instance is None:
47
+ from .provider import get_or_create_model
48
+
49
+ self._model_instance = get_or_create_model(
50
+ self.provider, self.name, self.api_key
51
+ )
52
+ return self._model_instance
53
+
54
+ @property
55
+ def pydantic_model_name(self) -> str:
56
+ """Compute the full Pydantic AI model identifier. For backward compatibility."""
57
+ provider_prefix = {
58
+ ProviderType.OPENAI: "openai",
59
+ ProviderType.ANTHROPIC: "anthropic",
60
+ ProviderType.GOOGLE: "google-gla",
61
+ }
62
+ return f"{provider_prefix[self.provider]}:{self.name}"
63
+
64
+ def get_model_settings(self, max_tokens: int | None = None) -> ModelSettings:
65
+ """Get ModelSettings with optional token override.
66
+
67
+ This provides flexibility for specific use cases that need different
68
+ token limits while defaulting to maximum utilization.
69
+
70
+ Args:
71
+ max_tokens: Optional override for max_tokens. If None, uses max_output_tokens
72
+
73
+ Returns:
74
+ ModelSettings configured with specified or maximum tokens
75
+ """
76
+ return ModelSettings(
77
+ max_tokens=max_tokens if max_tokens is not None else self.max_output_tokens
78
+ )
79
+
80
+
81
+ # Model specifications registry (static metadata)
82
+ MODEL_SPECS: dict[str, ModelSpec] = {
83
+ "gpt-5": ModelSpec(
84
+ name="gpt-5",
85
+ provider=ProviderType.OPENAI,
86
+ max_input_tokens=400_000,
87
+ max_output_tokens=128_000,
88
+ ),
89
+ "gpt-4o": ModelSpec(
90
+ name="gpt-4o",
91
+ provider=ProviderType.OPENAI,
92
+ max_input_tokens=128_000,
93
+ max_output_tokens=16_000,
94
+ ),
95
+ "claude-opus-4-1": ModelSpec(
96
+ name="claude-opus-4-1",
97
+ provider=ProviderType.ANTHROPIC,
98
+ max_input_tokens=200_000,
99
+ max_output_tokens=32_000,
100
+ ),
101
+ "claude-3-5-sonnet-latest": ModelSpec(
102
+ name="claude-3-5-sonnet-latest",
103
+ provider=ProviderType.ANTHROPIC,
104
+ max_input_tokens=200_000,
105
+ max_output_tokens=8_192,
106
+ ),
107
+ "gemini-2.5-pro": ModelSpec(
108
+ name="gemini-2.5-pro",
109
+ provider=ProviderType.GOOGLE,
110
+ max_input_tokens=1_000_000,
111
+ max_output_tokens=64_000,
112
+ ),
113
+ }
114
+
115
+
116
+ class OpenAIConfig(BaseModel):
117
+ """Configuration for OpenAI provider."""
118
+
119
+ api_key: SecretStr | None = None
120
+
121
+
122
+ class AnthropicConfig(BaseModel):
123
+ """Configuration for Anthropic provider."""
124
+
125
+ api_key: SecretStr | None = None
126
+
127
+
128
+ class GoogleConfig(BaseModel):
129
+ """Configuration for Google provider."""
130
+
131
+ api_key: SecretStr | None = None
132
+
133
+
134
+ class ShotgunConfig(BaseModel):
135
+ """Main configuration for Shotgun CLI."""
136
+
137
+ openai: OpenAIConfig = Field(default_factory=OpenAIConfig)
138
+ anthropic: AnthropicConfig = Field(default_factory=AnthropicConfig)
139
+ google: GoogleConfig = Field(default_factory=GoogleConfig)
140
+ default_provider: ProviderType = Field(
141
+ default=ProviderType.OPENAI, description="Default AI provider to use"
142
+ )
143
+ user_id: str = Field(description="Unique anonymous user identifier")
144
+ config_version: int = Field(default=1, description="Configuration schema version")
145
+
146
+
147
+ async def shotgun_model_request(
148
+ model_config: ModelConfig,
149
+ messages: list[ModelMessage],
150
+ max_tokens: int | None = None,
151
+ **kwargs: Any,
152
+ ) -> ModelResponse:
153
+ """Model request wrapper that uses full token capacity by default.
154
+
155
+ This wrapper ensures all LLM calls in Shotgun use the maximum available
156
+ token capacity of each model, improving response quality and completeness.
157
+ The most common issue this fixes is truncated summaries that were cut off
158
+ at default token limits (e.g., 4096 for Claude models).
159
+
160
+ Args:
161
+ model_config: ModelConfig instance with model settings and API key
162
+ messages: Messages to send to the model
163
+ max_tokens: Optional override for max_tokens. If None, uses model's max_output_tokens
164
+ **kwargs: Additional arguments passed to model_request
165
+
166
+ Returns:
167
+ ModelResponse from the model
168
+
169
+ Example:
170
+ # Uses full token capacity (e.g., 4096 for Claude, 128k for GPT-5)
171
+ response = await shotgun_model_request(model_config, messages)
172
+
173
+ # Override for specific use case
174
+ response = await shotgun_model_request(model_config, messages, max_tokens=1000)
175
+ """
176
+ # Get properly configured ModelSettings with maximum or overridden token limit
177
+ model_settings = model_config.get_model_settings(max_tokens)
178
+
179
+ # Make the model request with full token utilization
180
+ return await model_request(
181
+ model=model_config.model_instance,
182
+ messages=messages,
183
+ model_settings=model_settings,
184
+ **kwargs,
185
+ )
@@ -0,0 +1,206 @@
1
+ """Provider management for LLM configuration."""
2
+
3
+ import os
4
+
5
+ from pydantic import SecretStr
6
+ from pydantic_ai.models import Model
7
+ from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
8
+ from pydantic_ai.models.google import GoogleModel
9
+ from pydantic_ai.models.openai import OpenAIChatModel
10
+ from pydantic_ai.providers.anthropic import AnthropicProvider
11
+ from pydantic_ai.providers.google import GoogleProvider
12
+ from pydantic_ai.providers.openai import OpenAIProvider
13
+ from pydantic_ai.settings import ModelSettings
14
+
15
+ from shotgun.logging_config import get_logger
16
+
17
+ from .constants import (
18
+ ANTHROPIC_API_KEY_ENV,
19
+ GEMINI_API_KEY_ENV,
20
+ OPENAI_API_KEY_ENV,
21
+ )
22
+ from .manager import get_config_manager
23
+ from .models import MODEL_SPECS, ModelConfig, ProviderType
24
+
25
+ logger = get_logger(__name__)
26
+
27
+ # Global cache for Model instances (singleton pattern)
28
+ _model_cache: dict[tuple[ProviderType, str, str], Model] = {}
29
+
30
+
31
+ def get_or_create_model(provider: ProviderType, model_name: str, api_key: str) -> Model:
32
+ """Get or create a singleton Model instance.
33
+
34
+ Args:
35
+ provider: Provider type
36
+ model_name: Name of the model
37
+ api_key: API key for the provider
38
+
39
+ Returns:
40
+ Cached or newly created Model instance
41
+
42
+ Raises:
43
+ ValueError: If provider is not supported
44
+ """
45
+ cache_key = (provider, model_name, api_key)
46
+
47
+ if cache_key not in _model_cache:
48
+ logger.debug("Creating new %s model instance: %s", provider.value, model_name)
49
+
50
+ if provider == ProviderType.OPENAI:
51
+ # Get max_tokens from MODEL_SPECS to use full capacity
52
+ if model_name in MODEL_SPECS:
53
+ max_tokens = MODEL_SPECS[model_name].max_output_tokens
54
+ else:
55
+ max_tokens = 16_000 # Default for GPT models
56
+
57
+ openai_provider = OpenAIProvider(api_key=api_key)
58
+ _model_cache[cache_key] = OpenAIChatModel(
59
+ model_name,
60
+ provider=openai_provider,
61
+ settings=ModelSettings(max_tokens=max_tokens),
62
+ )
63
+ elif provider == ProviderType.ANTHROPIC:
64
+ # Get max_tokens from MODEL_SPECS to use full capacity
65
+ if model_name in MODEL_SPECS:
66
+ max_tokens = MODEL_SPECS[model_name].max_output_tokens
67
+ else:
68
+ max_tokens = 32_000 # Default for Claude models
69
+
70
+ anthropic_provider = AnthropicProvider(api_key=api_key)
71
+ _model_cache[cache_key] = AnthropicModel(
72
+ model_name,
73
+ provider=anthropic_provider,
74
+ settings=AnthropicModelSettings(
75
+ max_tokens=max_tokens,
76
+ timeout=600, # 10 minutes timeout for large responses
77
+ ),
78
+ )
79
+ elif provider == ProviderType.GOOGLE:
80
+ # Get max_tokens from MODEL_SPECS to use full capacity
81
+ if model_name in MODEL_SPECS:
82
+ max_tokens = MODEL_SPECS[model_name].max_output_tokens
83
+ else:
84
+ max_tokens = 64_000 # Default for Gemini models
85
+
86
+ google_provider = GoogleProvider(api_key=api_key)
87
+ _model_cache[cache_key] = GoogleModel(
88
+ model_name,
89
+ provider=google_provider,
90
+ settings=ModelSettings(max_tokens=max_tokens),
91
+ )
92
+ else:
93
+ raise ValueError(f"Unsupported provider: {provider}")
94
+ else:
95
+ logger.debug("Reusing cached %s model instance: %s", provider.value, model_name)
96
+
97
+ return _model_cache[cache_key]
98
+
99
+
100
+ def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
101
+ """Get a fully configured ModelConfig with API key and Model instance.
102
+
103
+ Args:
104
+ provider: Provider to get model for. If None, uses default provider
105
+
106
+ Returns:
107
+ ModelConfig with API key configured and lazy Model instance
108
+
109
+ Raises:
110
+ ValueError: If provider is not configured properly or model not found
111
+ """
112
+ config_manager = get_config_manager()
113
+ config = config_manager.load()
114
+ # Convert string to ProviderType enum if needed
115
+ provider_enum = (
116
+ provider
117
+ if isinstance(provider, ProviderType)
118
+ else ProviderType(provider)
119
+ if provider
120
+ else config.default_provider
121
+ )
122
+
123
+ if provider_enum == ProviderType.OPENAI:
124
+ api_key = _get_api_key(config.openai.api_key, OPENAI_API_KEY_ENV)
125
+ if not api_key:
126
+ raise ValueError(
127
+ f"OpenAI API key not configured. Set via environment variable {OPENAI_API_KEY_ENV} or config."
128
+ )
129
+
130
+ # Get model spec - hardcoded to gpt-5
131
+ model_name = "gpt-5"
132
+ if model_name not in MODEL_SPECS:
133
+ raise ValueError(f"Model '{model_name}' not found")
134
+ spec = MODEL_SPECS[model_name]
135
+
136
+ # Create fully configured ModelConfig
137
+ return ModelConfig(
138
+ name=spec.name,
139
+ provider=spec.provider,
140
+ max_input_tokens=spec.max_input_tokens,
141
+ max_output_tokens=spec.max_output_tokens,
142
+ api_key=api_key,
143
+ )
144
+
145
+ elif provider_enum == ProviderType.ANTHROPIC:
146
+ api_key = _get_api_key(config.anthropic.api_key, ANTHROPIC_API_KEY_ENV)
147
+ if not api_key:
148
+ raise ValueError(
149
+ f"Anthropic API key not configured. Set via environment variable {ANTHROPIC_API_KEY_ENV} or config."
150
+ )
151
+
152
+ # Get model spec - hardcoded to claude-opus-4-1
153
+ model_name = "claude-opus-4-1"
154
+ if model_name not in MODEL_SPECS:
155
+ raise ValueError(f"Model '{model_name}' not found")
156
+ spec = MODEL_SPECS[model_name]
157
+
158
+ # Create fully configured ModelConfig
159
+ return ModelConfig(
160
+ name=spec.name,
161
+ provider=spec.provider,
162
+ max_input_tokens=spec.max_input_tokens,
163
+ max_output_tokens=spec.max_output_tokens,
164
+ api_key=api_key,
165
+ )
166
+
167
+ elif provider_enum == ProviderType.GOOGLE:
168
+ api_key = _get_api_key(config.google.api_key, GEMINI_API_KEY_ENV)
169
+ if not api_key:
170
+ raise ValueError(
171
+ f"Gemini API key not configured. Set via environment variable {GEMINI_API_KEY_ENV} or config."
172
+ )
173
+
174
+ # Get model spec - hardcoded to gemini-2.5-pro
175
+ model_name = "gemini-2.5-pro"
176
+ if model_name not in MODEL_SPECS:
177
+ raise ValueError(f"Model '{model_name}' not found")
178
+ spec = MODEL_SPECS[model_name]
179
+
180
+ # Create fully configured ModelConfig
181
+ return ModelConfig(
182
+ name=spec.name,
183
+ provider=spec.provider,
184
+ max_input_tokens=spec.max_input_tokens,
185
+ max_output_tokens=spec.max_output_tokens,
186
+ api_key=api_key,
187
+ )
188
+
189
+ else:
190
+ raise ValueError(f"Unsupported provider: {provider_enum}")
191
+
192
+
193
+ def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
194
+ """Get API key from config or environment variable.
195
+
196
+ Args:
197
+ config_key: API key from configuration
198
+ env_var: Environment variable name to check
199
+
200
+ Returns:
201
+ API key string or None
202
+ """
203
+ if config_key is not None:
204
+ return config_key.get_secret_value()
205
+
206
+ return os.getenv(env_var)