shotgun-sh 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

Files changed (130) hide show
  1. shotgun/__init__.py +5 -0
  2. shotgun/agents/__init__.py +1 -0
  3. shotgun/agents/agent_manager.py +651 -0
  4. shotgun/agents/common.py +549 -0
  5. shotgun/agents/config/__init__.py +13 -0
  6. shotgun/agents/config/constants.py +17 -0
  7. shotgun/agents/config/manager.py +294 -0
  8. shotgun/agents/config/models.py +185 -0
  9. shotgun/agents/config/provider.py +206 -0
  10. shotgun/agents/conversation_history.py +106 -0
  11. shotgun/agents/conversation_manager.py +105 -0
  12. shotgun/agents/export.py +96 -0
  13. shotgun/agents/history/__init__.py +5 -0
  14. shotgun/agents/history/compaction.py +85 -0
  15. shotgun/agents/history/constants.py +19 -0
  16. shotgun/agents/history/context_extraction.py +108 -0
  17. shotgun/agents/history/history_building.py +104 -0
  18. shotgun/agents/history/history_processors.py +426 -0
  19. shotgun/agents/history/message_utils.py +84 -0
  20. shotgun/agents/history/token_counting.py +429 -0
  21. shotgun/agents/history/token_estimation.py +138 -0
  22. shotgun/agents/messages.py +35 -0
  23. shotgun/agents/models.py +275 -0
  24. shotgun/agents/plan.py +98 -0
  25. shotgun/agents/research.py +108 -0
  26. shotgun/agents/specify.py +98 -0
  27. shotgun/agents/tasks.py +96 -0
  28. shotgun/agents/tools/__init__.py +34 -0
  29. shotgun/agents/tools/codebase/__init__.py +28 -0
  30. shotgun/agents/tools/codebase/codebase_shell.py +256 -0
  31. shotgun/agents/tools/codebase/directory_lister.py +141 -0
  32. shotgun/agents/tools/codebase/file_read.py +144 -0
  33. shotgun/agents/tools/codebase/models.py +252 -0
  34. shotgun/agents/tools/codebase/query_graph.py +67 -0
  35. shotgun/agents/tools/codebase/retrieve_code.py +81 -0
  36. shotgun/agents/tools/file_management.py +218 -0
  37. shotgun/agents/tools/user_interaction.py +37 -0
  38. shotgun/agents/tools/web_search/__init__.py +60 -0
  39. shotgun/agents/tools/web_search/anthropic.py +144 -0
  40. shotgun/agents/tools/web_search/gemini.py +85 -0
  41. shotgun/agents/tools/web_search/openai.py +98 -0
  42. shotgun/agents/tools/web_search/utils.py +20 -0
  43. shotgun/build_constants.py +20 -0
  44. shotgun/cli/__init__.py +1 -0
  45. shotgun/cli/codebase/__init__.py +5 -0
  46. shotgun/cli/codebase/commands.py +202 -0
  47. shotgun/cli/codebase/models.py +21 -0
  48. shotgun/cli/config.py +275 -0
  49. shotgun/cli/export.py +81 -0
  50. shotgun/cli/models.py +10 -0
  51. shotgun/cli/plan.py +73 -0
  52. shotgun/cli/research.py +85 -0
  53. shotgun/cli/specify.py +69 -0
  54. shotgun/cli/tasks.py +78 -0
  55. shotgun/cli/update.py +152 -0
  56. shotgun/cli/utils.py +25 -0
  57. shotgun/codebase/__init__.py +12 -0
  58. shotgun/codebase/core/__init__.py +46 -0
  59. shotgun/codebase/core/change_detector.py +358 -0
  60. shotgun/codebase/core/code_retrieval.py +243 -0
  61. shotgun/codebase/core/ingestor.py +1497 -0
  62. shotgun/codebase/core/language_config.py +297 -0
  63. shotgun/codebase/core/manager.py +1662 -0
  64. shotgun/codebase/core/nl_query.py +331 -0
  65. shotgun/codebase/core/parser_loader.py +128 -0
  66. shotgun/codebase/models.py +111 -0
  67. shotgun/codebase/service.py +206 -0
  68. shotgun/logging_config.py +227 -0
  69. shotgun/main.py +167 -0
  70. shotgun/posthog_telemetry.py +158 -0
  71. shotgun/prompts/__init__.py +5 -0
  72. shotgun/prompts/agents/__init__.py +1 -0
  73. shotgun/prompts/agents/export.j2 +350 -0
  74. shotgun/prompts/agents/partials/codebase_understanding.j2 +87 -0
  75. shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +37 -0
  76. shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
  77. shotgun/prompts/agents/partials/interactive_mode.j2 +26 -0
  78. shotgun/prompts/agents/plan.j2 +144 -0
  79. shotgun/prompts/agents/research.j2 +69 -0
  80. shotgun/prompts/agents/specify.j2 +51 -0
  81. shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +19 -0
  82. shotgun/prompts/agents/state/system_state.j2 +31 -0
  83. shotgun/prompts/agents/tasks.j2 +143 -0
  84. shotgun/prompts/codebase/__init__.py +1 -0
  85. shotgun/prompts/codebase/cypher_query_patterns.j2 +223 -0
  86. shotgun/prompts/codebase/cypher_system.j2 +28 -0
  87. shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
  88. shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
  89. shotgun/prompts/codebase/partials/graph_schema.j2 +30 -0
  90. shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
  91. shotgun/prompts/history/__init__.py +1 -0
  92. shotgun/prompts/history/incremental_summarization.j2 +53 -0
  93. shotgun/prompts/history/summarization.j2 +46 -0
  94. shotgun/prompts/loader.py +140 -0
  95. shotgun/py.typed +0 -0
  96. shotgun/sdk/__init__.py +13 -0
  97. shotgun/sdk/codebase.py +219 -0
  98. shotgun/sdk/exceptions.py +17 -0
  99. shotgun/sdk/models.py +189 -0
  100. shotgun/sdk/services.py +23 -0
  101. shotgun/sentry_telemetry.py +87 -0
  102. shotgun/telemetry.py +93 -0
  103. shotgun/tui/__init__.py +0 -0
  104. shotgun/tui/app.py +116 -0
  105. shotgun/tui/commands/__init__.py +76 -0
  106. shotgun/tui/components/prompt_input.py +69 -0
  107. shotgun/tui/components/spinner.py +86 -0
  108. shotgun/tui/components/splash.py +25 -0
  109. shotgun/tui/components/vertical_tail.py +13 -0
  110. shotgun/tui/screens/chat.py +782 -0
  111. shotgun/tui/screens/chat.tcss +43 -0
  112. shotgun/tui/screens/chat_screen/__init__.py +0 -0
  113. shotgun/tui/screens/chat_screen/command_providers.py +219 -0
  114. shotgun/tui/screens/chat_screen/hint_message.py +40 -0
  115. shotgun/tui/screens/chat_screen/history.py +221 -0
  116. shotgun/tui/screens/directory_setup.py +113 -0
  117. shotgun/tui/screens/provider_config.py +221 -0
  118. shotgun/tui/screens/splash.py +31 -0
  119. shotgun/tui/styles.tcss +10 -0
  120. shotgun/tui/utils/__init__.py +5 -0
  121. shotgun/tui/utils/mode_progress.py +257 -0
  122. shotgun/utils/__init__.py +5 -0
  123. shotgun/utils/env_utils.py +35 -0
  124. shotgun/utils/file_system_utils.py +36 -0
  125. shotgun/utils/update_checker.py +375 -0
  126. shotgun_sh-0.1.0.dist-info/METADATA +466 -0
  127. shotgun_sh-0.1.0.dist-info/RECORD +130 -0
  128. shotgun_sh-0.1.0.dist-info/WHEEL +4 -0
  129. shotgun_sh-0.1.0.dist-info/entry_points.txt +2 -0
  130. shotgun_sh-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,252 @@
1
+ """Pydantic models for codebase tool outputs."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class QueryGraphResult(BaseModel):
9
+ """Result from graph query with formatted string output."""
10
+
11
+ success: bool = Field(description="Whether the query was successful")
12
+ query: str = Field(description="Original natural language query")
13
+ cypher_query: str | None = Field(None, description="Generated Cypher query")
14
+ column_names: list[str] = Field(
15
+ default_factory=list, description="Result column names"
16
+ )
17
+ results: list[dict[str, Any]] = Field(
18
+ default_factory=list, description="Query results"
19
+ )
20
+ row_count: int = Field(0, description="Number of result rows")
21
+ execution_time_ms: float = Field(
22
+ 0.0, description="Query execution time in milliseconds"
23
+ )
24
+ error: str | None = Field(default=None, description="Error message if failed")
25
+
26
+ def __str__(self) -> str:
27
+ """Format query result for LLM consumption."""
28
+ if not self.success:
29
+ return f"**Query Failed**: {self.error}"
30
+
31
+ if not self.results:
32
+ return f"**No Results Found** for query: {self.query}"
33
+
34
+ output_lines = []
35
+ if self.cypher_query:
36
+ output_lines.append(f"**Generated Cypher**: `{self.cypher_query}`")
37
+ output_lines.append("")
38
+
39
+ output_lines.append(
40
+ f"**Results** ({self.row_count} rows, {self.execution_time_ms:.1f}ms):"
41
+ )
42
+ output_lines.append("")
43
+
44
+ if self.column_names and self.results:
45
+ # Create markdown table
46
+ headers = self.column_names
47
+ output_lines.append("| " + " | ".join(headers) + " |")
48
+ output_lines.append("|" + "|".join([" --- " for _ in headers]) + "|")
49
+
50
+ # Limit to first 10 rows to avoid overwhelming output
51
+ rows_to_show = min(10, len(self.results))
52
+ for row in self.results[:rows_to_show]:
53
+ row_values = []
54
+ for col in headers:
55
+ value = row.get(col, "")
56
+ # Convert to string and truncate if too long
57
+ str_value = str(value) if value is not None else ""
58
+ if len(str_value) > 50:
59
+ str_value = str_value[:47] + "..."
60
+ row_values.append(str_value)
61
+ output_lines.append("| " + " | ".join(row_values) + " |")
62
+
63
+ if self.row_count > rows_to_show:
64
+ output_lines.append(
65
+ f"... and {self.row_count - rows_to_show} more rows"
66
+ )
67
+
68
+ return "\n".join(output_lines)
69
+
70
+
71
+ class CodeSnippetResult(BaseModel):
72
+ """Result from code retrieval with formatted output."""
73
+
74
+ found: bool = Field(description="Whether the code entity was found")
75
+ qualified_name: str = Field(description="Fully qualified name searched for")
76
+ file_path: str | None = Field(None, description="Path to source file")
77
+ line_start: int | None = Field(None, description="Starting line number")
78
+ line_end: int | None = Field(None, description="Ending line number")
79
+ source_code: str | None = Field(None, description="Source code content")
80
+ docstring: str | None = Field(None, description="Docstring if available")
81
+ language: str = Field(
82
+ default="", description="Programming language for syntax highlighting"
83
+ )
84
+ error: str | None = Field(None, description="Error message if not found")
85
+
86
+ def __str__(self) -> str:
87
+ """Format code snippet for LLM consumption."""
88
+ if not self.found:
89
+ error_msg = (
90
+ self.error or f"Entity '{self.qualified_name}' not found in graph"
91
+ )
92
+ return f"**Not Found**: {error_msg}\n\nTry using `query_graph` to search for similar entities or check the qualified name."
93
+
94
+ output_lines = []
95
+ output_lines.append(f"**Qualified Name**: `{self.qualified_name}`")
96
+
97
+ if self.file_path:
98
+ output_lines.append(f"**File**: `{self.file_path}`")
99
+
100
+ if self.line_start and self.line_end:
101
+ output_lines.append(f"**Lines**: {self.line_start}-{self.line_end}")
102
+
103
+ if self.docstring:
104
+ output_lines.append(f"**Docstring**: {self.docstring}")
105
+
106
+ if self.source_code:
107
+ output_lines.append("")
108
+ output_lines.append("**Source Code**:")
109
+ language_tag = self.language if self.language else ""
110
+ output_lines.append(f"```{language_tag}")
111
+ output_lines.append(self.source_code)
112
+ output_lines.append("```")
113
+
114
+ return "\n".join(output_lines)
115
+
116
+
117
+ class FileReadResult(BaseModel):
118
+ """Result from file reading with content output."""
119
+
120
+ success: bool = Field(description="Whether file was read successfully")
121
+ file_path: str = Field(description="Path to file that was read")
122
+ content: str | None = Field(None, description="File content")
123
+ encoding: str = Field("utf-8", description="Encoding used to read file")
124
+ size_bytes: int = Field(0, description="File size in bytes")
125
+ language: str = Field(
126
+ default="", description="Programming language for syntax highlighting"
127
+ )
128
+ error: str | None = Field(default=None, description="Error message if failed")
129
+
130
+ def __str__(self) -> str:
131
+ """Return file content or error message."""
132
+ if not self.success:
133
+ return f"**Error reading file `{self.file_path}`**: {self.error}"
134
+
135
+ output_lines = []
136
+ output_lines.append(f"**File**: `{self.file_path}`")
137
+ output_lines.append(f"**Size**: {self.size_bytes} bytes")
138
+
139
+ if self.encoding != "utf-8":
140
+ output_lines.append(f"**Encoding**: {self.encoding}")
141
+
142
+ output_lines.append("")
143
+ output_lines.append("**Content**:")
144
+ language_tag = self.language if self.language else ""
145
+ output_lines.append(f"```{language_tag}")
146
+ output_lines.append(self.content or "")
147
+ output_lines.append("```")
148
+
149
+ return "\n".join(output_lines)
150
+
151
+
152
+ class DirectoryListResult(BaseModel):
153
+ """Result from directory listing with structured output."""
154
+
155
+ success: bool = Field(description="Whether directory was listed successfully")
156
+ directory: str = Field(description="Directory path that was listed")
157
+ full_path: str = Field(description="Absolute path to directory")
158
+ directories: list[str] = Field(
159
+ default_factory=list, description="Subdirectory names"
160
+ )
161
+ files: list[tuple[str, int]] = Field(
162
+ default_factory=list, description="Files as (name, size_bytes) tuples"
163
+ )
164
+ error: str | None = Field(default=None, description="Error message if failed")
165
+
166
+ def __str__(self) -> str:
167
+ """Format directory listing for LLM consumption."""
168
+ if not self.success:
169
+ return f"**Error listing directory `{self.directory}`**: {self.error}"
170
+
171
+ output_lines = []
172
+ output_lines.append(f"**Directory**: `{self.directory}`")
173
+ output_lines.append(f"**Full Path**: `{self.full_path}`")
174
+ output_lines.append("")
175
+
176
+ if not self.directories and not self.files:
177
+ return "\n".join(output_lines + ["Directory is empty"])
178
+
179
+ if self.directories:
180
+ output_lines.append("**Directories**:")
181
+ for dir_name in self.directories:
182
+ output_lines.append(f" šŸ“ {dir_name}/")
183
+ output_lines.append("")
184
+
185
+ if self.files:
186
+ output_lines.append("**Files**:")
187
+ for file_name, size_bytes in self.files:
188
+ if size_bytes < 1024:
189
+ size_str = f"{size_bytes}B"
190
+ elif size_bytes < 1024 * 1024:
191
+ size_str = f"{size_bytes / 1024:.1f}KB"
192
+ else:
193
+ size_str = f"{size_bytes / (1024 * 1024):.1f}MB"
194
+ output_lines.append(f" šŸ“„ {file_name} ({size_str})")
195
+
196
+ output_lines.append("")
197
+ output_lines.append(
198
+ f"**Total**: {len(self.directories)} directories, {len(self.files)} files"
199
+ )
200
+
201
+ return "\n".join(output_lines)
202
+
203
+
204
+ class ShellCommandResult(BaseModel):
205
+ """Result from shell command execution with formatted output."""
206
+
207
+ success: bool = Field(description="Whether command executed without errors")
208
+ command: str = Field(description="Command that was executed")
209
+ args: list[str] = Field(default_factory=list, description="Command arguments")
210
+ stdout: str = Field(default="", description="Standard output")
211
+ stderr: str = Field(default="", description="Standard error output")
212
+ return_code: int = Field(default=0, description="Process return code")
213
+ execution_time_ms: float = Field(
214
+ default=0.0, description="Execution time in milliseconds"
215
+ )
216
+ error: str | None = Field(
217
+ default=None, description="Error message if execution failed"
218
+ )
219
+
220
+ def __str__(self) -> str:
221
+ """Format command output for LLM consumption."""
222
+ if self.error:
223
+ return f"**Command Failed**: {self.error}"
224
+
225
+ output_lines = []
226
+ cmd_str = f"{self.command} {' '.join(self.args)}".strip()
227
+ output_lines.append(f"**Command**: `{cmd_str}`")
228
+ output_lines.append(f"**Execution Time**: {self.execution_time_ms:.1f}ms")
229
+
230
+ if self.stdout:
231
+ output_lines.append("")
232
+ output_lines.append("**Output**:")
233
+ output_lines.append("```")
234
+ output_lines.append(self.stdout.rstrip())
235
+ output_lines.append("```")
236
+
237
+ if self.stderr:
238
+ output_lines.append("")
239
+ output_lines.append("**Error Output**:")
240
+ output_lines.append("```")
241
+ output_lines.append(self.stderr.rstrip())
242
+ output_lines.append("```")
243
+
244
+ if self.return_code != 0:
245
+ output_lines.append("")
246
+ output_lines.append(f"**Exit Code**: {self.return_code}")
247
+
248
+ if not self.stdout and not self.stderr:
249
+ output_lines.append("")
250
+ output_lines.append("Command executed successfully with no output")
251
+
252
+ return "\n".join(output_lines)
@@ -0,0 +1,67 @@
1
+ """Query codebase knowledge graph using natural language."""
2
+
3
+ from pydantic_ai import RunContext
4
+
5
+ from shotgun.agents.models import AgentDeps
6
+ from shotgun.codebase.models import QueryType
7
+ from shotgun.logging_config import get_logger
8
+
9
+ from .models import QueryGraphResult
10
+
11
+ logger = get_logger(__name__)
12
+
13
+
14
+ async def query_graph(
15
+ ctx: RunContext[AgentDeps], graph_id: str, query: str
16
+ ) -> QueryGraphResult:
17
+ """Query codebase knowledge graph using natural language.
18
+
19
+ Args:
20
+ ctx: RunContext containing AgentDeps with codebase service
21
+ graph_id: Graph ID to query (use the ID, not the name)
22
+ query: Natural language question about the codebase
23
+
24
+ Returns:
25
+ QueryGraphResult with formatted output via __str__
26
+ """
27
+ logger.debug("šŸ”§ Querying graph %s with query: %s", graph_id, query)
28
+
29
+ try:
30
+ if not ctx.deps.codebase_service:
31
+ return QueryGraphResult(
32
+ success=False,
33
+ query=query,
34
+ error="No codebase indexed",
35
+ )
36
+
37
+ # Execute natural language query
38
+ result = await ctx.deps.codebase_service.execute_query(
39
+ graph_id=graph_id,
40
+ query=query,
41
+ query_type=QueryType.NATURAL_LANGUAGE,
42
+ )
43
+
44
+ # Create QueryGraphResult from service result
45
+ graph_result = QueryGraphResult(
46
+ success=result.success,
47
+ query=query,
48
+ cypher_query=result.cypher_query,
49
+ column_names=result.column_names,
50
+ results=result.results,
51
+ row_count=result.row_count,
52
+ execution_time_ms=result.execution_time_ms,
53
+ error=result.error,
54
+ )
55
+
56
+ logger.debug(
57
+ "šŸ“„ Query completed: %s with %d results",
58
+ "success" if graph_result.success else "failed",
59
+ graph_result.row_count,
60
+ )
61
+
62
+ return graph_result
63
+
64
+ except Exception as e:
65
+ error_msg = f"Error querying graph: {str(e)}"
66
+ logger.error("āŒ Query graph failed: %s", str(e))
67
+ return QueryGraphResult(success=False, query=query, error=error_msg)
@@ -0,0 +1,81 @@
1
+ """Retrieve source code by qualified name from codebase."""
2
+
3
+ from pathlib import Path
4
+
5
+ from pydantic_ai import RunContext
6
+
7
+ from shotgun.agents.models import AgentDeps
8
+ from shotgun.codebase.core.code_retrieval import retrieve_code_by_qualified_name
9
+ from shotgun.codebase.core.language_config import get_language_config
10
+ from shotgun.logging_config import get_logger
11
+
12
+ from .models import CodeSnippetResult
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ async def retrieve_code(
18
+ ctx: RunContext[AgentDeps], graph_id: str, qualified_name: str
19
+ ) -> CodeSnippetResult:
20
+ """Get source code by fully qualified name.
21
+
22
+ Args:
23
+ ctx: RunContext containing AgentDeps with codebase service
24
+ graph_id: Graph ID to query (use the ID, not the name)
25
+ qualified_name: Fully qualified name like "module.Class.method"
26
+
27
+ Returns:
28
+ CodeSnippetResult with formatted output via __str__
29
+ """
30
+ logger.debug("šŸ”§ Retrieving code for: %s in graph %s", qualified_name, graph_id)
31
+
32
+ try:
33
+ if not ctx.deps.codebase_service:
34
+ return CodeSnippetResult(
35
+ found=False,
36
+ qualified_name=qualified_name,
37
+ error="No codebase indexed",
38
+ )
39
+
40
+ # Use the existing code retrieval functionality
41
+ code_snippet = await retrieve_code_by_qualified_name(
42
+ manager=ctx.deps.codebase_service.manager,
43
+ graph_id=graph_id,
44
+ qualified_name=qualified_name,
45
+ )
46
+
47
+ # Detect language from file extension
48
+ language = ""
49
+ if code_snippet.file_path:
50
+ file_extension = Path(code_snippet.file_path).suffix
51
+ language_config = get_language_config(file_extension)
52
+ if language_config:
53
+ language = language_config.name
54
+
55
+ # Convert to our result model
56
+ result = CodeSnippetResult(
57
+ found=code_snippet.found,
58
+ qualified_name=code_snippet.qualified_name,
59
+ file_path=code_snippet.file_path if code_snippet.found else None,
60
+ line_start=code_snippet.line_start if code_snippet.found else None,
61
+ line_end=code_snippet.line_end if code_snippet.found else None,
62
+ source_code=code_snippet.source_code if code_snippet.found else None,
63
+ docstring=code_snippet.docstring,
64
+ language=language,
65
+ error=code_snippet.error_message if not code_snippet.found else None,
66
+ )
67
+
68
+ logger.debug(
69
+ "šŸ“„ Retrieved code: %s for %s",
70
+ "found" if result.found else "not found",
71
+ qualified_name,
72
+ )
73
+
74
+ return result
75
+
76
+ except Exception as e:
77
+ error_msg = f"Error retrieving code: {str(e)}"
78
+ logger.error("āŒ Retrieve code failed: %s", str(e))
79
+ return CodeSnippetResult(
80
+ found=False, qualified_name=qualified_name, error=error_msg
81
+ )
@@ -0,0 +1,218 @@
1
+ """File management tools for Pydantic AI agents.
2
+
3
+ These tools are restricted to the .shotgun directory for security.
4
+ """
5
+
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ from pydantic_ai import RunContext
10
+
11
+ from shotgun.agents.models import AgentDeps, AgentType, FileOperationType
12
+ from shotgun.logging_config import get_logger
13
+ from shotgun.utils.file_system_utils import get_shotgun_base_path
14
+
15
+ logger = get_logger(__name__)
16
+
17
+ # Map agent modes to their allowed directories/files (in workflow order)
18
+ AGENT_DIRECTORIES = {
19
+ AgentType.RESEARCH: "research.md",
20
+ AgentType.SPECIFY: "specification.md",
21
+ AgentType.PLAN: "plan.md",
22
+ AgentType.TASKS: "tasks.md",
23
+ AgentType.EXPORT: "*", # Export agent can write anywhere except protected files
24
+ }
25
+
26
+ # Files protected from export agent modifications
27
+ PROTECTED_AGENT_FILES = {
28
+ "research.md",
29
+ "specification.md",
30
+ "plan.md",
31
+ "tasks.md",
32
+ }
33
+
34
+
35
+ def _validate_agent_scoped_path(filename: str, agent_mode: AgentType | None) -> Path:
36
+ """Validate and resolve a file path within the agent's scoped directory.
37
+
38
+ Args:
39
+ filename: Relative filename
40
+ agent_mode: The current agent mode
41
+
42
+ Returns:
43
+ Absolute path to the file within the agent's scoped directory
44
+
45
+ Raises:
46
+ ValueError: If the path attempts to access files outside the agent's scope
47
+ """
48
+ base_path = get_shotgun_base_path()
49
+
50
+ if agent_mode and agent_mode in AGENT_DIRECTORIES:
51
+ # For export mode, allow writing to any file except protected agent files
52
+ if agent_mode == AgentType.EXPORT:
53
+ # Check if trying to write to a protected file
54
+ if filename in PROTECTED_AGENT_FILES:
55
+ raise ValueError(
56
+ f"Export agent cannot write to protected file '{filename}'. "
57
+ f"Protected files are: {', '.join(sorted(PROTECTED_AGENT_FILES))}"
58
+ )
59
+
60
+ # Allow writing anywhere else in .shotgun directory
61
+ full_path = (base_path / filename).resolve()
62
+ else:
63
+ # For other agents, only allow writing to their specific file
64
+ allowed_file = AGENT_DIRECTORIES[agent_mode]
65
+ if filename != allowed_file:
66
+ raise ValueError(
67
+ f"{agent_mode.value.capitalize()} agent can only write to '{allowed_file}'. "
68
+ f"Attempted to write to '{filename}'"
69
+ )
70
+ full_path = (base_path / filename).resolve()
71
+ else:
72
+ # No agent mode specified, fall back to old validation
73
+ full_path = (base_path / filename).resolve()
74
+
75
+ # Ensure the resolved path is within the .shotgun directory
76
+ try:
77
+ full_path.relative_to(base_path.resolve())
78
+ except ValueError as e:
79
+ raise ValueError(
80
+ f"Access denied: Path '{filename}' is outside .shotgun directory"
81
+ ) from e
82
+
83
+ return full_path
84
+
85
+
86
+ def _validate_shotgun_path(filename: str) -> Path:
87
+ """Validate and resolve a file path within the .shotgun directory.
88
+
89
+ Args:
90
+ filename: Relative filename within .shotgun directory
91
+
92
+ Returns:
93
+ Absolute path to the file within .shotgun directory
94
+
95
+ Raises:
96
+ ValueError: If the path attempts to access files outside .shotgun directory
97
+ """
98
+ base_path = get_shotgun_base_path()
99
+
100
+ # Create the full path
101
+ full_path = (base_path / filename).resolve()
102
+
103
+ # Ensure the resolved path is within the .shotgun directory
104
+ try:
105
+ full_path.relative_to(base_path.resolve())
106
+ except ValueError as e:
107
+ raise ValueError(
108
+ f"Access denied: Path '{filename}' is outside .shotgun directory"
109
+ ) from e
110
+
111
+ return full_path
112
+
113
+
114
+ async def read_file(ctx: RunContext[AgentDeps], filename: str) -> str:
115
+ """Read a file from the .shotgun directory.
116
+
117
+ Args:
118
+ filename: Relative path to file within .shotgun directory
119
+
120
+ Returns:
121
+ File contents as string
122
+
123
+ Raises:
124
+ ValueError: If path is outside .shotgun directory
125
+ FileNotFoundError: If file does not exist
126
+ """
127
+ logger.debug("šŸ”§ Reading file: %s", filename)
128
+
129
+ try:
130
+ file_path = _validate_shotgun_path(filename)
131
+
132
+ if not file_path.exists():
133
+ raise FileNotFoundError(f"File not found: {filename}")
134
+
135
+ content = file_path.read_text(encoding="utf-8")
136
+ logger.debug("šŸ“„ Read %d characters from %s", len(content), filename)
137
+ return content
138
+
139
+ except Exception as e:
140
+ error_msg = f"Error reading file '{filename}': {str(e)}"
141
+ logger.error("āŒ File read failed: %s", error_msg)
142
+ return error_msg
143
+
144
+
145
+ async def write_file(
146
+ ctx: RunContext[AgentDeps],
147
+ filename: str,
148
+ content: str,
149
+ mode: Literal["w", "a"] = "w",
150
+ ) -> str:
151
+ """Write content to a file in the .shotgun directory.
152
+
153
+ Args:
154
+ filename: Relative path to file within .shotgun directory
155
+ content: Content to write to the file
156
+ mode: Write mode - 'w' for overwrite, 'a' for append
157
+
158
+ Returns:
159
+ Success message or error message
160
+
161
+ Raises:
162
+ ValueError: If path is outside .shotgun directory or invalid mode
163
+ """
164
+ logger.debug("šŸ”§ Writing file: %s (mode: %s)", filename, mode)
165
+
166
+ if mode not in ["w", "a"]:
167
+ raise ValueError(f"Invalid mode '{mode}'. Use 'w' for write or 'a' for append")
168
+
169
+ try:
170
+ # Use agent-scoped validation for write operations
171
+ file_path = _validate_agent_scoped_path(filename, ctx.deps.agent_mode)
172
+
173
+ # Determine operation type
174
+ if mode == "a":
175
+ operation = FileOperationType.UPDATED
176
+ else:
177
+ operation = (
178
+ FileOperationType.CREATED
179
+ if not file_path.exists()
180
+ else FileOperationType.UPDATED
181
+ )
182
+
183
+ # Ensure parent directory exists
184
+ file_path.parent.mkdir(parents=True, exist_ok=True)
185
+
186
+ # Write content
187
+ if mode == "a":
188
+ with open(file_path, "a", encoding="utf-8") as f:
189
+ f.write(content)
190
+ logger.debug("šŸ“„ Appended %d characters to %s", len(content), filename)
191
+ result = f"Successfully appended {len(content)} characters to {filename}"
192
+ else:
193
+ file_path.write_text(content, encoding="utf-8")
194
+ logger.debug("šŸ“„ Wrote %d characters to %s", len(content), filename)
195
+ result = f"Successfully wrote {len(content)} characters to {filename}"
196
+
197
+ # Track the file operation
198
+ ctx.deps.file_tracker.add_operation(file_path, operation)
199
+
200
+ return result
201
+
202
+ except Exception as e:
203
+ error_msg = f"Error writing file '{filename}': {str(e)}"
204
+ logger.error("āŒ File write failed: %s", error_msg)
205
+ return error_msg
206
+
207
+
208
+ async def append_file(ctx: RunContext[AgentDeps], filename: str, content: str) -> str:
209
+ """Append content to a file in the .shotgun directory.
210
+
211
+ Args:
212
+ filename: Relative path to file within .shotgun directory
213
+ content: Content to append to the file
214
+
215
+ Returns:
216
+ Success message or error message
217
+ """
218
+ return await write_file(ctx, filename, content, mode="a")
@@ -0,0 +1,37 @@
1
+ """User interaction tools for Pydantic AI agents."""
2
+
3
+ from asyncio import get_running_loop
4
+
5
+ from pydantic_ai import CallDeferred, RunContext
6
+
7
+ from shotgun.agents.models import AgentDeps, UserQuestion
8
+ from shotgun.logging_config import get_logger
9
+
10
+ logger = get_logger(__name__)
11
+
12
+
13
+ async def ask_user(ctx: RunContext[AgentDeps], question: str) -> str:
14
+ """Ask the human a question and return the answer.
15
+
16
+
17
+ Args:
18
+ question: The question to ask the user with a clear CTA at the end. Needs to be is readable, clear, and easy to understand. Use Markdown formatting. Make key phrases and words stand out.
19
+
20
+ Returns:
21
+ The user's response as a string
22
+ """
23
+ tool_call_id = ctx.tool_call_id
24
+ assert tool_call_id is not None # noqa: S101
25
+
26
+ try:
27
+ logger.debug("\nšŸ‘‰ %s\n", question)
28
+ future = get_running_loop().create_future()
29
+ await ctx.deps.queue.put(
30
+ UserQuestion(question=question, tool_call_id=tool_call_id, result=future)
31
+ )
32
+ ctx.deps.tasks.append(future)
33
+ raise CallDeferred(question)
34
+
35
+ except (EOFError, KeyboardInterrupt):
36
+ logger.warning("User input interrupted or unavailable")
37
+ return "User input not available or interrupted"