shotgun-sh 0.1.0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/__init__.py +3 -0
- shotgun/agents/__init__.py +1 -0
- shotgun/agents/agent_manager.py +196 -0
- shotgun/agents/common.py +295 -0
- shotgun/agents/config/__init__.py +13 -0
- shotgun/agents/config/manager.py +215 -0
- shotgun/agents/config/models.py +120 -0
- shotgun/agents/config/provider.py +91 -0
- shotgun/agents/history/__init__.py +5 -0
- shotgun/agents/history/history_processors.py +213 -0
- shotgun/agents/models.py +94 -0
- shotgun/agents/plan.py +119 -0
- shotgun/agents/research.py +131 -0
- shotgun/agents/tasks.py +122 -0
- shotgun/agents/tools/__init__.py +26 -0
- shotgun/agents/tools/codebase/__init__.py +28 -0
- shotgun/agents/tools/codebase/codebase_shell.py +256 -0
- shotgun/agents/tools/codebase/directory_lister.py +141 -0
- shotgun/agents/tools/codebase/file_read.py +144 -0
- shotgun/agents/tools/codebase/models.py +252 -0
- shotgun/agents/tools/codebase/query_graph.py +67 -0
- shotgun/agents/tools/codebase/retrieve_code.py +81 -0
- shotgun/agents/tools/file_management.py +130 -0
- shotgun/agents/tools/user_interaction.py +36 -0
- shotgun/agents/tools/web_search.py +69 -0
- shotgun/cli/__init__.py +1 -0
- shotgun/cli/codebase/__init__.py +5 -0
- shotgun/cli/codebase/commands.py +202 -0
- shotgun/cli/codebase/models.py +21 -0
- shotgun/cli/config.py +261 -0
- shotgun/cli/models.py +10 -0
- shotgun/cli/plan.py +65 -0
- shotgun/cli/research.py +78 -0
- shotgun/cli/tasks.py +71 -0
- shotgun/cli/utils.py +25 -0
- shotgun/codebase/__init__.py +12 -0
- shotgun/codebase/core/__init__.py +46 -0
- shotgun/codebase/core/change_detector.py +358 -0
- shotgun/codebase/core/code_retrieval.py +243 -0
- shotgun/codebase/core/ingestor.py +1497 -0
- shotgun/codebase/core/language_config.py +297 -0
- shotgun/codebase/core/manager.py +1554 -0
- shotgun/codebase/core/nl_query.py +327 -0
- shotgun/codebase/core/parser_loader.py +152 -0
- shotgun/codebase/models.py +107 -0
- shotgun/codebase/service.py +148 -0
- shotgun/logging_config.py +172 -0
- shotgun/main.py +73 -0
- shotgun/prompts/__init__.py +5 -0
- shotgun/prompts/agents/__init__.py +1 -0
- shotgun/prompts/agents/partials/codebase_understanding.j2 +79 -0
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +10 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +8 -0
- shotgun/prompts/agents/plan.j2 +57 -0
- shotgun/prompts/agents/research.j2 +38 -0
- shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +13 -0
- shotgun/prompts/agents/state/system_state.j2 +1 -0
- shotgun/prompts/agents/tasks.j2 +67 -0
- shotgun/prompts/codebase/__init__.py +1 -0
- shotgun/prompts/codebase/cypher_query_patterns.j2 +221 -0
- shotgun/prompts/codebase/cypher_system.j2 +28 -0
- shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
- shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
- shotgun/prompts/codebase/partials/graph_schema.j2 +28 -0
- shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
- shotgun/prompts/history/__init__.py +1 -0
- shotgun/prompts/history/summarization.j2 +46 -0
- shotgun/prompts/loader.py +140 -0
- shotgun/prompts/user/research.j2 +5 -0
- shotgun/py.typed +0 -0
- shotgun/sdk/__init__.py +13 -0
- shotgun/sdk/codebase.py +195 -0
- shotgun/sdk/exceptions.py +17 -0
- shotgun/sdk/models.py +189 -0
- shotgun/sdk/services.py +23 -0
- shotgun/telemetry.py +68 -0
- shotgun/tui/__init__.py +0 -0
- shotgun/tui/app.py +49 -0
- shotgun/tui/components/prompt_input.py +69 -0
- shotgun/tui/components/spinner.py +86 -0
- shotgun/tui/components/splash.py +25 -0
- shotgun/tui/components/vertical_tail.py +28 -0
- shotgun/tui/screens/chat.py +415 -0
- shotgun/tui/screens/chat.tcss +28 -0
- shotgun/tui/screens/provider_config.py +221 -0
- shotgun/tui/screens/splash.py +31 -0
- shotgun/tui/styles.tcss +10 -0
- shotgun/utils/__init__.py +5 -0
- shotgun/utils/file_system_utils.py +31 -0
- shotgun_sh-0.1.0.dev1.dist-info/METADATA +318 -0
- shotgun_sh-0.1.0.dev1.dist-info/RECORD +94 -0
- shotgun_sh-0.1.0.dev1.dist-info/WHEEL +4 -0
- shotgun_sh-0.1.0.dev1.dist-info/entry_points.txt +3 -0
- shotgun_sh-0.1.0.dev1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""Configuration manager for Shotgun CLI."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import SecretStr
|
|
8
|
+
|
|
9
|
+
from shotgun.logging_config import get_logger
|
|
10
|
+
from shotgun.utils import get_shotgun_home
|
|
11
|
+
|
|
12
|
+
from .models import ProviderType, ShotgunConfig
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ConfigManager:
|
|
18
|
+
"""Manager for Shotgun configuration."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, config_path: Path | None = None):
|
|
21
|
+
"""Initialize ConfigManager.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
config_path: Path to config file. If None, uses default ~/.shotgun-sh/config.json
|
|
25
|
+
"""
|
|
26
|
+
if config_path is None:
|
|
27
|
+
self.config_path = get_shotgun_home() / "config.json"
|
|
28
|
+
else:
|
|
29
|
+
self.config_path = config_path
|
|
30
|
+
|
|
31
|
+
self._config: ShotgunConfig | None = None
|
|
32
|
+
|
|
33
|
+
def load(self) -> ShotgunConfig:
|
|
34
|
+
"""Load configuration from file.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
ShotgunConfig: Loaded configuration or default config if file doesn't exist
|
|
38
|
+
"""
|
|
39
|
+
if self._config is not None:
|
|
40
|
+
return self._config
|
|
41
|
+
|
|
42
|
+
if not self.config_path.exists():
|
|
43
|
+
logger.info(
|
|
44
|
+
"Configuration file not found, using defaults: %s", self.config_path
|
|
45
|
+
)
|
|
46
|
+
self._config = ShotgunConfig()
|
|
47
|
+
return self._config
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
with open(self.config_path, encoding="utf-8") as f:
|
|
51
|
+
data = json.load(f)
|
|
52
|
+
|
|
53
|
+
# Convert plain text secrets to SecretStr objects
|
|
54
|
+
self._convert_secrets_to_secretstr(data)
|
|
55
|
+
|
|
56
|
+
self._config = ShotgunConfig.model_validate(data)
|
|
57
|
+
logger.debug("Configuration loaded successfully from %s", self.config_path)
|
|
58
|
+
return self._config
|
|
59
|
+
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(
|
|
62
|
+
"Failed to load configuration from %s: %s", self.config_path, e
|
|
63
|
+
)
|
|
64
|
+
logger.info("Using default configuration")
|
|
65
|
+
self._config = ShotgunConfig()
|
|
66
|
+
return self._config
|
|
67
|
+
|
|
68
|
+
def save(self, config: ShotgunConfig | None = None) -> None:
|
|
69
|
+
"""Save configuration to file.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
config: Configuration to save. If None, saves current loaded config
|
|
73
|
+
"""
|
|
74
|
+
if config is None:
|
|
75
|
+
config = self._config or ShotgunConfig()
|
|
76
|
+
|
|
77
|
+
# Ensure directory exists
|
|
78
|
+
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
# Convert SecretStr to plain text for JSON serialization
|
|
82
|
+
data = config.model_dump()
|
|
83
|
+
self._convert_secretstr_to_plain(data)
|
|
84
|
+
|
|
85
|
+
with open(self.config_path, "w", encoding="utf-8") as f:
|
|
86
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
87
|
+
|
|
88
|
+
logger.debug("Configuration saved to %s", self.config_path)
|
|
89
|
+
self._config = config
|
|
90
|
+
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error("Failed to save configuration to %s: %s", self.config_path, e)
|
|
93
|
+
raise
|
|
94
|
+
|
|
95
|
+
def update_provider(self, provider: ProviderType | str, **kwargs: Any) -> None:
|
|
96
|
+
"""Update provider configuration.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
provider: Provider to update
|
|
100
|
+
**kwargs: Configuration fields to update (only api_key supported)
|
|
101
|
+
"""
|
|
102
|
+
config = self.load()
|
|
103
|
+
provider_enum = self._ensure_provider_enum(provider)
|
|
104
|
+
provider_config = self._get_provider_config(config, provider_enum)
|
|
105
|
+
|
|
106
|
+
# Only support api_key updates
|
|
107
|
+
if "api_key" in kwargs:
|
|
108
|
+
api_key_value = kwargs["api_key"]
|
|
109
|
+
provider_config.api_key = (
|
|
110
|
+
SecretStr(api_key_value) if api_key_value is not None else None
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Reject other fields
|
|
114
|
+
unsupported_fields = set(kwargs.keys()) - {"api_key"}
|
|
115
|
+
if unsupported_fields:
|
|
116
|
+
raise ValueError(f"Unsupported configuration fields: {unsupported_fields}")
|
|
117
|
+
|
|
118
|
+
self.save(config)
|
|
119
|
+
|
|
120
|
+
def clear_provider_key(self, provider: ProviderType | str) -> None:
|
|
121
|
+
"""Remove the API key for the given provider."""
|
|
122
|
+
config = self.load()
|
|
123
|
+
provider_enum = self._ensure_provider_enum(provider)
|
|
124
|
+
provider_config = self._get_provider_config(config, provider_enum)
|
|
125
|
+
provider_config.api_key = None
|
|
126
|
+
self.save(config)
|
|
127
|
+
|
|
128
|
+
def has_provider_key(self, provider: ProviderType | str) -> bool:
|
|
129
|
+
"""Check if the given provider has a non-empty API key configured."""
|
|
130
|
+
config = self.load()
|
|
131
|
+
provider_enum = self._ensure_provider_enum(provider)
|
|
132
|
+
provider_config = self._get_provider_config(config, provider_enum)
|
|
133
|
+
return self._provider_has_api_key(provider_config)
|
|
134
|
+
|
|
135
|
+
def has_any_provider_key(self) -> bool:
|
|
136
|
+
"""Determine whether any provider has a configured API key."""
|
|
137
|
+
config = self.load()
|
|
138
|
+
return any(
|
|
139
|
+
self._provider_has_api_key(self._get_provider_config(config, provider))
|
|
140
|
+
for provider in (
|
|
141
|
+
ProviderType.OPENAI,
|
|
142
|
+
ProviderType.ANTHROPIC,
|
|
143
|
+
ProviderType.GOOGLE,
|
|
144
|
+
)
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def initialize(self) -> ShotgunConfig:
|
|
148
|
+
"""Initialize configuration with defaults and save to file.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Default ShotgunConfig
|
|
152
|
+
"""
|
|
153
|
+
config = ShotgunConfig()
|
|
154
|
+
self.save(config)
|
|
155
|
+
logger.info("Configuration initialized at %s", self.config_path)
|
|
156
|
+
return config
|
|
157
|
+
|
|
158
|
+
def _convert_secrets_to_secretstr(self, data: dict[str, Any]) -> None:
|
|
159
|
+
"""Convert plain text secrets in data to SecretStr objects."""
|
|
160
|
+
for provider in ["openai", "anthropic", "google"]:
|
|
161
|
+
if provider in data and isinstance(data[provider], dict):
|
|
162
|
+
if (
|
|
163
|
+
"api_key" in data[provider]
|
|
164
|
+
and data[provider]["api_key"] is not None
|
|
165
|
+
):
|
|
166
|
+
data[provider]["api_key"] = SecretStr(data[provider]["api_key"])
|
|
167
|
+
|
|
168
|
+
def _convert_secretstr_to_plain(self, data: dict[str, Any]) -> None:
|
|
169
|
+
"""Convert SecretStr objects in data to plain text for JSON serialization."""
|
|
170
|
+
for provider in ["openai", "anthropic", "google"]:
|
|
171
|
+
if provider in data and isinstance(data[provider], dict):
|
|
172
|
+
if (
|
|
173
|
+
"api_key" in data[provider]
|
|
174
|
+
and data[provider]["api_key"] is not None
|
|
175
|
+
):
|
|
176
|
+
if hasattr(data[provider]["api_key"], "get_secret_value"):
|
|
177
|
+
data[provider]["api_key"] = data[provider][
|
|
178
|
+
"api_key"
|
|
179
|
+
].get_secret_value()
|
|
180
|
+
|
|
181
|
+
def _ensure_provider_enum(self, provider: ProviderType | str) -> ProviderType:
|
|
182
|
+
"""Normalize provider values to ProviderType enum."""
|
|
183
|
+
return (
|
|
184
|
+
provider if isinstance(provider, ProviderType) else ProviderType(provider)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def _get_provider_config(
|
|
188
|
+
self, config: ShotgunConfig, provider: ProviderType
|
|
189
|
+
) -> Any:
|
|
190
|
+
"""Retrieve the provider-specific configuration section."""
|
|
191
|
+
if provider == ProviderType.OPENAI:
|
|
192
|
+
return config.openai
|
|
193
|
+
if provider == ProviderType.ANTHROPIC:
|
|
194
|
+
return config.anthropic
|
|
195
|
+
if provider == ProviderType.GOOGLE:
|
|
196
|
+
return config.google
|
|
197
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
198
|
+
|
|
199
|
+
def _provider_has_api_key(self, provider_config: Any) -> bool:
|
|
200
|
+
"""Return True if the provider config contains a usable API key."""
|
|
201
|
+
api_key = getattr(provider_config, "api_key", None)
|
|
202
|
+
if api_key is None:
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
if isinstance(api_key, SecretStr):
|
|
206
|
+
value = api_key.get_secret_value()
|
|
207
|
+
else:
|
|
208
|
+
value = str(api_key)
|
|
209
|
+
|
|
210
|
+
return bool(value.strip())
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_config_manager() -> ConfigManager:
|
|
214
|
+
"""Get the global ConfigManager instance."""
|
|
215
|
+
return ConfigManager()
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
"""Pydantic models for configuration."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field, SecretStr
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ProviderType(str, Enum):
|
|
9
|
+
"""Provider types for AI services."""
|
|
10
|
+
|
|
11
|
+
OPENAI = "openai"
|
|
12
|
+
ANTHROPIC = "anthropic"
|
|
13
|
+
GOOGLE = "google"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ModelConfig(BaseModel):
|
|
17
|
+
"""Configuration for an LLM model."""
|
|
18
|
+
|
|
19
|
+
name: str # Model identifier (e.g., "gpt-5", "claude-opus-4-1")
|
|
20
|
+
provider: ProviderType
|
|
21
|
+
max_input_tokens: int
|
|
22
|
+
max_output_tokens: int
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def pydantic_model_name(self) -> str:
|
|
26
|
+
"""Compute the full Pydantic AI model identifier."""
|
|
27
|
+
provider_prefix = {
|
|
28
|
+
ProviderType.OPENAI: "openai",
|
|
29
|
+
ProviderType.ANTHROPIC: "anthropic",
|
|
30
|
+
ProviderType.GOOGLE: "google-gla",
|
|
31
|
+
}
|
|
32
|
+
return f"{provider_prefix[self.provider]}:{self.name}"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# OpenAI Models
|
|
36
|
+
GPT_5 = ModelConfig(
|
|
37
|
+
name="gpt-5",
|
|
38
|
+
provider=ProviderType.OPENAI,
|
|
39
|
+
max_input_tokens=400_000,
|
|
40
|
+
max_output_tokens=128_000,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
GPT_4O = ModelConfig(
|
|
44
|
+
name="gpt-4o",
|
|
45
|
+
provider=ProviderType.OPENAI,
|
|
46
|
+
max_input_tokens=128_000,
|
|
47
|
+
max_output_tokens=16_000,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Anthropic Models
|
|
51
|
+
CLAUDE_OPUS_4_1 = ModelConfig(
|
|
52
|
+
name="claude-opus-4-1",
|
|
53
|
+
provider=ProviderType.ANTHROPIC,
|
|
54
|
+
max_input_tokens=200_000,
|
|
55
|
+
max_output_tokens=32_000,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
CLAUDE_3_5_SONNET = ModelConfig(
|
|
59
|
+
name="claude-3-5-sonnet-latest",
|
|
60
|
+
provider=ProviderType.ANTHROPIC,
|
|
61
|
+
max_input_tokens=200_000,
|
|
62
|
+
max_output_tokens=20_000,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Google Models
|
|
66
|
+
GEMINI_2_5_PRO = ModelConfig(
|
|
67
|
+
name="gemini-2.5-pro",
|
|
68
|
+
provider=ProviderType.GOOGLE,
|
|
69
|
+
max_input_tokens=1_000_000,
|
|
70
|
+
max_output_tokens=64_000,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# List of all available models
|
|
74
|
+
AVAILABLE_MODELS = [
|
|
75
|
+
GPT_5,
|
|
76
|
+
GPT_4O,
|
|
77
|
+
CLAUDE_OPUS_4_1,
|
|
78
|
+
CLAUDE_3_5_SONNET,
|
|
79
|
+
GEMINI_2_5_PRO,
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_model_by_name(name: str) -> ModelConfig:
|
|
84
|
+
"""Find a model configuration by name."""
|
|
85
|
+
for model in AVAILABLE_MODELS:
|
|
86
|
+
if model.name == name:
|
|
87
|
+
return model
|
|
88
|
+
raise ValueError(f"Model '{name}' not found")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class OpenAIConfig(BaseModel):
|
|
92
|
+
"""Configuration for OpenAI provider."""
|
|
93
|
+
|
|
94
|
+
api_key: SecretStr | None = None
|
|
95
|
+
model_name: str = "gpt-5"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class AnthropicConfig(BaseModel):
|
|
99
|
+
"""Configuration for Anthropic provider."""
|
|
100
|
+
|
|
101
|
+
api_key: SecretStr | None = None
|
|
102
|
+
model_name: str = "claude-opus-4-1"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class GoogleConfig(BaseModel):
|
|
106
|
+
"""Configuration for Google provider."""
|
|
107
|
+
|
|
108
|
+
api_key: SecretStr | None = None
|
|
109
|
+
model_name: str = "gemini-2.5-pro"
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class ShotgunConfig(BaseModel):
|
|
113
|
+
"""Main configuration for Shotgun CLI."""
|
|
114
|
+
|
|
115
|
+
openai: OpenAIConfig = Field(default_factory=OpenAIConfig)
|
|
116
|
+
anthropic: AnthropicConfig = Field(default_factory=AnthropicConfig)
|
|
117
|
+
google: GoogleConfig = Field(default_factory=GoogleConfig)
|
|
118
|
+
default_provider: ProviderType = Field(
|
|
119
|
+
default=ProviderType.OPENAI, description="Default AI provider to use"
|
|
120
|
+
)
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Provider management for LLM configuration."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from pydantic import SecretStr
|
|
6
|
+
|
|
7
|
+
from shotgun.logging_config import get_logger
|
|
8
|
+
|
|
9
|
+
from .manager import get_config_manager
|
|
10
|
+
from .models import ModelConfig, ProviderType, get_model_by_name
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_provider_model(provider: ProviderType | None = None) -> ModelConfig:
|
|
16
|
+
"""Get model configuration for the specified provider.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
provider: Provider to get model for. If None, uses default provider
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
ModelConfig with pydantic_model_name and token limits
|
|
23
|
+
|
|
24
|
+
Raises:
|
|
25
|
+
ValueError: If provider is not configured properly or model not found
|
|
26
|
+
"""
|
|
27
|
+
config_manager = get_config_manager()
|
|
28
|
+
config = config_manager.load()
|
|
29
|
+
# Convert string to ProviderType enum if needed
|
|
30
|
+
provider_enum = (
|
|
31
|
+
provider
|
|
32
|
+
if isinstance(provider, ProviderType)
|
|
33
|
+
else ProviderType(provider)
|
|
34
|
+
if provider
|
|
35
|
+
else config.default_provider
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if provider_enum == ProviderType.OPENAI:
|
|
39
|
+
api_key = _get_api_key(config.openai.api_key, "OPENAI_API_KEY")
|
|
40
|
+
if not api_key:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"OpenAI API key not configured. Set via environment variable OPENAI_API_KEY or config."
|
|
43
|
+
)
|
|
44
|
+
# Set the API key in environment if not already there
|
|
45
|
+
if "OPENAI_API_KEY" not in os.environ:
|
|
46
|
+
os.environ["OPENAI_API_KEY"] = api_key
|
|
47
|
+
|
|
48
|
+
return get_model_by_name(config.openai.model_name)
|
|
49
|
+
|
|
50
|
+
elif provider_enum == ProviderType.ANTHROPIC:
|
|
51
|
+
api_key = _get_api_key(config.anthropic.api_key, "ANTHROPIC_API_KEY")
|
|
52
|
+
if not api_key:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
"Anthropic API key not configured. Set via environment variable ANTHROPIC_API_KEY or config."
|
|
55
|
+
)
|
|
56
|
+
# Set the API key in environment if not already there
|
|
57
|
+
if "ANTHROPIC_API_KEY" not in os.environ:
|
|
58
|
+
os.environ["ANTHROPIC_API_KEY"] = api_key
|
|
59
|
+
|
|
60
|
+
return get_model_by_name(config.anthropic.model_name)
|
|
61
|
+
|
|
62
|
+
elif provider_enum == ProviderType.GOOGLE:
|
|
63
|
+
api_key = _get_api_key(config.google.api_key, "GOOGLE_API_KEY")
|
|
64
|
+
if not api_key:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Google API key not configured. Set via environment variable GOOGLE_API_KEY or config."
|
|
67
|
+
)
|
|
68
|
+
# Set the API key in environment if not already there
|
|
69
|
+
if "GOOGLE_API_KEY" not in os.environ:
|
|
70
|
+
os.environ["GOOGLE_API_KEY"] = api_key
|
|
71
|
+
|
|
72
|
+
return get_model_by_name(config.google.model_name)
|
|
73
|
+
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError(f"Unsupported provider: {provider_enum}")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_api_key(config_key: SecretStr | None, env_var: str) -> str | None:
|
|
79
|
+
"""Get API key from config or environment variable.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
config_key: API key from configuration
|
|
83
|
+
env_var: Environment variable name to check
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
API key string or None
|
|
87
|
+
"""
|
|
88
|
+
if config_key is not None:
|
|
89
|
+
return config_key.get_secret_value()
|
|
90
|
+
|
|
91
|
+
return os.getenv(env_var)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""History processors for managing conversation history in Shotgun agents."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import RunContext
|
|
4
|
+
from pydantic_ai.direct import model_request
|
|
5
|
+
from pydantic_ai.messages import (
|
|
6
|
+
BuiltinToolCallPart,
|
|
7
|
+
BuiltinToolReturnPart,
|
|
8
|
+
ModelMessage,
|
|
9
|
+
ModelRequest,
|
|
10
|
+
ModelResponse,
|
|
11
|
+
ModelResponsePart,
|
|
12
|
+
RetryPromptPart,
|
|
13
|
+
SystemPromptPart,
|
|
14
|
+
TextPart,
|
|
15
|
+
ThinkingPart,
|
|
16
|
+
ToolCallPart,
|
|
17
|
+
ToolReturnPart,
|
|
18
|
+
UserPromptPart,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from shotgun.agents.models import AgentDeps
|
|
22
|
+
from shotgun.logging_config import get_logger
|
|
23
|
+
from shotgun.prompts import PromptLoader
|
|
24
|
+
|
|
25
|
+
logger = get_logger(__name__)
|
|
26
|
+
|
|
27
|
+
# Global prompt loader instance
|
|
28
|
+
prompt_loader = PromptLoader()
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def token_limit_compactor(
|
|
32
|
+
ctx: RunContext[AgentDeps],
|
|
33
|
+
messages: list[ModelMessage],
|
|
34
|
+
) -> list[ModelMessage]:
|
|
35
|
+
"""Compact message history based on token limits.
|
|
36
|
+
|
|
37
|
+
This context-aware processor monitors token usage and removes older messages
|
|
38
|
+
when the conversation history becomes too large. It preserves system messages
|
|
39
|
+
and recent context while removing older user/assistant exchanges.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
ctx: Run context with usage information and dependencies
|
|
43
|
+
messages: List of messages in the conversation history
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Compacted list of messages within token limits
|
|
47
|
+
"""
|
|
48
|
+
# Get current token usage from context
|
|
49
|
+
current_tokens = ctx.usage.total_tokens if ctx.usage else 0
|
|
50
|
+
|
|
51
|
+
# Get token limit from model configuration or use fallback
|
|
52
|
+
model_max_tokens = ctx.deps.llm_model.max_input_tokens
|
|
53
|
+
max_tokens = int(
|
|
54
|
+
model_max_tokens * 0.8
|
|
55
|
+
) # Use 80% of max to leave room for response
|
|
56
|
+
percentage_of_limit_used = (
|
|
57
|
+
(current_tokens / max_tokens) * 100 if max_tokens > 0 else 0
|
|
58
|
+
)
|
|
59
|
+
logger.debug(
|
|
60
|
+
"History compactor: current tokens=%d, limit=%d, percentage used=%.2f%%",
|
|
61
|
+
current_tokens,
|
|
62
|
+
max_tokens,
|
|
63
|
+
percentage_of_limit_used,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# If we're under the limit, return all messages
|
|
67
|
+
if current_tokens < max_tokens:
|
|
68
|
+
logger.debug("Under token limit, keeping all %d messages", len(messages))
|
|
69
|
+
return messages
|
|
70
|
+
|
|
71
|
+
# Get current token usage from context
|
|
72
|
+
current_tokens = ctx.usage.total_tokens if ctx.usage else 0
|
|
73
|
+
|
|
74
|
+
context = ""
|
|
75
|
+
|
|
76
|
+
# Separate system messages from conversation messages
|
|
77
|
+
for msg in messages:
|
|
78
|
+
if isinstance(msg, ModelResponse) or isinstance(msg, ModelRequest):
|
|
79
|
+
for part in msg.parts:
|
|
80
|
+
message_content = get_context_from_message(part)
|
|
81
|
+
if not message_content:
|
|
82
|
+
continue
|
|
83
|
+
context += get_context_from_message(part) + "\n"
|
|
84
|
+
else:
|
|
85
|
+
# Handle whatever this is
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
summarization_prompt = prompt_loader.render("history/summarization.j2")
|
|
89
|
+
summary_response = await model_request(
|
|
90
|
+
model=ctx.model,
|
|
91
|
+
messages=[
|
|
92
|
+
ModelRequest.user_text_prompt(context, instructions=summarization_prompt)
|
|
93
|
+
],
|
|
94
|
+
)
|
|
95
|
+
# Usage before and after
|
|
96
|
+
summary_usage = summary_response.usage
|
|
97
|
+
reduction_percentage = (
|
|
98
|
+
(current_tokens - summary_usage.output_tokens) / current_tokens
|
|
99
|
+
) * 100
|
|
100
|
+
logger.debug(
|
|
101
|
+
"Compacted %s tokens into %s tokens for a %.2f percent reduction",
|
|
102
|
+
current_tokens,
|
|
103
|
+
summary_usage.output_tokens,
|
|
104
|
+
reduction_percentage,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
system_prompt = get_system_promt(messages) or ""
|
|
108
|
+
user_prompt = get_first_user_request(messages) or ""
|
|
109
|
+
# Extract content from the first response part safely
|
|
110
|
+
summarization_part = summary_response.parts[0]
|
|
111
|
+
return [
|
|
112
|
+
ModelRequest(
|
|
113
|
+
parts=[
|
|
114
|
+
SystemPromptPart(content=system_prompt),
|
|
115
|
+
UserPromptPart(content=user_prompt),
|
|
116
|
+
]
|
|
117
|
+
),
|
|
118
|
+
ModelResponse(
|
|
119
|
+
parts=[
|
|
120
|
+
summarization_part,
|
|
121
|
+
]
|
|
122
|
+
),
|
|
123
|
+
]
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_first_user_request(messages: list[ModelMessage]) -> str | None:
|
|
127
|
+
"""Extract first user request from messages.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
messages: List of messages in the conversation history
|
|
131
|
+
Returns:
|
|
132
|
+
The first user request as a string.
|
|
133
|
+
"""
|
|
134
|
+
for msg in messages:
|
|
135
|
+
if isinstance(msg, ModelRequest):
|
|
136
|
+
for part in msg.parts:
|
|
137
|
+
if isinstance(part, UserPromptPart):
|
|
138
|
+
if isinstance(part.content, str):
|
|
139
|
+
return part.content
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def get_system_promt(messages: list[ModelMessage]) -> str | None:
|
|
144
|
+
"""Extract system prompt from messages.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
messages: List of messages in the conversation history
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
The system prompt as a string.
|
|
151
|
+
"""
|
|
152
|
+
for msg in messages:
|
|
153
|
+
if isinstance(msg, ModelRequest):
|
|
154
|
+
for part in msg.parts:
|
|
155
|
+
if isinstance(part, SystemPromptPart):
|
|
156
|
+
return part.content
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_context_from_message(
|
|
161
|
+
message_part: SystemPromptPart
|
|
162
|
+
| UserPromptPart
|
|
163
|
+
| ToolReturnPart
|
|
164
|
+
| RetryPromptPart
|
|
165
|
+
| ModelResponsePart,
|
|
166
|
+
) -> str:
|
|
167
|
+
"""Extract context from a message part.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
message: The message part to extract context from.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
The extracted context as a string.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
if isinstance(message_part, SystemPromptPart):
|
|
177
|
+
return "" # We do not include system prompts in the summary
|
|
178
|
+
elif isinstance(message_part, UserPromptPart):
|
|
179
|
+
if isinstance(message_part.content, str):
|
|
180
|
+
return "<USER_PROMPT>\n" + message_part.content + "\n</USER_PROMPT>"
|
|
181
|
+
else:
|
|
182
|
+
return ""
|
|
183
|
+
elif isinstance(message_part, ToolReturnPart):
|
|
184
|
+
return "<TOOL_RETURN>\n" + str(message_part.content) + "\n</TOOL_RETURN>"
|
|
185
|
+
elif isinstance(message_part, RetryPromptPart):
|
|
186
|
+
if isinstance(message_part.content, str):
|
|
187
|
+
return "<RETRY_PROMPT>\n" + message_part.content + "\n</RETRY_PROMPT>"
|
|
188
|
+
return ""
|
|
189
|
+
|
|
190
|
+
# TextPart | ToolCallPart | BuiltinToolCallPart | BuiltinToolReturnPart | ThinkingPart
|
|
191
|
+
if isinstance(message_part, TextPart):
|
|
192
|
+
return "<ASSISTANT_TEXT>\n" + message_part.content + "\n</ASSISTANT_TEXT>"
|
|
193
|
+
elif isinstance(message_part, ToolCallPart):
|
|
194
|
+
if isinstance(message_part.args, dict):
|
|
195
|
+
args_str = ", ".join(f"{k}={repr(v)}" for k, v in message_part.args.items())
|
|
196
|
+
tool_call_str = f"{message_part.tool_name}({args_str})"
|
|
197
|
+
else:
|
|
198
|
+
tool_call_str = f"{message_part.tool_name}({message_part.args})"
|
|
199
|
+
return "<TOOL_CALL>\n" + tool_call_str + "\n</TOOL_CALL>"
|
|
200
|
+
elif isinstance(message_part, BuiltinToolCallPart):
|
|
201
|
+
return (
|
|
202
|
+
"<BUILTIN_TOOL_CALL>\n" + message_part.tool_name + "\n</BUILTIN_TOOL_CALL>"
|
|
203
|
+
)
|
|
204
|
+
elif isinstance(message_part, BuiltinToolReturnPart):
|
|
205
|
+
return (
|
|
206
|
+
"<BUILTIN_TOOL_RETURN>\n"
|
|
207
|
+
+ message_part.tool_name
|
|
208
|
+
+ "\n</BUILTIN_TOOL_RETURN>"
|
|
209
|
+
)
|
|
210
|
+
elif isinstance(message_part, ThinkingPart):
|
|
211
|
+
return "<THINKING>\n" + message_part.content + "\n</THINKING>"
|
|
212
|
+
|
|
213
|
+
return ""
|