shotgun-sh 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/__init__.py +5 -0
- shotgun/agents/__init__.py +1 -0
- shotgun/agents/agent_manager.py +651 -0
- shotgun/agents/common.py +549 -0
- shotgun/agents/config/__init__.py +13 -0
- shotgun/agents/config/constants.py +17 -0
- shotgun/agents/config/manager.py +294 -0
- shotgun/agents/config/models.py +185 -0
- shotgun/agents/config/provider.py +206 -0
- shotgun/agents/conversation_history.py +106 -0
- shotgun/agents/conversation_manager.py +105 -0
- shotgun/agents/export.py +96 -0
- shotgun/agents/history/__init__.py +5 -0
- shotgun/agents/history/compaction.py +85 -0
- shotgun/agents/history/constants.py +19 -0
- shotgun/agents/history/context_extraction.py +108 -0
- shotgun/agents/history/history_building.py +104 -0
- shotgun/agents/history/history_processors.py +426 -0
- shotgun/agents/history/message_utils.py +84 -0
- shotgun/agents/history/token_counting.py +429 -0
- shotgun/agents/history/token_estimation.py +138 -0
- shotgun/agents/messages.py +35 -0
- shotgun/agents/models.py +275 -0
- shotgun/agents/plan.py +98 -0
- shotgun/agents/research.py +108 -0
- shotgun/agents/specify.py +98 -0
- shotgun/agents/tasks.py +96 -0
- shotgun/agents/tools/__init__.py +34 -0
- shotgun/agents/tools/codebase/__init__.py +28 -0
- shotgun/agents/tools/codebase/codebase_shell.py +256 -0
- shotgun/agents/tools/codebase/directory_lister.py +141 -0
- shotgun/agents/tools/codebase/file_read.py +144 -0
- shotgun/agents/tools/codebase/models.py +252 -0
- shotgun/agents/tools/codebase/query_graph.py +67 -0
- shotgun/agents/tools/codebase/retrieve_code.py +81 -0
- shotgun/agents/tools/file_management.py +218 -0
- shotgun/agents/tools/user_interaction.py +37 -0
- shotgun/agents/tools/web_search/__init__.py +60 -0
- shotgun/agents/tools/web_search/anthropic.py +144 -0
- shotgun/agents/tools/web_search/gemini.py +85 -0
- shotgun/agents/tools/web_search/openai.py +98 -0
- shotgun/agents/tools/web_search/utils.py +20 -0
- shotgun/build_constants.py +20 -0
- shotgun/cli/__init__.py +1 -0
- shotgun/cli/codebase/__init__.py +5 -0
- shotgun/cli/codebase/commands.py +202 -0
- shotgun/cli/codebase/models.py +21 -0
- shotgun/cli/config.py +275 -0
- shotgun/cli/export.py +81 -0
- shotgun/cli/models.py +10 -0
- shotgun/cli/plan.py +73 -0
- shotgun/cli/research.py +85 -0
- shotgun/cli/specify.py +69 -0
- shotgun/cli/tasks.py +78 -0
- shotgun/cli/update.py +152 -0
- shotgun/cli/utils.py +25 -0
- shotgun/codebase/__init__.py +12 -0
- shotgun/codebase/core/__init__.py +46 -0
- shotgun/codebase/core/change_detector.py +358 -0
- shotgun/codebase/core/code_retrieval.py +243 -0
- shotgun/codebase/core/ingestor.py +1497 -0
- shotgun/codebase/core/language_config.py +297 -0
- shotgun/codebase/core/manager.py +1662 -0
- shotgun/codebase/core/nl_query.py +331 -0
- shotgun/codebase/core/parser_loader.py +128 -0
- shotgun/codebase/models.py +111 -0
- shotgun/codebase/service.py +206 -0
- shotgun/logging_config.py +227 -0
- shotgun/main.py +167 -0
- shotgun/posthog_telemetry.py +158 -0
- shotgun/prompts/__init__.py +5 -0
- shotgun/prompts/agents/__init__.py +1 -0
- shotgun/prompts/agents/export.j2 +350 -0
- shotgun/prompts/agents/partials/codebase_understanding.j2 +87 -0
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +37 -0
- shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +26 -0
- shotgun/prompts/agents/plan.j2 +144 -0
- shotgun/prompts/agents/research.j2 +69 -0
- shotgun/prompts/agents/specify.j2 +51 -0
- shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +19 -0
- shotgun/prompts/agents/state/system_state.j2 +31 -0
- shotgun/prompts/agents/tasks.j2 +143 -0
- shotgun/prompts/codebase/__init__.py +1 -0
- shotgun/prompts/codebase/cypher_query_patterns.j2 +223 -0
- shotgun/prompts/codebase/cypher_system.j2 +28 -0
- shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
- shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
- shotgun/prompts/codebase/partials/graph_schema.j2 +30 -0
- shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
- shotgun/prompts/history/__init__.py +1 -0
- shotgun/prompts/history/incremental_summarization.j2 +53 -0
- shotgun/prompts/history/summarization.j2 +46 -0
- shotgun/prompts/loader.py +140 -0
- shotgun/py.typed +0 -0
- shotgun/sdk/__init__.py +13 -0
- shotgun/sdk/codebase.py +219 -0
- shotgun/sdk/exceptions.py +17 -0
- shotgun/sdk/models.py +189 -0
- shotgun/sdk/services.py +23 -0
- shotgun/sentry_telemetry.py +87 -0
- shotgun/telemetry.py +93 -0
- shotgun/tui/__init__.py +0 -0
- shotgun/tui/app.py +116 -0
- shotgun/tui/commands/__init__.py +76 -0
- shotgun/tui/components/prompt_input.py +69 -0
- shotgun/tui/components/spinner.py +86 -0
- shotgun/tui/components/splash.py +25 -0
- shotgun/tui/components/vertical_tail.py +13 -0
- shotgun/tui/screens/chat.py +782 -0
- shotgun/tui/screens/chat.tcss +43 -0
- shotgun/tui/screens/chat_screen/__init__.py +0 -0
- shotgun/tui/screens/chat_screen/command_providers.py +219 -0
- shotgun/tui/screens/chat_screen/hint_message.py +40 -0
- shotgun/tui/screens/chat_screen/history.py +221 -0
- shotgun/tui/screens/directory_setup.py +113 -0
- shotgun/tui/screens/provider_config.py +221 -0
- shotgun/tui/screens/splash.py +31 -0
- shotgun/tui/styles.tcss +10 -0
- shotgun/tui/utils/__init__.py +5 -0
- shotgun/tui/utils/mode_progress.py +257 -0
- shotgun/utils/__init__.py +5 -0
- shotgun/utils/env_utils.py +35 -0
- shotgun/utils/file_system_utils.py +36 -0
- shotgun/utils/update_checker.py +375 -0
- shotgun_sh-0.1.0.dist-info/METADATA +466 -0
- shotgun_sh-0.1.0.dist-info/RECORD +130 -0
- shotgun_sh-0.1.0.dist-info/WHEEL +4 -0
- shotgun_sh-0.1.0.dist-info/entry_points.txt +2 -0
- shotgun_sh-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Pydantic models for codebase tool outputs."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class QueryGraphResult(BaseModel):
|
|
9
|
+
"""Result from graph query with formatted string output."""
|
|
10
|
+
|
|
11
|
+
success: bool = Field(description="Whether the query was successful")
|
|
12
|
+
query: str = Field(description="Original natural language query")
|
|
13
|
+
cypher_query: str | None = Field(None, description="Generated Cypher query")
|
|
14
|
+
column_names: list[str] = Field(
|
|
15
|
+
default_factory=list, description="Result column names"
|
|
16
|
+
)
|
|
17
|
+
results: list[dict[str, Any]] = Field(
|
|
18
|
+
default_factory=list, description="Query results"
|
|
19
|
+
)
|
|
20
|
+
row_count: int = Field(0, description="Number of result rows")
|
|
21
|
+
execution_time_ms: float = Field(
|
|
22
|
+
0.0, description="Query execution time in milliseconds"
|
|
23
|
+
)
|
|
24
|
+
error: str | None = Field(default=None, description="Error message if failed")
|
|
25
|
+
|
|
26
|
+
def __str__(self) -> str:
|
|
27
|
+
"""Format query result for LLM consumption."""
|
|
28
|
+
if not self.success:
|
|
29
|
+
return f"**Query Failed**: {self.error}"
|
|
30
|
+
|
|
31
|
+
if not self.results:
|
|
32
|
+
return f"**No Results Found** for query: {self.query}"
|
|
33
|
+
|
|
34
|
+
output_lines = []
|
|
35
|
+
if self.cypher_query:
|
|
36
|
+
output_lines.append(f"**Generated Cypher**: `{self.cypher_query}`")
|
|
37
|
+
output_lines.append("")
|
|
38
|
+
|
|
39
|
+
output_lines.append(
|
|
40
|
+
f"**Results** ({self.row_count} rows, {self.execution_time_ms:.1f}ms):"
|
|
41
|
+
)
|
|
42
|
+
output_lines.append("")
|
|
43
|
+
|
|
44
|
+
if self.column_names and self.results:
|
|
45
|
+
# Create markdown table
|
|
46
|
+
headers = self.column_names
|
|
47
|
+
output_lines.append("| " + " | ".join(headers) + " |")
|
|
48
|
+
output_lines.append("|" + "|".join([" --- " for _ in headers]) + "|")
|
|
49
|
+
|
|
50
|
+
# Limit to first 10 rows to avoid overwhelming output
|
|
51
|
+
rows_to_show = min(10, len(self.results))
|
|
52
|
+
for row in self.results[:rows_to_show]:
|
|
53
|
+
row_values = []
|
|
54
|
+
for col in headers:
|
|
55
|
+
value = row.get(col, "")
|
|
56
|
+
# Convert to string and truncate if too long
|
|
57
|
+
str_value = str(value) if value is not None else ""
|
|
58
|
+
if len(str_value) > 50:
|
|
59
|
+
str_value = str_value[:47] + "..."
|
|
60
|
+
row_values.append(str_value)
|
|
61
|
+
output_lines.append("| " + " | ".join(row_values) + " |")
|
|
62
|
+
|
|
63
|
+
if self.row_count > rows_to_show:
|
|
64
|
+
output_lines.append(
|
|
65
|
+
f"... and {self.row_count - rows_to_show} more rows"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return "\n".join(output_lines)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CodeSnippetResult(BaseModel):
|
|
72
|
+
"""Result from code retrieval with formatted output."""
|
|
73
|
+
|
|
74
|
+
found: bool = Field(description="Whether the code entity was found")
|
|
75
|
+
qualified_name: str = Field(description="Fully qualified name searched for")
|
|
76
|
+
file_path: str | None = Field(None, description="Path to source file")
|
|
77
|
+
line_start: int | None = Field(None, description="Starting line number")
|
|
78
|
+
line_end: int | None = Field(None, description="Ending line number")
|
|
79
|
+
source_code: str | None = Field(None, description="Source code content")
|
|
80
|
+
docstring: str | None = Field(None, description="Docstring if available")
|
|
81
|
+
language: str = Field(
|
|
82
|
+
default="", description="Programming language for syntax highlighting"
|
|
83
|
+
)
|
|
84
|
+
error: str | None = Field(None, description="Error message if not found")
|
|
85
|
+
|
|
86
|
+
def __str__(self) -> str:
|
|
87
|
+
"""Format code snippet for LLM consumption."""
|
|
88
|
+
if not self.found:
|
|
89
|
+
error_msg = (
|
|
90
|
+
self.error or f"Entity '{self.qualified_name}' not found in graph"
|
|
91
|
+
)
|
|
92
|
+
return f"**Not Found**: {error_msg}\n\nTry using `query_graph` to search for similar entities or check the qualified name."
|
|
93
|
+
|
|
94
|
+
output_lines = []
|
|
95
|
+
output_lines.append(f"**Qualified Name**: `{self.qualified_name}`")
|
|
96
|
+
|
|
97
|
+
if self.file_path:
|
|
98
|
+
output_lines.append(f"**File**: `{self.file_path}`")
|
|
99
|
+
|
|
100
|
+
if self.line_start and self.line_end:
|
|
101
|
+
output_lines.append(f"**Lines**: {self.line_start}-{self.line_end}")
|
|
102
|
+
|
|
103
|
+
if self.docstring:
|
|
104
|
+
output_lines.append(f"**Docstring**: {self.docstring}")
|
|
105
|
+
|
|
106
|
+
if self.source_code:
|
|
107
|
+
output_lines.append("")
|
|
108
|
+
output_lines.append("**Source Code**:")
|
|
109
|
+
language_tag = self.language if self.language else ""
|
|
110
|
+
output_lines.append(f"```{language_tag}")
|
|
111
|
+
output_lines.append(self.source_code)
|
|
112
|
+
output_lines.append("```")
|
|
113
|
+
|
|
114
|
+
return "\n".join(output_lines)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class FileReadResult(BaseModel):
|
|
118
|
+
"""Result from file reading with content output."""
|
|
119
|
+
|
|
120
|
+
success: bool = Field(description="Whether file was read successfully")
|
|
121
|
+
file_path: str = Field(description="Path to file that was read")
|
|
122
|
+
content: str | None = Field(None, description="File content")
|
|
123
|
+
encoding: str = Field("utf-8", description="Encoding used to read file")
|
|
124
|
+
size_bytes: int = Field(0, description="File size in bytes")
|
|
125
|
+
language: str = Field(
|
|
126
|
+
default="", description="Programming language for syntax highlighting"
|
|
127
|
+
)
|
|
128
|
+
error: str | None = Field(default=None, description="Error message if failed")
|
|
129
|
+
|
|
130
|
+
def __str__(self) -> str:
|
|
131
|
+
"""Return file content or error message."""
|
|
132
|
+
if not self.success:
|
|
133
|
+
return f"**Error reading file `{self.file_path}`**: {self.error}"
|
|
134
|
+
|
|
135
|
+
output_lines = []
|
|
136
|
+
output_lines.append(f"**File**: `{self.file_path}`")
|
|
137
|
+
output_lines.append(f"**Size**: {self.size_bytes} bytes")
|
|
138
|
+
|
|
139
|
+
if self.encoding != "utf-8":
|
|
140
|
+
output_lines.append(f"**Encoding**: {self.encoding}")
|
|
141
|
+
|
|
142
|
+
output_lines.append("")
|
|
143
|
+
output_lines.append("**Content**:")
|
|
144
|
+
language_tag = self.language if self.language else ""
|
|
145
|
+
output_lines.append(f"```{language_tag}")
|
|
146
|
+
output_lines.append(self.content or "")
|
|
147
|
+
output_lines.append("```")
|
|
148
|
+
|
|
149
|
+
return "\n".join(output_lines)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class DirectoryListResult(BaseModel):
|
|
153
|
+
"""Result from directory listing with structured output."""
|
|
154
|
+
|
|
155
|
+
success: bool = Field(description="Whether directory was listed successfully")
|
|
156
|
+
directory: str = Field(description="Directory path that was listed")
|
|
157
|
+
full_path: str = Field(description="Absolute path to directory")
|
|
158
|
+
directories: list[str] = Field(
|
|
159
|
+
default_factory=list, description="Subdirectory names"
|
|
160
|
+
)
|
|
161
|
+
files: list[tuple[str, int]] = Field(
|
|
162
|
+
default_factory=list, description="Files as (name, size_bytes) tuples"
|
|
163
|
+
)
|
|
164
|
+
error: str | None = Field(default=None, description="Error message if failed")
|
|
165
|
+
|
|
166
|
+
def __str__(self) -> str:
|
|
167
|
+
"""Format directory listing for LLM consumption."""
|
|
168
|
+
if not self.success:
|
|
169
|
+
return f"**Error listing directory `{self.directory}`**: {self.error}"
|
|
170
|
+
|
|
171
|
+
output_lines = []
|
|
172
|
+
output_lines.append(f"**Directory**: `{self.directory}`")
|
|
173
|
+
output_lines.append(f"**Full Path**: `{self.full_path}`")
|
|
174
|
+
output_lines.append("")
|
|
175
|
+
|
|
176
|
+
if not self.directories and not self.files:
|
|
177
|
+
return "\n".join(output_lines + ["Directory is empty"])
|
|
178
|
+
|
|
179
|
+
if self.directories:
|
|
180
|
+
output_lines.append("**Directories**:")
|
|
181
|
+
for dir_name in self.directories:
|
|
182
|
+
output_lines.append(f" š {dir_name}/")
|
|
183
|
+
output_lines.append("")
|
|
184
|
+
|
|
185
|
+
if self.files:
|
|
186
|
+
output_lines.append("**Files**:")
|
|
187
|
+
for file_name, size_bytes in self.files:
|
|
188
|
+
if size_bytes < 1024:
|
|
189
|
+
size_str = f"{size_bytes}B"
|
|
190
|
+
elif size_bytes < 1024 * 1024:
|
|
191
|
+
size_str = f"{size_bytes / 1024:.1f}KB"
|
|
192
|
+
else:
|
|
193
|
+
size_str = f"{size_bytes / (1024 * 1024):.1f}MB"
|
|
194
|
+
output_lines.append(f" š {file_name} ({size_str})")
|
|
195
|
+
|
|
196
|
+
output_lines.append("")
|
|
197
|
+
output_lines.append(
|
|
198
|
+
f"**Total**: {len(self.directories)} directories, {len(self.files)} files"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
return "\n".join(output_lines)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class ShellCommandResult(BaseModel):
|
|
205
|
+
"""Result from shell command execution with formatted output."""
|
|
206
|
+
|
|
207
|
+
success: bool = Field(description="Whether command executed without errors")
|
|
208
|
+
command: str = Field(description="Command that was executed")
|
|
209
|
+
args: list[str] = Field(default_factory=list, description="Command arguments")
|
|
210
|
+
stdout: str = Field(default="", description="Standard output")
|
|
211
|
+
stderr: str = Field(default="", description="Standard error output")
|
|
212
|
+
return_code: int = Field(default=0, description="Process return code")
|
|
213
|
+
execution_time_ms: float = Field(
|
|
214
|
+
default=0.0, description="Execution time in milliseconds"
|
|
215
|
+
)
|
|
216
|
+
error: str | None = Field(
|
|
217
|
+
default=None, description="Error message if execution failed"
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
def __str__(self) -> str:
|
|
221
|
+
"""Format command output for LLM consumption."""
|
|
222
|
+
if self.error:
|
|
223
|
+
return f"**Command Failed**: {self.error}"
|
|
224
|
+
|
|
225
|
+
output_lines = []
|
|
226
|
+
cmd_str = f"{self.command} {' '.join(self.args)}".strip()
|
|
227
|
+
output_lines.append(f"**Command**: `{cmd_str}`")
|
|
228
|
+
output_lines.append(f"**Execution Time**: {self.execution_time_ms:.1f}ms")
|
|
229
|
+
|
|
230
|
+
if self.stdout:
|
|
231
|
+
output_lines.append("")
|
|
232
|
+
output_lines.append("**Output**:")
|
|
233
|
+
output_lines.append("```")
|
|
234
|
+
output_lines.append(self.stdout.rstrip())
|
|
235
|
+
output_lines.append("```")
|
|
236
|
+
|
|
237
|
+
if self.stderr:
|
|
238
|
+
output_lines.append("")
|
|
239
|
+
output_lines.append("**Error Output**:")
|
|
240
|
+
output_lines.append("```")
|
|
241
|
+
output_lines.append(self.stderr.rstrip())
|
|
242
|
+
output_lines.append("```")
|
|
243
|
+
|
|
244
|
+
if self.return_code != 0:
|
|
245
|
+
output_lines.append("")
|
|
246
|
+
output_lines.append(f"**Exit Code**: {self.return_code}")
|
|
247
|
+
|
|
248
|
+
if not self.stdout and not self.stderr:
|
|
249
|
+
output_lines.append("")
|
|
250
|
+
output_lines.append("Command executed successfully with no output")
|
|
251
|
+
|
|
252
|
+
return "\n".join(output_lines)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Query codebase knowledge graph using natural language."""
|
|
2
|
+
|
|
3
|
+
from pydantic_ai import RunContext
|
|
4
|
+
|
|
5
|
+
from shotgun.agents.models import AgentDeps
|
|
6
|
+
from shotgun.codebase.models import QueryType
|
|
7
|
+
from shotgun.logging_config import get_logger
|
|
8
|
+
|
|
9
|
+
from .models import QueryGraphResult
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
async def query_graph(
|
|
15
|
+
ctx: RunContext[AgentDeps], graph_id: str, query: str
|
|
16
|
+
) -> QueryGraphResult:
|
|
17
|
+
"""Query codebase knowledge graph using natural language.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
ctx: RunContext containing AgentDeps with codebase service
|
|
21
|
+
graph_id: Graph ID to query (use the ID, not the name)
|
|
22
|
+
query: Natural language question about the codebase
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
QueryGraphResult with formatted output via __str__
|
|
26
|
+
"""
|
|
27
|
+
logger.debug("š§ Querying graph %s with query: %s", graph_id, query)
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
if not ctx.deps.codebase_service:
|
|
31
|
+
return QueryGraphResult(
|
|
32
|
+
success=False,
|
|
33
|
+
query=query,
|
|
34
|
+
error="No codebase indexed",
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Execute natural language query
|
|
38
|
+
result = await ctx.deps.codebase_service.execute_query(
|
|
39
|
+
graph_id=graph_id,
|
|
40
|
+
query=query,
|
|
41
|
+
query_type=QueryType.NATURAL_LANGUAGE,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Create QueryGraphResult from service result
|
|
45
|
+
graph_result = QueryGraphResult(
|
|
46
|
+
success=result.success,
|
|
47
|
+
query=query,
|
|
48
|
+
cypher_query=result.cypher_query,
|
|
49
|
+
column_names=result.column_names,
|
|
50
|
+
results=result.results,
|
|
51
|
+
row_count=result.row_count,
|
|
52
|
+
execution_time_ms=result.execution_time_ms,
|
|
53
|
+
error=result.error,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
logger.debug(
|
|
57
|
+
"š Query completed: %s with %d results",
|
|
58
|
+
"success" if graph_result.success else "failed",
|
|
59
|
+
graph_result.row_count,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return graph_result
|
|
63
|
+
|
|
64
|
+
except Exception as e:
|
|
65
|
+
error_msg = f"Error querying graph: {str(e)}"
|
|
66
|
+
logger.error("ā Query graph failed: %s", str(e))
|
|
67
|
+
return QueryGraphResult(success=False, query=query, error=error_msg)
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Retrieve source code by qualified name from codebase."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from pydantic_ai import RunContext
|
|
6
|
+
|
|
7
|
+
from shotgun.agents.models import AgentDeps
|
|
8
|
+
from shotgun.codebase.core.code_retrieval import retrieve_code_by_qualified_name
|
|
9
|
+
from shotgun.codebase.core.language_config import get_language_config
|
|
10
|
+
from shotgun.logging_config import get_logger
|
|
11
|
+
|
|
12
|
+
from .models import CodeSnippetResult
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def retrieve_code(
|
|
18
|
+
ctx: RunContext[AgentDeps], graph_id: str, qualified_name: str
|
|
19
|
+
) -> CodeSnippetResult:
|
|
20
|
+
"""Get source code by fully qualified name.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
ctx: RunContext containing AgentDeps with codebase service
|
|
24
|
+
graph_id: Graph ID to query (use the ID, not the name)
|
|
25
|
+
qualified_name: Fully qualified name like "module.Class.method"
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
CodeSnippetResult with formatted output via __str__
|
|
29
|
+
"""
|
|
30
|
+
logger.debug("š§ Retrieving code for: %s in graph %s", qualified_name, graph_id)
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
if not ctx.deps.codebase_service:
|
|
34
|
+
return CodeSnippetResult(
|
|
35
|
+
found=False,
|
|
36
|
+
qualified_name=qualified_name,
|
|
37
|
+
error="No codebase indexed",
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Use the existing code retrieval functionality
|
|
41
|
+
code_snippet = await retrieve_code_by_qualified_name(
|
|
42
|
+
manager=ctx.deps.codebase_service.manager,
|
|
43
|
+
graph_id=graph_id,
|
|
44
|
+
qualified_name=qualified_name,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Detect language from file extension
|
|
48
|
+
language = ""
|
|
49
|
+
if code_snippet.file_path:
|
|
50
|
+
file_extension = Path(code_snippet.file_path).suffix
|
|
51
|
+
language_config = get_language_config(file_extension)
|
|
52
|
+
if language_config:
|
|
53
|
+
language = language_config.name
|
|
54
|
+
|
|
55
|
+
# Convert to our result model
|
|
56
|
+
result = CodeSnippetResult(
|
|
57
|
+
found=code_snippet.found,
|
|
58
|
+
qualified_name=code_snippet.qualified_name,
|
|
59
|
+
file_path=code_snippet.file_path if code_snippet.found else None,
|
|
60
|
+
line_start=code_snippet.line_start if code_snippet.found else None,
|
|
61
|
+
line_end=code_snippet.line_end if code_snippet.found else None,
|
|
62
|
+
source_code=code_snippet.source_code if code_snippet.found else None,
|
|
63
|
+
docstring=code_snippet.docstring,
|
|
64
|
+
language=language,
|
|
65
|
+
error=code_snippet.error_message if not code_snippet.found else None,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
logger.debug(
|
|
69
|
+
"š Retrieved code: %s for %s",
|
|
70
|
+
"found" if result.found else "not found",
|
|
71
|
+
qualified_name,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
except Exception as e:
|
|
77
|
+
error_msg = f"Error retrieving code: {str(e)}"
|
|
78
|
+
logger.error("ā Retrieve code failed: %s", str(e))
|
|
79
|
+
return CodeSnippetResult(
|
|
80
|
+
found=False, qualified_name=qualified_name, error=error_msg
|
|
81
|
+
)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
"""File management tools for Pydantic AI agents.
|
|
2
|
+
|
|
3
|
+
These tools are restricted to the .shotgun directory for security.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
from pydantic_ai import RunContext
|
|
10
|
+
|
|
11
|
+
from shotgun.agents.models import AgentDeps, AgentType, FileOperationType
|
|
12
|
+
from shotgun.logging_config import get_logger
|
|
13
|
+
from shotgun.utils.file_system_utils import get_shotgun_base_path
|
|
14
|
+
|
|
15
|
+
logger = get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
# Map agent modes to their allowed directories/files (in workflow order)
|
|
18
|
+
AGENT_DIRECTORIES = {
|
|
19
|
+
AgentType.RESEARCH: "research.md",
|
|
20
|
+
AgentType.SPECIFY: "specification.md",
|
|
21
|
+
AgentType.PLAN: "plan.md",
|
|
22
|
+
AgentType.TASKS: "tasks.md",
|
|
23
|
+
AgentType.EXPORT: "*", # Export agent can write anywhere except protected files
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
# Files protected from export agent modifications
|
|
27
|
+
PROTECTED_AGENT_FILES = {
|
|
28
|
+
"research.md",
|
|
29
|
+
"specification.md",
|
|
30
|
+
"plan.md",
|
|
31
|
+
"tasks.md",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _validate_agent_scoped_path(filename: str, agent_mode: AgentType | None) -> Path:
|
|
36
|
+
"""Validate and resolve a file path within the agent's scoped directory.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
filename: Relative filename
|
|
40
|
+
agent_mode: The current agent mode
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Absolute path to the file within the agent's scoped directory
|
|
44
|
+
|
|
45
|
+
Raises:
|
|
46
|
+
ValueError: If the path attempts to access files outside the agent's scope
|
|
47
|
+
"""
|
|
48
|
+
base_path = get_shotgun_base_path()
|
|
49
|
+
|
|
50
|
+
if agent_mode and agent_mode in AGENT_DIRECTORIES:
|
|
51
|
+
# For export mode, allow writing to any file except protected agent files
|
|
52
|
+
if agent_mode == AgentType.EXPORT:
|
|
53
|
+
# Check if trying to write to a protected file
|
|
54
|
+
if filename in PROTECTED_AGENT_FILES:
|
|
55
|
+
raise ValueError(
|
|
56
|
+
f"Export agent cannot write to protected file '{filename}'. "
|
|
57
|
+
f"Protected files are: {', '.join(sorted(PROTECTED_AGENT_FILES))}"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Allow writing anywhere else in .shotgun directory
|
|
61
|
+
full_path = (base_path / filename).resolve()
|
|
62
|
+
else:
|
|
63
|
+
# For other agents, only allow writing to their specific file
|
|
64
|
+
allowed_file = AGENT_DIRECTORIES[agent_mode]
|
|
65
|
+
if filename != allowed_file:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
f"{agent_mode.value.capitalize()} agent can only write to '{allowed_file}'. "
|
|
68
|
+
f"Attempted to write to '{filename}'"
|
|
69
|
+
)
|
|
70
|
+
full_path = (base_path / filename).resolve()
|
|
71
|
+
else:
|
|
72
|
+
# No agent mode specified, fall back to old validation
|
|
73
|
+
full_path = (base_path / filename).resolve()
|
|
74
|
+
|
|
75
|
+
# Ensure the resolved path is within the .shotgun directory
|
|
76
|
+
try:
|
|
77
|
+
full_path.relative_to(base_path.resolve())
|
|
78
|
+
except ValueError as e:
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Access denied: Path '{filename}' is outside .shotgun directory"
|
|
81
|
+
) from e
|
|
82
|
+
|
|
83
|
+
return full_path
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _validate_shotgun_path(filename: str) -> Path:
|
|
87
|
+
"""Validate and resolve a file path within the .shotgun directory.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
filename: Relative filename within .shotgun directory
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Absolute path to the file within .shotgun directory
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
ValueError: If the path attempts to access files outside .shotgun directory
|
|
97
|
+
"""
|
|
98
|
+
base_path = get_shotgun_base_path()
|
|
99
|
+
|
|
100
|
+
# Create the full path
|
|
101
|
+
full_path = (base_path / filename).resolve()
|
|
102
|
+
|
|
103
|
+
# Ensure the resolved path is within the .shotgun directory
|
|
104
|
+
try:
|
|
105
|
+
full_path.relative_to(base_path.resolve())
|
|
106
|
+
except ValueError as e:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Access denied: Path '{filename}' is outside .shotgun directory"
|
|
109
|
+
) from e
|
|
110
|
+
|
|
111
|
+
return full_path
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def read_file(ctx: RunContext[AgentDeps], filename: str) -> str:
|
|
115
|
+
"""Read a file from the .shotgun directory.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
filename: Relative path to file within .shotgun directory
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
File contents as string
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
ValueError: If path is outside .shotgun directory
|
|
125
|
+
FileNotFoundError: If file does not exist
|
|
126
|
+
"""
|
|
127
|
+
logger.debug("š§ Reading file: %s", filename)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
file_path = _validate_shotgun_path(filename)
|
|
131
|
+
|
|
132
|
+
if not file_path.exists():
|
|
133
|
+
raise FileNotFoundError(f"File not found: {filename}")
|
|
134
|
+
|
|
135
|
+
content = file_path.read_text(encoding="utf-8")
|
|
136
|
+
logger.debug("š Read %d characters from %s", len(content), filename)
|
|
137
|
+
return content
|
|
138
|
+
|
|
139
|
+
except Exception as e:
|
|
140
|
+
error_msg = f"Error reading file '{filename}': {str(e)}"
|
|
141
|
+
logger.error("ā File read failed: %s", error_msg)
|
|
142
|
+
return error_msg
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
async def write_file(
|
|
146
|
+
ctx: RunContext[AgentDeps],
|
|
147
|
+
filename: str,
|
|
148
|
+
content: str,
|
|
149
|
+
mode: Literal["w", "a"] = "w",
|
|
150
|
+
) -> str:
|
|
151
|
+
"""Write content to a file in the .shotgun directory.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
filename: Relative path to file within .shotgun directory
|
|
155
|
+
content: Content to write to the file
|
|
156
|
+
mode: Write mode - 'w' for overwrite, 'a' for append
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Success message or error message
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
ValueError: If path is outside .shotgun directory or invalid mode
|
|
163
|
+
"""
|
|
164
|
+
logger.debug("š§ Writing file: %s (mode: %s)", filename, mode)
|
|
165
|
+
|
|
166
|
+
if mode not in ["w", "a"]:
|
|
167
|
+
raise ValueError(f"Invalid mode '{mode}'. Use 'w' for write or 'a' for append")
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
# Use agent-scoped validation for write operations
|
|
171
|
+
file_path = _validate_agent_scoped_path(filename, ctx.deps.agent_mode)
|
|
172
|
+
|
|
173
|
+
# Determine operation type
|
|
174
|
+
if mode == "a":
|
|
175
|
+
operation = FileOperationType.UPDATED
|
|
176
|
+
else:
|
|
177
|
+
operation = (
|
|
178
|
+
FileOperationType.CREATED
|
|
179
|
+
if not file_path.exists()
|
|
180
|
+
else FileOperationType.UPDATED
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# Ensure parent directory exists
|
|
184
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
185
|
+
|
|
186
|
+
# Write content
|
|
187
|
+
if mode == "a":
|
|
188
|
+
with open(file_path, "a", encoding="utf-8") as f:
|
|
189
|
+
f.write(content)
|
|
190
|
+
logger.debug("š Appended %d characters to %s", len(content), filename)
|
|
191
|
+
result = f"Successfully appended {len(content)} characters to {filename}"
|
|
192
|
+
else:
|
|
193
|
+
file_path.write_text(content, encoding="utf-8")
|
|
194
|
+
logger.debug("š Wrote %d characters to %s", len(content), filename)
|
|
195
|
+
result = f"Successfully wrote {len(content)} characters to {filename}"
|
|
196
|
+
|
|
197
|
+
# Track the file operation
|
|
198
|
+
ctx.deps.file_tracker.add_operation(file_path, operation)
|
|
199
|
+
|
|
200
|
+
return result
|
|
201
|
+
|
|
202
|
+
except Exception as e:
|
|
203
|
+
error_msg = f"Error writing file '{filename}': {str(e)}"
|
|
204
|
+
logger.error("ā File write failed: %s", error_msg)
|
|
205
|
+
return error_msg
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
async def append_file(ctx: RunContext[AgentDeps], filename: str, content: str) -> str:
|
|
209
|
+
"""Append content to a file in the .shotgun directory.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
filename: Relative path to file within .shotgun directory
|
|
213
|
+
content: Content to append to the file
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Success message or error message
|
|
217
|
+
"""
|
|
218
|
+
return await write_file(ctx, filename, content, mode="a")
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""User interaction tools for Pydantic AI agents."""
|
|
2
|
+
|
|
3
|
+
from asyncio import get_running_loop
|
|
4
|
+
|
|
5
|
+
from pydantic_ai import CallDeferred, RunContext
|
|
6
|
+
|
|
7
|
+
from shotgun.agents.models import AgentDeps, UserQuestion
|
|
8
|
+
from shotgun.logging_config import get_logger
|
|
9
|
+
|
|
10
|
+
logger = get_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def ask_user(ctx: RunContext[AgentDeps], question: str) -> str:
|
|
14
|
+
"""Ask the human a question and return the answer.
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
question: The question to ask the user with a clear CTA at the end. Needs to be is readable, clear, and easy to understand. Use Markdown formatting. Make key phrases and words stand out.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The user's response as a string
|
|
22
|
+
"""
|
|
23
|
+
tool_call_id = ctx.tool_call_id
|
|
24
|
+
assert tool_call_id is not None # noqa: S101
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
logger.debug("\nš %s\n", question)
|
|
28
|
+
future = get_running_loop().create_future()
|
|
29
|
+
await ctx.deps.queue.put(
|
|
30
|
+
UserQuestion(question=question, tool_call_id=tool_call_id, result=future)
|
|
31
|
+
)
|
|
32
|
+
ctx.deps.tasks.append(future)
|
|
33
|
+
raise CallDeferred(question)
|
|
34
|
+
|
|
35
|
+
except (EOFError, KeyboardInterrupt):
|
|
36
|
+
logger.warning("User input interrupted or unavailable")
|
|
37
|
+
return "User input not available or interrupted"
|