quantalogic 0.2.26__py3-none-any.whl → 0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/coding_agent.py +3 -1
- quantalogic/tools/__init__.py +7 -1
- quantalogic/tools/execute_bash_command_tool.py +70 -53
- quantalogic/tools/generate_database_report_tool.py +52 -0
- quantalogic/tools/grep_app_tool.py +499 -0
- quantalogic/tools/sql_query_tool.py +167 -0
- quantalogic/tools/utils/__init__.py +13 -0
- quantalogic/tools/utils/create_sample_database.py +124 -0
- quantalogic/tools/utils/generate_database_report.py +289 -0
- {quantalogic-0.2.26.dist-info → quantalogic-0.28.dist-info}/METADATA +6 -2
- {quantalogic-0.2.26.dist-info → quantalogic-0.28.dist-info}/RECORD +14 -9
- quantalogic/.DS_Store +0 -0
- {quantalogic-0.2.26.dist-info → quantalogic-0.28.dist-info}/LICENSE +0 -0
- {quantalogic-0.2.26.dist-info → quantalogic-0.28.dist-info}/WHEEL +0 -0
- {quantalogic-0.2.26.dist-info → quantalogic-0.28.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,499 @@
|
|
1
|
+
# quantalogic/tools/grep_app_tool.py
|
2
|
+
|
3
|
+
import random
|
4
|
+
import sys
|
5
|
+
import time
|
6
|
+
from typing import Any, ClassVar, Dict, Optional, Union
|
7
|
+
|
8
|
+
import requests
|
9
|
+
from loguru import logger
|
10
|
+
from pydantic import BaseModel, Field, ValidationError, model_validator
|
11
|
+
|
12
|
+
from quantalogic.tools.tool import Tool, ToolArgument
|
13
|
+
|
14
|
+
# Configurable User Agents
|
15
|
+
USER_AGENTS = [
|
16
|
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) "
|
17
|
+
"Chrome/91.0.4472.124 Safari/537.36",
|
18
|
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) "
|
19
|
+
"Chrome/91.0.4472.124 Safari/537.36",
|
20
|
+
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:89.0) Gecko/20100101 Firefox/89.0",
|
21
|
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:89.0) Gecko/20100101 Firefox/89.0",
|
22
|
+
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) "
|
23
|
+
"Version/14.1.1 Safari/605.1.15"
|
24
|
+
]
|
25
|
+
|
26
|
+
class SearchError(Exception):
|
27
|
+
"""Custom exception for search-related errors"""
|
28
|
+
pass
|
29
|
+
|
30
|
+
class GrepAppArguments(BaseModel):
|
31
|
+
"""Pydantic model for grep.app search arguments"""
|
32
|
+
search_query: str = Field(
|
33
|
+
...,
|
34
|
+
description="GitHub Code search using simple keyword or regular expression",
|
35
|
+
example="code2prompt"
|
36
|
+
)
|
37
|
+
repository: Optional[str] = Field(
|
38
|
+
None,
|
39
|
+
description="Filter by repository (e.g. user/repo)",
|
40
|
+
example="quantalogic/quantalogic",
|
41
|
+
)
|
42
|
+
page: int = Field(
|
43
|
+
1,
|
44
|
+
description="Results page number",
|
45
|
+
ge=1
|
46
|
+
)
|
47
|
+
per_page: int = Field(
|
48
|
+
10,
|
49
|
+
description="Number of results per page",
|
50
|
+
ge=1,
|
51
|
+
le=100
|
52
|
+
)
|
53
|
+
regexp: bool = Field(
|
54
|
+
False,
|
55
|
+
description="Enable regular expression search"
|
56
|
+
)
|
57
|
+
case: bool = Field(
|
58
|
+
False,
|
59
|
+
description="Enable case-sensitive search"
|
60
|
+
)
|
61
|
+
words: bool = Field(
|
62
|
+
False,
|
63
|
+
description="Match whole words only"
|
64
|
+
)
|
65
|
+
|
66
|
+
@model_validator(mode='before')
|
67
|
+
@classmethod
|
68
|
+
def convert_types(cls, data: Dict[str, Any]) -> Dict[str, Any]:
|
69
|
+
"""Convert input types before validation"""
|
70
|
+
# Convert string numbers to integers
|
71
|
+
for field in ['page', 'per_page']:
|
72
|
+
if field in data and isinstance(data[field], str):
|
73
|
+
try:
|
74
|
+
data[field] = int(data[field])
|
75
|
+
except ValueError:
|
76
|
+
raise ValueError(f"{field} must be a valid integer")
|
77
|
+
|
78
|
+
# Convert various string representations to booleans
|
79
|
+
for field in ['regexp', 'case', 'words']:
|
80
|
+
if field in data:
|
81
|
+
if isinstance(data[field], str):
|
82
|
+
data[field] = data[field].lower() in ['true', '1', 'yes', 'on']
|
83
|
+
|
84
|
+
return data
|
85
|
+
|
86
|
+
@model_validator(mode='after')
|
87
|
+
def validate_search_query(self) -> 'GrepAppArguments':
|
88
|
+
"""Validate search query is not empty and has reasonable length"""
|
89
|
+
if not self.search_query or not self.search_query.strip():
|
90
|
+
raise ValueError("Search query cannot be empty")
|
91
|
+
if len(self.search_query) > 500: # Reasonable limit for search query
|
92
|
+
raise ValueError("Search query is too long (max 500 characters)")
|
93
|
+
return self
|
94
|
+
|
95
|
+
class GrepAppTool(Tool):
|
96
|
+
"""Tool for searching GitHub code via grep.app API"""
|
97
|
+
|
98
|
+
BASE_URL: ClassVar[str] = "https://grep.app/api/search"
|
99
|
+
TIMEOUT: ClassVar[int] = 10
|
100
|
+
|
101
|
+
def __init__(self):
|
102
|
+
super().__init__(
|
103
|
+
name="grep_app_tool",
|
104
|
+
description="Searches GitHub code using grep.app API. Returns code matches with metadata."
|
105
|
+
)
|
106
|
+
self.arguments = [
|
107
|
+
ToolArgument(
|
108
|
+
name="search_query",
|
109
|
+
arg_type="string",
|
110
|
+
description="Search query using grep.app syntax",
|
111
|
+
required=True
|
112
|
+
),
|
113
|
+
ToolArgument(
|
114
|
+
name="repository",
|
115
|
+
arg_type="string",
|
116
|
+
description="Filter by repository",
|
117
|
+
required=False
|
118
|
+
),
|
119
|
+
ToolArgument(
|
120
|
+
name="page",
|
121
|
+
arg_type="int",
|
122
|
+
description="Pagination page number",
|
123
|
+
default="1",
|
124
|
+
required=False
|
125
|
+
),
|
126
|
+
ToolArgument(
|
127
|
+
name="per_page",
|
128
|
+
arg_type="int",
|
129
|
+
description="Results per page",
|
130
|
+
default="10",
|
131
|
+
required=False
|
132
|
+
),
|
133
|
+
ToolArgument(
|
134
|
+
name="regexp",
|
135
|
+
arg_type="boolean",
|
136
|
+
description="Enable regular expression search",
|
137
|
+
default="False",
|
138
|
+
required=False
|
139
|
+
),
|
140
|
+
ToolArgument(
|
141
|
+
name="case",
|
142
|
+
arg_type="boolean",
|
143
|
+
description="Enable case-sensitive search",
|
144
|
+
default="False",
|
145
|
+
required=False
|
146
|
+
),
|
147
|
+
ToolArgument(
|
148
|
+
name="words",
|
149
|
+
arg_type="boolean",
|
150
|
+
description="Match whole words only",
|
151
|
+
default="False",
|
152
|
+
required=False
|
153
|
+
)
|
154
|
+
]
|
155
|
+
|
156
|
+
def _build_headers(self) -> Dict[str, str]:
|
157
|
+
"""Build request headers with random User-Agent"""
|
158
|
+
headers = {
|
159
|
+
"User-Agent": random.choice(USER_AGENTS),
|
160
|
+
"Accept": "application/json",
|
161
|
+
"Accept-Language": "en-US,en;q=0.5",
|
162
|
+
"DNT": "1"
|
163
|
+
}
|
164
|
+
logger.debug(f"Built headers: {headers}")
|
165
|
+
return headers
|
166
|
+
|
167
|
+
def _build_params(self, args: GrepAppArguments) -> Dict[str, Any]:
|
168
|
+
"""Build request parameters from arguments"""
|
169
|
+
params = {
|
170
|
+
"q": args.search_query,
|
171
|
+
"page": args.page,
|
172
|
+
"per_page": args.per_page
|
173
|
+
}
|
174
|
+
if args.repository:
|
175
|
+
params["filter[repo][0]"] = args.repository
|
176
|
+
if args.regexp:
|
177
|
+
params["regexp"] = "true"
|
178
|
+
if args.case:
|
179
|
+
params["case"] = "true"
|
180
|
+
if args.words:
|
181
|
+
params["words"] = "true"
|
182
|
+
logger.debug(f"Built params: {params}")
|
183
|
+
return params
|
184
|
+
|
185
|
+
def _make_request(self, params: Dict[str, Any], headers: Dict[str, str]) -> Dict[str, Any]:
|
186
|
+
"""Make the API request"""
|
187
|
+
logger.info("Making API request to grep.app")
|
188
|
+
response = requests.get(
|
189
|
+
self.BASE_URL,
|
190
|
+
params=params,
|
191
|
+
headers=headers,
|
192
|
+
timeout=self.TIMEOUT
|
193
|
+
)
|
194
|
+
logger.debug(f"API Response Status Code: {response.status_code}")
|
195
|
+
response.raise_for_status()
|
196
|
+
data = response.json()
|
197
|
+
if not isinstance(data, dict):
|
198
|
+
raise SearchError("Invalid response format from API")
|
199
|
+
logger.debug(f"API Response Data: {data}")
|
200
|
+
return data
|
201
|
+
|
202
|
+
def execute(self,
|
203
|
+
search_query: str,
|
204
|
+
repository: Optional[str] = None,
|
205
|
+
page: Union[int, str] = 1,
|
206
|
+
per_page: Union[int, str] = 10,
|
207
|
+
regexp: bool = False,
|
208
|
+
case: bool = False,
|
209
|
+
words: bool = False,
|
210
|
+
skip_delay: bool = False) -> str:
|
211
|
+
"""Execute grep.app API search with pagination and return formatted results as a string"""
|
212
|
+
try:
|
213
|
+
# Validate and convert arguments
|
214
|
+
args = GrepAppArguments(
|
215
|
+
search_query=search_query,
|
216
|
+
repository=repository,
|
217
|
+
page=int(page),
|
218
|
+
per_page=int(per_page),
|
219
|
+
regexp=regexp,
|
220
|
+
case=case,
|
221
|
+
words=words
|
222
|
+
)
|
223
|
+
|
224
|
+
logger.info(f"Executing search: '{args.search_query}'")
|
225
|
+
logger.debug(f"Search parameters: {args.model_dump()}")
|
226
|
+
|
227
|
+
# Add random delay to mimic human behavior (unless skipped for testing)
|
228
|
+
if not skip_delay:
|
229
|
+
delay = random.uniform(0.5, 1.5)
|
230
|
+
logger.debug(f"Sleeping for {delay:.2f} seconds to mimic human behavior")
|
231
|
+
time.sleep(delay)
|
232
|
+
|
233
|
+
# Make API request
|
234
|
+
headers = self._build_headers()
|
235
|
+
params = self._build_params(args)
|
236
|
+
results = self._make_request(params, headers)
|
237
|
+
|
238
|
+
# Format and return results
|
239
|
+
return self._format_results(results)
|
240
|
+
|
241
|
+
except ValidationError as e:
|
242
|
+
logger.error(f"Validation error: {e}")
|
243
|
+
return self._format_error("Validation Error", str(e))
|
244
|
+
except requests.RequestException as e:
|
245
|
+
logger.error(f"API request failed: {e}")
|
246
|
+
return self._format_error(
|
247
|
+
"API Error",
|
248
|
+
str(e),
|
249
|
+
{"Request URL": getattr(e.response, 'url', 'N/A') if hasattr(e, 'response') else 'N/A'}
|
250
|
+
)
|
251
|
+
except SearchError as e:
|
252
|
+
logger.error(f"Search error: {e}")
|
253
|
+
return self._format_error("Search Error", str(e))
|
254
|
+
except Exception as e:
|
255
|
+
logger.error(f"Unexpected error: {e}")
|
256
|
+
return self._format_error("Unexpected Error", str(e))
|
257
|
+
|
258
|
+
def _format_results(self, data: Dict[str, Any]) -> str:
|
259
|
+
"""Format API results into a structured Markdown string"""
|
260
|
+
query = data.get('query', '')
|
261
|
+
total_results = data.get('hits', {}).get('total', 0)
|
262
|
+
hits = data.get("hits", {}).get("hits", [])
|
263
|
+
|
264
|
+
output = [
|
265
|
+
"# 🔍 Search Results",
|
266
|
+
"",
|
267
|
+
f"**Query:** `{query if query else '<empty>'}` • **Found:** {total_results} matches",
|
268
|
+
""
|
269
|
+
]
|
270
|
+
|
271
|
+
if not hits:
|
272
|
+
output.append("> No matches found for your search query.")
|
273
|
+
else:
|
274
|
+
for idx, result in enumerate(hits, 1):
|
275
|
+
repo = result.get('repo', {}).get('raw', 'N/A')
|
276
|
+
file_path = result.get('path', {}).get('raw', 'N/A')
|
277
|
+
language = result.get('language', 'N/A').lower()
|
278
|
+
content = result.get("content", {})
|
279
|
+
|
280
|
+
# Extract the actual code and line info
|
281
|
+
snippet = content.get("snippet", "")
|
282
|
+
line_num = content.get("line", "")
|
283
|
+
|
284
|
+
# Clean up the snippet
|
285
|
+
import re
|
286
|
+
clean_snippet = re.sub(r'<[^>]+>', '', snippet)
|
287
|
+
clean_snippet = re.sub(r'"', '"', clean_snippet)
|
288
|
+
clean_snippet = re.sub(r'<', '<', clean_snippet)
|
289
|
+
clean_snippet = re.sub(r'>', '>', clean_snippet)
|
290
|
+
clean_snippet = clean_snippet.strip()
|
291
|
+
|
292
|
+
# Split into lines and clean each line
|
293
|
+
raw_lines = clean_snippet.split('\n')
|
294
|
+
lines = []
|
295
|
+
current_line_num = int(line_num) if line_num else 1
|
296
|
+
|
297
|
+
# First pass: collect all lines and their content
|
298
|
+
for line in raw_lines:
|
299
|
+
# Remove excess whitespace but preserve indentation
|
300
|
+
stripped = line.rstrip()
|
301
|
+
if not stripped:
|
302
|
+
lines.append(('', current_line_num))
|
303
|
+
current_line_num += 1
|
304
|
+
continue
|
305
|
+
|
306
|
+
# Remove duplicate indentation
|
307
|
+
if stripped.startswith(' '):
|
308
|
+
stripped = stripped[4:]
|
309
|
+
|
310
|
+
# Handle URLs that might be split across lines
|
311
|
+
if stripped.startswith(('prompt', '-working')):
|
312
|
+
if lines and lines[-1][0].endswith('/'):
|
313
|
+
# Combine with previous line
|
314
|
+
prev_content, prev_num = lines.pop()
|
315
|
+
lines.append((prev_content + stripped, prev_num))
|
316
|
+
continue
|
317
|
+
|
318
|
+
# Handle concatenated lines by looking for line numbers
|
319
|
+
line_parts = re.split(r'(\d+)(?=\s*[^\d])', stripped)
|
320
|
+
if len(line_parts) > 1:
|
321
|
+
# Process each part that might be a new line
|
322
|
+
for i in range(0, len(line_parts)-1, 2):
|
323
|
+
prefix = line_parts[i].rstrip()
|
324
|
+
if prefix:
|
325
|
+
if not any(l[0] == prefix for l in lines): # Avoid duplicates
|
326
|
+
lines.append((prefix, current_line_num))
|
327
|
+
|
328
|
+
# Update line number if found
|
329
|
+
try:
|
330
|
+
current_line_num = int(line_parts[i+1])
|
331
|
+
except ValueError:
|
332
|
+
current_line_num += 1
|
333
|
+
|
334
|
+
# Add the content after the line number
|
335
|
+
if i+2 < len(line_parts):
|
336
|
+
content = line_parts[i+2].lstrip()
|
337
|
+
if content and not any(l[0] == content for l in lines): # Avoid duplicates
|
338
|
+
lines.append((content, current_line_num))
|
339
|
+
else:
|
340
|
+
if not any(l[0] == stripped for l in lines): # Avoid duplicates
|
341
|
+
lines.append((stripped, current_line_num))
|
342
|
+
current_line_num += 1
|
343
|
+
|
344
|
+
# Format line numbers and code
|
345
|
+
formatted_lines = []
|
346
|
+
max_line_width = len(str(max(line[1] for line in lines))) if lines else 3
|
347
|
+
|
348
|
+
# Second pass: format each line
|
349
|
+
for line_content, line_no in lines:
|
350
|
+
if not line_content: # Empty line
|
351
|
+
formatted_lines.append('')
|
352
|
+
continue
|
353
|
+
|
354
|
+
# Special handling for markdown badges and links
|
355
|
+
if '[![' in line_content or '[!' in line_content:
|
356
|
+
badges = re.findall(r'(\[!\[.*?\]\(.*?\)\]\(.*?\))', line_content)
|
357
|
+
if badges:
|
358
|
+
for badge in badges:
|
359
|
+
if not any(badge in l for l in formatted_lines): # Avoid duplicates
|
360
|
+
formatted_lines.append(f"{str(line_no).rjust(max_line_width)} │ {badge}")
|
361
|
+
continue
|
362
|
+
|
363
|
+
# Add syntax highlighting for comments
|
364
|
+
if line_content.lstrip().startswith(('// ', '# ', '/* ', '* ', '*/')):
|
365
|
+
line_str = f"{str(line_no).rjust(max_line_width)} │ <dim>{line_content}</dim>"
|
366
|
+
if not any(line_str in l for l in formatted_lines): # Avoid duplicates
|
367
|
+
formatted_lines.append(line_str)
|
368
|
+
else:
|
369
|
+
# Split line into indentation and content for better formatting
|
370
|
+
indent = len(line_content) - len(line_content.lstrip())
|
371
|
+
indentation = line_content[:indent]
|
372
|
+
content = line_content[indent:]
|
373
|
+
|
374
|
+
# Highlight strings and special syntax
|
375
|
+
content = re.sub(r'(["\'])(.*?)\1', r'<str>\1\2\1</str>', content)
|
376
|
+
content = re.sub(r'\b(function|const|let|var|import|export|class|interface|type|enum)\b',
|
377
|
+
r'<keyword>\1</keyword>', content)
|
378
|
+
|
379
|
+
line_str = f"{str(line_no).rjust(max_line_width)} │ {indentation}{content}"
|
380
|
+
if not any(line_str in l for l in formatted_lines): # Avoid duplicates
|
381
|
+
formatted_lines.append(line_str)
|
382
|
+
|
383
|
+
# Truncate if too long and add line count
|
384
|
+
if len(formatted_lines) > 5:
|
385
|
+
remaining = len(formatted_lines) - 5
|
386
|
+
formatted_lines = formatted_lines[:5]
|
387
|
+
if remaining > 0:
|
388
|
+
formatted_lines.append(f" ┆ {remaining} more line{'s' if remaining > 1 else ''}")
|
389
|
+
|
390
|
+
clean_snippet = '\n'.join(formatted_lines)
|
391
|
+
|
392
|
+
# Format the repository link to be clickable
|
393
|
+
if '/' in repo:
|
394
|
+
repo_link = f"[`{repo}`](https://github.com/{repo})"
|
395
|
+
else:
|
396
|
+
repo_link = f"`{repo}`"
|
397
|
+
|
398
|
+
# Determine the best language display and icon
|
399
|
+
lang_display = language if language != 'n/a' else ''
|
400
|
+
lang_icon = {
|
401
|
+
'python': '🐍',
|
402
|
+
'typescript': '📘',
|
403
|
+
'javascript': '📒',
|
404
|
+
'markdown': '📝',
|
405
|
+
'toml': '⚙️',
|
406
|
+
'yaml': '📋',
|
407
|
+
'json': '📦',
|
408
|
+
'shell': '🐚',
|
409
|
+
'rust': '🦀',
|
410
|
+
'go': '🔵',
|
411
|
+
'java': '☕',
|
412
|
+
'ruby': '💎',
|
413
|
+
}.get(lang_display, '📄')
|
414
|
+
|
415
|
+
# Format file path with language icon and line info
|
416
|
+
file_info = [f"{lang_icon} `{file_path}`"]
|
417
|
+
if line_num:
|
418
|
+
file_info.append(f"Line {line_num}")
|
419
|
+
|
420
|
+
output.extend([
|
421
|
+
f"### {repo_link}",
|
422
|
+
" • ".join(file_info),
|
423
|
+
"```",
|
424
|
+
clean_snippet,
|
425
|
+
"```",
|
426
|
+
""
|
427
|
+
])
|
428
|
+
|
429
|
+
return "\n".join(filter(None, output))
|
430
|
+
|
431
|
+
def _format_error(self, error_type: str, message: str, additional_info: Dict[str, str] = None) -> str:
|
432
|
+
"""Format error messages consistently using Markdown"""
|
433
|
+
output = [
|
434
|
+
f"## {error_type}",
|
435
|
+
f"**Message:** {message}"
|
436
|
+
]
|
437
|
+
|
438
|
+
if additional_info:
|
439
|
+
output.append("**Additional Information:**")
|
440
|
+
for key, value in additional_info.items():
|
441
|
+
output.append(f"- **{key}:** {value}")
|
442
|
+
|
443
|
+
output.append(f"## End {error_type}")
|
444
|
+
return "\n\n".join(output)
|
445
|
+
|
446
|
+
if __name__ == "__main__":
|
447
|
+
# Configure logger
|
448
|
+
logger.remove() # Remove default handlers
|
449
|
+
logger.add(sys.stderr, level="INFO", format="<green>{time}</green> <level>{message}</level>")
|
450
|
+
|
451
|
+
logger.info("Starting GrepAppTool test cases")
|
452
|
+
tool = GrepAppTool()
|
453
|
+
|
454
|
+
test_cases = [
|
455
|
+
{
|
456
|
+
"name": "Python __init__ Methods Search",
|
457
|
+
"args": {
|
458
|
+
"search_query": "lang:python def __init__",
|
459
|
+
"per_page": 5,
|
460
|
+
"skip_delay": True # Skip delay for testing
|
461
|
+
}
|
462
|
+
},
|
463
|
+
{
|
464
|
+
"name": "Logging Patterns Search",
|
465
|
+
"args": {
|
466
|
+
"search_query": "logger",
|
467
|
+
"per_page": 3,
|
468
|
+
"skip_delay": True
|
469
|
+
}
|
470
|
+
},
|
471
|
+
{
|
472
|
+
"name": "Repository-Specific Search",
|
473
|
+
"args": {
|
474
|
+
"search_query": "def",
|
475
|
+
"repository": "quantalogic/quantalogic",
|
476
|
+
"per_page": 5,
|
477
|
+
"words": True,
|
478
|
+
"skip_delay": True
|
479
|
+
}
|
480
|
+
},
|
481
|
+
{
|
482
|
+
"name": "Raphaël MANSUY",
|
483
|
+
"args": {
|
484
|
+
"search_query": "raphaelmansuy",
|
485
|
+
"per_page": 3,
|
486
|
+
"skip_delay": True
|
487
|
+
}
|
488
|
+
}
|
489
|
+
]
|
490
|
+
|
491
|
+
for test in test_cases:
|
492
|
+
try:
|
493
|
+
logger.info(f"Running test: {test['name']}")
|
494
|
+
logger.info(f"Executing with arguments: {test['args']}")
|
495
|
+
result = tool.execute(**test['args'])
|
496
|
+
print(f"\n### Test: {test['name']}\n{result}\n")
|
497
|
+
time.sleep(1) # Add a small delay between tests to avoid rate limiting
|
498
|
+
except Exception as e:
|
499
|
+
logger.error(f"{test['name']} Failed: {e}", exc_info=True)
|
@@ -0,0 +1,167 @@
|
|
1
|
+
"""Tool for executing SQL queries and returning paginated results in markdown format."""
|
2
|
+
|
3
|
+
from typing import Any, Dict, List
|
4
|
+
|
5
|
+
from pydantic import Field, ValidationError
|
6
|
+
from sqlalchemy import create_engine, text
|
7
|
+
from sqlalchemy.exc import SQLAlchemyError
|
8
|
+
|
9
|
+
from quantalogic.tools.tool import Tool, ToolArgument
|
10
|
+
|
11
|
+
|
12
|
+
class SQLQueryTool(Tool):
|
13
|
+
"""Tool for executing SQL queries and returning paginated results in markdown format."""
|
14
|
+
|
15
|
+
name: str = "sql_query_tool"
|
16
|
+
description: str = (
|
17
|
+
"Executes a SQL query and returns results in markdown table format "
|
18
|
+
"with pagination support. Results are truncated based on start/end row numbers."
|
19
|
+
)
|
20
|
+
arguments: list = [
|
21
|
+
ToolArgument(
|
22
|
+
name="query",
|
23
|
+
arg_type="string",
|
24
|
+
description="The SQL query to execute",
|
25
|
+
required=True,
|
26
|
+
example="SELECT * FROM customers WHERE country = 'France'"
|
27
|
+
),
|
28
|
+
ToolArgument(
|
29
|
+
name="start_row",
|
30
|
+
arg_type="int",
|
31
|
+
description="1-based starting row number for results",
|
32
|
+
required=True,
|
33
|
+
example="1"
|
34
|
+
),
|
35
|
+
ToolArgument(
|
36
|
+
name="end_row",
|
37
|
+
arg_type="int",
|
38
|
+
description="1-based ending row number for results",
|
39
|
+
required=True,
|
40
|
+
example="100"
|
41
|
+
),
|
42
|
+
]
|
43
|
+
connection_string: str = Field(
|
44
|
+
...,
|
45
|
+
description="SQLAlchemy-compatible database connection string",
|
46
|
+
example="postgresql://user:password@localhost/mydb"
|
47
|
+
)
|
48
|
+
|
49
|
+
def execute(self, query: str, start_row: Any, end_row: Any) -> str:
|
50
|
+
"""
|
51
|
+
Executes a SQL query and returns formatted results.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
query: SQL query to execute
|
55
|
+
start_row: 1-based starting row number (supports various numeric types)
|
56
|
+
end_row: 1-based ending row number (supports various numeric types)
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
str: Markdown-formatted results with pagination metadata
|
60
|
+
|
61
|
+
Raises:
|
62
|
+
ValueError: For invalid parameters or query errors
|
63
|
+
RuntimeError: For database connection issues
|
64
|
+
"""
|
65
|
+
try:
|
66
|
+
# Convert and validate row numbers
|
67
|
+
start = self._convert_row_number(start_row, "start_row")
|
68
|
+
end = self._convert_row_number(end_row, "end_row")
|
69
|
+
|
70
|
+
if start > end:
|
71
|
+
raise ValueError(f"start_row ({start}) must be <= end_row ({end})")
|
72
|
+
|
73
|
+
# Execute query
|
74
|
+
engine = create_engine(self.connection_string)
|
75
|
+
with engine.connect() as conn:
|
76
|
+
result = conn.execute(text(query))
|
77
|
+
columns: List[str] = result.keys()
|
78
|
+
all_rows: List[Dict] = [dict(row._mapping) for row in result]
|
79
|
+
|
80
|
+
# Apply pagination
|
81
|
+
total_rows = len(all_rows)
|
82
|
+
actual_start = max(1, start)
|
83
|
+
actual_end = min(end, total_rows)
|
84
|
+
|
85
|
+
if actual_start > total_rows:
|
86
|
+
return f"No results found (total rows: {total_rows})"
|
87
|
+
|
88
|
+
# Slice results (convert to 0-based index)
|
89
|
+
displayed_rows = all_rows[actual_start-1:actual_end]
|
90
|
+
|
91
|
+
# Format results
|
92
|
+
markdown = [
|
93
|
+
f"**Query Results:** `{actual_start}-{actual_end}` of `{total_rows}` rows",
|
94
|
+
self._format_table(columns, displayed_rows)
|
95
|
+
]
|
96
|
+
|
97
|
+
# Add pagination notice
|
98
|
+
if actual_end < total_rows:
|
99
|
+
remaining = total_rows - actual_end
|
100
|
+
markdown.append(f"\n*Showing first {actual_end} rows - {remaining} more row{'s' if remaining > 1 else ''} available*")
|
101
|
+
|
102
|
+
return "\n".join(markdown)
|
103
|
+
|
104
|
+
except SQLAlchemyError as e:
|
105
|
+
raise ValueError(f"SQL Error: {str(e)}") from e
|
106
|
+
except ValidationError as e:
|
107
|
+
raise ValueError(f"Validation Error: {str(e)}") from e
|
108
|
+
except Exception as e:
|
109
|
+
raise RuntimeError(f"Database Error: {str(e)}") from e
|
110
|
+
|
111
|
+
def _convert_row_number(self, value: Any, field_name: str) -> int:
|
112
|
+
"""Convert and validate row number input."""
|
113
|
+
try:
|
114
|
+
# Handle numeric strings and floats
|
115
|
+
if isinstance(value, str):
|
116
|
+
if "." in value:
|
117
|
+
num = float(value)
|
118
|
+
else:
|
119
|
+
num = int(value)
|
120
|
+
else:
|
121
|
+
num = value
|
122
|
+
|
123
|
+
converted = int(num)
|
124
|
+
if converted != num: # Check if float had decimal part
|
125
|
+
raise ValueError("Decimal values are not allowed for row numbers")
|
126
|
+
|
127
|
+
if converted <= 0:
|
128
|
+
raise ValueError(f"{field_name} must be a positive integer")
|
129
|
+
|
130
|
+
return converted
|
131
|
+
except (ValueError, TypeError) as e:
|
132
|
+
raise ValueError(f"Invalid value for {field_name}: {repr(value)}") from e
|
133
|
+
|
134
|
+
def _format_table(self, columns: List[str], rows: List[Dict]) -> str:
|
135
|
+
"""Format results as markdown table with truncation."""
|
136
|
+
if not rows:
|
137
|
+
return "No results found"
|
138
|
+
|
139
|
+
# Create header
|
140
|
+
header = "| " + " | ".join(columns) + " |"
|
141
|
+
separator = "| " + " | ".join(["---"] * len(columns)) + " |"
|
142
|
+
|
143
|
+
# Create rows with truncation
|
144
|
+
body = []
|
145
|
+
for row in rows:
|
146
|
+
values = []
|
147
|
+
for col in columns:
|
148
|
+
val = str(row.get(col, ""))
|
149
|
+
# Truncate long values
|
150
|
+
values.append(val[:50] + "..." if len(val) > 50 else val)
|
151
|
+
body.append("| " + " | ".join(values) + " |")
|
152
|
+
|
153
|
+
return "\n".join([header, separator] + body)
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
if __name__ == "__main__":
|
158
|
+
from quantalogic.tools.utils.create_sample_database import create_sample_database
|
159
|
+
|
160
|
+
# Create and document sample database
|
161
|
+
create_sample_database("sample.db")
|
162
|
+
tool = SQLQueryTool(connection_string="sqlite:///sample.db")
|
163
|
+
print(tool.execute("select * from customers", 1, 10))
|
164
|
+
print(tool.execute("select * from customers", 11, 20))
|
165
|
+
|
166
|
+
|
167
|
+
|
@@ -0,0 +1,13 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions and classes for quantalogic tools.
|
3
|
+
|
4
|
+
This module provides common utility functions used across the quantalogic package.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .create_sample_database import create_sample_database
|
8
|
+
from .generate_database_report import generate_database_report
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
'create_sample_database',
|
12
|
+
'generate_database_report'
|
13
|
+
]
|