amd-gaia 0.15.0__py3-none-any.whl → 0.15.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/METADATA +223 -223
- amd_gaia-0.15.1.dist-info/RECORD +178 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/entry_points.txt +1 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/licenses/LICENSE.md +20 -20
- gaia/__init__.py +29 -29
- gaia/agents/__init__.py +19 -19
- gaia/agents/base/__init__.py +9 -9
- gaia/agents/base/agent.py +2177 -2177
- gaia/agents/base/api_agent.py +120 -120
- gaia/agents/base/console.py +1841 -1841
- gaia/agents/base/errors.py +237 -237
- gaia/agents/base/mcp_agent.py +86 -86
- gaia/agents/base/tools.py +83 -83
- gaia/agents/blender/agent.py +556 -556
- gaia/agents/blender/agent_simple.py +133 -135
- gaia/agents/blender/app.py +211 -211
- gaia/agents/blender/app_simple.py +41 -41
- gaia/agents/blender/core/__init__.py +16 -16
- gaia/agents/blender/core/materials.py +506 -506
- gaia/agents/blender/core/objects.py +316 -316
- gaia/agents/blender/core/rendering.py +225 -225
- gaia/agents/blender/core/scene.py +220 -220
- gaia/agents/blender/core/view.py +146 -146
- gaia/agents/chat/__init__.py +9 -9
- gaia/agents/chat/agent.py +835 -835
- gaia/agents/chat/app.py +1058 -1058
- gaia/agents/chat/session.py +508 -508
- gaia/agents/chat/tools/__init__.py +15 -15
- gaia/agents/chat/tools/file_tools.py +96 -96
- gaia/agents/chat/tools/rag_tools.py +1729 -1729
- gaia/agents/chat/tools/shell_tools.py +436 -436
- gaia/agents/code/__init__.py +7 -7
- gaia/agents/code/agent.py +549 -549
- gaia/agents/code/cli.py +377 -0
- gaia/agents/code/models.py +135 -135
- gaia/agents/code/orchestration/__init__.py +24 -24
- gaia/agents/code/orchestration/checklist_executor.py +1763 -1763
- gaia/agents/code/orchestration/checklist_generator.py +713 -713
- gaia/agents/code/orchestration/factories/__init__.py +9 -9
- gaia/agents/code/orchestration/factories/base.py +63 -63
- gaia/agents/code/orchestration/factories/nextjs_factory.py +118 -118
- gaia/agents/code/orchestration/factories/python_factory.py +106 -106
- gaia/agents/code/orchestration/orchestrator.py +841 -841
- gaia/agents/code/orchestration/project_analyzer.py +391 -391
- gaia/agents/code/orchestration/steps/__init__.py +67 -67
- gaia/agents/code/orchestration/steps/base.py +188 -188
- gaia/agents/code/orchestration/steps/error_handler.py +314 -314
- gaia/agents/code/orchestration/steps/nextjs.py +828 -828
- gaia/agents/code/orchestration/steps/python.py +307 -307
- gaia/agents/code/orchestration/template_catalog.py +469 -469
- gaia/agents/code/orchestration/workflows/__init__.py +14 -14
- gaia/agents/code/orchestration/workflows/base.py +80 -80
- gaia/agents/code/orchestration/workflows/nextjs.py +186 -186
- gaia/agents/code/orchestration/workflows/python.py +94 -94
- gaia/agents/code/prompts/__init__.py +11 -11
- gaia/agents/code/prompts/base_prompt.py +77 -77
- gaia/agents/code/prompts/code_patterns.py +2036 -2036
- gaia/agents/code/prompts/nextjs_prompt.py +40 -40
- gaia/agents/code/prompts/python_prompt.py +109 -109
- gaia/agents/code/schema_inference.py +365 -365
- gaia/agents/code/system_prompt.py +41 -41
- gaia/agents/code/tools/__init__.py +42 -42
- gaia/agents/code/tools/cli_tools.py +1138 -1138
- gaia/agents/code/tools/code_formatting.py +319 -319
- gaia/agents/code/tools/code_tools.py +769 -769
- gaia/agents/code/tools/error_fixing.py +1347 -1347
- gaia/agents/code/tools/external_tools.py +180 -180
- gaia/agents/code/tools/file_io.py +845 -845
- gaia/agents/code/tools/prisma_tools.py +190 -190
- gaia/agents/code/tools/project_management.py +1016 -1016
- gaia/agents/code/tools/testing.py +321 -321
- gaia/agents/code/tools/typescript_tools.py +122 -122
- gaia/agents/code/tools/validation_parsing.py +461 -461
- gaia/agents/code/tools/validation_tools.py +806 -806
- gaia/agents/code/tools/web_dev_tools.py +1758 -1758
- gaia/agents/code/validators/__init__.py +16 -16
- gaia/agents/code/validators/antipattern_checker.py +241 -241
- gaia/agents/code/validators/ast_analyzer.py +197 -197
- gaia/agents/code/validators/requirements_validator.py +145 -145
- gaia/agents/code/validators/syntax_validator.py +171 -171
- gaia/agents/docker/__init__.py +7 -7
- gaia/agents/docker/agent.py +642 -642
- gaia/agents/emr/__init__.py +8 -8
- gaia/agents/emr/agent.py +1506 -1506
- gaia/agents/emr/cli.py +1322 -1322
- gaia/agents/emr/constants.py +475 -475
- gaia/agents/emr/dashboard/__init__.py +4 -4
- gaia/agents/emr/dashboard/server.py +1974 -1974
- gaia/agents/jira/__init__.py +11 -11
- gaia/agents/jira/agent.py +894 -894
- gaia/agents/jira/jql_templates.py +299 -299
- gaia/agents/routing/__init__.py +7 -7
- gaia/agents/routing/agent.py +567 -570
- gaia/agents/routing/system_prompt.py +75 -75
- gaia/agents/summarize/__init__.py +11 -0
- gaia/agents/summarize/agent.py +885 -0
- gaia/agents/summarize/prompts.py +129 -0
- gaia/api/__init__.py +23 -23
- gaia/api/agent_registry.py +238 -238
- gaia/api/app.py +305 -305
- gaia/api/openai_server.py +575 -575
- gaia/api/schemas.py +186 -186
- gaia/api/sse_handler.py +373 -373
- gaia/apps/__init__.py +4 -4
- gaia/apps/llm/__init__.py +6 -6
- gaia/apps/llm/app.py +173 -169
- gaia/apps/summarize/app.py +116 -633
- gaia/apps/summarize/html_viewer.py +133 -133
- gaia/apps/summarize/pdf_formatter.py +284 -284
- gaia/audio/__init__.py +2 -2
- gaia/audio/audio_client.py +439 -439
- gaia/audio/audio_recorder.py +269 -269
- gaia/audio/kokoro_tts.py +599 -599
- gaia/audio/whisper_asr.py +432 -432
- gaia/chat/__init__.py +16 -16
- gaia/chat/app.py +430 -430
- gaia/chat/prompts.py +522 -522
- gaia/chat/sdk.py +1228 -1225
- gaia/cli.py +5481 -5632
- gaia/database/__init__.py +10 -10
- gaia/database/agent.py +176 -176
- gaia/database/mixin.py +290 -290
- gaia/database/testing.py +64 -64
- gaia/eval/batch_experiment.py +2332 -2332
- gaia/eval/claude.py +542 -542
- gaia/eval/config.py +37 -37
- gaia/eval/email_generator.py +512 -512
- gaia/eval/eval.py +3179 -3179
- gaia/eval/groundtruth.py +1130 -1130
- gaia/eval/transcript_generator.py +582 -582
- gaia/eval/webapp/README.md +167 -167
- gaia/eval/webapp/package-lock.json +875 -875
- gaia/eval/webapp/package.json +20 -20
- gaia/eval/webapp/public/app.js +3402 -3402
- gaia/eval/webapp/public/index.html +87 -87
- gaia/eval/webapp/public/styles.css +3661 -3661
- gaia/eval/webapp/server.js +415 -415
- gaia/eval/webapp/test-setup.js +72 -72
- gaia/llm/__init__.py +9 -2
- gaia/llm/base_client.py +60 -0
- gaia/llm/exceptions.py +12 -0
- gaia/llm/factory.py +70 -0
- gaia/llm/lemonade_client.py +3236 -3221
- gaia/llm/lemonade_manager.py +294 -294
- gaia/llm/providers/__init__.py +9 -0
- gaia/llm/providers/claude.py +108 -0
- gaia/llm/providers/lemonade.py +120 -0
- gaia/llm/providers/openai_provider.py +79 -0
- gaia/llm/vlm_client.py +382 -382
- gaia/logger.py +189 -189
- gaia/mcp/agent_mcp_server.py +245 -245
- gaia/mcp/blender_mcp_client.py +138 -138
- gaia/mcp/blender_mcp_server.py +648 -648
- gaia/mcp/context7_cache.py +332 -332
- gaia/mcp/external_services.py +518 -518
- gaia/mcp/mcp_bridge.py +811 -550
- gaia/mcp/servers/__init__.py +6 -6
- gaia/mcp/servers/docker_mcp.py +83 -83
- gaia/perf_analysis.py +361 -0
- gaia/rag/__init__.py +10 -10
- gaia/rag/app.py +293 -293
- gaia/rag/demo.py +304 -304
- gaia/rag/pdf_utils.py +235 -235
- gaia/rag/sdk.py +2194 -2194
- gaia/security.py +163 -163
- gaia/talk/app.py +289 -289
- gaia/talk/sdk.py +538 -538
- gaia/testing/__init__.py +87 -87
- gaia/testing/assertions.py +330 -330
- gaia/testing/fixtures.py +333 -333
- gaia/testing/mocks.py +493 -493
- gaia/util.py +46 -46
- gaia/utils/__init__.py +33 -33
- gaia/utils/file_watcher.py +675 -675
- gaia/utils/parsing.py +223 -223
- gaia/version.py +100 -100
- amd_gaia-0.15.0.dist-info/RECORD +0 -168
- gaia/agents/code/app.py +0 -266
- gaia/llm/llm_client.py +0 -723
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/WHEEL +0 -0
- {amd_gaia-0.15.0.dist-info → amd_gaia-0.15.1.dist-info}/top_level.txt +0 -0
gaia/database/mixin.py
CHANGED
|
@@ -1,290 +1,290 @@
|
|
|
1
|
-
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
-
# SPDX-License-Identifier: MIT
|
|
3
|
-
|
|
4
|
-
"""SQLite database mixin for GAIA agents."""
|
|
5
|
-
|
|
6
|
-
import logging
|
|
7
|
-
import sqlite3
|
|
8
|
-
from contextlib import contextmanager
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from typing import Any, Dict, List, Optional, Union
|
|
11
|
-
|
|
12
|
-
logger = logging.getLogger(__name__)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class DatabaseMixin:
|
|
16
|
-
"""
|
|
17
|
-
Mixin providing SQLite database access for GAIA agents.
|
|
18
|
-
|
|
19
|
-
A lean, zero-dependency mixin that uses Python's built-in sqlite3 module.
|
|
20
|
-
|
|
21
|
-
Example:
|
|
22
|
-
class MyAgent(Agent, DatabaseMixin):
|
|
23
|
-
def __init__(self, **kwargs):
|
|
24
|
-
super().__init__(**kwargs)
|
|
25
|
-
self.init_db("data/app.db")
|
|
26
|
-
|
|
27
|
-
if not self.table_exists("items"):
|
|
28
|
-
self.execute('''
|
|
29
|
-
CREATE TABLE items (
|
|
30
|
-
id INTEGER PRIMARY KEY,
|
|
31
|
-
name TEXT NOT NULL
|
|
32
|
-
)
|
|
33
|
-
''')
|
|
34
|
-
|
|
35
|
-
def _register_tools(self):
|
|
36
|
-
@tool
|
|
37
|
-
def add_item(name: str) -> dict:
|
|
38
|
-
item_id = self.insert("items", {"name": name})
|
|
39
|
-
return {"id": item_id}
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
_db: Optional[sqlite3.Connection] = None
|
|
43
|
-
_in_tx: bool = False
|
|
44
|
-
|
|
45
|
-
def init_db(self, path: str = ":memory:") -> None:
|
|
46
|
-
"""
|
|
47
|
-
Initialize SQLite database.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
path: Database file path, or ":memory:" for in-memory database.
|
|
51
|
-
Parent directories are created automatically.
|
|
52
|
-
|
|
53
|
-
Example:
|
|
54
|
-
self.init_db("data/myagent.db") # File-based
|
|
55
|
-
self.init_db() # In-memory (for testing)
|
|
56
|
-
"""
|
|
57
|
-
if self._db:
|
|
58
|
-
self.close_db()
|
|
59
|
-
|
|
60
|
-
if path != ":memory:":
|
|
61
|
-
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
62
|
-
|
|
63
|
-
self._db = sqlite3.connect(path, check_same_thread=False)
|
|
64
|
-
self._db.row_factory = sqlite3.Row
|
|
65
|
-
self._db.execute("PRAGMA foreign_keys = ON")
|
|
66
|
-
self._in_tx = False
|
|
67
|
-
logger.info("Database initialized: %s", path)
|
|
68
|
-
|
|
69
|
-
def close_db(self) -> None:
|
|
70
|
-
"""
|
|
71
|
-
Close database connection.
|
|
72
|
-
|
|
73
|
-
Safe to call multiple times.
|
|
74
|
-
"""
|
|
75
|
-
if self._db:
|
|
76
|
-
self._db.close()
|
|
77
|
-
self._db = None
|
|
78
|
-
self._in_tx = False
|
|
79
|
-
|
|
80
|
-
@property
|
|
81
|
-
def db_ready(self) -> bool:
|
|
82
|
-
"""True if database is initialized."""
|
|
83
|
-
return self._db is not None
|
|
84
|
-
|
|
85
|
-
def _require_db(self) -> None:
|
|
86
|
-
"""Raise RuntimeError if database not initialized."""
|
|
87
|
-
if not self._db:
|
|
88
|
-
raise RuntimeError("Database not initialized. Call init_db() first.")
|
|
89
|
-
|
|
90
|
-
def query(
|
|
91
|
-
self,
|
|
92
|
-
sql: str,
|
|
93
|
-
params: Optional[Dict[str, Any]] = None,
|
|
94
|
-
one: bool = False,
|
|
95
|
-
) -> Union[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
96
|
-
"""
|
|
97
|
-
Execute SELECT query and return results as dicts.
|
|
98
|
-
|
|
99
|
-
Args:
|
|
100
|
-
sql: SQL query with :param_name placeholders
|
|
101
|
-
params: Dictionary of parameter values
|
|
102
|
-
one: If True, return single row dict or None
|
|
103
|
-
|
|
104
|
-
Returns:
|
|
105
|
-
List of row dicts, or single dict/None if one=True
|
|
106
|
-
|
|
107
|
-
Example:
|
|
108
|
-
# Get all
|
|
109
|
-
users = self.query("SELECT * FROM users")
|
|
110
|
-
|
|
111
|
-
# Get one
|
|
112
|
-
user = self.query(
|
|
113
|
-
"SELECT * FROM users WHERE id = :id",
|
|
114
|
-
{"id": 42},
|
|
115
|
-
one=True
|
|
116
|
-
)
|
|
117
|
-
"""
|
|
118
|
-
self._require_db()
|
|
119
|
-
cursor = self._db.execute(sql, params or {})
|
|
120
|
-
rows = [dict(row) for row in cursor.fetchall()]
|
|
121
|
-
if one:
|
|
122
|
-
return rows[0] if rows else None
|
|
123
|
-
return rows
|
|
124
|
-
|
|
125
|
-
def insert(self, table: str, data: Dict[str, Any]) -> int:
|
|
126
|
-
"""
|
|
127
|
-
Insert a row and return its ID.
|
|
128
|
-
|
|
129
|
-
Args:
|
|
130
|
-
table: Table name
|
|
131
|
-
data: Column-value dictionary
|
|
132
|
-
|
|
133
|
-
Returns:
|
|
134
|
-
The inserted row's ID (lastrowid)
|
|
135
|
-
|
|
136
|
-
Example:
|
|
137
|
-
user_id = self.insert("users", {
|
|
138
|
-
"name": "Alice",
|
|
139
|
-
"email": "alice@example.com"
|
|
140
|
-
})
|
|
141
|
-
"""
|
|
142
|
-
self._require_db()
|
|
143
|
-
cols = ", ".join(data.keys())
|
|
144
|
-
placeholders = ", ".join(f":{k}" for k in data.keys())
|
|
145
|
-
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
|
|
146
|
-
cursor = self._db.execute(sql, data)
|
|
147
|
-
if not self._in_tx:
|
|
148
|
-
self._db.commit()
|
|
149
|
-
return cursor.lastrowid
|
|
150
|
-
|
|
151
|
-
def update(
|
|
152
|
-
self,
|
|
153
|
-
table: str,
|
|
154
|
-
data: Dict[str, Any],
|
|
155
|
-
where: str,
|
|
156
|
-
params: Dict[str, Any],
|
|
157
|
-
) -> int:
|
|
158
|
-
"""
|
|
159
|
-
Update rows matching condition and return affected count.
|
|
160
|
-
|
|
161
|
-
Args:
|
|
162
|
-
table: Table name
|
|
163
|
-
data: Column-value dictionary to update
|
|
164
|
-
where: WHERE clause with :param placeholders (without WHERE keyword)
|
|
165
|
-
params: Parameters for WHERE clause
|
|
166
|
-
|
|
167
|
-
Returns:
|
|
168
|
-
Number of rows affected
|
|
169
|
-
|
|
170
|
-
Example:
|
|
171
|
-
count = self.update(
|
|
172
|
-
"users",
|
|
173
|
-
{"email": "new@example.com"},
|
|
174
|
-
"id = :id",
|
|
175
|
-
{"id": 42}
|
|
176
|
-
)
|
|
177
|
-
"""
|
|
178
|
-
self._require_db()
|
|
179
|
-
# Prefix data params with __set_ to avoid collision with where params
|
|
180
|
-
set_clause = ", ".join(f"{k} = :__set_{k}" for k in data.keys())
|
|
181
|
-
merged_params = {f"__set_{k}": v for k, v in data.items()}
|
|
182
|
-
merged_params.update(params)
|
|
183
|
-
sql = f"UPDATE {table} SET {set_clause} WHERE {where}"
|
|
184
|
-
cursor = self._db.execute(sql, merged_params)
|
|
185
|
-
if not self._in_tx:
|
|
186
|
-
self._db.commit()
|
|
187
|
-
return cursor.rowcount
|
|
188
|
-
|
|
189
|
-
def delete(self, table: str, where: str, params: Dict[str, Any]) -> int:
|
|
190
|
-
"""
|
|
191
|
-
Delete rows matching condition and return deleted count.
|
|
192
|
-
|
|
193
|
-
Args:
|
|
194
|
-
table: Table name
|
|
195
|
-
where: WHERE clause with :param placeholders (without WHERE keyword)
|
|
196
|
-
params: Parameters for WHERE clause
|
|
197
|
-
|
|
198
|
-
Returns:
|
|
199
|
-
Number of rows deleted
|
|
200
|
-
|
|
201
|
-
Example:
|
|
202
|
-
count = self.delete("sessions", "expires_at < :now", {"now": now})
|
|
203
|
-
"""
|
|
204
|
-
self._require_db()
|
|
205
|
-
sql = f"DELETE FROM {table} WHERE {where}"
|
|
206
|
-
cursor = self._db.execute(sql, params)
|
|
207
|
-
if not self._in_tx:
|
|
208
|
-
self._db.commit()
|
|
209
|
-
return cursor.rowcount
|
|
210
|
-
|
|
211
|
-
@contextmanager
|
|
212
|
-
def transaction(self):
|
|
213
|
-
"""
|
|
214
|
-
Execute operations atomically.
|
|
215
|
-
|
|
216
|
-
Auto-commits on success, rolls back on exception.
|
|
217
|
-
|
|
218
|
-
Example:
|
|
219
|
-
with self.transaction():
|
|
220
|
-
user_id = self.insert("users", {"name": "Alice"})
|
|
221
|
-
self.insert("profiles", {"user_id": user_id, "bio": "Hello"})
|
|
222
|
-
# If any operation fails, all are rolled back
|
|
223
|
-
"""
|
|
224
|
-
self._require_db()
|
|
225
|
-
self._in_tx = True
|
|
226
|
-
try:
|
|
227
|
-
yield
|
|
228
|
-
self._db.commit()
|
|
229
|
-
except Exception:
|
|
230
|
-
self._db.rollback()
|
|
231
|
-
raise
|
|
232
|
-
finally:
|
|
233
|
-
self._in_tx = False
|
|
234
|
-
|
|
235
|
-
def execute(self, sql: str) -> None:
|
|
236
|
-
"""
|
|
237
|
-
Execute raw SQL (CREATE TABLE, etc).
|
|
238
|
-
|
|
239
|
-
Supports multiple statements separated by semicolons.
|
|
240
|
-
|
|
241
|
-
WARNING: Do NOT call inside a transaction() block. This method uses
|
|
242
|
-
executescript() which auto-commits any pending transaction.
|
|
243
|
-
|
|
244
|
-
Args:
|
|
245
|
-
sql: SQL statement(s) to execute
|
|
246
|
-
|
|
247
|
-
Raises:
|
|
248
|
-
RuntimeError: If called inside a transaction() block
|
|
249
|
-
|
|
250
|
-
Example:
|
|
251
|
-
self.execute('''
|
|
252
|
-
CREATE TABLE users (
|
|
253
|
-
id INTEGER PRIMARY KEY,
|
|
254
|
-
name TEXT NOT NULL
|
|
255
|
-
);
|
|
256
|
-
CREATE TABLE posts (
|
|
257
|
-
id INTEGER PRIMARY KEY,
|
|
258
|
-
user_id INTEGER REFERENCES users(id),
|
|
259
|
-
content TEXT
|
|
260
|
-
);
|
|
261
|
-
''')
|
|
262
|
-
"""
|
|
263
|
-
self._require_db()
|
|
264
|
-
if self._in_tx:
|
|
265
|
-
raise RuntimeError(
|
|
266
|
-
"execute() cannot be called inside a transaction() block. "
|
|
267
|
-
"Use query() for SELECT or individual insert/update/delete calls."
|
|
268
|
-
)
|
|
269
|
-
self._db.executescript(sql)
|
|
270
|
-
|
|
271
|
-
def table_exists(self, name: str) -> bool:
|
|
272
|
-
"""
|
|
273
|
-
Check if a table exists in the database.
|
|
274
|
-
|
|
275
|
-
Args:
|
|
276
|
-
name: Table name to check
|
|
277
|
-
|
|
278
|
-
Returns:
|
|
279
|
-
True if table exists, False otherwise
|
|
280
|
-
|
|
281
|
-
Example:
|
|
282
|
-
if not self.table_exists("users"):
|
|
283
|
-
self.execute("CREATE TABLE users (...)")
|
|
284
|
-
"""
|
|
285
|
-
result = self.query(
|
|
286
|
-
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=:name",
|
|
287
|
-
{"name": name},
|
|
288
|
-
one=True,
|
|
289
|
-
)
|
|
290
|
-
return result is not None
|
|
1
|
+
# Copyright(C) 2025-2026 Advanced Micro Devices, Inc. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""SQLite database mixin for GAIA agents."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import sqlite3
|
|
8
|
+
from contextlib import contextmanager
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DatabaseMixin:
|
|
16
|
+
"""
|
|
17
|
+
Mixin providing SQLite database access for GAIA agents.
|
|
18
|
+
|
|
19
|
+
A lean, zero-dependency mixin that uses Python's built-in sqlite3 module.
|
|
20
|
+
|
|
21
|
+
Example:
|
|
22
|
+
class MyAgent(Agent, DatabaseMixin):
|
|
23
|
+
def __init__(self, **kwargs):
|
|
24
|
+
super().__init__(**kwargs)
|
|
25
|
+
self.init_db("data/app.db")
|
|
26
|
+
|
|
27
|
+
if not self.table_exists("items"):
|
|
28
|
+
self.execute('''
|
|
29
|
+
CREATE TABLE items (
|
|
30
|
+
id INTEGER PRIMARY KEY,
|
|
31
|
+
name TEXT NOT NULL
|
|
32
|
+
)
|
|
33
|
+
''')
|
|
34
|
+
|
|
35
|
+
def _register_tools(self):
|
|
36
|
+
@tool
|
|
37
|
+
def add_item(name: str) -> dict:
|
|
38
|
+
item_id = self.insert("items", {"name": name})
|
|
39
|
+
return {"id": item_id}
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
_db: Optional[sqlite3.Connection] = None
|
|
43
|
+
_in_tx: bool = False
|
|
44
|
+
|
|
45
|
+
def init_db(self, path: str = ":memory:") -> None:
|
|
46
|
+
"""
|
|
47
|
+
Initialize SQLite database.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
path: Database file path, or ":memory:" for in-memory database.
|
|
51
|
+
Parent directories are created automatically.
|
|
52
|
+
|
|
53
|
+
Example:
|
|
54
|
+
self.init_db("data/myagent.db") # File-based
|
|
55
|
+
self.init_db() # In-memory (for testing)
|
|
56
|
+
"""
|
|
57
|
+
if self._db:
|
|
58
|
+
self.close_db()
|
|
59
|
+
|
|
60
|
+
if path != ":memory:":
|
|
61
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
|
|
63
|
+
self._db = sqlite3.connect(path, check_same_thread=False)
|
|
64
|
+
self._db.row_factory = sqlite3.Row
|
|
65
|
+
self._db.execute("PRAGMA foreign_keys = ON")
|
|
66
|
+
self._in_tx = False
|
|
67
|
+
logger.info("Database initialized: %s", path)
|
|
68
|
+
|
|
69
|
+
def close_db(self) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Close database connection.
|
|
72
|
+
|
|
73
|
+
Safe to call multiple times.
|
|
74
|
+
"""
|
|
75
|
+
if self._db:
|
|
76
|
+
self._db.close()
|
|
77
|
+
self._db = None
|
|
78
|
+
self._in_tx = False
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def db_ready(self) -> bool:
|
|
82
|
+
"""True if database is initialized."""
|
|
83
|
+
return self._db is not None
|
|
84
|
+
|
|
85
|
+
def _require_db(self) -> None:
|
|
86
|
+
"""Raise RuntimeError if database not initialized."""
|
|
87
|
+
if not self._db:
|
|
88
|
+
raise RuntimeError("Database not initialized. Call init_db() first.")
|
|
89
|
+
|
|
90
|
+
def query(
|
|
91
|
+
self,
|
|
92
|
+
sql: str,
|
|
93
|
+
params: Optional[Dict[str, Any]] = None,
|
|
94
|
+
one: bool = False,
|
|
95
|
+
) -> Union[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
96
|
+
"""
|
|
97
|
+
Execute SELECT query and return results as dicts.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
sql: SQL query with :param_name placeholders
|
|
101
|
+
params: Dictionary of parameter values
|
|
102
|
+
one: If True, return single row dict or None
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List of row dicts, or single dict/None if one=True
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
# Get all
|
|
109
|
+
users = self.query("SELECT * FROM users")
|
|
110
|
+
|
|
111
|
+
# Get one
|
|
112
|
+
user = self.query(
|
|
113
|
+
"SELECT * FROM users WHERE id = :id",
|
|
114
|
+
{"id": 42},
|
|
115
|
+
one=True
|
|
116
|
+
)
|
|
117
|
+
"""
|
|
118
|
+
self._require_db()
|
|
119
|
+
cursor = self._db.execute(sql, params or {})
|
|
120
|
+
rows = [dict(row) for row in cursor.fetchall()]
|
|
121
|
+
if one:
|
|
122
|
+
return rows[0] if rows else None
|
|
123
|
+
return rows
|
|
124
|
+
|
|
125
|
+
def insert(self, table: str, data: Dict[str, Any]) -> int:
|
|
126
|
+
"""
|
|
127
|
+
Insert a row and return its ID.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
table: Table name
|
|
131
|
+
data: Column-value dictionary
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
The inserted row's ID (lastrowid)
|
|
135
|
+
|
|
136
|
+
Example:
|
|
137
|
+
user_id = self.insert("users", {
|
|
138
|
+
"name": "Alice",
|
|
139
|
+
"email": "alice@example.com"
|
|
140
|
+
})
|
|
141
|
+
"""
|
|
142
|
+
self._require_db()
|
|
143
|
+
cols = ", ".join(data.keys())
|
|
144
|
+
placeholders = ", ".join(f":{k}" for k in data.keys())
|
|
145
|
+
sql = f"INSERT INTO {table} ({cols}) VALUES ({placeholders})"
|
|
146
|
+
cursor = self._db.execute(sql, data)
|
|
147
|
+
if not self._in_tx:
|
|
148
|
+
self._db.commit()
|
|
149
|
+
return cursor.lastrowid
|
|
150
|
+
|
|
151
|
+
def update(
|
|
152
|
+
self,
|
|
153
|
+
table: str,
|
|
154
|
+
data: Dict[str, Any],
|
|
155
|
+
where: str,
|
|
156
|
+
params: Dict[str, Any],
|
|
157
|
+
) -> int:
|
|
158
|
+
"""
|
|
159
|
+
Update rows matching condition and return affected count.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
table: Table name
|
|
163
|
+
data: Column-value dictionary to update
|
|
164
|
+
where: WHERE clause with :param placeholders (without WHERE keyword)
|
|
165
|
+
params: Parameters for WHERE clause
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Number of rows affected
|
|
169
|
+
|
|
170
|
+
Example:
|
|
171
|
+
count = self.update(
|
|
172
|
+
"users",
|
|
173
|
+
{"email": "new@example.com"},
|
|
174
|
+
"id = :id",
|
|
175
|
+
{"id": 42}
|
|
176
|
+
)
|
|
177
|
+
"""
|
|
178
|
+
self._require_db()
|
|
179
|
+
# Prefix data params with __set_ to avoid collision with where params
|
|
180
|
+
set_clause = ", ".join(f"{k} = :__set_{k}" for k in data.keys())
|
|
181
|
+
merged_params = {f"__set_{k}": v for k, v in data.items()}
|
|
182
|
+
merged_params.update(params)
|
|
183
|
+
sql = f"UPDATE {table} SET {set_clause} WHERE {where}"
|
|
184
|
+
cursor = self._db.execute(sql, merged_params)
|
|
185
|
+
if not self._in_tx:
|
|
186
|
+
self._db.commit()
|
|
187
|
+
return cursor.rowcount
|
|
188
|
+
|
|
189
|
+
def delete(self, table: str, where: str, params: Dict[str, Any]) -> int:
|
|
190
|
+
"""
|
|
191
|
+
Delete rows matching condition and return deleted count.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
table: Table name
|
|
195
|
+
where: WHERE clause with :param placeholders (without WHERE keyword)
|
|
196
|
+
params: Parameters for WHERE clause
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Number of rows deleted
|
|
200
|
+
|
|
201
|
+
Example:
|
|
202
|
+
count = self.delete("sessions", "expires_at < :now", {"now": now})
|
|
203
|
+
"""
|
|
204
|
+
self._require_db()
|
|
205
|
+
sql = f"DELETE FROM {table} WHERE {where}"
|
|
206
|
+
cursor = self._db.execute(sql, params)
|
|
207
|
+
if not self._in_tx:
|
|
208
|
+
self._db.commit()
|
|
209
|
+
return cursor.rowcount
|
|
210
|
+
|
|
211
|
+
@contextmanager
|
|
212
|
+
def transaction(self):
|
|
213
|
+
"""
|
|
214
|
+
Execute operations atomically.
|
|
215
|
+
|
|
216
|
+
Auto-commits on success, rolls back on exception.
|
|
217
|
+
|
|
218
|
+
Example:
|
|
219
|
+
with self.transaction():
|
|
220
|
+
user_id = self.insert("users", {"name": "Alice"})
|
|
221
|
+
self.insert("profiles", {"user_id": user_id, "bio": "Hello"})
|
|
222
|
+
# If any operation fails, all are rolled back
|
|
223
|
+
"""
|
|
224
|
+
self._require_db()
|
|
225
|
+
self._in_tx = True
|
|
226
|
+
try:
|
|
227
|
+
yield
|
|
228
|
+
self._db.commit()
|
|
229
|
+
except Exception:
|
|
230
|
+
self._db.rollback()
|
|
231
|
+
raise
|
|
232
|
+
finally:
|
|
233
|
+
self._in_tx = False
|
|
234
|
+
|
|
235
|
+
def execute(self, sql: str) -> None:
|
|
236
|
+
"""
|
|
237
|
+
Execute raw SQL (CREATE TABLE, etc).
|
|
238
|
+
|
|
239
|
+
Supports multiple statements separated by semicolons.
|
|
240
|
+
|
|
241
|
+
WARNING: Do NOT call inside a transaction() block. This method uses
|
|
242
|
+
executescript() which auto-commits any pending transaction.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
sql: SQL statement(s) to execute
|
|
246
|
+
|
|
247
|
+
Raises:
|
|
248
|
+
RuntimeError: If called inside a transaction() block
|
|
249
|
+
|
|
250
|
+
Example:
|
|
251
|
+
self.execute('''
|
|
252
|
+
CREATE TABLE users (
|
|
253
|
+
id INTEGER PRIMARY KEY,
|
|
254
|
+
name TEXT NOT NULL
|
|
255
|
+
);
|
|
256
|
+
CREATE TABLE posts (
|
|
257
|
+
id INTEGER PRIMARY KEY,
|
|
258
|
+
user_id INTEGER REFERENCES users(id),
|
|
259
|
+
content TEXT
|
|
260
|
+
);
|
|
261
|
+
''')
|
|
262
|
+
"""
|
|
263
|
+
self._require_db()
|
|
264
|
+
if self._in_tx:
|
|
265
|
+
raise RuntimeError(
|
|
266
|
+
"execute() cannot be called inside a transaction() block. "
|
|
267
|
+
"Use query() for SELECT or individual insert/update/delete calls."
|
|
268
|
+
)
|
|
269
|
+
self._db.executescript(sql)
|
|
270
|
+
|
|
271
|
+
def table_exists(self, name: str) -> bool:
|
|
272
|
+
"""
|
|
273
|
+
Check if a table exists in the database.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
name: Table name to check
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
True if table exists, False otherwise
|
|
280
|
+
|
|
281
|
+
Example:
|
|
282
|
+
if not self.table_exists("users"):
|
|
283
|
+
self.execute("CREATE TABLE users (...)")
|
|
284
|
+
"""
|
|
285
|
+
result = self.query(
|
|
286
|
+
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=:name",
|
|
287
|
+
{"name": name},
|
|
288
|
+
one=True,
|
|
289
|
+
)
|
|
290
|
+
return result is not None
|