shotgun-sh 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/__init__.py +5 -0
- shotgun/agents/__init__.py +1 -0
- shotgun/agents/agent_manager.py +651 -0
- shotgun/agents/common.py +549 -0
- shotgun/agents/config/__init__.py +13 -0
- shotgun/agents/config/constants.py +17 -0
- shotgun/agents/config/manager.py +294 -0
- shotgun/agents/config/models.py +185 -0
- shotgun/agents/config/provider.py +206 -0
- shotgun/agents/conversation_history.py +106 -0
- shotgun/agents/conversation_manager.py +105 -0
- shotgun/agents/export.py +96 -0
- shotgun/agents/history/__init__.py +5 -0
- shotgun/agents/history/compaction.py +85 -0
- shotgun/agents/history/constants.py +19 -0
- shotgun/agents/history/context_extraction.py +108 -0
- shotgun/agents/history/history_building.py +104 -0
- shotgun/agents/history/history_processors.py +426 -0
- shotgun/agents/history/message_utils.py +84 -0
- shotgun/agents/history/token_counting.py +429 -0
- shotgun/agents/history/token_estimation.py +138 -0
- shotgun/agents/messages.py +35 -0
- shotgun/agents/models.py +275 -0
- shotgun/agents/plan.py +98 -0
- shotgun/agents/research.py +108 -0
- shotgun/agents/specify.py +98 -0
- shotgun/agents/tasks.py +96 -0
- shotgun/agents/tools/__init__.py +34 -0
- shotgun/agents/tools/codebase/__init__.py +28 -0
- shotgun/agents/tools/codebase/codebase_shell.py +256 -0
- shotgun/agents/tools/codebase/directory_lister.py +141 -0
- shotgun/agents/tools/codebase/file_read.py +144 -0
- shotgun/agents/tools/codebase/models.py +252 -0
- shotgun/agents/tools/codebase/query_graph.py +67 -0
- shotgun/agents/tools/codebase/retrieve_code.py +81 -0
- shotgun/agents/tools/file_management.py +218 -0
- shotgun/agents/tools/user_interaction.py +37 -0
- shotgun/agents/tools/web_search/__init__.py +60 -0
- shotgun/agents/tools/web_search/anthropic.py +144 -0
- shotgun/agents/tools/web_search/gemini.py +85 -0
- shotgun/agents/tools/web_search/openai.py +98 -0
- shotgun/agents/tools/web_search/utils.py +20 -0
- shotgun/build_constants.py +20 -0
- shotgun/cli/__init__.py +1 -0
- shotgun/cli/codebase/__init__.py +5 -0
- shotgun/cli/codebase/commands.py +202 -0
- shotgun/cli/codebase/models.py +21 -0
- shotgun/cli/config.py +275 -0
- shotgun/cli/export.py +81 -0
- shotgun/cli/models.py +10 -0
- shotgun/cli/plan.py +73 -0
- shotgun/cli/research.py +85 -0
- shotgun/cli/specify.py +69 -0
- shotgun/cli/tasks.py +78 -0
- shotgun/cli/update.py +152 -0
- shotgun/cli/utils.py +25 -0
- shotgun/codebase/__init__.py +12 -0
- shotgun/codebase/core/__init__.py +46 -0
- shotgun/codebase/core/change_detector.py +358 -0
- shotgun/codebase/core/code_retrieval.py +243 -0
- shotgun/codebase/core/ingestor.py +1497 -0
- shotgun/codebase/core/language_config.py +297 -0
- shotgun/codebase/core/manager.py +1662 -0
- shotgun/codebase/core/nl_query.py +331 -0
- shotgun/codebase/core/parser_loader.py +128 -0
- shotgun/codebase/models.py +111 -0
- shotgun/codebase/service.py +206 -0
- shotgun/logging_config.py +227 -0
- shotgun/main.py +167 -0
- shotgun/posthog_telemetry.py +158 -0
- shotgun/prompts/__init__.py +5 -0
- shotgun/prompts/agents/__init__.py +1 -0
- shotgun/prompts/agents/export.j2 +350 -0
- shotgun/prompts/agents/partials/codebase_understanding.j2 +87 -0
- shotgun/prompts/agents/partials/common_agent_system_prompt.j2 +37 -0
- shotgun/prompts/agents/partials/content_formatting.j2 +65 -0
- shotgun/prompts/agents/partials/interactive_mode.j2 +26 -0
- shotgun/prompts/agents/plan.j2 +144 -0
- shotgun/prompts/agents/research.j2 +69 -0
- shotgun/prompts/agents/specify.j2 +51 -0
- shotgun/prompts/agents/state/codebase/codebase_graphs_available.j2 +19 -0
- shotgun/prompts/agents/state/system_state.j2 +31 -0
- shotgun/prompts/agents/tasks.j2 +143 -0
- shotgun/prompts/codebase/__init__.py +1 -0
- shotgun/prompts/codebase/cypher_query_patterns.j2 +223 -0
- shotgun/prompts/codebase/cypher_system.j2 +28 -0
- shotgun/prompts/codebase/enhanced_query_context.j2 +10 -0
- shotgun/prompts/codebase/partials/cypher_rules.j2 +24 -0
- shotgun/prompts/codebase/partials/graph_schema.j2 +30 -0
- shotgun/prompts/codebase/partials/temporal_context.j2 +21 -0
- shotgun/prompts/history/__init__.py +1 -0
- shotgun/prompts/history/incremental_summarization.j2 +53 -0
- shotgun/prompts/history/summarization.j2 +46 -0
- shotgun/prompts/loader.py +140 -0
- shotgun/py.typed +0 -0
- shotgun/sdk/__init__.py +13 -0
- shotgun/sdk/codebase.py +219 -0
- shotgun/sdk/exceptions.py +17 -0
- shotgun/sdk/models.py +189 -0
- shotgun/sdk/services.py +23 -0
- shotgun/sentry_telemetry.py +87 -0
- shotgun/telemetry.py +93 -0
- shotgun/tui/__init__.py +0 -0
- shotgun/tui/app.py +116 -0
- shotgun/tui/commands/__init__.py +76 -0
- shotgun/tui/components/prompt_input.py +69 -0
- shotgun/tui/components/spinner.py +86 -0
- shotgun/tui/components/splash.py +25 -0
- shotgun/tui/components/vertical_tail.py +13 -0
- shotgun/tui/screens/chat.py +782 -0
- shotgun/tui/screens/chat.tcss +43 -0
- shotgun/tui/screens/chat_screen/__init__.py +0 -0
- shotgun/tui/screens/chat_screen/command_providers.py +219 -0
- shotgun/tui/screens/chat_screen/hint_message.py +40 -0
- shotgun/tui/screens/chat_screen/history.py +221 -0
- shotgun/tui/screens/directory_setup.py +113 -0
- shotgun/tui/screens/provider_config.py +221 -0
- shotgun/tui/screens/splash.py +31 -0
- shotgun/tui/styles.tcss +10 -0
- shotgun/tui/utils/__init__.py +5 -0
- shotgun/tui/utils/mode_progress.py +257 -0
- shotgun/utils/__init__.py +5 -0
- shotgun/utils/env_utils.py +35 -0
- shotgun/utils/file_system_utils.py +36 -0
- shotgun/utils/update_checker.py +375 -0
- shotgun_sh-0.1.0.dist-info/METADATA +466 -0
- shotgun_sh-0.1.0.dist-info/RECORD +130 -0
- shotgun_sh-0.1.0.dist-info/WHEEL +4 -0
- shotgun_sh-0.1.0.dist-info/entry_points.txt +2 -0
- shotgun_sh-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Natural language to Cypher query conversion for code graphs."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.messages import (
|
|
8
|
+
ModelRequest,
|
|
9
|
+
SystemPromptPart,
|
|
10
|
+
TextPart,
|
|
11
|
+
UserPromptPart,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from shotgun.agents.config import get_provider_model
|
|
15
|
+
from shotgun.agents.config.models import shotgun_model_request
|
|
16
|
+
from shotgun.logging_config import get_logger
|
|
17
|
+
from shotgun.prompts import PromptLoader
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from openai import AsyncOpenAI
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
# Global prompt loader instance
|
|
25
|
+
prompt_loader = PromptLoader()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
async def llm_cypher_prompt(system_prompt: str, user_prompt: str) -> str:
|
|
29
|
+
"""Generate a Cypher query from a natural language prompt using the configured LLM provider.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
system_prompt: The system prompt defining the behavior and context for the LLM
|
|
33
|
+
user_prompt: The user's natural language query
|
|
34
|
+
Returns:
|
|
35
|
+
The generated Cypher query as a string
|
|
36
|
+
"""
|
|
37
|
+
model_config = get_provider_model()
|
|
38
|
+
# Use shotgun wrapper to maximize response quality for codebase queries
|
|
39
|
+
# Limit max_tokens to 2000 for Cypher queries (they're typically 50-200 tokens)
|
|
40
|
+
# This prevents Anthropic SDK from requiring streaming for longer token limits
|
|
41
|
+
query_cypher_response = await shotgun_model_request(
|
|
42
|
+
model_config=model_config,
|
|
43
|
+
messages=[
|
|
44
|
+
ModelRequest(
|
|
45
|
+
parts=[
|
|
46
|
+
SystemPromptPart(content=system_prompt),
|
|
47
|
+
UserPromptPart(content=user_prompt),
|
|
48
|
+
]
|
|
49
|
+
),
|
|
50
|
+
],
|
|
51
|
+
max_tokens=2000, # Cypher queries are short, 2000 tokens is plenty
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if not query_cypher_response.parts or not query_cypher_response.parts[0]:
|
|
55
|
+
raise ValueError("Empty response from LLM")
|
|
56
|
+
|
|
57
|
+
message_part = query_cypher_response.parts[0]
|
|
58
|
+
if not isinstance(message_part, TextPart):
|
|
59
|
+
raise ValueError("Unexpected response part type from LLM")
|
|
60
|
+
cypher_query = str(message_part.content)
|
|
61
|
+
if not cypher_query:
|
|
62
|
+
raise ValueError("Empty content in LLM response")
|
|
63
|
+
return cypher_query
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def generate_cypher(natural_language_query: str) -> str:
|
|
67
|
+
"""Convert a natural language query to Cypher using Shotgun's LLM client.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
natural_language_query: The user's query in natural language
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Generated Cypher query
|
|
74
|
+
"""
|
|
75
|
+
# Get current time for context
|
|
76
|
+
current_timestamp = int(time.time())
|
|
77
|
+
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
78
|
+
|
|
79
|
+
# Generate system prompt using template
|
|
80
|
+
system_prompt = prompt_loader.render("codebase/cypher_system.j2")
|
|
81
|
+
|
|
82
|
+
# Generate enhanced query using template
|
|
83
|
+
enhanced_query = prompt_loader.render(
|
|
84
|
+
"codebase/enhanced_query_context.j2",
|
|
85
|
+
current_datetime=current_datetime,
|
|
86
|
+
current_timestamp=current_timestamp,
|
|
87
|
+
natural_language_query=natural_language_query,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
cypher_query = await llm_cypher_prompt(system_prompt, enhanced_query)
|
|
92
|
+
cleaned_query = clean_cypher_response(cypher_query)
|
|
93
|
+
|
|
94
|
+
# Validate UNION ALL queries
|
|
95
|
+
is_valid, validation_error = validate_union_query(cleaned_query)
|
|
96
|
+
if not is_valid:
|
|
97
|
+
logger.warning(f"Generated query failed validation: {validation_error}")
|
|
98
|
+
logger.warning(f"Problematic query: {cleaned_query}")
|
|
99
|
+
raise ValueError(f"Generated query validation failed: {validation_error}")
|
|
100
|
+
|
|
101
|
+
return cleaned_query
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
raise RuntimeError(f"Failed to generate Cypher query: {e}") from e
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def generate_cypher_with_error_context(
|
|
108
|
+
natural_language_query: str, error_context: str = ""
|
|
109
|
+
) -> str:
|
|
110
|
+
"""Convert a natural language query to Cypher with additional error context for retry scenarios.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
natural_language_query: The user's query in natural language
|
|
114
|
+
error_context: Additional context about previous errors to help generate better query
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Generated Cypher query
|
|
118
|
+
"""
|
|
119
|
+
# Get current time for context
|
|
120
|
+
current_timestamp = int(time.time())
|
|
121
|
+
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
122
|
+
|
|
123
|
+
# Generate enhanced query with error context using template
|
|
124
|
+
enhanced_query = prompt_loader.render_string(
|
|
125
|
+
"""Current datetime: {{ current_datetime }} (Unix timestamp: {{ current_timestamp }})
|
|
126
|
+
|
|
127
|
+
User query: {{ natural_language_query }}
|
|
128
|
+
|
|
129
|
+
ERROR CONTEXT (CRITICAL - Previous attempt failed):
|
|
130
|
+
{{ error_context }}
|
|
131
|
+
|
|
132
|
+
IMPORTANT: All timestamps in the database are stored as Unix timestamps (INT64). When generating time-based queries:
|
|
133
|
+
- For "2 minutes ago": use {{ current_timestamp - 120 }}
|
|
134
|
+
- For "1 hour ago": use {{ current_timestamp - 3600 }}
|
|
135
|
+
- For "today": use timestamps >= {{ current_timestamp - (current_timestamp % 86400) }}
|
|
136
|
+
- For "yesterday": use timestamps between {{ current_timestamp - 86400 - (current_timestamp % 86400) }} and {{ current_timestamp - (current_timestamp % 86400) }}
|
|
137
|
+
- NEVER use placeholder values like 1704067200, always calculate based on the current timestamp: {{ current_timestamp }}""",
|
|
138
|
+
current_datetime=current_datetime,
|
|
139
|
+
current_timestamp=current_timestamp,
|
|
140
|
+
natural_language_query=natural_language_query,
|
|
141
|
+
error_context=error_context,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
# Create enhanced system prompt with error recovery instructions
|
|
146
|
+
enhanced_system_prompt = prompt_loader.render_string(
|
|
147
|
+
"""{{ base_system_prompt }}
|
|
148
|
+
|
|
149
|
+
**CRITICAL ERROR RECOVERY INSTRUCTIONS:**
|
|
150
|
+
When retrying after a UNION ALL error:
|
|
151
|
+
1. Each UNION ALL branch MUST return exactly the same number of columns
|
|
152
|
+
2. Column names MUST be in the same order across all branches
|
|
153
|
+
3. Use explicit column aliases to ensure consistency: RETURN prop1 as name, prop2 as qualified_name, 'Type' as type
|
|
154
|
+
4. If different node types have different properties, use COALESCE or NULL for missing properties
|
|
155
|
+
5. Test each UNION branch separately before combining
|
|
156
|
+
|
|
157
|
+
Example of CORRECT UNION ALL:
|
|
158
|
+
```cypher
|
|
159
|
+
MATCH (c:Class) RETURN c.name as name, c.qualified_name as qualified_name, 'Class' as type
|
|
160
|
+
UNION ALL
|
|
161
|
+
MATCH (f:Function) RETURN f.name as name, f.qualified_name as qualified_name, 'Function' as type
|
|
162
|
+
```
|
|
163
|
+
|
|
164
|
+
Example of INCORRECT UNION ALL (different column counts):
|
|
165
|
+
```cypher
|
|
166
|
+
MATCH (c:Class) RETURN c.name, c.qualified_name, c.docstring
|
|
167
|
+
UNION ALL
|
|
168
|
+
MATCH (f:Function) RETURN f.name, f.qualified_name // WRONG: missing third column
|
|
169
|
+
```""",
|
|
170
|
+
base_system_prompt=prompt_loader.render("codebase/cypher_system.j2"),
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
cypher_query = await llm_cypher_prompt(enhanced_system_prompt, enhanced_query)
|
|
174
|
+
cleaned_query = clean_cypher_response(cypher_query)
|
|
175
|
+
|
|
176
|
+
# Validate UNION ALL queries
|
|
177
|
+
is_valid, validation_error = validate_union_query(cleaned_query)
|
|
178
|
+
if not is_valid:
|
|
179
|
+
logger.warning(f"Retry query failed validation: {validation_error}")
|
|
180
|
+
logger.warning(f"Problematic retry query: {cleaned_query}")
|
|
181
|
+
raise ValueError(f"Retry query validation failed: {validation_error}")
|
|
182
|
+
|
|
183
|
+
return cleaned_query
|
|
184
|
+
|
|
185
|
+
except Exception as e:
|
|
186
|
+
raise RuntimeError(
|
|
187
|
+
f"Failed to generate Cypher query with error context: {e}"
|
|
188
|
+
) from e
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
async def generate_cypher_openai_async(
|
|
192
|
+
client: "AsyncOpenAI", natural_language_query: str, model: str = "gpt-4o"
|
|
193
|
+
) -> str:
|
|
194
|
+
"""Convert a natural language query to Cypher using async OpenAI client.
|
|
195
|
+
|
|
196
|
+
This function is for standalone usage without Shotgun's LLM infrastructure.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
client: Async OpenAI client instance
|
|
200
|
+
natural_language_query: The user's query in natural language
|
|
201
|
+
model: OpenAI model to use (default: gpt-4o)
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Generated Cypher query
|
|
205
|
+
"""
|
|
206
|
+
# Get current time for context
|
|
207
|
+
current_timestamp = int(time.time())
|
|
208
|
+
current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
209
|
+
|
|
210
|
+
# Generate system prompt using template
|
|
211
|
+
system_prompt = prompt_loader.render("codebase/cypher_system.j2")
|
|
212
|
+
|
|
213
|
+
# Generate enhanced query using template
|
|
214
|
+
enhanced_query = prompt_loader.render(
|
|
215
|
+
"codebase/enhanced_query_context.j2",
|
|
216
|
+
current_datetime=current_datetime,
|
|
217
|
+
current_timestamp=current_timestamp,
|
|
218
|
+
natural_language_query=natural_language_query,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
cypher_query = await llm_cypher_prompt(system_prompt, enhanced_query)
|
|
223
|
+
return clean_cypher_response(cypher_query)
|
|
224
|
+
|
|
225
|
+
except Exception as e:
|
|
226
|
+
logger.error(f"OpenAI API error: {e}")
|
|
227
|
+
raise RuntimeError(f"Failed to generate Cypher query: {e}") from e
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def validate_union_query(cypher_query: str) -> tuple[bool, str]:
|
|
231
|
+
"""Validate that UNION ALL queries have matching column counts and names.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
cypher_query: The Cypher query to validate
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
Tuple of (is_valid, error_message)
|
|
238
|
+
"""
|
|
239
|
+
query_upper = cypher_query.upper()
|
|
240
|
+
if "UNION ALL" not in query_upper:
|
|
241
|
+
return True, ""
|
|
242
|
+
|
|
243
|
+
# Split by UNION ALL and extract RETURN clauses
|
|
244
|
+
parts = query_upper.split("UNION ALL")
|
|
245
|
+
return_patterns = []
|
|
246
|
+
|
|
247
|
+
for i, part in enumerate(parts):
|
|
248
|
+
if "RETURN" not in part:
|
|
249
|
+
continue
|
|
250
|
+
|
|
251
|
+
# Extract the RETURN clause
|
|
252
|
+
return_start = part.rfind("RETURN")
|
|
253
|
+
return_clause = part[return_start + 6 :] # Skip "RETURN "
|
|
254
|
+
|
|
255
|
+
# Stop at ORDER BY, LIMIT, or end of query
|
|
256
|
+
for stop_word in ["ORDER BY", "LIMIT", ";"]:
|
|
257
|
+
if stop_word in return_clause:
|
|
258
|
+
return_clause = return_clause.split(stop_word)[0]
|
|
259
|
+
|
|
260
|
+
# Parse columns (basic parsing - split by comma and handle AS aliases)
|
|
261
|
+
columns = []
|
|
262
|
+
for col in return_clause.split(","):
|
|
263
|
+
col = col.strip()
|
|
264
|
+
if " AS " in col:
|
|
265
|
+
# Extract the alias name after AS
|
|
266
|
+
alias = col.split(" AS ")[-1].strip()
|
|
267
|
+
columns.append(alias)
|
|
268
|
+
else:
|
|
269
|
+
# Use the column name as-is (simplified)
|
|
270
|
+
columns.append(col.strip())
|
|
271
|
+
|
|
272
|
+
return_patterns.append((i, columns))
|
|
273
|
+
|
|
274
|
+
# Check all parts have same number of columns
|
|
275
|
+
if len(return_patterns) < 2:
|
|
276
|
+
return True, ""
|
|
277
|
+
|
|
278
|
+
first_part, first_columns = return_patterns[0]
|
|
279
|
+
first_count = len(first_columns)
|
|
280
|
+
|
|
281
|
+
for part_idx, columns in return_patterns[1:]:
|
|
282
|
+
if len(columns) != first_count:
|
|
283
|
+
return (
|
|
284
|
+
False,
|
|
285
|
+
f"UNION ALL part {part_idx + 1} has {len(columns)} columns, expected {first_count}. First part columns: {first_columns}, this part: {columns}",
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return True, ""
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def clean_cypher_response(response_text: str) -> str:
|
|
292
|
+
"""Clean up common LLM formatting artifacts from a Cypher query.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
response_text: Raw response from LLM
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Cleaned Cypher query
|
|
299
|
+
"""
|
|
300
|
+
query = response_text.strip()
|
|
301
|
+
|
|
302
|
+
# Remove markdown code blocks
|
|
303
|
+
if query.startswith("```"):
|
|
304
|
+
lines = query.split("\n")
|
|
305
|
+
# Find the actual query content
|
|
306
|
+
start_idx = 0
|
|
307
|
+
end_idx = len(lines)
|
|
308
|
+
|
|
309
|
+
for i, line in enumerate(lines):
|
|
310
|
+
if line.startswith("```") and i == 0:
|
|
311
|
+
start_idx = 1
|
|
312
|
+
elif line.startswith("```") and i > 0:
|
|
313
|
+
end_idx = i
|
|
314
|
+
break
|
|
315
|
+
|
|
316
|
+
query = "\n".join(lines[start_idx:end_idx])
|
|
317
|
+
|
|
318
|
+
# Remove 'cypher' prefix if present
|
|
319
|
+
query = query.strip()
|
|
320
|
+
if query.lower().startswith("cypher"):
|
|
321
|
+
query = query[6:].strip()
|
|
322
|
+
|
|
323
|
+
# Remove backticks
|
|
324
|
+
query = query.replace("`", "")
|
|
325
|
+
|
|
326
|
+
# Ensure it ends with semicolon
|
|
327
|
+
query = query.strip()
|
|
328
|
+
if not query.endswith(";"):
|
|
329
|
+
query += ";"
|
|
330
|
+
|
|
331
|
+
return query
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Tree-sitter parser loader for code parsing."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from tree_sitter import Language, Parser
|
|
8
|
+
|
|
9
|
+
from shotgun.codebase.core.language_config import LANGUAGE_CONFIGS
|
|
10
|
+
from shotgun.logging_config import get_logger
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_parsers() -> tuple[dict[str, Parser], dict[str, Any]]:
|
|
16
|
+
"""Load available Tree-sitter parsers and compile their queries.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Tuple of (parsers dict, queries dict)
|
|
20
|
+
"""
|
|
21
|
+
parsers: dict[str, Parser] = {}
|
|
22
|
+
queries: dict[str, Any] = {}
|
|
23
|
+
available_languages = []
|
|
24
|
+
|
|
25
|
+
# Try to import available language libraries
|
|
26
|
+
language_loaders: dict[str, Callable[[], Any]] = {}
|
|
27
|
+
|
|
28
|
+
# Try individual language imports first
|
|
29
|
+
try:
|
|
30
|
+
import tree_sitter_python
|
|
31
|
+
|
|
32
|
+
language_loaders["python"] = lambda: tree_sitter_python.language()
|
|
33
|
+
available_languages.append("python")
|
|
34
|
+
except ImportError as e:
|
|
35
|
+
logger.warning(f"Failed to import tree_sitter_python: {e}")
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import tree_sitter_javascript
|
|
39
|
+
|
|
40
|
+
language_loaders["javascript"] = lambda: tree_sitter_javascript.language()
|
|
41
|
+
available_languages.append("javascript")
|
|
42
|
+
except ImportError as e:
|
|
43
|
+
logger.warning(f"Failed to import tree_sitter_javascript: {e}")
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
import tree_sitter_typescript
|
|
47
|
+
|
|
48
|
+
language_loaders["typescript"] = (
|
|
49
|
+
lambda: tree_sitter_typescript.language_typescript()
|
|
50
|
+
)
|
|
51
|
+
available_languages.append("typescript")
|
|
52
|
+
except ImportError as e:
|
|
53
|
+
logger.warning(f"Failed to import tree_sitter_typescript: {e}")
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
import tree_sitter_go
|
|
57
|
+
|
|
58
|
+
language_loaders["go"] = lambda: tree_sitter_go.language()
|
|
59
|
+
available_languages.append("go")
|
|
60
|
+
except ImportError as e:
|
|
61
|
+
logger.warning(f"Failed to import tree_sitter_go: {e}")
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
import tree_sitter_rust
|
|
65
|
+
|
|
66
|
+
language_loaders["rust"] = lambda: tree_sitter_rust.language()
|
|
67
|
+
available_languages.append("rust")
|
|
68
|
+
except ImportError as e:
|
|
69
|
+
logger.warning(f"Failed to import tree_sitter_rust: {e}")
|
|
70
|
+
|
|
71
|
+
logger.info(f"Available languages: {', '.join(available_languages)}")
|
|
72
|
+
|
|
73
|
+
# Create parsers for available languages
|
|
74
|
+
for lang_name, lang_loader in language_loaders.items():
|
|
75
|
+
if lang_name in LANGUAGE_CONFIGS:
|
|
76
|
+
try:
|
|
77
|
+
parser = Parser()
|
|
78
|
+
# Handle both function and direct language object
|
|
79
|
+
if callable(lang_loader):
|
|
80
|
+
lang_obj = lang_loader()
|
|
81
|
+
else:
|
|
82
|
+
lang_obj = lang_loader
|
|
83
|
+
|
|
84
|
+
# Create Language object if needed
|
|
85
|
+
if not isinstance(lang_obj, Language):
|
|
86
|
+
lang_obj = Language(lang_obj)
|
|
87
|
+
|
|
88
|
+
parser.language = lang_obj
|
|
89
|
+
parsers[lang_name] = parser
|
|
90
|
+
|
|
91
|
+
# Compile queries for this language
|
|
92
|
+
config = LANGUAGE_CONFIGS[lang_name]
|
|
93
|
+
lang_queries = {}
|
|
94
|
+
|
|
95
|
+
# Compile each query type
|
|
96
|
+
for query_type in [
|
|
97
|
+
"function_query",
|
|
98
|
+
"class_query",
|
|
99
|
+
"call_query",
|
|
100
|
+
"import_query",
|
|
101
|
+
]:
|
|
102
|
+
query_text = getattr(config, query_type)
|
|
103
|
+
if query_text:
|
|
104
|
+
try:
|
|
105
|
+
lang_queries[query_type] = lang_obj.query(query_text)
|
|
106
|
+
except Exception as e:
|
|
107
|
+
logger.debug(
|
|
108
|
+
f"Failed to compile {query_type} for {lang_name}: {e}"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if lang_queries:
|
|
112
|
+
queries[lang_name] = lang_queries
|
|
113
|
+
|
|
114
|
+
logger.debug(f"Loaded parser for {lang_name}")
|
|
115
|
+
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logger.error(f"Failed to load parser for {lang_name}: {e}")
|
|
118
|
+
|
|
119
|
+
if not parsers:
|
|
120
|
+
logger.error(
|
|
121
|
+
"No parsers could be loaded. Please install language-specific tree-sitter packages."
|
|
122
|
+
)
|
|
123
|
+
logger.error(
|
|
124
|
+
"Install with: pip install tree-sitter-python tree-sitter-javascript tree-sitter-typescript tree-sitter-go tree-sitter-rust"
|
|
125
|
+
)
|
|
126
|
+
sys.exit(1)
|
|
127
|
+
|
|
128
|
+
return parsers, queries
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Data models for codebase service."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GraphStatus(str, Enum):
|
|
10
|
+
"""Status of a code knowledge graph."""
|
|
11
|
+
|
|
12
|
+
READY = "READY" # Graph is ready for queries
|
|
13
|
+
BUILDING = "BUILDING" # Initial build in progress
|
|
14
|
+
UPDATING = "UPDATING" # Update in progress
|
|
15
|
+
ERROR = "ERROR" # Last operation failed
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class QueryType(str, Enum):
|
|
19
|
+
"""Type of query being executed."""
|
|
20
|
+
|
|
21
|
+
NATURAL_LANGUAGE = "natural_language"
|
|
22
|
+
CYPHER = "cypher"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OperationStats(BaseModel):
|
|
26
|
+
"""Statistics for a graph operation (build/update)."""
|
|
27
|
+
|
|
28
|
+
operation_type: str = Field(..., description="Type of operation: build or update")
|
|
29
|
+
started_at: float = Field(..., description="Unix timestamp when operation started")
|
|
30
|
+
completed_at: float | None = Field(
|
|
31
|
+
None, description="Unix timestamp when operation completed"
|
|
32
|
+
)
|
|
33
|
+
success: bool = Field(default=False, description="Whether operation succeeded")
|
|
34
|
+
error: str | None = Field(None, description="Error message if operation failed")
|
|
35
|
+
stats: dict[str, Any] = Field(
|
|
36
|
+
default_factory=dict, description="Operation statistics"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CodebaseGraph(BaseModel):
|
|
41
|
+
"""Represents a code knowledge graph."""
|
|
42
|
+
|
|
43
|
+
graph_id: str = Field(..., description="Unique graph ID (hash of repo path)")
|
|
44
|
+
repo_path: str = Field(..., description="Absolute path to repository")
|
|
45
|
+
graph_path: str = Field(..., description="Path to Kuzu database")
|
|
46
|
+
name: str = Field(..., description="Human-readable name for the graph")
|
|
47
|
+
created_at: float = Field(..., description="Unix timestamp of creation")
|
|
48
|
+
updated_at: float = Field(..., description="Unix timestamp of last update")
|
|
49
|
+
schema_version: str = Field(default="1.0.0", description="Graph schema version")
|
|
50
|
+
build_options: dict[str, Any] = Field(
|
|
51
|
+
default_factory=dict, description="Build configuration"
|
|
52
|
+
)
|
|
53
|
+
language_stats: dict[str, int] = Field(
|
|
54
|
+
default_factory=dict, description="File count by language"
|
|
55
|
+
)
|
|
56
|
+
node_count: int = Field(default=0, description="Total number of nodes")
|
|
57
|
+
relationship_count: int = Field(
|
|
58
|
+
default=0, description="Total number of relationships"
|
|
59
|
+
)
|
|
60
|
+
node_stats: dict[str, int] = Field(
|
|
61
|
+
default_factory=dict, description="Node counts by type"
|
|
62
|
+
)
|
|
63
|
+
relationship_stats: dict[str, int] = Field(
|
|
64
|
+
default_factory=dict, description="Relationship counts by type"
|
|
65
|
+
)
|
|
66
|
+
is_watching: bool = Field(default=False, description="Whether watcher is active")
|
|
67
|
+
status: GraphStatus = Field(
|
|
68
|
+
default=GraphStatus.READY, description="Current status of the graph"
|
|
69
|
+
)
|
|
70
|
+
last_operation: OperationStats | None = Field(
|
|
71
|
+
None, description="Statistics from the last operation"
|
|
72
|
+
)
|
|
73
|
+
current_operation_id: str | None = Field(
|
|
74
|
+
None, description="ID of current in-progress operation"
|
|
75
|
+
)
|
|
76
|
+
indexed_from_cwds: list[str] = Field(
|
|
77
|
+
default_factory=list,
|
|
78
|
+
description="List of working directories from which this graph is accessible. Empty list means globally accessible.",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class QueryResult(BaseModel):
|
|
83
|
+
"""Result of a Cypher query execution."""
|
|
84
|
+
|
|
85
|
+
query: str = Field(..., description="Original query (natural language or Cypher)")
|
|
86
|
+
cypher_query: str | None = Field(
|
|
87
|
+
None, description="Generated Cypher query if from natural language"
|
|
88
|
+
)
|
|
89
|
+
results: list[dict[str, Any]] = Field(
|
|
90
|
+
default_factory=list, description="Query results"
|
|
91
|
+
)
|
|
92
|
+
column_names: list[str] = Field(
|
|
93
|
+
default_factory=list, description="Result column names"
|
|
94
|
+
)
|
|
95
|
+
row_count: int = Field(default=0, description="Number of result rows")
|
|
96
|
+
execution_time_ms: float = Field(
|
|
97
|
+
..., description="Query execution time in milliseconds"
|
|
98
|
+
)
|
|
99
|
+
success: bool = Field(default=True, description="Whether query succeeded")
|
|
100
|
+
error: str | None = Field(None, description="Error message if failed")
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class FileChange(BaseModel):
|
|
104
|
+
"""Represents a file system change."""
|
|
105
|
+
|
|
106
|
+
event_type: str = Field(
|
|
107
|
+
..., description="Type of change: created, modified, deleted, moved"
|
|
108
|
+
)
|
|
109
|
+
src_path: str = Field(..., description="Source file path")
|
|
110
|
+
dest_path: str | None = Field(None, description="Destination path for moves")
|
|
111
|
+
is_directory: bool = Field(default=False, description="Whether path is a directory")
|