daita-agents 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- daita/__init__.py +216 -0
- daita/agents/__init__.py +33 -0
- daita/agents/base.py +743 -0
- daita/agents/substrate.py +1141 -0
- daita/cli/__init__.py +145 -0
- daita/cli/__main__.py +7 -0
- daita/cli/ascii_art.py +44 -0
- daita/cli/core/__init__.py +0 -0
- daita/cli/core/create.py +254 -0
- daita/cli/core/deploy.py +473 -0
- daita/cli/core/deployments.py +309 -0
- daita/cli/core/import_detector.py +219 -0
- daita/cli/core/init.py +481 -0
- daita/cli/core/logs.py +239 -0
- daita/cli/core/managed_deploy.py +709 -0
- daita/cli/core/run.py +648 -0
- daita/cli/core/status.py +421 -0
- daita/cli/core/test.py +239 -0
- daita/cli/core/webhooks.py +172 -0
- daita/cli/main.py +588 -0
- daita/cli/utils.py +541 -0
- daita/config/__init__.py +62 -0
- daita/config/base.py +159 -0
- daita/config/settings.py +184 -0
- daita/core/__init__.py +262 -0
- daita/core/decision_tracing.py +701 -0
- daita/core/exceptions.py +480 -0
- daita/core/focus.py +251 -0
- daita/core/interfaces.py +76 -0
- daita/core/plugin_tracing.py +550 -0
- daita/core/relay.py +779 -0
- daita/core/reliability.py +381 -0
- daita/core/scaling.py +459 -0
- daita/core/tools.py +554 -0
- daita/core/tracing.py +770 -0
- daita/core/workflow.py +1144 -0
- daita/display/__init__.py +1 -0
- daita/display/console.py +160 -0
- daita/execution/__init__.py +58 -0
- daita/execution/client.py +856 -0
- daita/execution/exceptions.py +92 -0
- daita/execution/models.py +317 -0
- daita/llm/__init__.py +60 -0
- daita/llm/anthropic.py +291 -0
- daita/llm/base.py +530 -0
- daita/llm/factory.py +101 -0
- daita/llm/gemini.py +355 -0
- daita/llm/grok.py +219 -0
- daita/llm/mock.py +172 -0
- daita/llm/openai.py +220 -0
- daita/plugins/__init__.py +141 -0
- daita/plugins/base.py +37 -0
- daita/plugins/base_db.py +167 -0
- daita/plugins/elasticsearch.py +849 -0
- daita/plugins/mcp.py +481 -0
- daita/plugins/mongodb.py +520 -0
- daita/plugins/mysql.py +362 -0
- daita/plugins/postgresql.py +342 -0
- daita/plugins/redis_messaging.py +500 -0
- daita/plugins/rest.py +537 -0
- daita/plugins/s3.py +770 -0
- daita/plugins/slack.py +729 -0
- daita/utils/__init__.py +18 -0
- daita_agents-0.2.0.dist-info/METADATA +409 -0
- daita_agents-0.2.0.dist-info/RECORD +69 -0
- daita_agents-0.2.0.dist-info/WHEEL +5 -0
- daita_agents-0.2.0.dist-info/entry_points.txt +2 -0
- daita_agents-0.2.0.dist-info/licenses/LICENSE +56 -0
- daita_agents-0.2.0.dist-info/top_level.txt +1 -0
daita/plugins/mysql.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MySQL plugin for Daita Agents.
|
|
3
|
+
|
|
4
|
+
Simple MySQL connection and querying - no over-engineering.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
|
8
|
+
from .base_db import BaseDatabasePlugin
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..core.tools import AgentTool
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class MySQLPlugin(BaseDatabasePlugin):
|
|
16
|
+
"""
|
|
17
|
+
MySQL plugin for agents with standardized connection management.
|
|
18
|
+
|
|
19
|
+
Inherits common database functionality from BaseDatabasePlugin.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
host: str = "localhost",
|
|
25
|
+
port: int = 3306,
|
|
26
|
+
database: str = "",
|
|
27
|
+
username: str = "",
|
|
28
|
+
password: str = "",
|
|
29
|
+
connection_string: Optional[str] = None,
|
|
30
|
+
**kwargs
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Initialize MySQL connection.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
host: Database host
|
|
37
|
+
port: Database port
|
|
38
|
+
database: Database name
|
|
39
|
+
username: Username
|
|
40
|
+
password: Password
|
|
41
|
+
connection_string: Full connection string (overrides individual params)
|
|
42
|
+
**kwargs: Additional aiomysql parameters
|
|
43
|
+
"""
|
|
44
|
+
if connection_string:
|
|
45
|
+
self.connection_string = connection_string
|
|
46
|
+
else:
|
|
47
|
+
self.connection_string = f"mysql://{username}:{password}@{host}:{port}/{database}"
|
|
48
|
+
|
|
49
|
+
# Store connection parameters for aiomysql
|
|
50
|
+
self.host = host
|
|
51
|
+
self.port = port
|
|
52
|
+
self.user = username
|
|
53
|
+
self.password = password
|
|
54
|
+
self.db = database
|
|
55
|
+
|
|
56
|
+
self.pool_config = {
|
|
57
|
+
'minsize': kwargs.get('min_size', 1),
|
|
58
|
+
'maxsize': kwargs.get('max_size', 10),
|
|
59
|
+
'charset': kwargs.get('charset', 'utf8mb4'),
|
|
60
|
+
'autocommit': kwargs.get('autocommit', True),
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
# Initialize base class with all config
|
|
64
|
+
super().__init__(
|
|
65
|
+
host=host, port=port, database=database,
|
|
66
|
+
username=username, connection_string=connection_string,
|
|
67
|
+
**kwargs
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
logger.debug(f"MySQL plugin configured for {host}:{port}/{database}")
|
|
71
|
+
|
|
72
|
+
async def connect(self):
|
|
73
|
+
"""Connect to MySQL database."""
|
|
74
|
+
if self._pool is not None:
|
|
75
|
+
return # Already connected
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
import aiomysql
|
|
79
|
+
self._pool = await aiomysql.create_pool(
|
|
80
|
+
host=self.host,
|
|
81
|
+
port=self.port,
|
|
82
|
+
user=self.user,
|
|
83
|
+
password=self.password,
|
|
84
|
+
db=self.db,
|
|
85
|
+
**self.pool_config
|
|
86
|
+
)
|
|
87
|
+
logger.info("Connected to MySQL")
|
|
88
|
+
except ImportError:
|
|
89
|
+
self._handle_connection_error(
|
|
90
|
+
ImportError("aiomysql not installed. Run: pip install aiomysql"),
|
|
91
|
+
"connection"
|
|
92
|
+
)
|
|
93
|
+
except Exception as e:
|
|
94
|
+
self._handle_connection_error(e, "connection")
|
|
95
|
+
|
|
96
|
+
async def disconnect(self):
|
|
97
|
+
"""Disconnect from the database."""
|
|
98
|
+
if self._pool:
|
|
99
|
+
self._pool.close()
|
|
100
|
+
await self._pool.wait_closed()
|
|
101
|
+
self._pool = None
|
|
102
|
+
logger.info("Disconnected from MySQL")
|
|
103
|
+
|
|
104
|
+
async def query(self, sql: str, params: Optional[List] = None) -> List[Dict[str, Any]]:
|
|
105
|
+
"""
|
|
106
|
+
Run a SELECT query and return results.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
sql: SQL query with %s placeholders
|
|
110
|
+
params: List of parameters for the query
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
List of rows as dictionaries
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
results = await db.query("SELECT * FROM users WHERE age > %s", [25])
|
|
117
|
+
"""
|
|
118
|
+
# Only auto-connect if pool is None - allows manual mocking
|
|
119
|
+
if self._pool is None:
|
|
120
|
+
await self.connect()
|
|
121
|
+
|
|
122
|
+
async with self._pool.acquire() as conn:
|
|
123
|
+
async with conn.cursor() as cursor:
|
|
124
|
+
if params:
|
|
125
|
+
await cursor.execute(sql, params)
|
|
126
|
+
else:
|
|
127
|
+
await cursor.execute(sql)
|
|
128
|
+
|
|
129
|
+
rows = await cursor.fetchall()
|
|
130
|
+
columns = [desc[0] for desc in cursor.description]
|
|
131
|
+
|
|
132
|
+
return [dict(zip(columns, row)) for row in rows]
|
|
133
|
+
|
|
134
|
+
async def execute(self, sql: str, params: Optional[List] = None) -> int:
|
|
135
|
+
"""
|
|
136
|
+
Execute INSERT/UPDATE/DELETE and return affected rows.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
sql: SQL statement
|
|
140
|
+
params: List of parameters
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Number of affected rows
|
|
144
|
+
"""
|
|
145
|
+
# Only auto-connect if pool is None - allows manual mocking
|
|
146
|
+
if self._pool is None:
|
|
147
|
+
await self.connect()
|
|
148
|
+
|
|
149
|
+
async with self._pool.acquire() as conn:
|
|
150
|
+
async with conn.cursor() as cursor:
|
|
151
|
+
if params:
|
|
152
|
+
await cursor.execute(sql, params)
|
|
153
|
+
else:
|
|
154
|
+
await cursor.execute(sql)
|
|
155
|
+
|
|
156
|
+
return cursor.rowcount
|
|
157
|
+
|
|
158
|
+
async def insert_many(self, table: str, data: List[Dict[str, Any]]) -> int:
|
|
159
|
+
"""
|
|
160
|
+
Bulk insert data into a table.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
table: Table name
|
|
164
|
+
data: List of dictionaries to insert
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Number of rows inserted
|
|
168
|
+
"""
|
|
169
|
+
if not data:
|
|
170
|
+
return 0
|
|
171
|
+
|
|
172
|
+
# Only auto-connect if pool is None - allows manual mocking
|
|
173
|
+
if self._pool is None:
|
|
174
|
+
await self.connect()
|
|
175
|
+
|
|
176
|
+
# Get columns from first row
|
|
177
|
+
columns = list(data[0].keys())
|
|
178
|
+
placeholders = ', '.join(['%s'] * len(columns))
|
|
179
|
+
|
|
180
|
+
sql = f"INSERT INTO {table} (`{'`, `'.join(columns)}`) VALUES ({placeholders})"
|
|
181
|
+
|
|
182
|
+
# Convert to list of tuples for executemany
|
|
183
|
+
rows = [tuple(row[col] for col in columns) for row in data]
|
|
184
|
+
|
|
185
|
+
async with self._pool.acquire() as conn:
|
|
186
|
+
async with conn.cursor() as cursor:
|
|
187
|
+
await cursor.executemany(sql, rows)
|
|
188
|
+
return cursor.rowcount
|
|
189
|
+
|
|
190
|
+
async def tables(self) -> List[str]:
|
|
191
|
+
"""List all tables in the database."""
|
|
192
|
+
sql = """
|
|
193
|
+
SELECT TABLE_NAME as table_name
|
|
194
|
+
FROM INFORMATION_SCHEMA.TABLES
|
|
195
|
+
WHERE TABLE_TYPE = 'BASE TABLE'
|
|
196
|
+
AND TABLE_SCHEMA = DATABASE()
|
|
197
|
+
ORDER BY TABLE_NAME
|
|
198
|
+
"""
|
|
199
|
+
results = await self.query(sql)
|
|
200
|
+
return [row['table_name'] for row in results]
|
|
201
|
+
|
|
202
|
+
async def describe(self, table: str) -> List[Dict[str, Any]]:
|
|
203
|
+
"""Get table schema information."""
|
|
204
|
+
sql = """
|
|
205
|
+
SELECT
|
|
206
|
+
COLUMN_NAME as column_name,
|
|
207
|
+
DATA_TYPE as data_type,
|
|
208
|
+
IS_NULLABLE as is_nullable,
|
|
209
|
+
COLUMN_DEFAULT as column_default,
|
|
210
|
+
COLUMN_TYPE as column_type
|
|
211
|
+
FROM INFORMATION_SCHEMA.COLUMNS
|
|
212
|
+
WHERE TABLE_NAME = %s
|
|
213
|
+
AND TABLE_SCHEMA = DATABASE()
|
|
214
|
+
ORDER BY ORDINAL_POSITION
|
|
215
|
+
"""
|
|
216
|
+
return await self.query(sql, [table])
|
|
217
|
+
|
|
218
|
+
def get_tools(self) -> List['AgentTool']:
|
|
219
|
+
"""
|
|
220
|
+
Expose MySQL operations as agent tools.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
List of AgentTool instances for database operations
|
|
224
|
+
"""
|
|
225
|
+
from ..core.tools import AgentTool
|
|
226
|
+
|
|
227
|
+
return [
|
|
228
|
+
AgentTool(
|
|
229
|
+
name="query_database",
|
|
230
|
+
description="Execute a SQL SELECT query on the MySQL database and return results as a list of dictionaries",
|
|
231
|
+
parameters={
|
|
232
|
+
"type": "object",
|
|
233
|
+
"properties": {
|
|
234
|
+
"sql": {
|
|
235
|
+
"type": "string",
|
|
236
|
+
"description": "SQL SELECT query with %s placeholders for parameters"
|
|
237
|
+
},
|
|
238
|
+
"params": {
|
|
239
|
+
"type": "array",
|
|
240
|
+
"description": "Optional list of parameter values for query placeholders",
|
|
241
|
+
"items": {"type": "string"}
|
|
242
|
+
}
|
|
243
|
+
},
|
|
244
|
+
"required": ["sql"]
|
|
245
|
+
},
|
|
246
|
+
handler=self._tool_query,
|
|
247
|
+
category="database",
|
|
248
|
+
source="plugin",
|
|
249
|
+
plugin_name="MySQL",
|
|
250
|
+
timeout_seconds=60
|
|
251
|
+
),
|
|
252
|
+
AgentTool(
|
|
253
|
+
name="list_tables",
|
|
254
|
+
description="List all tables in the MySQL database",
|
|
255
|
+
parameters={
|
|
256
|
+
"type": "object",
|
|
257
|
+
"properties": {},
|
|
258
|
+
"required": []
|
|
259
|
+
},
|
|
260
|
+
handler=self._tool_list_tables,
|
|
261
|
+
category="database",
|
|
262
|
+
source="plugin",
|
|
263
|
+
plugin_name="MySQL",
|
|
264
|
+
timeout_seconds=30
|
|
265
|
+
),
|
|
266
|
+
AgentTool(
|
|
267
|
+
name="get_table_schema",
|
|
268
|
+
description="Get column information (name, data type, nullable) for a specific table in MySQL",
|
|
269
|
+
parameters={
|
|
270
|
+
"type": "object",
|
|
271
|
+
"properties": {
|
|
272
|
+
"table_name": {
|
|
273
|
+
"type": "string",
|
|
274
|
+
"description": "Name of the table to inspect"
|
|
275
|
+
}
|
|
276
|
+
},
|
|
277
|
+
"required": ["table_name"]
|
|
278
|
+
},
|
|
279
|
+
handler=self._tool_get_schema,
|
|
280
|
+
category="database",
|
|
281
|
+
source="plugin",
|
|
282
|
+
plugin_name="MySQL",
|
|
283
|
+
timeout_seconds=30
|
|
284
|
+
),
|
|
285
|
+
AgentTool(
|
|
286
|
+
name="execute_sql",
|
|
287
|
+
description="Execute an INSERT, UPDATE, or DELETE SQL statement on MySQL. Returns the number of affected rows.",
|
|
288
|
+
parameters={
|
|
289
|
+
"type": "object",
|
|
290
|
+
"properties": {
|
|
291
|
+
"sql": {
|
|
292
|
+
"type": "string",
|
|
293
|
+
"description": "SQL statement to execute (INSERT, UPDATE, or DELETE)"
|
|
294
|
+
},
|
|
295
|
+
"params": {
|
|
296
|
+
"type": "array",
|
|
297
|
+
"description": "Optional list of parameter values for statement placeholders",
|
|
298
|
+
"items": {"type": "string"}
|
|
299
|
+
}
|
|
300
|
+
},
|
|
301
|
+
"required": ["sql"]
|
|
302
|
+
},
|
|
303
|
+
handler=self._tool_execute,
|
|
304
|
+
category="database",
|
|
305
|
+
source="plugin",
|
|
306
|
+
plugin_name="MySQL",
|
|
307
|
+
timeout_seconds=60
|
|
308
|
+
)
|
|
309
|
+
]
|
|
310
|
+
|
|
311
|
+
async def _tool_query(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
312
|
+
"""Tool handler for query_database"""
|
|
313
|
+
sql = args.get("sql")
|
|
314
|
+
params = args.get("params")
|
|
315
|
+
|
|
316
|
+
results = await self.query(sql, params)
|
|
317
|
+
|
|
318
|
+
return {
|
|
319
|
+
"success": True,
|
|
320
|
+
"rows": results,
|
|
321
|
+
"row_count": len(results)
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
async def _tool_list_tables(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
325
|
+
"""Tool handler for list_tables"""
|
|
326
|
+
tables = await self.tables()
|
|
327
|
+
|
|
328
|
+
return {
|
|
329
|
+
"success": True,
|
|
330
|
+
"tables": tables,
|
|
331
|
+
"count": len(tables)
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
async def _tool_get_schema(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
335
|
+
"""Tool handler for get_table_schema"""
|
|
336
|
+
table_name = args.get("table_name")
|
|
337
|
+
|
|
338
|
+
columns = await self.describe(table_name)
|
|
339
|
+
|
|
340
|
+
return {
|
|
341
|
+
"success": True,
|
|
342
|
+
"table": table_name,
|
|
343
|
+
"columns": columns,
|
|
344
|
+
"column_count": len(columns)
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
async def _tool_execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
348
|
+
"""Tool handler for execute_sql"""
|
|
349
|
+
sql = args.get("sql")
|
|
350
|
+
params = args.get("params")
|
|
351
|
+
|
|
352
|
+
affected_rows = await self.execute(sql, params)
|
|
353
|
+
|
|
354
|
+
return {
|
|
355
|
+
"success": True,
|
|
356
|
+
"affected_rows": affected_rows
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def mysql(**kwargs) -> MySQLPlugin:
|
|
361
|
+
"""Create MySQL plugin with simplified interface."""
|
|
362
|
+
return MySQLPlugin(**kwargs)
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PostgreSQL plugin for Daita Agents.
|
|
3
|
+
|
|
4
|
+
Simple database connection and querying - no over-engineering.
|
|
5
|
+
"""
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
|
8
|
+
from .base_db import BaseDatabasePlugin
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..core.tools import AgentTool
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
class PostgreSQLPlugin(BaseDatabasePlugin):
|
|
16
|
+
"""
|
|
17
|
+
PostgreSQL plugin for agents with standardized connection management.
|
|
18
|
+
|
|
19
|
+
Inherits common database functionality from BaseDatabasePlugin.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
host: str = "localhost",
|
|
25
|
+
port: int = 5432,
|
|
26
|
+
database: str = "",
|
|
27
|
+
username: str = "",
|
|
28
|
+
user: Optional[str] = None, # Add this
|
|
29
|
+
password: str = "",
|
|
30
|
+
connection_string: Optional[str] = None,
|
|
31
|
+
**kwargs
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Initialize PostgreSQL connection.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
host: Database host
|
|
38
|
+
port: Database port
|
|
39
|
+
database: Database name
|
|
40
|
+
username: Username
|
|
41
|
+
user: Username (alias for username)
|
|
42
|
+
password: Password
|
|
43
|
+
connection_string: Full connection string (overrides individual params)
|
|
44
|
+
**kwargs: Additional asyncpg parameters
|
|
45
|
+
"""
|
|
46
|
+
# Use 'user' parameter as alias for 'username' if provided
|
|
47
|
+
effective_username = user if user is not None else username
|
|
48
|
+
|
|
49
|
+
# Build connection string
|
|
50
|
+
if connection_string:
|
|
51
|
+
self.connection_string = connection_string
|
|
52
|
+
else:
|
|
53
|
+
self.connection_string = f"postgresql://{effective_username}:{password}@{host}:{port}/{database}"
|
|
54
|
+
|
|
55
|
+
# PostgreSQL-specific pool configuration
|
|
56
|
+
self.pool_config = {
|
|
57
|
+
'min_size': kwargs.get('min_size', 1),
|
|
58
|
+
'max_size': kwargs.get('max_size', 10),
|
|
59
|
+
'command_timeout': kwargs.get('command_timeout', 60),
|
|
60
|
+
'statement_cache_size': kwargs.get('statement_cache_size', 0), # Set to 0 for pgbouncer compatibility
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
# Initialize base class with all config
|
|
64
|
+
super().__init__(
|
|
65
|
+
host=host, port=port, database=database,
|
|
66
|
+
username=effective_username, connection_string=connection_string,
|
|
67
|
+
**kwargs
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
logger.debug(f"PostgreSQL plugin configured for {host}:{port}/{database}")
|
|
71
|
+
|
|
72
|
+
async def connect(self):
|
|
73
|
+
"""Connect to PostgreSQL database."""
|
|
74
|
+
if self._pool is not None:
|
|
75
|
+
return # Already connected
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
import asyncpg
|
|
79
|
+
self._pool = await asyncpg.create_pool(
|
|
80
|
+
self.connection_string,
|
|
81
|
+
**self.pool_config
|
|
82
|
+
)
|
|
83
|
+
logger.info("Connected to PostgreSQL")
|
|
84
|
+
except ImportError:
|
|
85
|
+
self._handle_connection_error(
|
|
86
|
+
ImportError("asyncpg not installed. Run: pip install asyncpg"),
|
|
87
|
+
"connection"
|
|
88
|
+
)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
self._handle_connection_error(e, "connection")
|
|
91
|
+
|
|
92
|
+
async def disconnect(self):
|
|
93
|
+
"""Disconnect from PostgreSQL database."""
|
|
94
|
+
if self._pool:
|
|
95
|
+
await self._pool.close()
|
|
96
|
+
self._pool = None
|
|
97
|
+
logger.info("Disconnected from PostgreSQL")
|
|
98
|
+
|
|
99
|
+
async def query(self, sql: str, params: Optional[List] = None) -> List[Dict[str, Any]]:
|
|
100
|
+
"""
|
|
101
|
+
Run a SELECT query and return results.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
sql: SQL query with $1, $2, etc. placeholders
|
|
105
|
+
params: List of parameters for the query
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
List of rows as dictionaries
|
|
109
|
+
|
|
110
|
+
Example:
|
|
111
|
+
results = await db.query("SELECT * FROM users WHERE age > $1", [25])
|
|
112
|
+
"""
|
|
113
|
+
# Only auto-connect if pool is None - allows manual mocking
|
|
114
|
+
if self._pool is None:
|
|
115
|
+
await self.connect()
|
|
116
|
+
|
|
117
|
+
async with self._pool.acquire() as conn:
|
|
118
|
+
if params:
|
|
119
|
+
rows = await conn.fetch(sql, *params)
|
|
120
|
+
else:
|
|
121
|
+
rows = await conn.fetch(sql)
|
|
122
|
+
|
|
123
|
+
return [dict(row) for row in rows]
|
|
124
|
+
|
|
125
|
+
async def execute(self, sql: str, params: Optional[List] = None) -> int:
|
|
126
|
+
"""
|
|
127
|
+
Execute INSERT/UPDATE/DELETE and return affected rows.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
sql: SQL statement
|
|
131
|
+
params: List of parameters
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Number of affected rows
|
|
135
|
+
"""
|
|
136
|
+
# Only auto-connect if pool is None - allows manual mocking
|
|
137
|
+
if self._pool is None:
|
|
138
|
+
await self.connect()
|
|
139
|
+
|
|
140
|
+
async with self._pool.acquire() as conn:
|
|
141
|
+
if params:
|
|
142
|
+
result = await conn.execute(sql, *params)
|
|
143
|
+
else:
|
|
144
|
+
result = await conn.execute(sql)
|
|
145
|
+
|
|
146
|
+
# Extract number from result like "INSERT 0 5"
|
|
147
|
+
return int(result.split()[-1]) if result else 0
|
|
148
|
+
|
|
149
|
+
async def insert_many(self, table: str, data: List[Dict[str, Any]]) -> int:
|
|
150
|
+
"""
|
|
151
|
+
Bulk insert data into a table.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
table: Table name
|
|
155
|
+
data: List of dictionaries to insert
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Number of rows inserted
|
|
159
|
+
"""
|
|
160
|
+
if not data:
|
|
161
|
+
return 0
|
|
162
|
+
|
|
163
|
+
# Only auto-connect if pool is None - allows manual mocking
|
|
164
|
+
if self._pool is None:
|
|
165
|
+
await self.connect()
|
|
166
|
+
|
|
167
|
+
# Get columns from first row
|
|
168
|
+
columns = list(data[0].keys())
|
|
169
|
+
placeholders = ', '.join([f'${i+1}' for i in range(len(columns))])
|
|
170
|
+
|
|
171
|
+
sql = f"INSERT INTO {table} ({', '.join(columns)}) VALUES ({placeholders})"
|
|
172
|
+
|
|
173
|
+
# Convert to list of tuples for executemany
|
|
174
|
+
rows = [[row[col] for col in columns] for row in data]
|
|
175
|
+
|
|
176
|
+
async with self._pool.acquire() as conn:
|
|
177
|
+
await conn.executemany(sql, rows)
|
|
178
|
+
|
|
179
|
+
return len(data)
|
|
180
|
+
|
|
181
|
+
async def tables(self) -> List[str]:
|
|
182
|
+
"""List all tables in the database."""
|
|
183
|
+
sql = """
|
|
184
|
+
SELECT table_name
|
|
185
|
+
FROM information_schema.tables
|
|
186
|
+
WHERE table_schema = 'public'
|
|
187
|
+
ORDER BY table_name
|
|
188
|
+
"""
|
|
189
|
+
results = await self.query(sql)
|
|
190
|
+
return [row['table_name'] for row in results]
|
|
191
|
+
|
|
192
|
+
def get_tools(self) -> List['AgentTool']:
|
|
193
|
+
"""
|
|
194
|
+
Expose PostgreSQL operations as agent tools.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
List of AgentTool instances for database operations
|
|
198
|
+
"""
|
|
199
|
+
from ..core.tools import AgentTool
|
|
200
|
+
|
|
201
|
+
return [
|
|
202
|
+
AgentTool(
|
|
203
|
+
name="query_database",
|
|
204
|
+
description="Execute a SQL SELECT query on the PostgreSQL database and return results as a list of dictionaries",
|
|
205
|
+
parameters={
|
|
206
|
+
"type": "object",
|
|
207
|
+
"properties": {
|
|
208
|
+
"sql": {
|
|
209
|
+
"type": "string",
|
|
210
|
+
"description": "SQL SELECT query with $1, $2, etc. placeholders for parameters"
|
|
211
|
+
},
|
|
212
|
+
"params": {
|
|
213
|
+
"type": "array",
|
|
214
|
+
"description": "Optional list of parameter values for query placeholders",
|
|
215
|
+
"items": {"type": "string"}
|
|
216
|
+
}
|
|
217
|
+
},
|
|
218
|
+
"required": ["sql"]
|
|
219
|
+
},
|
|
220
|
+
handler=self._tool_query,
|
|
221
|
+
category="database",
|
|
222
|
+
source="plugin",
|
|
223
|
+
plugin_name="PostgreSQL",
|
|
224
|
+
timeout_seconds=60
|
|
225
|
+
),
|
|
226
|
+
AgentTool(
|
|
227
|
+
name="list_tables",
|
|
228
|
+
description="List all tables in the PostgreSQL database",
|
|
229
|
+
parameters={
|
|
230
|
+
"type": "object",
|
|
231
|
+
"properties": {},
|
|
232
|
+
"required": []
|
|
233
|
+
},
|
|
234
|
+
handler=self._tool_list_tables,
|
|
235
|
+
category="database",
|
|
236
|
+
source="plugin",
|
|
237
|
+
plugin_name="PostgreSQL",
|
|
238
|
+
timeout_seconds=30
|
|
239
|
+
),
|
|
240
|
+
AgentTool(
|
|
241
|
+
name="get_table_schema",
|
|
242
|
+
description="Get column information (name, data type, nullable) for a specific table in PostgreSQL",
|
|
243
|
+
parameters={
|
|
244
|
+
"type": "object",
|
|
245
|
+
"properties": {
|
|
246
|
+
"table_name": {
|
|
247
|
+
"type": "string",
|
|
248
|
+
"description": "Name of the table to inspect"
|
|
249
|
+
}
|
|
250
|
+
},
|
|
251
|
+
"required": ["table_name"]
|
|
252
|
+
},
|
|
253
|
+
handler=self._tool_get_schema,
|
|
254
|
+
category="database",
|
|
255
|
+
source="plugin",
|
|
256
|
+
plugin_name="PostgreSQL",
|
|
257
|
+
timeout_seconds=30
|
|
258
|
+
),
|
|
259
|
+
AgentTool(
|
|
260
|
+
name="execute_sql",
|
|
261
|
+
description="Execute an INSERT, UPDATE, or DELETE SQL statement on PostgreSQL. Returns the number of affected rows.",
|
|
262
|
+
parameters={
|
|
263
|
+
"type": "object",
|
|
264
|
+
"properties": {
|
|
265
|
+
"sql": {
|
|
266
|
+
"type": "string",
|
|
267
|
+
"description": "SQL statement to execute (INSERT, UPDATE, or DELETE)"
|
|
268
|
+
},
|
|
269
|
+
"params": {
|
|
270
|
+
"type": "array",
|
|
271
|
+
"description": "Optional list of parameter values for statement placeholders",
|
|
272
|
+
"items": {"type": "string"}
|
|
273
|
+
}
|
|
274
|
+
},
|
|
275
|
+
"required": ["sql"]
|
|
276
|
+
},
|
|
277
|
+
handler=self._tool_execute,
|
|
278
|
+
category="database",
|
|
279
|
+
source="plugin",
|
|
280
|
+
plugin_name="PostgreSQL",
|
|
281
|
+
timeout_seconds=60
|
|
282
|
+
)
|
|
283
|
+
]
|
|
284
|
+
|
|
285
|
+
async def _tool_query(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
286
|
+
"""Tool handler for query_database"""
|
|
287
|
+
sql = args.get("sql")
|
|
288
|
+
params = args.get("params")
|
|
289
|
+
|
|
290
|
+
results = await self.query(sql, params)
|
|
291
|
+
|
|
292
|
+
return {
|
|
293
|
+
"success": True,
|
|
294
|
+
"rows": results,
|
|
295
|
+
"row_count": len(results)
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
async def _tool_list_tables(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
299
|
+
"""Tool handler for list_tables"""
|
|
300
|
+
tables = await self.tables()
|
|
301
|
+
|
|
302
|
+
return {
|
|
303
|
+
"success": True,
|
|
304
|
+
"tables": tables,
|
|
305
|
+
"count": len(tables)
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
async def _tool_get_schema(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
309
|
+
"""Tool handler for get_table_schema"""
|
|
310
|
+
table_name = args.get("table_name")
|
|
311
|
+
|
|
312
|
+
schema_query = """
|
|
313
|
+
SELECT column_name, data_type, is_nullable
|
|
314
|
+
FROM information_schema.columns
|
|
315
|
+
WHERE table_name = $1
|
|
316
|
+
ORDER BY ordinal_position
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
columns = await self.query(schema_query, [table_name])
|
|
320
|
+
|
|
321
|
+
return {
|
|
322
|
+
"success": True,
|
|
323
|
+
"table": table_name,
|
|
324
|
+
"columns": columns,
|
|
325
|
+
"column_count": len(columns)
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
async def _tool_execute(self, args: Dict[str, Any]) -> Dict[str, Any]:
|
|
329
|
+
"""Tool handler for execute_sql"""
|
|
330
|
+
sql = args.get("sql")
|
|
331
|
+
params = args.get("params")
|
|
332
|
+
|
|
333
|
+
affected_rows = await self.execute(sql, params)
|
|
334
|
+
|
|
335
|
+
return {
|
|
336
|
+
"success": True,
|
|
337
|
+
"affected_rows": affected_rows
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
def postgresql(**kwargs) -> PostgreSQLPlugin:
|
|
341
|
+
"""Create PostgreSQL plugin with simplified interface."""
|
|
342
|
+
return PostgreSQLPlugin(**kwargs)
|