arcade-postgres 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arcade_postgres-0.1.0/.gitignore +175 -0
- arcade_postgres-0.1.0/Makefile +53 -0
- arcade_postgres-0.1.0/PKG-INFO +23 -0
- arcade_postgres-0.1.0/arcade_postgres/__init__.py +0 -0
- arcade_postgres-0.1.0/arcade_postgres/database_engine.py +104 -0
- arcade_postgres-0.1.0/arcade_postgres/tools/__init__.py +0 -0
- arcade_postgres-0.1.0/arcade_postgres/tools/postgres.py +178 -0
- arcade_postgres-0.1.0/evals/eval_postgres.py +94 -0
- arcade_postgres-0.1.0/pyproject.toml +65 -0
- arcade_postgres-0.1.0/tests/__init__.py +0 -0
- arcade_postgres-0.1.0/tests/dump.sql +114 -0
- arcade_postgres-0.1.0/tests/test_postgres.py +119 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
.DS_Store
|
|
2
|
+
credentials.yaml
|
|
3
|
+
docker/credentials.yaml
|
|
4
|
+
|
|
5
|
+
*.lock
|
|
6
|
+
|
|
7
|
+
# example data
|
|
8
|
+
examples/data
|
|
9
|
+
scratch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
docs/source
|
|
13
|
+
|
|
14
|
+
# From https://raw.githubusercontent.com/github/gitignore/main/Python.gitignore
|
|
15
|
+
|
|
16
|
+
# Byte-compiled / optimized / DLL files
|
|
17
|
+
__pycache__/
|
|
18
|
+
*.py[cod]
|
|
19
|
+
*$py.class
|
|
20
|
+
|
|
21
|
+
# C extensions
|
|
22
|
+
*.so
|
|
23
|
+
|
|
24
|
+
# Distribution / packaging
|
|
25
|
+
.Python
|
|
26
|
+
build/
|
|
27
|
+
develop-eggs/
|
|
28
|
+
dist/
|
|
29
|
+
downloads/
|
|
30
|
+
eggs/
|
|
31
|
+
.eggs/
|
|
32
|
+
lib/
|
|
33
|
+
lib64/
|
|
34
|
+
parts/
|
|
35
|
+
sdist/
|
|
36
|
+
var/
|
|
37
|
+
wheels/
|
|
38
|
+
share/python-wheels/
|
|
39
|
+
*.egg-info/
|
|
40
|
+
.installed.cfg
|
|
41
|
+
*.egg
|
|
42
|
+
MANIFEST
|
|
43
|
+
|
|
44
|
+
# PyInstaller
|
|
45
|
+
# Usually these files are written by a python script from a template
|
|
46
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
47
|
+
*.manifest
|
|
48
|
+
*.spec
|
|
49
|
+
|
|
50
|
+
# Installer logs
|
|
51
|
+
pip-log.txt
|
|
52
|
+
pip-delete-this-directory.txt
|
|
53
|
+
|
|
54
|
+
# Unit test / coverage reports
|
|
55
|
+
htmlcov/
|
|
56
|
+
.tox/
|
|
57
|
+
.nox/
|
|
58
|
+
.coverage
|
|
59
|
+
.coverage.*
|
|
60
|
+
.cache
|
|
61
|
+
nosetests.xml
|
|
62
|
+
coverage.xml
|
|
63
|
+
*.cover
|
|
64
|
+
*.py,cover
|
|
65
|
+
.hypothesis/
|
|
66
|
+
.pytest_cache/
|
|
67
|
+
cover/
|
|
68
|
+
|
|
69
|
+
# Translations
|
|
70
|
+
*.mo
|
|
71
|
+
*.pot
|
|
72
|
+
|
|
73
|
+
# Django stuff:
|
|
74
|
+
*.log
|
|
75
|
+
local_settings.py
|
|
76
|
+
db.sqlite3
|
|
77
|
+
db.sqlite3-journal
|
|
78
|
+
|
|
79
|
+
# Flask stuff:
|
|
80
|
+
instance/
|
|
81
|
+
.webassets-cache
|
|
82
|
+
|
|
83
|
+
# Scrapy stuff:
|
|
84
|
+
.scrapy
|
|
85
|
+
|
|
86
|
+
# Sphinx documentation
|
|
87
|
+
docs/_build/
|
|
88
|
+
|
|
89
|
+
# PyBuilder
|
|
90
|
+
.pybuilder/
|
|
91
|
+
target/
|
|
92
|
+
|
|
93
|
+
# Jupyter Notebook
|
|
94
|
+
.ipynb_checkpoints
|
|
95
|
+
|
|
96
|
+
# IPython
|
|
97
|
+
profile_default/
|
|
98
|
+
ipython_config.py
|
|
99
|
+
|
|
100
|
+
# pyenv
|
|
101
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
102
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
103
|
+
# .python-version
|
|
104
|
+
|
|
105
|
+
# pipenv
|
|
106
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
107
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
108
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
109
|
+
# install all needed dependencies.
|
|
110
|
+
#Pipfile.lock
|
|
111
|
+
|
|
112
|
+
# poetry
|
|
113
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
114
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
115
|
+
# commonly ignored for libraries.
|
|
116
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
117
|
+
poetry.lock
|
|
118
|
+
|
|
119
|
+
# pdm
|
|
120
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
121
|
+
#pdm.lock
|
|
122
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
123
|
+
# in version control.
|
|
124
|
+
# https://pdm.fming.dev/#use-with-ide
|
|
125
|
+
.pdm.toml
|
|
126
|
+
|
|
127
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
128
|
+
__pypackages__/
|
|
129
|
+
|
|
130
|
+
# Celery stuff
|
|
131
|
+
celerybeat-schedule
|
|
132
|
+
celerybeat.pid
|
|
133
|
+
|
|
134
|
+
# SageMath parsed files
|
|
135
|
+
*.sage.py
|
|
136
|
+
|
|
137
|
+
# Environments
|
|
138
|
+
.env
|
|
139
|
+
.venv
|
|
140
|
+
env/
|
|
141
|
+
venv/
|
|
142
|
+
ENV/
|
|
143
|
+
env.bak/
|
|
144
|
+
venv.bak/
|
|
145
|
+
|
|
146
|
+
# Spyder project settings
|
|
147
|
+
.spyderproject
|
|
148
|
+
.spyproject
|
|
149
|
+
|
|
150
|
+
# Rope project settings
|
|
151
|
+
.ropeproject
|
|
152
|
+
|
|
153
|
+
# mkdocs documentation
|
|
154
|
+
/site
|
|
155
|
+
|
|
156
|
+
# mypy
|
|
157
|
+
.mypy_cache/
|
|
158
|
+
.dmypy.json
|
|
159
|
+
dmypy.json
|
|
160
|
+
|
|
161
|
+
# Pyre type checker
|
|
162
|
+
.pyre/
|
|
163
|
+
|
|
164
|
+
# pytype static type analyzer
|
|
165
|
+
.pytype/
|
|
166
|
+
|
|
167
|
+
# Cython debug symbols
|
|
168
|
+
cython_debug/
|
|
169
|
+
|
|
170
|
+
# PyCharm
|
|
171
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
172
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
173
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
174
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
175
|
+
#.idea/
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
.PHONY: help
|
|
2
|
+
|
|
3
|
+
help:
|
|
4
|
+
@echo "🛠️ github Commands:\n"
|
|
5
|
+
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
|
6
|
+
|
|
7
|
+
.PHONY: install
|
|
8
|
+
install: ## Install the uv environment and install all packages with dependencies
|
|
9
|
+
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
|
10
|
+
@uv sync --active --all-extras --no-sources
|
|
11
|
+
@uv run pre-commit install
|
|
12
|
+
@echo "✅ All packages and dependencies installed via uv"
|
|
13
|
+
|
|
14
|
+
.PHONY: install-local
|
|
15
|
+
install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources
|
|
16
|
+
@echo "🚀 Creating virtual environment and installing all packages using uv"
|
|
17
|
+
@uv sync --active --all-extras
|
|
18
|
+
@uv run pre-commit install
|
|
19
|
+
@echo "✅ All packages and dependencies installed via uv"
|
|
20
|
+
|
|
21
|
+
.PHONY: build
|
|
22
|
+
build: clean-build ## Build wheel file using poetry
|
|
23
|
+
@echo "🚀 Creating wheel file"
|
|
24
|
+
uv build
|
|
25
|
+
|
|
26
|
+
.PHONY: clean-build
|
|
27
|
+
clean-build: ## clean build artifacts
|
|
28
|
+
@echo "🗑️ Cleaning dist directory"
|
|
29
|
+
rm -rf dist
|
|
30
|
+
|
|
31
|
+
.PHONY: test
|
|
32
|
+
test: ## Test the code with pytest
|
|
33
|
+
@echo "🚀 Testing code: Running pytest"
|
|
34
|
+
@uv run pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml
|
|
35
|
+
|
|
36
|
+
.PHONY: coverage
|
|
37
|
+
coverage: ## Generate coverage report
|
|
38
|
+
@echo "coverage report"
|
|
39
|
+
coverage report
|
|
40
|
+
@echo "Generating coverage report"
|
|
41
|
+
coverage html
|
|
42
|
+
|
|
43
|
+
.PHONY: bump-version
|
|
44
|
+
bump-version: ## Bump the version in the pyproject.toml file by a patch version
|
|
45
|
+
@echo "🚀 Bumping version in pyproject.toml"
|
|
46
|
+
uv version --bump patch
|
|
47
|
+
|
|
48
|
+
.PHONY: check
|
|
49
|
+
check: ## Run code quality tools.
|
|
50
|
+
@echo "🚀 Linting code: Running pre-commit"
|
|
51
|
+
@uv run pre-commit run -a
|
|
52
|
+
@echo "🚀 Static type checking: Running mypy"
|
|
53
|
+
@uv run mypy --config-file=pyproject.toml
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: arcade_postgres
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Tools to query and explore a postgres database
|
|
5
|
+
Author-email: evantahler <support@arcade.dev>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Requires-Dist: arcade-tdk<3.0.0,>=2.0.0
|
|
8
|
+
Requires-Dist: asyncpg>=0.30.0
|
|
9
|
+
Requires-Dist: greenlet>=3.2.3
|
|
10
|
+
Requires-Dist: psycopg2-binary>=2.9.10
|
|
11
|
+
Requires-Dist: pydantic>=2.11.7
|
|
12
|
+
Requires-Dist: sqlalchemy>=2.0.41
|
|
13
|
+
Provides-Extra: dev
|
|
14
|
+
Requires-Dist: arcade-ai[evals]<3.0.0,>=2.0.0; extra == 'dev'
|
|
15
|
+
Requires-Dist: arcade-serve<3.0.0,>=2.0.0; extra == 'dev'
|
|
16
|
+
Requires-Dist: mypy<1.6.0,>=1.5.1; extra == 'dev'
|
|
17
|
+
Requires-Dist: pre-commit<3.5.0,>=3.4.0; extra == 'dev'
|
|
18
|
+
Requires-Dist: pytest-asyncio<0.25.0,>=0.24.0; extra == 'dev'
|
|
19
|
+
Requires-Dist: pytest-cov<4.1.0,>=4.0.0; extra == 'dev'
|
|
20
|
+
Requires-Dist: pytest-mock<3.12.0,>=3.11.1; extra == 'dev'
|
|
21
|
+
Requires-Dist: pytest<8.4.0,>=8.3.0; extra == 'dev'
|
|
22
|
+
Requires-Dist: ruff<0.8.0,>=0.7.4; extra == 'dev'
|
|
23
|
+
Requires-Dist: tox<4.12.0,>=4.11.1; extra == 'dev'
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Any, ClassVar
|
|
2
|
+
from urllib.parse import urlparse
|
|
3
|
+
|
|
4
|
+
from arcade_tdk.errors import RetryableToolError
|
|
5
|
+
from sqlalchemy import text
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
|
7
|
+
|
|
8
|
+
MAX_ROWS_RETURNED = 1000
|
|
9
|
+
TEST_QUERY = "SELECT 1"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DatabaseEngine:
|
|
13
|
+
_instance: ClassVar[None] = None
|
|
14
|
+
_engines: ClassVar[dict[str, AsyncEngine]] = {}
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
async def get_instance(cls, connection_string: str) -> AsyncEngine:
|
|
18
|
+
parsed_url = urlparse(connection_string)
|
|
19
|
+
|
|
20
|
+
# TODO: something strange with sslmode= and friends
|
|
21
|
+
# query_params = parse_qs(parsed_url.query)
|
|
22
|
+
# query_params = {
|
|
23
|
+
# k: v[0] for k, v in query_params.items()
|
|
24
|
+
# } # assume one value allowed for each query param
|
|
25
|
+
|
|
26
|
+
async_connection_string = f"{parsed_url.scheme.replace('postgresql', 'postgresql+asyncpg')}://{parsed_url.netloc}{parsed_url.path}"
|
|
27
|
+
key = f"{async_connection_string}"
|
|
28
|
+
if key not in cls._engines:
|
|
29
|
+
cls._engines[key] = create_async_engine(async_connection_string)
|
|
30
|
+
|
|
31
|
+
# try a simple query to see if the connection is valid
|
|
32
|
+
try:
|
|
33
|
+
async with cls._engines[key].connect() as connection:
|
|
34
|
+
await connection.execute(text(TEST_QUERY))
|
|
35
|
+
return cls._engines[key]
|
|
36
|
+
except Exception:
|
|
37
|
+
await cls._engines[key].dispose()
|
|
38
|
+
|
|
39
|
+
# try again
|
|
40
|
+
try:
|
|
41
|
+
async with cls._engines[key].connect() as connection:
|
|
42
|
+
await connection.execute(text(TEST_QUERY))
|
|
43
|
+
return cls._engines[key]
|
|
44
|
+
except Exception as e:
|
|
45
|
+
raise RetryableToolError(
|
|
46
|
+
f"Connection failed: {e}",
|
|
47
|
+
developer_message="Connection to postgres failed.",
|
|
48
|
+
additional_prompt_content="Check the connection string and try again.",
|
|
49
|
+
) from e
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
async def get_engine(cls, connection_string: str) -> Any:
|
|
53
|
+
engine = await cls.get_instance(connection_string)
|
|
54
|
+
|
|
55
|
+
class ConnectionContextManager:
|
|
56
|
+
def __init__(self, engine: AsyncEngine) -> None:
|
|
57
|
+
self.engine = engine
|
|
58
|
+
|
|
59
|
+
async def __aenter__(self) -> AsyncEngine:
|
|
60
|
+
return self.engine
|
|
61
|
+
|
|
62
|
+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
63
|
+
# Connection cleanup is handled by the async context manager
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
return ConnectionContextManager(engine)
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
async def cleanup(cls) -> None:
|
|
70
|
+
"""Clean up all cached engines. Call this when shutting down."""
|
|
71
|
+
for engine in cls._engines.values():
|
|
72
|
+
await engine.dispose()
|
|
73
|
+
cls._engines.clear()
|
|
74
|
+
|
|
75
|
+
@classmethod
|
|
76
|
+
def clear_cache(cls) -> None:
|
|
77
|
+
"""Clear the engine cache without disposing engines. Use with caution."""
|
|
78
|
+
cls._engines.clear()
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def sanitize_query(cls, query: str) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Sanitize a query to not break our read-only session.
|
|
84
|
+
THIS IS REALLY UNSAFE AND SHOULD NOT BE USED IN PRODUCTION. USE A DATABASE CONNECTION WITH A READ-ONLY USER AND PREPARE STATEMENTS.
|
|
85
|
+
There are also valid reasons for the ";" character, and this prevents that.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
parts = query.split(";")
|
|
89
|
+
if len(parts) > 1:
|
|
90
|
+
raise RetryableToolError(
|
|
91
|
+
"Multiple statements are not allowed in a single query.",
|
|
92
|
+
developer_message="Multiple statements are not allowed in a single query.",
|
|
93
|
+
additional_prompt_content="Split your query into multiple queries and try again.",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
words = parts[0].split(" ")
|
|
97
|
+
if words[0].upper().strip() != "SELECT":
|
|
98
|
+
raise RetryableToolError(
|
|
99
|
+
"Only SELECT queries are allowed.",
|
|
100
|
+
developer_message="Only SELECT queries are allowed.",
|
|
101
|
+
additional_prompt_content="Use the <DiscoverTables> and <GetTableSchema> tools to discover the tables and try again.",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
return f"{query}"
|
|
File without changes
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from typing import Annotated, Any
|
|
2
|
+
|
|
3
|
+
from arcade_tdk import ToolContext, tool
|
|
4
|
+
from arcade_tdk.errors import RetryableToolError
|
|
5
|
+
from sqlalchemy import inspect, text
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
7
|
+
|
|
8
|
+
from ..database_engine import MAX_ROWS_RETURNED, DatabaseEngine
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
|
12
|
+
async def discover_schemas(
|
|
13
|
+
context: ToolContext,
|
|
14
|
+
) -> list[str]:
|
|
15
|
+
"""Discover all the schemas in the postgres database."""
|
|
16
|
+
async with await DatabaseEngine.get_engine(
|
|
17
|
+
context.get_secret("DATABASE_CONNECTION_STRING")
|
|
18
|
+
) as engine:
|
|
19
|
+
schemas = await _get_schemas(engine)
|
|
20
|
+
return schemas
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
|
24
|
+
async def discover_tables(
|
|
25
|
+
context: ToolContext,
|
|
26
|
+
schema_name: Annotated[
|
|
27
|
+
str, "The database schema to discover tables in (default value: 'public')"
|
|
28
|
+
] = "public",
|
|
29
|
+
) -> list[str]:
|
|
30
|
+
"""Discover all the tables in the postgres database when the list of tables is not known.
|
|
31
|
+
|
|
32
|
+
THIS TOOL SHOULD ALWAYS BE USED BEFORE ANY OTHER TOOL THAT REQUIRES A TABLE NAME.
|
|
33
|
+
"""
|
|
34
|
+
async with await DatabaseEngine.get_engine(
|
|
35
|
+
context.get_secret("DATABASE_CONNECTION_STRING")
|
|
36
|
+
) as engine:
|
|
37
|
+
tables = await _get_tables(engine, schema_name)
|
|
38
|
+
return tables
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
|
42
|
+
async def get_table_schema(
|
|
43
|
+
context: ToolContext,
|
|
44
|
+
schema_name: Annotated[str, "The database schema to get the table schema of"],
|
|
45
|
+
table_name: Annotated[str, "The table to get the schema of"],
|
|
46
|
+
) -> list[str]:
|
|
47
|
+
"""
|
|
48
|
+
Get the schema/structure of a postgres table in the postgres database when the schema is not known, and the name of the table is provided.
|
|
49
|
+
|
|
50
|
+
THIS TOOL SHOULD ALWAYS BE USED BEFORE EXECUTING ANY QUERY. ALL TABLES IN THE QUERY MUST BE DISCOVERED FIRST USING THE <DiscoverTables> TOOL.
|
|
51
|
+
"""
|
|
52
|
+
async with await DatabaseEngine.get_engine(
|
|
53
|
+
context.get_secret("DATABASE_CONNECTION_STRING")
|
|
54
|
+
) as engine:
|
|
55
|
+
return await _get_table_schema(engine, schema_name, table_name)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@tool(requires_secrets=["DATABASE_CONNECTION_STRING"])
|
|
59
|
+
async def execute_query(
|
|
60
|
+
context: ToolContext,
|
|
61
|
+
query: Annotated[str, "The postgres SQL query to execute. Only SELECT queries are allowed."],
|
|
62
|
+
) -> list[str]:
|
|
63
|
+
"""
|
|
64
|
+
You have a connection to a postgres database.
|
|
65
|
+
Execute a query and return the results against the postgres database.
|
|
66
|
+
|
|
67
|
+
ONLY USE THIS TOOL IF YOU HAVE ALREADY LOADED THE SCHEMA OF THE TABLES YOU NEED TO QUERY. USE THE <GetTableSchema> TOOL TO LOAD THE SCHEMA IF NOT ALREADY KNOWN.
|
|
68
|
+
|
|
69
|
+
When running queries, follow these rules which will help avoid errors:
|
|
70
|
+
* Always use case-insensitive queries to match strings in the query.
|
|
71
|
+
* Always trim strings in the query.
|
|
72
|
+
* Prefer LIKE queries over direct string matches or regex queries.
|
|
73
|
+
* Only join on columns that are indexed or the primary key. Do not join on arbitrary columns.
|
|
74
|
+
|
|
75
|
+
Only SELECT queries are allowed. Do not use INSERT, UPDATE, DELETE, or other DML statements. This tool will reject them.
|
|
76
|
+
|
|
77
|
+
Unless otherwise specified, ensure that query has a LIMIT of 100 for all results. This tool will enforce that no more than 1000 rows are returned at maximum.
|
|
78
|
+
"""
|
|
79
|
+
async with await DatabaseEngine.get_engine(
|
|
80
|
+
context.get_secret("DATABASE_CONNECTION_STRING")
|
|
81
|
+
) as engine:
|
|
82
|
+
try:
|
|
83
|
+
return await _execute_query(engine, query)
|
|
84
|
+
except Exception as e:
|
|
85
|
+
raise RetryableToolError(
|
|
86
|
+
f"Query failed: {e}",
|
|
87
|
+
developer_message=f"Query '{query}' failed.",
|
|
88
|
+
additional_prompt_content="Load the database schema <GetTableSchema> or use the <DiscoverTables> tool to discover the tables and try again.",
|
|
89
|
+
retry_after_ms=10,
|
|
90
|
+
) from e
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
async def _get_schemas(engine: AsyncEngine) -> list[str]:
|
|
94
|
+
"""Get all the schemas in the database"""
|
|
95
|
+
async with engine.connect() as conn:
|
|
96
|
+
|
|
97
|
+
def get_schema_names(sync_conn: Any) -> list[str]:
|
|
98
|
+
return list(inspect(sync_conn).get_schema_names())
|
|
99
|
+
|
|
100
|
+
schemas: list[str] = await conn.run_sync(get_schema_names)
|
|
101
|
+
schemas = [schema for schema in schemas if schema != "information_schema"]
|
|
102
|
+
|
|
103
|
+
return schemas
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
async def _get_tables(engine: AsyncEngine, schema_name: str) -> list[str]:
|
|
107
|
+
"""Get all the tables in the database"""
|
|
108
|
+
async with engine.connect() as conn:
|
|
109
|
+
|
|
110
|
+
def get_schema_names(sync_conn: Any) -> list[str]:
|
|
111
|
+
return list(inspect(sync_conn).get_schema_names())
|
|
112
|
+
|
|
113
|
+
schemas: list[str] = await conn.run_sync(get_schema_names)
|
|
114
|
+
tables = []
|
|
115
|
+
for schema in schemas:
|
|
116
|
+
if schema == schema_name:
|
|
117
|
+
|
|
118
|
+
def get_table_names(sync_conn: Any, s: str = schema) -> list[str]:
|
|
119
|
+
return list(inspect(sync_conn).get_table_names(schema=s))
|
|
120
|
+
|
|
121
|
+
these_tables = await conn.run_sync(get_table_names)
|
|
122
|
+
tables.extend(these_tables)
|
|
123
|
+
return tables
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def _get_table_schema(engine: AsyncEngine, schema_name: str, table_name: str) -> list[str]:
|
|
127
|
+
"""Get the schema of a table"""
|
|
128
|
+
async with engine.connect() as connection:
|
|
129
|
+
|
|
130
|
+
def get_columns(sync_conn: Any, t: str = table_name, s: str = schema_name) -> list[Any]:
|
|
131
|
+
return list(inspect(sync_conn).get_columns(t, s))
|
|
132
|
+
|
|
133
|
+
columns_table = await connection.run_sync(get_columns)
|
|
134
|
+
|
|
135
|
+
# Get primary key information
|
|
136
|
+
pk_constraint = await connection.run_sync(
|
|
137
|
+
lambda sync_conn: inspect(sync_conn).get_pk_constraint(table_name, schema_name)
|
|
138
|
+
)
|
|
139
|
+
primary_keys = set(pk_constraint.get("constrained_columns", []))
|
|
140
|
+
|
|
141
|
+
# Get index information
|
|
142
|
+
indexes = await connection.run_sync(
|
|
143
|
+
lambda sync_conn: inspect(sync_conn).get_indexes(table_name, schema_name)
|
|
144
|
+
)
|
|
145
|
+
indexed_columns = set()
|
|
146
|
+
for index in indexes:
|
|
147
|
+
indexed_columns.update(index.get("column_names", []))
|
|
148
|
+
|
|
149
|
+
results = []
|
|
150
|
+
for column in columns_table:
|
|
151
|
+
column_name = column["name"]
|
|
152
|
+
column_type = column["type"].python_type.__name__
|
|
153
|
+
|
|
154
|
+
# Build column description
|
|
155
|
+
description = f"{column_name}: {column_type}"
|
|
156
|
+
|
|
157
|
+
# Add primary key indicator
|
|
158
|
+
if column_name in primary_keys:
|
|
159
|
+
description += " (PRIMARY KEY)"
|
|
160
|
+
|
|
161
|
+
# Add index indicator
|
|
162
|
+
if column_name in indexed_columns:
|
|
163
|
+
description += " (INDEXED)"
|
|
164
|
+
|
|
165
|
+
results.append(description)
|
|
166
|
+
|
|
167
|
+
return results[:MAX_ROWS_RETURNED]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
async def _execute_query(
|
|
171
|
+
engine: AsyncEngine, query: str, params: dict[str, Any] | None = None
|
|
172
|
+
) -> list[str]:
|
|
173
|
+
"""Execute a query and return the results."""
|
|
174
|
+
async with engine.connect() as connection:
|
|
175
|
+
result = await connection.execute(text(DatabaseEngine.sanitize_query(query)), params)
|
|
176
|
+
rows = result.fetchall()
|
|
177
|
+
results = [str(row) for row in rows]
|
|
178
|
+
return results[:MAX_ROWS_RETURNED]
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
import arcade_postgres
|
|
2
|
+
from arcade_evals import (
|
|
3
|
+
BinaryCritic,
|
|
4
|
+
EvalRubric,
|
|
5
|
+
EvalSuite,
|
|
6
|
+
ExpectedToolCall,
|
|
7
|
+
SimilarityCritic,
|
|
8
|
+
tool_eval,
|
|
9
|
+
)
|
|
10
|
+
from arcade_postgres.tools.postgres import (
|
|
11
|
+
discover_tables,
|
|
12
|
+
execute_query,
|
|
13
|
+
get_table_schema,
|
|
14
|
+
)
|
|
15
|
+
from arcade_tdk import ToolCatalog
|
|
16
|
+
|
|
17
|
+
# Evaluation rubric
|
|
18
|
+
rubric = EvalRubric(
|
|
19
|
+
fail_threshold=0.85,
|
|
20
|
+
warn_threshold=0.95,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
catalog = ToolCatalog()
|
|
25
|
+
catalog.add_module(arcade_postgres)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@tool_eval()
|
|
29
|
+
def sql_eval_suite() -> EvalSuite:
|
|
30
|
+
suite = EvalSuite(
|
|
31
|
+
name="sql Tools Evaluation",
|
|
32
|
+
system_message=(
|
|
33
|
+
"You are an AI assistant with access to sql tools. "
|
|
34
|
+
"Use them to help the user with their tasks."
|
|
35
|
+
),
|
|
36
|
+
catalog=catalog,
|
|
37
|
+
rubric=rubric,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
suite.add_case(
|
|
41
|
+
name="Get user by id (schema known)",
|
|
42
|
+
user_message="Tell me the name and email of user #1 in my database. The table 'users' has the following schema: id: int, name: str, email: str, password_hash: str, created_at: datetime, updated_at: datetime",
|
|
43
|
+
expected_tool_calls=[
|
|
44
|
+
ExpectedToolCall(
|
|
45
|
+
func=execute_query, args={"query": "SELECT name, email FROM users WHERE id = 1"}
|
|
46
|
+
)
|
|
47
|
+
],
|
|
48
|
+
rubric=rubric,
|
|
49
|
+
critics=[SimilarityCritic(critic_field="query", weight=1.0)],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
suite.add_case(
|
|
53
|
+
name="Discover tables",
|
|
54
|
+
user_message="What tables are in my database?",
|
|
55
|
+
expected_tool_calls=[
|
|
56
|
+
ExpectedToolCall(func=discover_tables, args={}),
|
|
57
|
+
],
|
|
58
|
+
rubric=rubric,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
suite.add_case(
|
|
62
|
+
name="Get table schema (schema provided)",
|
|
63
|
+
user_message="What columns are in the table 'public.users' in my database?",
|
|
64
|
+
expected_tool_calls=[
|
|
65
|
+
ExpectedToolCall(
|
|
66
|
+
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
|
|
67
|
+
),
|
|
68
|
+
],
|
|
69
|
+
rubric=rubric,
|
|
70
|
+
critics=[
|
|
71
|
+
BinaryCritic(critic_field="schema_name", weight=0.5),
|
|
72
|
+
BinaryCritic(critic_field="table_name", weight=0.5),
|
|
73
|
+
],
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
suite.add_case(
|
|
77
|
+
name="Get table schema (schema not provided)",
|
|
78
|
+
user_message="What columns are in the table 'users' in my database?",
|
|
79
|
+
additional_messages=[
|
|
80
|
+
{"role": "user", "content": "When not provided, the schema is 'public'."}
|
|
81
|
+
],
|
|
82
|
+
expected_tool_calls=[
|
|
83
|
+
ExpectedToolCall(
|
|
84
|
+
func=get_table_schema, args={"schema_name": "public", "table_name": "users"}
|
|
85
|
+
),
|
|
86
|
+
],
|
|
87
|
+
rubric=rubric,
|
|
88
|
+
critics=[
|
|
89
|
+
BinaryCritic(critic_field="schema_name", weight=0.5),
|
|
90
|
+
BinaryCritic(critic_field="table_name", weight=0.5),
|
|
91
|
+
],
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
return suite
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = [ "hatchling",]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "arcade_postgres"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Tools to query and explore a postgres database"
|
|
9
|
+
requires-python = ">=3.10"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"arcade-tdk>=2.0.0,<3.0.0",
|
|
12
|
+
"psycopg2-binary>=2.9.10",
|
|
13
|
+
"pydantic>=2.11.7",
|
|
14
|
+
"sqlalchemy>=2.0.41",
|
|
15
|
+
"psycopg2-binary>=2.9.10",
|
|
16
|
+
"asyncpg>=0.30.0",
|
|
17
|
+
"greenlet>=3.2.3",
|
|
18
|
+
]
|
|
19
|
+
[[project.authors]]
|
|
20
|
+
name = "evantahler"
|
|
21
|
+
email = "support@arcade.dev"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
[project.optional-dependencies]
|
|
25
|
+
dev = [
|
|
26
|
+
"arcade-ai[evals]>=2.0.0,<3.0.0",
|
|
27
|
+
"arcade-serve>=2.0.0,<3.0.0",
|
|
28
|
+
"pytest>=8.3.0,<8.4.0",
|
|
29
|
+
"pytest-cov>=4.0.0,<4.1.0",
|
|
30
|
+
"pytest-mock>=3.11.1,<3.12.0",
|
|
31
|
+
"pytest-asyncio>=0.24.0,<0.25.0",
|
|
32
|
+
"mypy>=1.5.1,<1.6.0",
|
|
33
|
+
"pre-commit>=3.4.0,<3.5.0",
|
|
34
|
+
"tox>=4.11.1,<4.12.0",
|
|
35
|
+
"ruff>=0.7.4,<0.8.0",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
# Use local path sources for arcade libs when working locally
|
|
39
|
+
[tool.uv.sources]
|
|
40
|
+
arcade-ai = { path = "../../", editable = true }
|
|
41
|
+
arcade-serve = { path = "../../libs/arcade-serve/", editable = true }
|
|
42
|
+
arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true }
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
[tool.mypy]
|
|
46
|
+
files = [ "arcade_postgres/**/*.py",]
|
|
47
|
+
python_version = "3.10"
|
|
48
|
+
disallow_untyped_defs = "True"
|
|
49
|
+
disallow_any_unimported = "True"
|
|
50
|
+
no_implicit_optional = "True"
|
|
51
|
+
check_untyped_defs = "True"
|
|
52
|
+
warn_return_any = "True"
|
|
53
|
+
warn_unused_ignores = "True"
|
|
54
|
+
show_error_codes = "True"
|
|
55
|
+
ignore_missing_imports = "True"
|
|
56
|
+
|
|
57
|
+
[tool.pytest.ini_options]
|
|
58
|
+
testpaths = [ "tests",]
|
|
59
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
60
|
+
|
|
61
|
+
[tool.coverage.report]
|
|
62
|
+
skip_empty = true
|
|
63
|
+
|
|
64
|
+
[tool.hatch.build.targets.wheel]
|
|
65
|
+
packages = [ "arcade_postgres",]
|
|
File without changes
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
DROP TABLE IF EXISTS "public"."messages";
|
|
2
|
+
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
|
|
3
|
+
-- Sequence and defined type
|
|
4
|
+
CREATE SEQUENCE IF NOT EXISTS messages_id_seq;
|
|
5
|
+
-- Table Definition
|
|
6
|
+
CREATE TABLE "public"."messages" (
|
|
7
|
+
"id" int4 NOT NULL DEFAULT nextval('messages_id_seq'::regclass),
|
|
8
|
+
"body" text NOT NULL,
|
|
9
|
+
"user_id" int4 NOT NULL,
|
|
10
|
+
"created_at" timestamp NOT NULL DEFAULT now(),
|
|
11
|
+
"updated_at" timestamp NOT NULL DEFAULT now(),
|
|
12
|
+
PRIMARY KEY ("id")
|
|
13
|
+
);
|
|
14
|
+
DROP TABLE IF EXISTS "public"."users";
|
|
15
|
+
-- This script only contains the table creation statements and does not fully represent the table in the database. Do not use it as a backup.
|
|
16
|
+
-- Sequence and defined type
|
|
17
|
+
CREATE SEQUENCE IF NOT EXISTS users_id_seq;
|
|
18
|
+
-- Table Definition
|
|
19
|
+
CREATE TABLE "public"."users" (
|
|
20
|
+
"id" int4 NOT NULL DEFAULT nextval('users_id_seq'::regclass),
|
|
21
|
+
"name" varchar(256) NOT NULL,
|
|
22
|
+
"email" text NOT NULL,
|
|
23
|
+
"password_hash" text NOT NULL,
|
|
24
|
+
"created_at" timestamp NOT NULL DEFAULT now(),
|
|
25
|
+
"updated_at" timestamp NOT NULL DEFAULT now(),
|
|
26
|
+
"status" varchar,
|
|
27
|
+
PRIMARY KEY ("id")
|
|
28
|
+
);
|
|
29
|
+
INSERT INTO "public"."messages" (
|
|
30
|
+
"id",
|
|
31
|
+
"body",
|
|
32
|
+
"user_id",
|
|
33
|
+
"created_at",
|
|
34
|
+
"updated_at"
|
|
35
|
+
)
|
|
36
|
+
VALUES (
|
|
37
|
+
1,
|
|
38
|
+
'Evan says hello',
|
|
39
|
+
3,
|
|
40
|
+
'2025-04-10 17:21:05.504468',
|
|
41
|
+
'2025-04-10 17:21:05.504468'
|
|
42
|
+
),
|
|
43
|
+
(
|
|
44
|
+
5100,
|
|
45
|
+
'Hello! The current time is 2025-01-13T14:38:39.204Z',
|
|
46
|
+
12,
|
|
47
|
+
'2025-01-13 06:38:39.210897',
|
|
48
|
+
'2025-01-13 06:38:39.210897'
|
|
49
|
+
),
|
|
50
|
+
(
|
|
51
|
+
5101,
|
|
52
|
+
'Hello! The current time is 2025-01-13T14:55:32.560Z',
|
|
53
|
+
12,
|
|
54
|
+
'2025-01-13 06:55:32.56934',
|
|
55
|
+
'2025-01-13 06:55:32.56934'
|
|
56
|
+
),
|
|
57
|
+
(
|
|
58
|
+
5102,
|
|
59
|
+
'Hello! The current time is 2025-01-13T15:00:37.250Z',
|
|
60
|
+
12,
|
|
61
|
+
'2025-01-13 07:00:37.261816',
|
|
62
|
+
'2025-01-13 07:00:37.261816'
|
|
63
|
+
),
|
|
64
|
+
(
|
|
65
|
+
5319,
|
|
66
|
+
'Hello! The current time is 2025-01-14T07:17:07.115Z',
|
|
67
|
+
12,
|
|
68
|
+
'2025-01-13 23:17:07.123393',
|
|
69
|
+
'2025-01-13 23:17:07.123393'
|
|
70
|
+
);
|
|
71
|
+
INSERT INTO "public"."users" (
|
|
72
|
+
"id",
|
|
73
|
+
"name",
|
|
74
|
+
"email",
|
|
75
|
+
"password_hash",
|
|
76
|
+
"created_at",
|
|
77
|
+
"updated_at",
|
|
78
|
+
"status"
|
|
79
|
+
)
|
|
80
|
+
VALUES (
|
|
81
|
+
1,
|
|
82
|
+
'Mario',
|
|
83
|
+
'mario@example.com',
|
|
84
|
+
'$argon2id$v=19$m=65536,t=2,p=1$tMg1Rd3IEDnp3iFKrqsF4Dsbw6/Cbf6seRB/H5bhaPg$zZj5yn4x3D3O3mDHcW2aczQNiYfAs3cw21XMEIgkF0E',
|
|
85
|
+
'2024-09-01 20:49:38.759432',
|
|
86
|
+
'2024-09-02 03:49:39.927',
|
|
87
|
+
'active'
|
|
88
|
+
),
|
|
89
|
+
(
|
|
90
|
+
3,
|
|
91
|
+
'Evan',
|
|
92
|
+
'evantahler@gmail.com',
|
|
93
|
+
'$argon2id$v=19$m=65536,t=2,p=1$CvOMK1WUd99R7kYXpiBPNYw4OQP53pYIgeMnwz92mrE$HPthId4phMoPT1TWuCRHHCr9BSQA8XoUkQuB1HZsqTY',
|
|
94
|
+
'2024-09-02 17:49:23.377425',
|
|
95
|
+
'2024-09-02 17:49:23.377425',
|
|
96
|
+
'active'
|
|
97
|
+
),
|
|
98
|
+
(
|
|
99
|
+
12,
|
|
100
|
+
'Admin',
|
|
101
|
+
'admin@arcade.dev',
|
|
102
|
+
'$argon2id$v=19$m=65536,t=2,p=1$paCAAD1HVZkncP/WvecuUO6zFXp2/8BISpgr5rXRxps$M5kBFc9JHHGNw9SXnPu2ggpJY0mFFCska7TXMrllndo',
|
|
103
|
+
'2024-10-13 15:01:30.792909',
|
|
104
|
+
'2024-10-13 15:01:30.792909',
|
|
105
|
+
'inactive'
|
|
106
|
+
);
|
|
107
|
+
ALTER TABLE "public"."messages"
|
|
108
|
+
ADD FOREIGN KEY ("user_id") REFERENCES "public"."users"("id");
|
|
109
|
+
-- set pk to 13
|
|
110
|
+
ALTER SEQUENCE users_id_seq RESTART WITH 13;
|
|
111
|
+
-- Indices
|
|
112
|
+
CREATE UNIQUE INDEX name_idx ON public.users USING btree (name);
|
|
113
|
+
CREATE UNIQUE INDEX email_idx ON public.users USING btree (email);
|
|
114
|
+
CREATE UNIQUE INDEX users_email_unique ON public.users USING btree (email);
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from os import environ
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
import pytest_asyncio
|
|
6
|
+
from arcade_postgres.tools.postgres import (
|
|
7
|
+
DatabaseEngine,
|
|
8
|
+
discover_schemas,
|
|
9
|
+
discover_tables,
|
|
10
|
+
execute_query,
|
|
11
|
+
get_table_schema,
|
|
12
|
+
)
|
|
13
|
+
from arcade_tdk import ToolContext, ToolSecretItem
|
|
14
|
+
from arcade_tdk.errors import RetryableToolError
|
|
15
|
+
from sqlalchemy import text
|
|
16
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
|
17
|
+
|
|
18
|
+
DATABASE_CONNECTION_STRING = (
|
|
19
|
+
environ.get("TEST_POSTGRES_DATABASE_CONNECTION_STRING")
|
|
20
|
+
or "postgresql://evan@localhost:5432/postgres"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def mock_context():
|
|
26
|
+
context = ToolContext()
|
|
27
|
+
context.secrets = []
|
|
28
|
+
context.secrets.append(
|
|
29
|
+
ToolSecretItem(key="DATABASE_CONNECTION_STRING", value=DATABASE_CONNECTION_STRING)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
return context
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# before the tests, restore the database from the dump
|
|
36
|
+
@pytest_asyncio.fixture(autouse=True)
|
|
37
|
+
async def restore_database():
|
|
38
|
+
with open(f"{os.path.dirname(__file__)}/dump.sql") as f:
|
|
39
|
+
engine = create_async_engine(
|
|
40
|
+
DATABASE_CONNECTION_STRING.replace("postgresql", "postgresql+asyncpg").split("?")[0]
|
|
41
|
+
)
|
|
42
|
+
async with engine.connect() as c:
|
|
43
|
+
queries = f.read().split(";")
|
|
44
|
+
await c.execute(text("BEGIN"))
|
|
45
|
+
for query in queries:
|
|
46
|
+
if query.strip():
|
|
47
|
+
await c.execute(text(query))
|
|
48
|
+
await c.commit()
|
|
49
|
+
await engine.dispose()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@pytest_asyncio.fixture(autouse=True)
|
|
53
|
+
async def cleanup_engines():
|
|
54
|
+
"""Clean up database engines after each test to prevent connection leaks."""
|
|
55
|
+
yield
|
|
56
|
+
# Clean up all cached engines after each test
|
|
57
|
+
await DatabaseEngine.cleanup()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.mark.asyncio
|
|
61
|
+
async def test_discover_schemas(mock_context) -> None:
|
|
62
|
+
assert await discover_schemas(mock_context) == ["public"]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@pytest.mark.asyncio
|
|
66
|
+
async def test_discover_tables(mock_context) -> None:
|
|
67
|
+
assert await discover_tables(mock_context) == ["users", "messages"]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_get_table_schema(mock_context) -> None:
|
|
72
|
+
assert await get_table_schema(mock_context, "public", "users") == [
|
|
73
|
+
"id: int (PRIMARY KEY)",
|
|
74
|
+
"name: str (INDEXED)",
|
|
75
|
+
"email: str (INDEXED)",
|
|
76
|
+
"password_hash: str",
|
|
77
|
+
"created_at: datetime",
|
|
78
|
+
"updated_at: datetime",
|
|
79
|
+
"status: str",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
assert await get_table_schema(mock_context, "public", "messages") == [
|
|
83
|
+
"id: int (PRIMARY KEY)",
|
|
84
|
+
"body: str",
|
|
85
|
+
"user_id: int",
|
|
86
|
+
"created_at: datetime",
|
|
87
|
+
"updated_at: datetime",
|
|
88
|
+
]
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@pytest.mark.asyncio
|
|
92
|
+
async def test_execute_query(mock_context) -> None:
|
|
93
|
+
assert await execute_query(mock_context, "SELECT id, name, email FROM users WHERE id = 1") == [
|
|
94
|
+
"(1, 'Mario', 'mario@example.com')"
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.asyncio
|
|
99
|
+
async def test_execute_query_with_no_results(mock_context) -> None:
|
|
100
|
+
# does not raise an error
|
|
101
|
+
assert await execute_query(mock_context, "SELECT * FROM users WHERE id = 9999999999") == []
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@pytest.mark.asyncio
|
|
105
|
+
async def test_execute_query_with_problem(mock_context) -> None:
|
|
106
|
+
# 'foo' is not a valid id
|
|
107
|
+
with pytest.raises(RetryableToolError) as e:
|
|
108
|
+
await execute_query(mock_context, "SELECT * FROM users WHERE id = 'foo'")
|
|
109
|
+
assert "invalid input syntax" in str(e.value)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@pytest.mark.asyncio
|
|
113
|
+
async def test_execute_query_rejects_non_select(mock_context) -> None:
|
|
114
|
+
with pytest.raises(RetryableToolError) as e:
|
|
115
|
+
await execute_query(
|
|
116
|
+
mock_context,
|
|
117
|
+
"INSERT INTO users (name, email, password_hash) VALUES ('Luigi', 'luigi@example.com', 'password')",
|
|
118
|
+
)
|
|
119
|
+
assert "Only SELECT queries are allowed" in str(e.value)
|