codetether 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- a2a_server/__init__.py +29 -0
- a2a_server/a2a_agent_card.py +365 -0
- a2a_server/a2a_errors.py +1133 -0
- a2a_server/a2a_executor.py +926 -0
- a2a_server/a2a_router.py +1033 -0
- a2a_server/a2a_types.py +344 -0
- a2a_server/agent_card.py +408 -0
- a2a_server/agents_server.py +271 -0
- a2a_server/auth_api.py +349 -0
- a2a_server/billing_api.py +638 -0
- a2a_server/billing_service.py +712 -0
- a2a_server/billing_webhooks.py +501 -0
- a2a_server/config.py +96 -0
- a2a_server/database.py +2165 -0
- a2a_server/email_inbound.py +398 -0
- a2a_server/email_notifications.py +486 -0
- a2a_server/enhanced_agents.py +919 -0
- a2a_server/enhanced_server.py +160 -0
- a2a_server/hosted_worker.py +1049 -0
- a2a_server/integrated_agents_server.py +347 -0
- a2a_server/keycloak_auth.py +750 -0
- a2a_server/livekit_bridge.py +439 -0
- a2a_server/marketing_tools.py +1364 -0
- a2a_server/mcp_client.py +196 -0
- a2a_server/mcp_http_server.py +2256 -0
- a2a_server/mcp_server.py +191 -0
- a2a_server/message_broker.py +725 -0
- a2a_server/mock_mcp.py +273 -0
- a2a_server/models.py +494 -0
- a2a_server/monitor_api.py +5904 -0
- a2a_server/opencode_bridge.py +1594 -0
- a2a_server/redis_task_manager.py +518 -0
- a2a_server/server.py +726 -0
- a2a_server/task_manager.py +668 -0
- a2a_server/task_queue.py +742 -0
- a2a_server/tenant_api.py +333 -0
- a2a_server/tenant_middleware.py +219 -0
- a2a_server/tenant_service.py +760 -0
- a2a_server/user_auth.py +721 -0
- a2a_server/vault_client.py +576 -0
- a2a_server/worker_sse.py +873 -0
- agent_worker/__init__.py +8 -0
- agent_worker/worker.py +4877 -0
- codetether/__init__.py +10 -0
- codetether/__main__.py +4 -0
- codetether/cli.py +112 -0
- codetether/worker_cli.py +57 -0
- codetether-1.2.2.dist-info/METADATA +570 -0
- codetether-1.2.2.dist-info/RECORD +66 -0
- codetether-1.2.2.dist-info/WHEEL +5 -0
- codetether-1.2.2.dist-info/entry_points.txt +4 -0
- codetether-1.2.2.dist-info/licenses/LICENSE +202 -0
- codetether-1.2.2.dist-info/top_level.txt +5 -0
- codetether_voice_agent/__init__.py +6 -0
- codetether_voice_agent/agent.py +445 -0
- codetether_voice_agent/codetether_mcp.py +345 -0
- codetether_voice_agent/config.py +16 -0
- codetether_voice_agent/functiongemma_caller.py +380 -0
- codetether_voice_agent/session_playback.py +247 -0
- codetether_voice_agent/tools/__init__.py +21 -0
- codetether_voice_agent/tools/definitions.py +135 -0
- codetether_voice_agent/tools/handlers.py +380 -0
- run_server.py +314 -0
- ui/monitor-tailwind.html +1790 -0
- ui/monitor.html +1775 -0
- ui/monitor.js +2662 -0
a2a_server/database.py
ADDED
|
@@ -0,0 +1,2165 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PostgreSQL database persistence layer for A2A Server.
|
|
3
|
+
|
|
4
|
+
Provides durable storage for workers, codebases, tasks, and sessions
|
|
5
|
+
that survives server restarts and works across multiple replicas.
|
|
6
|
+
|
|
7
|
+
Configuration:
|
|
8
|
+
DATABASE_URL: PostgreSQL connection string
|
|
9
|
+
Format: postgresql://user:password@host:port/database
|
|
10
|
+
Example: postgresql://a2a:secret@localhost:5432/a2a_server
|
|
11
|
+
|
|
12
|
+
Row-Level Security (RLS):
|
|
13
|
+
RLS_ENABLED: Enable database-level tenant isolation (default: false)
|
|
14
|
+
RLS_STRICT_MODE: Require tenant context for all queries (default: false)
|
|
15
|
+
|
|
16
|
+
When RLS is enabled, the database enforces tenant isolation at the row level.
|
|
17
|
+
Use the tenant_scope() context manager or set_tenant_context() to set the
|
|
18
|
+
tenant context before executing queries.
|
|
19
|
+
|
|
20
|
+
See a2a_server/rls.py for RLS utilities and documentation.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import asyncio
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
import uuid
|
|
28
|
+
from contextlib import asynccontextmanager
|
|
29
|
+
from datetime import datetime
|
|
30
|
+
from typing import Any, Dict, List, Optional, Union
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
# Database URL from environment
|
|
35
|
+
# Use config default if no environment variable set
|
|
36
|
+
DEFAULT_DATABASE_URL = (
|
|
37
|
+
'postgresql://postgres:spike2@192.168.50.70:5432/a2a_server'
|
|
38
|
+
)
|
|
39
|
+
DATABASE_URL = os.environ.get(
|
|
40
|
+
'DATABASE_URL', os.environ.get('A2A_DATABASE_URL', DEFAULT_DATABASE_URL)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Module-level state
|
|
44
|
+
_pool = None
|
|
45
|
+
_pool_lock = asyncio.Lock()
|
|
46
|
+
_initialized = False
|
|
47
|
+
|
|
48
|
+
# RLS Configuration (can be overridden by environment)
|
|
49
|
+
RLS_ENABLED = os.environ.get('RLS_ENABLED', 'false').lower() == 'true'
|
|
50
|
+
RLS_STRICT_MODE = os.environ.get('RLS_STRICT_MODE', 'false').lower() == 'true'
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _parse_timestamp(value: Union[str, datetime, None]) -> Optional[datetime]:
|
|
54
|
+
"""Parse a timestamp from string or datetime to datetime object."""
|
|
55
|
+
if value is None:
|
|
56
|
+
return None
|
|
57
|
+
if isinstance(value, datetime):
|
|
58
|
+
return value
|
|
59
|
+
if isinstance(value, str):
|
|
60
|
+
try:
|
|
61
|
+
# Try ISO format first
|
|
62
|
+
return datetime.fromisoformat(value.replace('Z', '+00:00'))
|
|
63
|
+
except ValueError:
|
|
64
|
+
try:
|
|
65
|
+
# Try common formats
|
|
66
|
+
return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%f')
|
|
67
|
+
except ValueError:
|
|
68
|
+
try:
|
|
69
|
+
return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S')
|
|
70
|
+
except ValueError:
|
|
71
|
+
return datetime.utcnow()
|
|
72
|
+
return datetime.utcnow()
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def get_pool():
|
|
76
|
+
"""Get or create the asyncpg connection pool."""
|
|
77
|
+
global _pool, _initialized
|
|
78
|
+
|
|
79
|
+
if _pool is not None:
|
|
80
|
+
return _pool
|
|
81
|
+
|
|
82
|
+
if not DATABASE_URL:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
async with _pool_lock:
|
|
86
|
+
if _pool is not None:
|
|
87
|
+
return _pool
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
import asyncpg
|
|
91
|
+
|
|
92
|
+
_pool = await asyncpg.create_pool(
|
|
93
|
+
DATABASE_URL,
|
|
94
|
+
min_size=1,
|
|
95
|
+
max_size=10,
|
|
96
|
+
command_timeout=30,
|
|
97
|
+
)
|
|
98
|
+
logger.info(f'✓ PostgreSQL connection pool created')
|
|
99
|
+
|
|
100
|
+
# Initialize schema if needed
|
|
101
|
+
if not _initialized:
|
|
102
|
+
await _init_schema()
|
|
103
|
+
_initialized = True
|
|
104
|
+
|
|
105
|
+
# Initialize task queue for hosted workers
|
|
106
|
+
try:
|
|
107
|
+
from .task_queue import TaskQueue, set_task_queue
|
|
108
|
+
|
|
109
|
+
task_queue = TaskQueue(_pool)
|
|
110
|
+
set_task_queue(task_queue)
|
|
111
|
+
logger.info('✓ Task queue initialized for hosted workers')
|
|
112
|
+
except ImportError:
|
|
113
|
+
logger.debug('Task queue module not available')
|
|
114
|
+
except Exception as e:
|
|
115
|
+
logger.warning(f'Failed to initialize task queue: {e}')
|
|
116
|
+
|
|
117
|
+
return _pool
|
|
118
|
+
except ImportError:
|
|
119
|
+
logger.warning(
|
|
120
|
+
'asyncpg not installed, PostgreSQL persistence disabled'
|
|
121
|
+
)
|
|
122
|
+
return None
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f'Failed to create PostgreSQL pool: {e}')
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
async def _init_schema():
|
|
129
|
+
"""Initialize database schema if tables don't exist."""
|
|
130
|
+
pool = await get_pool()
|
|
131
|
+
if not pool:
|
|
132
|
+
return
|
|
133
|
+
|
|
134
|
+
async with pool.acquire() as conn:
|
|
135
|
+
# Tenants table (multi-tenant support)
|
|
136
|
+
await conn.execute("""
|
|
137
|
+
CREATE TABLE IF NOT EXISTS tenants (
|
|
138
|
+
id TEXT PRIMARY KEY,
|
|
139
|
+
realm_name TEXT UNIQUE NOT NULL,
|
|
140
|
+
display_name TEXT,
|
|
141
|
+
plan TEXT DEFAULT 'free',
|
|
142
|
+
stripe_customer_id TEXT,
|
|
143
|
+
stripe_subscription_id TEXT,
|
|
144
|
+
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
145
|
+
updated_at TIMESTAMPTZ DEFAULT NOW()
|
|
146
|
+
)
|
|
147
|
+
""")
|
|
148
|
+
|
|
149
|
+
# Workers table
|
|
150
|
+
await conn.execute("""
|
|
151
|
+
CREATE TABLE IF NOT EXISTS workers (
|
|
152
|
+
worker_id TEXT PRIMARY KEY,
|
|
153
|
+
name TEXT NOT NULL,
|
|
154
|
+
capabilities JSONB DEFAULT '[]'::jsonb,
|
|
155
|
+
hostname TEXT,
|
|
156
|
+
models JSONB DEFAULT '[]'::jsonb,
|
|
157
|
+
global_codebase_id TEXT,
|
|
158
|
+
registered_at TIMESTAMPTZ DEFAULT NOW(),
|
|
159
|
+
last_seen TIMESTAMPTZ DEFAULT NOW(),
|
|
160
|
+
status TEXT DEFAULT 'active',
|
|
161
|
+
tenant_id TEXT REFERENCES tenants(id)
|
|
162
|
+
)
|
|
163
|
+
""")
|
|
164
|
+
|
|
165
|
+
# Migration: Add models column if it doesn't exist
|
|
166
|
+
try:
|
|
167
|
+
await conn.execute(
|
|
168
|
+
"ALTER TABLE workers ADD COLUMN IF NOT EXISTS models JSONB DEFAULT '[]'::jsonb"
|
|
169
|
+
)
|
|
170
|
+
except Exception:
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
# Migration: Add global_codebase_id column if it doesn't exist
|
|
174
|
+
try:
|
|
175
|
+
await conn.execute(
|
|
176
|
+
'ALTER TABLE workers ADD COLUMN IF NOT EXISTS global_codebase_id TEXT'
|
|
177
|
+
)
|
|
178
|
+
except Exception:
|
|
179
|
+
pass
|
|
180
|
+
|
|
181
|
+
# Migration: Add tenant_id column to workers if it doesn't exist
|
|
182
|
+
try:
|
|
183
|
+
await conn.execute(
|
|
184
|
+
'ALTER TABLE workers ADD COLUMN IF NOT EXISTS tenant_id TEXT REFERENCES tenants(id)'
|
|
185
|
+
)
|
|
186
|
+
except Exception:
|
|
187
|
+
pass
|
|
188
|
+
|
|
189
|
+
# Codebases table
|
|
190
|
+
await conn.execute("""
|
|
191
|
+
CREATE TABLE IF NOT EXISTS codebases (
|
|
192
|
+
id TEXT PRIMARY KEY,
|
|
193
|
+
name TEXT NOT NULL,
|
|
194
|
+
path TEXT NOT NULL,
|
|
195
|
+
description TEXT DEFAULT '',
|
|
196
|
+
worker_id TEXT REFERENCES workers(worker_id) ON DELETE SET NULL,
|
|
197
|
+
agent_config JSONB DEFAULT '{}'::jsonb,
|
|
198
|
+
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
199
|
+
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
|
200
|
+
status TEXT DEFAULT 'active',
|
|
201
|
+
session_id TEXT,
|
|
202
|
+
opencode_port INTEGER,
|
|
203
|
+
tenant_id TEXT REFERENCES tenants(id)
|
|
204
|
+
)
|
|
205
|
+
""")
|
|
206
|
+
|
|
207
|
+
# Migration: Add tenant_id column to codebases if it doesn't exist
|
|
208
|
+
try:
|
|
209
|
+
await conn.execute(
|
|
210
|
+
'ALTER TABLE codebases ADD COLUMN IF NOT EXISTS tenant_id TEXT REFERENCES tenants(id)'
|
|
211
|
+
)
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
# Tasks table
|
|
216
|
+
await conn.execute("""
|
|
217
|
+
CREATE TABLE IF NOT EXISTS tasks (
|
|
218
|
+
id TEXT PRIMARY KEY,
|
|
219
|
+
codebase_id TEXT REFERENCES codebases(id) ON DELETE CASCADE,
|
|
220
|
+
title TEXT NOT NULL,
|
|
221
|
+
prompt TEXT NOT NULL,
|
|
222
|
+
agent_type TEXT DEFAULT 'build',
|
|
223
|
+
status TEXT DEFAULT 'pending',
|
|
224
|
+
priority INTEGER DEFAULT 0,
|
|
225
|
+
worker_id TEXT REFERENCES workers(worker_id) ON DELETE SET NULL,
|
|
226
|
+
result TEXT,
|
|
227
|
+
error TEXT,
|
|
228
|
+
metadata JSONB DEFAULT '{}'::jsonb,
|
|
229
|
+
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
230
|
+
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
|
231
|
+
started_at TIMESTAMPTZ,
|
|
232
|
+
completed_at TIMESTAMPTZ,
|
|
233
|
+
tenant_id TEXT REFERENCES tenants(id)
|
|
234
|
+
)
|
|
235
|
+
""")
|
|
236
|
+
|
|
237
|
+
# Migration: Add tenant_id column to tasks if it doesn't exist
|
|
238
|
+
try:
|
|
239
|
+
await conn.execute(
|
|
240
|
+
'ALTER TABLE tasks ADD COLUMN IF NOT EXISTS tenant_id TEXT REFERENCES tenants(id)'
|
|
241
|
+
)
|
|
242
|
+
except Exception:
|
|
243
|
+
pass
|
|
244
|
+
|
|
245
|
+
# Sessions table (for worker-synced OpenCode sessions)
|
|
246
|
+
await conn.execute("""
|
|
247
|
+
CREATE TABLE IF NOT EXISTS sessions (
|
|
248
|
+
id TEXT PRIMARY KEY,
|
|
249
|
+
codebase_id TEXT REFERENCES codebases(id) ON DELETE CASCADE,
|
|
250
|
+
project_id TEXT,
|
|
251
|
+
directory TEXT,
|
|
252
|
+
title TEXT,
|
|
253
|
+
version TEXT,
|
|
254
|
+
summary JSONB DEFAULT '{}'::jsonb,
|
|
255
|
+
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
256
|
+
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
|
257
|
+
tenant_id TEXT REFERENCES tenants(id)
|
|
258
|
+
)
|
|
259
|
+
""")
|
|
260
|
+
|
|
261
|
+
# Migration: Add tenant_id column to sessions if it doesn't exist
|
|
262
|
+
try:
|
|
263
|
+
await conn.execute(
|
|
264
|
+
'ALTER TABLE sessions ADD COLUMN IF NOT EXISTS tenant_id TEXT REFERENCES tenants(id)'
|
|
265
|
+
)
|
|
266
|
+
except Exception:
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
# Messages table (for session messages)
|
|
270
|
+
await conn.execute("""
|
|
271
|
+
CREATE TABLE IF NOT EXISTS session_messages (
|
|
272
|
+
id TEXT PRIMARY KEY,
|
|
273
|
+
session_id TEXT REFERENCES sessions(id) ON DELETE CASCADE,
|
|
274
|
+
role TEXT,
|
|
275
|
+
content TEXT,
|
|
276
|
+
model TEXT,
|
|
277
|
+
cost REAL,
|
|
278
|
+
tokens JSONB DEFAULT '{}'::jsonb,
|
|
279
|
+
tool_calls JSONB DEFAULT '[]'::jsonb,
|
|
280
|
+
created_at TIMESTAMPTZ DEFAULT NOW()
|
|
281
|
+
)
|
|
282
|
+
""")
|
|
283
|
+
|
|
284
|
+
# Monitor messages table (for agent monitoring)
|
|
285
|
+
await conn.execute("""
|
|
286
|
+
CREATE TABLE IF NOT EXISTS monitor_messages (
|
|
287
|
+
id TEXT PRIMARY KEY,
|
|
288
|
+
timestamp TIMESTAMPTZ DEFAULT NOW(),
|
|
289
|
+
type TEXT NOT NULL,
|
|
290
|
+
agent_name TEXT NOT NULL,
|
|
291
|
+
content TEXT NOT NULL,
|
|
292
|
+
metadata JSONB DEFAULT '{}'::jsonb,
|
|
293
|
+
response_time REAL,
|
|
294
|
+
tokens INTEGER,
|
|
295
|
+
error TEXT
|
|
296
|
+
)
|
|
297
|
+
""")
|
|
298
|
+
|
|
299
|
+
# Create indexes
|
|
300
|
+
await conn.execute(
|
|
301
|
+
'CREATE INDEX IF NOT EXISTS idx_workers_status ON workers(status)'
|
|
302
|
+
)
|
|
303
|
+
await conn.execute(
|
|
304
|
+
'CREATE INDEX IF NOT EXISTS idx_workers_last_seen ON workers(last_seen)'
|
|
305
|
+
)
|
|
306
|
+
await conn.execute(
|
|
307
|
+
'CREATE INDEX IF NOT EXISTS idx_codebases_worker ON codebases(worker_id)'
|
|
308
|
+
)
|
|
309
|
+
await conn.execute(
|
|
310
|
+
'CREATE INDEX IF NOT EXISTS idx_codebases_status ON codebases(status)'
|
|
311
|
+
)
|
|
312
|
+
await conn.execute(
|
|
313
|
+
'CREATE INDEX IF NOT EXISTS idx_codebases_path ON codebases(path)'
|
|
314
|
+
)
|
|
315
|
+
await conn.execute(
|
|
316
|
+
'CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)'
|
|
317
|
+
)
|
|
318
|
+
await conn.execute(
|
|
319
|
+
'CREATE INDEX IF NOT EXISTS idx_tasks_codebase ON tasks(codebase_id)'
|
|
320
|
+
)
|
|
321
|
+
await conn.execute(
|
|
322
|
+
'CREATE INDEX IF NOT EXISTS idx_tasks_worker ON tasks(worker_id)'
|
|
323
|
+
)
|
|
324
|
+
await conn.execute(
|
|
325
|
+
'CREATE INDEX IF NOT EXISTS idx_tasks_priority ON tasks(priority DESC, created_at ASC)'
|
|
326
|
+
)
|
|
327
|
+
await conn.execute(
|
|
328
|
+
'CREATE INDEX IF NOT EXISTS idx_sessions_codebase ON sessions(codebase_id)'
|
|
329
|
+
)
|
|
330
|
+
await conn.execute(
|
|
331
|
+
'CREATE INDEX IF NOT EXISTS idx_messages_session ON session_messages(session_id)'
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Tenant indexes
|
|
335
|
+
await conn.execute(
|
|
336
|
+
'CREATE INDEX IF NOT EXISTS idx_tenants_realm ON tenants(realm_name)'
|
|
337
|
+
)
|
|
338
|
+
await conn.execute(
|
|
339
|
+
'CREATE INDEX IF NOT EXISTS idx_workers_tenant ON workers(tenant_id)'
|
|
340
|
+
)
|
|
341
|
+
await conn.execute(
|
|
342
|
+
'CREATE INDEX IF NOT EXISTS idx_codebases_tenant ON codebases(tenant_id)'
|
|
343
|
+
)
|
|
344
|
+
await conn.execute(
|
|
345
|
+
'CREATE INDEX IF NOT EXISTS idx_tasks_tenant ON tasks(tenant_id)'
|
|
346
|
+
)
|
|
347
|
+
await conn.execute(
|
|
348
|
+
'CREATE INDEX IF NOT EXISTS idx_sessions_tenant ON sessions(tenant_id)'
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
logger.info('✓ PostgreSQL schema initialized')
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
async def close_pool():
|
|
355
|
+
"""Close the database connection pool."""
|
|
356
|
+
global _pool
|
|
357
|
+
if _pool:
|
|
358
|
+
await _pool.close()
|
|
359
|
+
_pool = None
|
|
360
|
+
logger.info('PostgreSQL connection pool closed')
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
# ========================================
|
|
364
|
+
# Worker Operations
|
|
365
|
+
# ========================================
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
async def db_upsert_worker(
|
|
369
|
+
worker_info: Dict[str, Any], tenant_id: Optional[str] = None
|
|
370
|
+
) -> bool:
|
|
371
|
+
"""Insert or update a worker in the database.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
worker_info: The worker data dict
|
|
375
|
+
tenant_id: Optional tenant ID for multi-tenant isolation
|
|
376
|
+
"""
|
|
377
|
+
pool = await get_pool()
|
|
378
|
+
if not pool:
|
|
379
|
+
return False
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
# Use provided tenant_id or fall back to worker_info dict
|
|
383
|
+
effective_tenant_id = tenant_id or worker_info.get('tenant_id')
|
|
384
|
+
|
|
385
|
+
async with pool.acquire() as conn:
|
|
386
|
+
await conn.execute(
|
|
387
|
+
"""
|
|
388
|
+
INSERT INTO workers (worker_id, name, capabilities, hostname, models, global_codebase_id, registered_at, last_seen, status, tenant_id)
|
|
389
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
390
|
+
ON CONFLICT (worker_id)
|
|
391
|
+
DO UPDATE SET
|
|
392
|
+
name = EXCLUDED.name,
|
|
393
|
+
capabilities = EXCLUDED.capabilities,
|
|
394
|
+
hostname = EXCLUDED.hostname,
|
|
395
|
+
models = EXCLUDED.models,
|
|
396
|
+
global_codebase_id = EXCLUDED.global_codebase_id,
|
|
397
|
+
last_seen = EXCLUDED.last_seen,
|
|
398
|
+
status = EXCLUDED.status,
|
|
399
|
+
tenant_id = COALESCE(EXCLUDED.tenant_id, workers.tenant_id)
|
|
400
|
+
""",
|
|
401
|
+
worker_info.get('worker_id'),
|
|
402
|
+
worker_info.get('name'),
|
|
403
|
+
json.dumps(worker_info.get('capabilities', [])),
|
|
404
|
+
worker_info.get('hostname'),
|
|
405
|
+
json.dumps(worker_info.get('models', [])),
|
|
406
|
+
worker_info.get('global_codebase_id'),
|
|
407
|
+
_parse_timestamp(worker_info.get('registered_at'))
|
|
408
|
+
or datetime.utcnow(),
|
|
409
|
+
_parse_timestamp(worker_info.get('last_seen'))
|
|
410
|
+
or datetime.utcnow(),
|
|
411
|
+
worker_info.get('status', 'active'),
|
|
412
|
+
effective_tenant_id,
|
|
413
|
+
)
|
|
414
|
+
return True
|
|
415
|
+
except Exception as e:
|
|
416
|
+
logger.error(f'Failed to upsert worker: {e}')
|
|
417
|
+
return False
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
async def db_delete_worker(worker_id: str) -> bool:
|
|
421
|
+
"""Delete a worker from the database."""
|
|
422
|
+
pool = await get_pool()
|
|
423
|
+
if not pool:
|
|
424
|
+
return False
|
|
425
|
+
|
|
426
|
+
try:
|
|
427
|
+
async with pool.acquire() as conn:
|
|
428
|
+
await conn.execute(
|
|
429
|
+
'DELETE FROM workers WHERE worker_id = $1', worker_id
|
|
430
|
+
)
|
|
431
|
+
return True
|
|
432
|
+
except Exception as e:
|
|
433
|
+
logger.error(f'Failed to delete worker: {e}')
|
|
434
|
+
return False
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
async def db_get_worker(worker_id: str) -> Optional[Dict[str, Any]]:
|
|
438
|
+
"""Get a worker by ID."""
|
|
439
|
+
pool = await get_pool()
|
|
440
|
+
if not pool:
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
try:
|
|
444
|
+
async with pool.acquire() as conn:
|
|
445
|
+
row = await conn.fetchrow(
|
|
446
|
+
'SELECT * FROM workers WHERE worker_id = $1', worker_id
|
|
447
|
+
)
|
|
448
|
+
if row:
|
|
449
|
+
return _row_to_worker(row)
|
|
450
|
+
return None
|
|
451
|
+
except Exception as e:
|
|
452
|
+
logger.error(f'Failed to get worker: {e}')
|
|
453
|
+
return None
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
async def db_list_workers(
|
|
457
|
+
status: Optional[str] = None,
|
|
458
|
+
tenant_id: Optional[str] = None,
|
|
459
|
+
) -> List[Dict[str, Any]]:
|
|
460
|
+
"""List all workers, optionally filtered by status or tenant."""
|
|
461
|
+
pool = await get_pool()
|
|
462
|
+
if not pool:
|
|
463
|
+
return []
|
|
464
|
+
|
|
465
|
+
try:
|
|
466
|
+
async with pool.acquire() as conn:
|
|
467
|
+
query = 'SELECT * FROM workers WHERE 1=1'
|
|
468
|
+
params = []
|
|
469
|
+
param_idx = 1
|
|
470
|
+
|
|
471
|
+
if status:
|
|
472
|
+
query += f' AND status = ${param_idx}'
|
|
473
|
+
params.append(status)
|
|
474
|
+
param_idx += 1
|
|
475
|
+
|
|
476
|
+
if tenant_id:
|
|
477
|
+
query += f' AND tenant_id = ${param_idx}'
|
|
478
|
+
params.append(tenant_id)
|
|
479
|
+
param_idx += 1
|
|
480
|
+
|
|
481
|
+
query += ' ORDER BY last_seen DESC'
|
|
482
|
+
|
|
483
|
+
rows = await conn.fetch(query, *params)
|
|
484
|
+
return [_row_to_worker(row) for row in rows]
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.error(f'Failed to list workers: {e}')
|
|
487
|
+
return []
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
async def db_update_worker_heartbeat(worker_id: str) -> bool:
|
|
491
|
+
"""Update worker's last_seen timestamp."""
|
|
492
|
+
pool = await get_pool()
|
|
493
|
+
if not pool:
|
|
494
|
+
return False
|
|
495
|
+
|
|
496
|
+
try:
|
|
497
|
+
async with pool.acquire() as conn:
|
|
498
|
+
result = await conn.execute(
|
|
499
|
+
'UPDATE workers SET last_seen = NOW() WHERE worker_id = $1',
|
|
500
|
+
worker_id,
|
|
501
|
+
)
|
|
502
|
+
return 'UPDATE 1' in result
|
|
503
|
+
except Exception as e:
|
|
504
|
+
logger.error(f'Failed to update worker heartbeat: {e}')
|
|
505
|
+
return False
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def _row_to_worker(row) -> Dict[str, Any]:
|
|
509
|
+
"""Convert a database row to a worker dict."""
|
|
510
|
+
return {
|
|
511
|
+
'worker_id': row['worker_id'],
|
|
512
|
+
'name': row['name'],
|
|
513
|
+
'capabilities': json.loads(row['capabilities'])
|
|
514
|
+
if isinstance(row['capabilities'], str)
|
|
515
|
+
else row['capabilities'],
|
|
516
|
+
'hostname': row['hostname'],
|
|
517
|
+
'models': json.loads(row['models'])
|
|
518
|
+
if isinstance(row['models'], str)
|
|
519
|
+
else row['models'],
|
|
520
|
+
'global_codebase_id': row['global_codebase_id'],
|
|
521
|
+
'registered_at': row['registered_at'].isoformat()
|
|
522
|
+
if row['registered_at']
|
|
523
|
+
else None,
|
|
524
|
+
'last_seen': row['last_seen'].isoformat() if row['last_seen'] else None,
|
|
525
|
+
'status': row['status'],
|
|
526
|
+
}
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
# ========================================
|
|
530
|
+
# Codebase Operations
|
|
531
|
+
# ========================================
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
async def db_upsert_codebase(
|
|
535
|
+
codebase: Dict[str, Any], tenant_id: Optional[str] = None
|
|
536
|
+
) -> bool:
|
|
537
|
+
"""Insert or update a codebase in the database.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
codebase: The codebase data dict
|
|
541
|
+
tenant_id: Optional tenant ID for multi-tenant isolation
|
|
542
|
+
"""
|
|
543
|
+
pool = await get_pool()
|
|
544
|
+
if not pool:
|
|
545
|
+
return False
|
|
546
|
+
|
|
547
|
+
try:
|
|
548
|
+
# Handle both 'created_at' and 'registered_at' field names
|
|
549
|
+
created_at = codebase.get('created_at') or codebase.get('registered_at')
|
|
550
|
+
# Use provided tenant_id or fall back to codebase dict
|
|
551
|
+
effective_tenant_id = tenant_id or codebase.get('tenant_id')
|
|
552
|
+
|
|
553
|
+
async with pool.acquire() as conn:
|
|
554
|
+
await conn.execute(
|
|
555
|
+
"""
|
|
556
|
+
INSERT INTO codebases (id, name, path, description, worker_id, agent_config, created_at, updated_at, status, session_id, opencode_port, tenant_id)
|
|
557
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
|
558
|
+
ON CONFLICT (id)
|
|
559
|
+
DO UPDATE SET
|
|
560
|
+
name = EXCLUDED.name,
|
|
561
|
+
path = EXCLUDED.path,
|
|
562
|
+
description = EXCLUDED.description,
|
|
563
|
+
worker_id = EXCLUDED.worker_id,
|
|
564
|
+
agent_config = EXCLUDED.agent_config,
|
|
565
|
+
updated_at = NOW(),
|
|
566
|
+
status = EXCLUDED.status,
|
|
567
|
+
session_id = EXCLUDED.session_id,
|
|
568
|
+
opencode_port = EXCLUDED.opencode_port,
|
|
569
|
+
tenant_id = COALESCE(EXCLUDED.tenant_id, codebases.tenant_id)
|
|
570
|
+
""",
|
|
571
|
+
codebase.get('id'),
|
|
572
|
+
codebase.get('name'),
|
|
573
|
+
codebase.get('path'),
|
|
574
|
+
codebase.get('description', ''),
|
|
575
|
+
codebase.get('worker_id'),
|
|
576
|
+
json.dumps(codebase.get('agent_config', {})),
|
|
577
|
+
_parse_timestamp(created_at) or datetime.utcnow(),
|
|
578
|
+
_parse_timestamp(codebase.get('updated_at'))
|
|
579
|
+
or datetime.utcnow(),
|
|
580
|
+
codebase.get('status', 'active'),
|
|
581
|
+
codebase.get('session_id'),
|
|
582
|
+
codebase.get('opencode_port'),
|
|
583
|
+
effective_tenant_id,
|
|
584
|
+
)
|
|
585
|
+
return True
|
|
586
|
+
except Exception as e:
|
|
587
|
+
logger.error(f'Failed to upsert codebase: {e}')
|
|
588
|
+
return False
|
|
589
|
+
|
|
590
|
+
|
|
591
|
+
async def db_delete_codebase(codebase_id: str) -> bool:
|
|
592
|
+
"""Delete a codebase from the database."""
|
|
593
|
+
pool = await get_pool()
|
|
594
|
+
if not pool:
|
|
595
|
+
return False
|
|
596
|
+
|
|
597
|
+
try:
|
|
598
|
+
async with pool.acquire() as conn:
|
|
599
|
+
await conn.execute(
|
|
600
|
+
'DELETE FROM codebases WHERE id = $1', codebase_id
|
|
601
|
+
)
|
|
602
|
+
return True
|
|
603
|
+
except Exception as e:
|
|
604
|
+
logger.error(f'Failed to delete codebase: {e}')
|
|
605
|
+
return False
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
async def db_get_codebase(codebase_id: str) -> Optional[Dict[str, Any]]:
|
|
609
|
+
"""Get a codebase by ID."""
|
|
610
|
+
pool = await get_pool()
|
|
611
|
+
if not pool:
|
|
612
|
+
return None
|
|
613
|
+
|
|
614
|
+
try:
|
|
615
|
+
async with pool.acquire() as conn:
|
|
616
|
+
row = await conn.fetchrow(
|
|
617
|
+
'SELECT * FROM codebases WHERE id = $1', codebase_id
|
|
618
|
+
)
|
|
619
|
+
if row:
|
|
620
|
+
return _row_to_codebase(row)
|
|
621
|
+
return None
|
|
622
|
+
except Exception as e:
|
|
623
|
+
logger.error(f'Failed to get codebase: {e}')
|
|
624
|
+
return None
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
async def db_list_codebases(
|
|
628
|
+
worker_id: Optional[str] = None,
|
|
629
|
+
status: Optional[str] = None,
|
|
630
|
+
tenant_id: Optional[str] = None,
|
|
631
|
+
) -> List[Dict[str, Any]]:
|
|
632
|
+
"""List all codebases, optionally filtered by worker, status, or tenant."""
|
|
633
|
+
pool = await get_pool()
|
|
634
|
+
if not pool:
|
|
635
|
+
return []
|
|
636
|
+
|
|
637
|
+
try:
|
|
638
|
+
async with pool.acquire() as conn:
|
|
639
|
+
query = 'SELECT * FROM codebases WHERE 1=1'
|
|
640
|
+
params = []
|
|
641
|
+
param_idx = 1
|
|
642
|
+
|
|
643
|
+
if worker_id:
|
|
644
|
+
query += f' AND worker_id = ${param_idx}'
|
|
645
|
+
params.append(worker_id)
|
|
646
|
+
param_idx += 1
|
|
647
|
+
|
|
648
|
+
if status:
|
|
649
|
+
query += f' AND status = ${param_idx}'
|
|
650
|
+
params.append(status)
|
|
651
|
+
param_idx += 1
|
|
652
|
+
|
|
653
|
+
if tenant_id:
|
|
654
|
+
query += f' AND tenant_id = ${param_idx}'
|
|
655
|
+
params.append(tenant_id)
|
|
656
|
+
param_idx += 1
|
|
657
|
+
|
|
658
|
+
query += ' ORDER BY updated_at DESC'
|
|
659
|
+
|
|
660
|
+
rows = await conn.fetch(query, *params)
|
|
661
|
+
return [_row_to_codebase(row) for row in rows]
|
|
662
|
+
except Exception as e:
|
|
663
|
+
logger.error(f'Failed to list codebases: {e}')
|
|
664
|
+
return []
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
async def db_list_codebases_by_path(path: str) -> List[Dict[str, Any]]:
|
|
668
|
+
"""List all codebases matching a specific normalized path."""
|
|
669
|
+
pool = await get_pool()
|
|
670
|
+
if not pool:
|
|
671
|
+
return []
|
|
672
|
+
|
|
673
|
+
try:
|
|
674
|
+
async with pool.acquire() as conn:
|
|
675
|
+
rows = await conn.fetch(
|
|
676
|
+
'SELECT * FROM codebases WHERE path = $1 ORDER BY updated_at DESC',
|
|
677
|
+
path,
|
|
678
|
+
)
|
|
679
|
+
return [_row_to_codebase(row) for row in rows]
|
|
680
|
+
except Exception as e:
|
|
681
|
+
logger.error(f'Failed to list codebases by path: {e}')
|
|
682
|
+
return []
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def _row_to_codebase(row) -> Dict[str, Any]:
|
|
686
|
+
"""Convert a database row to a codebase dict."""
|
|
687
|
+
agent_config = row['agent_config']
|
|
688
|
+
if isinstance(agent_config, str):
|
|
689
|
+
agent_config = json.loads(agent_config)
|
|
690
|
+
elif agent_config is None:
|
|
691
|
+
agent_config = {}
|
|
692
|
+
|
|
693
|
+
return {
|
|
694
|
+
'id': row['id'],
|
|
695
|
+
'name': row['name'],
|
|
696
|
+
'path': row['path'],
|
|
697
|
+
'description': row['description'],
|
|
698
|
+
'worker_id': row['worker_id'],
|
|
699
|
+
'agent_config': agent_config,
|
|
700
|
+
'created_at': row['created_at'].isoformat()
|
|
701
|
+
if row['created_at']
|
|
702
|
+
else None,
|
|
703
|
+
'updated_at': row['updated_at'].isoformat()
|
|
704
|
+
if row['updated_at']
|
|
705
|
+
else None,
|
|
706
|
+
'status': row['status'],
|
|
707
|
+
'session_id': row['session_id'],
|
|
708
|
+
'opencode_port': row['opencode_port'],
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
# ========================================
|
|
713
|
+
# Task Operations
|
|
714
|
+
# ========================================
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
async def db_upsert_task(
|
|
718
|
+
task: Dict[str, Any], tenant_id: Optional[str] = None
|
|
719
|
+
) -> bool:
|
|
720
|
+
"""Insert or update a task in the database.
|
|
721
|
+
|
|
722
|
+
Args:
|
|
723
|
+
task: The task data dict
|
|
724
|
+
tenant_id: Optional tenant ID for multi-tenant isolation
|
|
725
|
+
"""
|
|
726
|
+
pool = await get_pool()
|
|
727
|
+
if not pool:
|
|
728
|
+
return False
|
|
729
|
+
|
|
730
|
+
try:
|
|
731
|
+
# Use provided tenant_id or fall back to task dict
|
|
732
|
+
effective_tenant_id = tenant_id or task.get('tenant_id')
|
|
733
|
+
|
|
734
|
+
async with pool.acquire() as conn:
|
|
735
|
+
await conn.execute(
|
|
736
|
+
"""
|
|
737
|
+
INSERT INTO tasks (id, codebase_id, title, prompt, agent_type, status, priority, worker_id, result, error, metadata, created_at, updated_at, started_at, completed_at, tenant_id)
|
|
738
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
|
739
|
+
ON CONFLICT (id)
|
|
740
|
+
DO UPDATE SET
|
|
741
|
+
status = EXCLUDED.status,
|
|
742
|
+
worker_id = EXCLUDED.worker_id,
|
|
743
|
+
result = EXCLUDED.result,
|
|
744
|
+
error = EXCLUDED.error,
|
|
745
|
+
updated_at = NOW(),
|
|
746
|
+
started_at = COALESCE(tasks.started_at, EXCLUDED.started_at),
|
|
747
|
+
completed_at = EXCLUDED.completed_at,
|
|
748
|
+
tenant_id = COALESCE(EXCLUDED.tenant_id, tasks.tenant_id)
|
|
749
|
+
""",
|
|
750
|
+
task.get('id'),
|
|
751
|
+
task.get('codebase_id'),
|
|
752
|
+
task.get('title'),
|
|
753
|
+
task.get('prompt'),
|
|
754
|
+
task.get('agent_type', 'build'),
|
|
755
|
+
task.get('status', 'pending'),
|
|
756
|
+
task.get('priority', 0),
|
|
757
|
+
task.get('worker_id'),
|
|
758
|
+
task.get('result'),
|
|
759
|
+
task.get('error'),
|
|
760
|
+
json.dumps(task.get('metadata', {})),
|
|
761
|
+
_parse_timestamp(task.get('created_at')) or datetime.utcnow(),
|
|
762
|
+
_parse_timestamp(task.get('updated_at')) or datetime.utcnow(),
|
|
763
|
+
_parse_timestamp(task.get('started_at')),
|
|
764
|
+
_parse_timestamp(task.get('completed_at')),
|
|
765
|
+
effective_tenant_id,
|
|
766
|
+
)
|
|
767
|
+
return True
|
|
768
|
+
except Exception as e:
|
|
769
|
+
logger.error(f'Failed to upsert task: {e}')
|
|
770
|
+
return False
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
async def db_get_task(task_id: str) -> Optional[Dict[str, Any]]:
|
|
774
|
+
"""Get a task by ID."""
|
|
775
|
+
pool = await get_pool()
|
|
776
|
+
if not pool:
|
|
777
|
+
return None
|
|
778
|
+
|
|
779
|
+
try:
|
|
780
|
+
async with pool.acquire() as conn:
|
|
781
|
+
row = await conn.fetchrow(
|
|
782
|
+
'SELECT * FROM tasks WHERE id = $1', task_id
|
|
783
|
+
)
|
|
784
|
+
if row:
|
|
785
|
+
return _row_to_task(row)
|
|
786
|
+
return None
|
|
787
|
+
except Exception as e:
|
|
788
|
+
logger.error(f'Failed to get task: {e}')
|
|
789
|
+
return None
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
async def db_list_tasks(
|
|
793
|
+
codebase_id: Optional[str] = None,
|
|
794
|
+
status: Optional[str] = None,
|
|
795
|
+
worker_id: Optional[str] = None,
|
|
796
|
+
limit: int = 100,
|
|
797
|
+
tenant_id: Optional[str] = None,
|
|
798
|
+
) -> List[Dict[str, Any]]:
|
|
799
|
+
"""List tasks with optional filters including tenant isolation."""
|
|
800
|
+
pool = await get_pool()
|
|
801
|
+
if not pool:
|
|
802
|
+
return []
|
|
803
|
+
|
|
804
|
+
try:
|
|
805
|
+
async with pool.acquire() as conn:
|
|
806
|
+
query = 'SELECT * FROM tasks WHERE 1=1'
|
|
807
|
+
params = []
|
|
808
|
+
param_idx = 1
|
|
809
|
+
|
|
810
|
+
if codebase_id:
|
|
811
|
+
query += f' AND codebase_id = ${param_idx}'
|
|
812
|
+
params.append(codebase_id)
|
|
813
|
+
param_idx += 1
|
|
814
|
+
|
|
815
|
+
if status:
|
|
816
|
+
query += f' AND status = ${param_idx}'
|
|
817
|
+
params.append(status)
|
|
818
|
+
param_idx += 1
|
|
819
|
+
|
|
820
|
+
if worker_id:
|
|
821
|
+
query += f' AND worker_id = ${param_idx}'
|
|
822
|
+
params.append(worker_id)
|
|
823
|
+
param_idx += 1
|
|
824
|
+
|
|
825
|
+
if tenant_id:
|
|
826
|
+
query += f' AND tenant_id = ${param_idx}'
|
|
827
|
+
params.append(tenant_id)
|
|
828
|
+
param_idx += 1
|
|
829
|
+
|
|
830
|
+
query += (
|
|
831
|
+
f' ORDER BY priority DESC, created_at ASC LIMIT ${param_idx}'
|
|
832
|
+
)
|
|
833
|
+
params.append(limit)
|
|
834
|
+
|
|
835
|
+
rows = await conn.fetch(query, *params)
|
|
836
|
+
return [_row_to_task(row) for row in rows]
|
|
837
|
+
except Exception as e:
|
|
838
|
+
logger.error(f'Failed to list tasks: {e}')
|
|
839
|
+
return []
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
async def db_get_next_pending_task(
|
|
843
|
+
codebase_id: Optional[str] = None,
|
|
844
|
+
) -> Optional[Dict[str, Any]]:
|
|
845
|
+
"""Get the next pending task (highest priority, oldest first)."""
|
|
846
|
+
pool = await get_pool()
|
|
847
|
+
if not pool:
|
|
848
|
+
return None
|
|
849
|
+
|
|
850
|
+
try:
|
|
851
|
+
async with pool.acquire() as conn:
|
|
852
|
+
if codebase_id:
|
|
853
|
+
row = await conn.fetchrow(
|
|
854
|
+
"""
|
|
855
|
+
SELECT * FROM tasks
|
|
856
|
+
WHERE status = 'pending' AND codebase_id = $1
|
|
857
|
+
ORDER BY priority DESC, created_at ASC
|
|
858
|
+
LIMIT 1
|
|
859
|
+
""",
|
|
860
|
+
codebase_id,
|
|
861
|
+
)
|
|
862
|
+
else:
|
|
863
|
+
row = await conn.fetchrow("""
|
|
864
|
+
SELECT * FROM tasks
|
|
865
|
+
WHERE status = 'pending'
|
|
866
|
+
ORDER BY priority DESC, created_at ASC
|
|
867
|
+
LIMIT 1
|
|
868
|
+
""")
|
|
869
|
+
if row:
|
|
870
|
+
return _row_to_task(row)
|
|
871
|
+
return None
|
|
872
|
+
except Exception as e:
|
|
873
|
+
logger.error(f'Failed to get next pending task: {e}')
|
|
874
|
+
return None
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
async def db_update_task_status(
|
|
878
|
+
task_id: str,
|
|
879
|
+
status: str,
|
|
880
|
+
worker_id: Optional[str] = None,
|
|
881
|
+
result: Optional[str] = None,
|
|
882
|
+
error: Optional[str] = None,
|
|
883
|
+
) -> bool:
|
|
884
|
+
"""Update task status and optionally result/error."""
|
|
885
|
+
pool = await get_pool()
|
|
886
|
+
if not pool:
|
|
887
|
+
return False
|
|
888
|
+
|
|
889
|
+
try:
|
|
890
|
+
async with pool.acquire() as conn:
|
|
891
|
+
updates = ['status = $2', 'updated_at = NOW()']
|
|
892
|
+
params = [task_id, status]
|
|
893
|
+
param_idx = 3
|
|
894
|
+
|
|
895
|
+
if worker_id:
|
|
896
|
+
updates.append(f'worker_id = ${param_idx}')
|
|
897
|
+
params.append(worker_id)
|
|
898
|
+
param_idx += 1
|
|
899
|
+
|
|
900
|
+
if result is not None:
|
|
901
|
+
updates.append(f'result = ${param_idx}')
|
|
902
|
+
params.append(result)
|
|
903
|
+
param_idx += 1
|
|
904
|
+
|
|
905
|
+
if error is not None:
|
|
906
|
+
updates.append(f'error = ${param_idx}')
|
|
907
|
+
params.append(error)
|
|
908
|
+
param_idx += 1
|
|
909
|
+
|
|
910
|
+
if status == 'running':
|
|
911
|
+
updates.append('started_at = NOW()')
|
|
912
|
+
elif status in ('completed', 'failed', 'cancelled'):
|
|
913
|
+
updates.append('completed_at = NOW()')
|
|
914
|
+
|
|
915
|
+
query = f'UPDATE tasks SET {", ".join(updates)} WHERE id = $1'
|
|
916
|
+
result_msg = await conn.execute(query, *params)
|
|
917
|
+
return 'UPDATE 1' in result_msg
|
|
918
|
+
except Exception as e:
|
|
919
|
+
logger.error(f'Failed to update task status: {e}')
|
|
920
|
+
return False
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
def _row_to_task(row) -> Dict[str, Any]:
|
|
924
|
+
"""Convert a database row to a task dict."""
|
|
925
|
+
metadata = row['metadata']
|
|
926
|
+
if isinstance(metadata, str):
|
|
927
|
+
metadata = json.loads(metadata)
|
|
928
|
+
elif metadata is None:
|
|
929
|
+
metadata = {}
|
|
930
|
+
|
|
931
|
+
return {
|
|
932
|
+
'id': row['id'],
|
|
933
|
+
'codebase_id': row['codebase_id'],
|
|
934
|
+
'title': row['title'],
|
|
935
|
+
'prompt': row['prompt'],
|
|
936
|
+
'agent_type': row['agent_type'],
|
|
937
|
+
'status': row['status'],
|
|
938
|
+
'priority': row['priority'],
|
|
939
|
+
'worker_id': row['worker_id'],
|
|
940
|
+
'result': row['result'],
|
|
941
|
+
'error': row['error'],
|
|
942
|
+
'metadata': metadata,
|
|
943
|
+
'created_at': row['created_at'].isoformat()
|
|
944
|
+
if row['created_at']
|
|
945
|
+
else None,
|
|
946
|
+
'updated_at': row['updated_at'].isoformat()
|
|
947
|
+
if row['updated_at']
|
|
948
|
+
else None,
|
|
949
|
+
'started_at': row['started_at'].isoformat()
|
|
950
|
+
if row['started_at']
|
|
951
|
+
else None,
|
|
952
|
+
'completed_at': row['completed_at'].isoformat()
|
|
953
|
+
if row['completed_at']
|
|
954
|
+
else None,
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
# ========================================
|
|
959
|
+
# Session Operations
|
|
960
|
+
# ========================================
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
async def db_upsert_session(
|
|
964
|
+
session: Dict[str, Any], tenant_id: Optional[str] = None
|
|
965
|
+
) -> bool:
|
|
966
|
+
"""Insert or update a session in the database.
|
|
967
|
+
|
|
968
|
+
Args:
|
|
969
|
+
session: The session data dict
|
|
970
|
+
tenant_id: Optional tenant ID for multi-tenant isolation
|
|
971
|
+
"""
|
|
972
|
+
pool = await get_pool()
|
|
973
|
+
if not pool:
|
|
974
|
+
return False
|
|
975
|
+
|
|
976
|
+
try:
|
|
977
|
+
# Use provided tenant_id or fall back to session dict
|
|
978
|
+
effective_tenant_id = tenant_id or session.get('tenant_id')
|
|
979
|
+
|
|
980
|
+
async with pool.acquire() as conn:
|
|
981
|
+
await conn.execute(
|
|
982
|
+
"""
|
|
983
|
+
INSERT INTO sessions (id, codebase_id, project_id, directory, title, version, summary, created_at, updated_at, tenant_id)
|
|
984
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
|
985
|
+
ON CONFLICT (id)
|
|
986
|
+
DO UPDATE SET
|
|
987
|
+
title = EXCLUDED.title,
|
|
988
|
+
version = EXCLUDED.version,
|
|
989
|
+
summary = EXCLUDED.summary,
|
|
990
|
+
updated_at = NOW(),
|
|
991
|
+
tenant_id = COALESCE(EXCLUDED.tenant_id, sessions.tenant_id)
|
|
992
|
+
""",
|
|
993
|
+
session.get('id'),
|
|
994
|
+
session.get('codebase_id'),
|
|
995
|
+
session.get('project_id'),
|
|
996
|
+
session.get('directory'),
|
|
997
|
+
session.get('title'),
|
|
998
|
+
session.get('version'),
|
|
999
|
+
json.dumps(session.get('summary', {})),
|
|
1000
|
+
_parse_timestamp(session.get('created_at'))
|
|
1001
|
+
or datetime.utcnow(),
|
|
1002
|
+
_parse_timestamp(session.get('updated_at'))
|
|
1003
|
+
or datetime.utcnow(),
|
|
1004
|
+
effective_tenant_id,
|
|
1005
|
+
)
|
|
1006
|
+
return True
|
|
1007
|
+
except Exception as e:
|
|
1008
|
+
logger.error(f'Failed to upsert session: {e}')
|
|
1009
|
+
return False
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
async def db_list_sessions(
|
|
1013
|
+
codebase_id: str,
|
|
1014
|
+
limit: int = 50,
|
|
1015
|
+
tenant_id: Optional[str] = None,
|
|
1016
|
+
) -> List[Dict[str, Any]]:
|
|
1017
|
+
"""List sessions for a codebase, optionally filtered by tenant."""
|
|
1018
|
+
pool = await get_pool()
|
|
1019
|
+
if not pool:
|
|
1020
|
+
return []
|
|
1021
|
+
|
|
1022
|
+
try:
|
|
1023
|
+
async with pool.acquire() as conn:
|
|
1024
|
+
if tenant_id:
|
|
1025
|
+
rows = await conn.fetch(
|
|
1026
|
+
"""
|
|
1027
|
+
SELECT * FROM sessions
|
|
1028
|
+
WHERE codebase_id = $1 AND tenant_id = $3
|
|
1029
|
+
ORDER BY updated_at DESC
|
|
1030
|
+
LIMIT $2
|
|
1031
|
+
""",
|
|
1032
|
+
codebase_id,
|
|
1033
|
+
limit,
|
|
1034
|
+
tenant_id,
|
|
1035
|
+
)
|
|
1036
|
+
else:
|
|
1037
|
+
rows = await conn.fetch(
|
|
1038
|
+
"""
|
|
1039
|
+
SELECT * FROM sessions
|
|
1040
|
+
WHERE codebase_id = $1
|
|
1041
|
+
ORDER BY updated_at DESC
|
|
1042
|
+
LIMIT $2
|
|
1043
|
+
""",
|
|
1044
|
+
codebase_id,
|
|
1045
|
+
limit,
|
|
1046
|
+
)
|
|
1047
|
+
return [_row_to_session(row) for row in rows]
|
|
1048
|
+
except Exception as e:
|
|
1049
|
+
logger.error(f'Failed to list sessions: {e}')
|
|
1050
|
+
return []
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
async def db_list_all_sessions(
|
|
1054
|
+
limit: int = 100,
|
|
1055
|
+
offset: int = 0,
|
|
1056
|
+
tenant_id: Optional[str] = None,
|
|
1057
|
+
) -> List[Dict[str, Any]]:
|
|
1058
|
+
"""List all sessions across all codebases, optionally filtered by tenant."""
|
|
1059
|
+
pool = await get_pool()
|
|
1060
|
+
if not pool:
|
|
1061
|
+
return []
|
|
1062
|
+
|
|
1063
|
+
try:
|
|
1064
|
+
async with pool.acquire() as conn:
|
|
1065
|
+
if tenant_id:
|
|
1066
|
+
rows = await conn.fetch(
|
|
1067
|
+
"""
|
|
1068
|
+
SELECT s.*, c.name as codebase_name, c.path as codebase_path
|
|
1069
|
+
FROM sessions s
|
|
1070
|
+
LEFT JOIN codebases c ON s.codebase_id = c.id
|
|
1071
|
+
WHERE s.tenant_id = $3
|
|
1072
|
+
ORDER BY s.updated_at DESC
|
|
1073
|
+
LIMIT $1 OFFSET $2
|
|
1074
|
+
""",
|
|
1075
|
+
limit,
|
|
1076
|
+
offset,
|
|
1077
|
+
tenant_id,
|
|
1078
|
+
)
|
|
1079
|
+
else:
|
|
1080
|
+
rows = await conn.fetch(
|
|
1081
|
+
"""
|
|
1082
|
+
SELECT s.*, c.name as codebase_name, c.path as codebase_path
|
|
1083
|
+
FROM sessions s
|
|
1084
|
+
LEFT JOIN codebases c ON s.codebase_id = c.id
|
|
1085
|
+
ORDER BY s.updated_at DESC
|
|
1086
|
+
LIMIT $1 OFFSET $2
|
|
1087
|
+
""",
|
|
1088
|
+
limit,
|
|
1089
|
+
offset,
|
|
1090
|
+
)
|
|
1091
|
+
sessions = []
|
|
1092
|
+
for row in rows:
|
|
1093
|
+
session = _row_to_session(row)
|
|
1094
|
+
session['codebase_name'] = row.get('codebase_name')
|
|
1095
|
+
session['codebase_path'] = row.get('codebase_path')
|
|
1096
|
+
sessions.append(session)
|
|
1097
|
+
return sessions
|
|
1098
|
+
except Exception as e:
|
|
1099
|
+
logger.error(f'Failed to list all sessions: {e}')
|
|
1100
|
+
return []
|
|
1101
|
+
|
|
1102
|
+
|
|
1103
|
+
async def db_get_session(session_id: str) -> Optional[Dict[str, Any]]:
|
|
1104
|
+
"""Get a session by ID."""
|
|
1105
|
+
pool = await get_pool()
|
|
1106
|
+
if not pool:
|
|
1107
|
+
return None
|
|
1108
|
+
|
|
1109
|
+
try:
|
|
1110
|
+
async with pool.acquire() as conn:
|
|
1111
|
+
row = await conn.fetchrow(
|
|
1112
|
+
'SELECT * FROM sessions WHERE id = $1', session_id
|
|
1113
|
+
)
|
|
1114
|
+
if row:
|
|
1115
|
+
return _row_to_session(row)
|
|
1116
|
+
return None
|
|
1117
|
+
except Exception as e:
|
|
1118
|
+
logger.error(f'Failed to get session: {e}')
|
|
1119
|
+
return None
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def _row_to_session(row) -> Dict[str, Any]:
|
|
1123
|
+
"""Convert a database row to a session dict."""
|
|
1124
|
+
summary = row['summary']
|
|
1125
|
+
if isinstance(summary, str):
|
|
1126
|
+
summary = json.loads(summary)
|
|
1127
|
+
elif summary is None:
|
|
1128
|
+
summary = {}
|
|
1129
|
+
|
|
1130
|
+
return {
|
|
1131
|
+
'id': row['id'],
|
|
1132
|
+
'codebase_id': row['codebase_id'],
|
|
1133
|
+
'project_id': row['project_id'],
|
|
1134
|
+
'directory': row['directory'],
|
|
1135
|
+
'title': row['title'],
|
|
1136
|
+
'version': row['version'],
|
|
1137
|
+
'summary': summary,
|
|
1138
|
+
'created_at': row['created_at'].isoformat()
|
|
1139
|
+
if row['created_at']
|
|
1140
|
+
else None,
|
|
1141
|
+
'updated_at': row['updated_at'].isoformat()
|
|
1142
|
+
if row['updated_at']
|
|
1143
|
+
else None,
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
# ========================================
|
|
1148
|
+
# Session Message Operations
|
|
1149
|
+
# ========================================
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
async def db_upsert_message(message: Dict[str, Any]) -> bool:
|
|
1153
|
+
"""Insert or update a session message."""
|
|
1154
|
+
pool = await get_pool()
|
|
1155
|
+
if not pool:
|
|
1156
|
+
return False
|
|
1157
|
+
|
|
1158
|
+
try:
|
|
1159
|
+
async with pool.acquire() as conn:
|
|
1160
|
+
await conn.execute(
|
|
1161
|
+
"""
|
|
1162
|
+
INSERT INTO session_messages (id, session_id, role, content, model, cost, tokens, tool_calls, created_at)
|
|
1163
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
1164
|
+
ON CONFLICT (id)
|
|
1165
|
+
DO UPDATE SET
|
|
1166
|
+
content = EXCLUDED.content,
|
|
1167
|
+
tokens = EXCLUDED.tokens,
|
|
1168
|
+
tool_calls = EXCLUDED.tool_calls
|
|
1169
|
+
""",
|
|
1170
|
+
message.get('id'),
|
|
1171
|
+
message.get('session_id'),
|
|
1172
|
+
message.get('role'),
|
|
1173
|
+
message.get('content'),
|
|
1174
|
+
message.get('model'),
|
|
1175
|
+
message.get('cost'),
|
|
1176
|
+
json.dumps(message.get('tokens', {})),
|
|
1177
|
+
json.dumps(message.get('tool_calls', [])),
|
|
1178
|
+
_parse_timestamp(message.get('created_at'))
|
|
1179
|
+
or datetime.utcnow(),
|
|
1180
|
+
)
|
|
1181
|
+
return True
|
|
1182
|
+
except Exception as e:
|
|
1183
|
+
logger.error(f'Failed to upsert message: {e}')
|
|
1184
|
+
return False
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
async def db_list_messages(
|
|
1188
|
+
session_id: str, limit: int = 50
|
|
1189
|
+
) -> List[Dict[str, Any]]:
|
|
1190
|
+
"""List messages for a session."""
|
|
1191
|
+
pool = await get_pool()
|
|
1192
|
+
if not pool:
|
|
1193
|
+
return []
|
|
1194
|
+
|
|
1195
|
+
try:
|
|
1196
|
+
async with pool.acquire() as conn:
|
|
1197
|
+
rows = await conn.fetch(
|
|
1198
|
+
"""
|
|
1199
|
+
SELECT * FROM session_messages
|
|
1200
|
+
WHERE session_id = $1
|
|
1201
|
+
ORDER BY created_at ASC
|
|
1202
|
+
LIMIT $2
|
|
1203
|
+
""",
|
|
1204
|
+
session_id,
|
|
1205
|
+
limit,
|
|
1206
|
+
)
|
|
1207
|
+
return [_row_to_message(row) for row in rows]
|
|
1208
|
+
except Exception as e:
|
|
1209
|
+
logger.error(f'Failed to list messages: {e}')
|
|
1210
|
+
return []
|
|
1211
|
+
|
|
1212
|
+
|
|
1213
|
+
def _row_to_message(row) -> Dict[str, Any]:
|
|
1214
|
+
"""Convert a database row to a message dict."""
|
|
1215
|
+
tokens = row['tokens']
|
|
1216
|
+
if isinstance(tokens, str):
|
|
1217
|
+
tokens = json.loads(tokens)
|
|
1218
|
+
elif tokens is None:
|
|
1219
|
+
tokens = {}
|
|
1220
|
+
|
|
1221
|
+
tool_calls = row['tool_calls']
|
|
1222
|
+
if isinstance(tool_calls, str):
|
|
1223
|
+
tool_calls = json.loads(tool_calls)
|
|
1224
|
+
elif tool_calls is None:
|
|
1225
|
+
tool_calls = []
|
|
1226
|
+
|
|
1227
|
+
return {
|
|
1228
|
+
'id': row['id'],
|
|
1229
|
+
'session_id': row['session_id'],
|
|
1230
|
+
'role': row['role'],
|
|
1231
|
+
'content': row['content'],
|
|
1232
|
+
'model': row['model'],
|
|
1233
|
+
'cost': row['cost'],
|
|
1234
|
+
'tokens': tokens,
|
|
1235
|
+
'tool_calls': tool_calls,
|
|
1236
|
+
'created_at': row['created_at'].isoformat()
|
|
1237
|
+
if row['created_at']
|
|
1238
|
+
else None,
|
|
1239
|
+
}
|
|
1240
|
+
|
|
1241
|
+
|
|
1242
|
+
# ========================================
|
|
1243
|
+
# Health Check
|
|
1244
|
+
# ========================================
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
async def db_health_check() -> Dict[str, Any]:
|
|
1248
|
+
"""Check database health and return stats."""
|
|
1249
|
+
pool = await get_pool()
|
|
1250
|
+
|
|
1251
|
+
if not pool:
|
|
1252
|
+
return {
|
|
1253
|
+
'available': False,
|
|
1254
|
+
'message': 'PostgreSQL not configured (set DATABASE_URL environment variable)',
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
try:
|
|
1258
|
+
async with pool.acquire() as conn:
|
|
1259
|
+
# Check connectivity
|
|
1260
|
+
await conn.fetchval('SELECT 1')
|
|
1261
|
+
|
|
1262
|
+
# Get counts
|
|
1263
|
+
worker_count = await conn.fetchval('SELECT COUNT(*) FROM workers')
|
|
1264
|
+
codebase_count = await conn.fetchval(
|
|
1265
|
+
'SELECT COUNT(*) FROM codebases'
|
|
1266
|
+
)
|
|
1267
|
+
task_count = await conn.fetchval('SELECT COUNT(*) FROM tasks')
|
|
1268
|
+
session_count = await conn.fetchval('SELECT COUNT(*) FROM sessions')
|
|
1269
|
+
|
|
1270
|
+
return {
|
|
1271
|
+
'available': True,
|
|
1272
|
+
'message': 'PostgreSQL connected',
|
|
1273
|
+
'stats': {
|
|
1274
|
+
'workers': worker_count,
|
|
1275
|
+
'codebases': codebase_count,
|
|
1276
|
+
'tasks': task_count,
|
|
1277
|
+
'sessions': session_count,
|
|
1278
|
+
},
|
|
1279
|
+
'pool_size': pool.get_size(),
|
|
1280
|
+
'pool_idle': pool.get_idle_size(),
|
|
1281
|
+
}
|
|
1282
|
+
except Exception as e:
|
|
1283
|
+
return {
|
|
1284
|
+
'available': False,
|
|
1285
|
+
'message': f'PostgreSQL error: {e}',
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
|
|
1289
|
+
# ========================================
|
|
1290
|
+
# Monitor/Messages Operations
|
|
1291
|
+
# ========================================
|
|
1292
|
+
|
|
1293
|
+
|
|
1294
|
+
async def db_save_monitor_message(message: Dict[str, Any]) -> bool:
|
|
1295
|
+
"""Save a monitor message to the database."""
|
|
1296
|
+
pool = await get_pool()
|
|
1297
|
+
if not pool:
|
|
1298
|
+
return False
|
|
1299
|
+
|
|
1300
|
+
try:
|
|
1301
|
+
async with pool.acquire() as conn:
|
|
1302
|
+
await conn.execute(
|
|
1303
|
+
"""
|
|
1304
|
+
INSERT INTO monitor_messages
|
|
1305
|
+
(id, timestamp, type, agent_name, content, metadata, response_time, tokens, error)
|
|
1306
|
+
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
1307
|
+
ON CONFLICT (id) DO NOTHING
|
|
1308
|
+
""",
|
|
1309
|
+
message.get('id'),
|
|
1310
|
+
_parse_timestamp(message.get('timestamp')) or datetime.utcnow(),
|
|
1311
|
+
message.get('type'),
|
|
1312
|
+
message.get('agent_name'),
|
|
1313
|
+
message.get('content'),
|
|
1314
|
+
json.dumps(message.get('metadata', {})),
|
|
1315
|
+
message.get('response_time'),
|
|
1316
|
+
message.get('tokens'),
|
|
1317
|
+
message.get('error'),
|
|
1318
|
+
)
|
|
1319
|
+
return True
|
|
1320
|
+
except Exception as e:
|
|
1321
|
+
logger.error(f'Failed to save monitor message: {e}')
|
|
1322
|
+
return False
|
|
1323
|
+
|
|
1324
|
+
|
|
1325
|
+
async def db_list_monitor_messages(
|
|
1326
|
+
limit: int = 100,
|
|
1327
|
+
agent_name: Optional[str] = None,
|
|
1328
|
+
msg_type: Optional[str] = None,
|
|
1329
|
+
) -> List[Dict[str, Any]]:
|
|
1330
|
+
"""List monitor messages from the database."""
|
|
1331
|
+
pool = await get_pool()
|
|
1332
|
+
if not pool:
|
|
1333
|
+
return []
|
|
1334
|
+
|
|
1335
|
+
try:
|
|
1336
|
+
async with pool.acquire() as conn:
|
|
1337
|
+
query = 'SELECT * FROM monitor_messages WHERE 1=1'
|
|
1338
|
+
params = []
|
|
1339
|
+
param_idx = 1
|
|
1340
|
+
|
|
1341
|
+
if agent_name:
|
|
1342
|
+
query += f' AND agent_name = ${param_idx}'
|
|
1343
|
+
params.append(agent_name)
|
|
1344
|
+
param_idx += 1
|
|
1345
|
+
|
|
1346
|
+
if msg_type:
|
|
1347
|
+
query += f' AND type = ${param_idx}'
|
|
1348
|
+
params.append(msg_type)
|
|
1349
|
+
param_idx += 1
|
|
1350
|
+
|
|
1351
|
+
query += f' ORDER BY timestamp DESC LIMIT ${param_idx}'
|
|
1352
|
+
params.append(limit)
|
|
1353
|
+
|
|
1354
|
+
rows = await conn.fetch(query, *params)
|
|
1355
|
+
return [_row_to_monitor_message(row) for row in rows]
|
|
1356
|
+
except Exception as e:
|
|
1357
|
+
logger.error(f'Failed to list monitor messages: {e}')
|
|
1358
|
+
return []
|
|
1359
|
+
|
|
1360
|
+
|
|
1361
|
+
async def db_get_monitor_stats() -> Dict[str, Any]:
|
|
1362
|
+
"""Get monitor statistics from the database."""
|
|
1363
|
+
pool = await get_pool()
|
|
1364
|
+
if not pool:
|
|
1365
|
+
return {}
|
|
1366
|
+
|
|
1367
|
+
try:
|
|
1368
|
+
async with pool.acquire() as conn:
|
|
1369
|
+
stats = await conn.fetchrow("""
|
|
1370
|
+
SELECT
|
|
1371
|
+
COUNT(*) as total_messages,
|
|
1372
|
+
SUM(CASE WHEN type = 'tool' THEN 1 ELSE 0 END) as tool_calls,
|
|
1373
|
+
SUM(CASE WHEN type = 'error' THEN 1 ELSE 0 END) as errors,
|
|
1374
|
+
COALESCE(SUM(tokens), 0) as total_tokens,
|
|
1375
|
+
COUNT(DISTINCT agent_name) as unique_agents
|
|
1376
|
+
FROM monitor_messages
|
|
1377
|
+
""")
|
|
1378
|
+
return dict(stats) if stats else {}
|
|
1379
|
+
except Exception as e:
|
|
1380
|
+
logger.error(f'Failed to get monitor stats: {e}')
|
|
1381
|
+
return {}
|
|
1382
|
+
|
|
1383
|
+
|
|
1384
|
+
async def db_count_monitor_messages() -> int:
|
|
1385
|
+
"""Count total monitor messages."""
|
|
1386
|
+
pool = await get_pool()
|
|
1387
|
+
if not pool:
|
|
1388
|
+
return 0
|
|
1389
|
+
|
|
1390
|
+
try:
|
|
1391
|
+
async with pool.acquire() as conn:
|
|
1392
|
+
return await conn.fetchval('SELECT COUNT(*) FROM monitor_messages')
|
|
1393
|
+
except Exception:
|
|
1394
|
+
return 0
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
def _row_to_monitor_message(row) -> Dict[str, Any]:
|
|
1398
|
+
"""Convert a database row to a monitor message dict."""
|
|
1399
|
+
metadata = row['metadata']
|
|
1400
|
+
if isinstance(metadata, str):
|
|
1401
|
+
metadata = json.loads(metadata)
|
|
1402
|
+
elif metadata is None:
|
|
1403
|
+
metadata = {}
|
|
1404
|
+
|
|
1405
|
+
return {
|
|
1406
|
+
'id': row['id'],
|
|
1407
|
+
'timestamp': row['timestamp'].isoformat() if row['timestamp'] else None,
|
|
1408
|
+
'type': row['type'],
|
|
1409
|
+
'agent_name': row['agent_name'],
|
|
1410
|
+
'content': row['content'],
|
|
1411
|
+
'metadata': metadata,
|
|
1412
|
+
'response_time': row['response_time'],
|
|
1413
|
+
'tokens': row['tokens'],
|
|
1414
|
+
'error': row['error'],
|
|
1415
|
+
}
|
|
1416
|
+
|
|
1417
|
+
|
|
1418
|
+
# ========================================
|
|
1419
|
+
# Tenant Operations (Multi-tenant support)
|
|
1420
|
+
# ========================================
|
|
1421
|
+
|
|
1422
|
+
|
|
1423
|
+
async def create_tenant(
|
|
1424
|
+
realm_name: str, display_name: str, plan: str = 'free'
|
|
1425
|
+
) -> dict:
|
|
1426
|
+
"""Create a new tenant.
|
|
1427
|
+
|
|
1428
|
+
Args:
|
|
1429
|
+
realm_name: Unique realm identifier (e.g., "acme.codetether.run")
|
|
1430
|
+
display_name: Human-readable tenant name
|
|
1431
|
+
plan: Subscription plan ('free', 'pro', 'enterprise')
|
|
1432
|
+
|
|
1433
|
+
Returns:
|
|
1434
|
+
The created tenant dict
|
|
1435
|
+
"""
|
|
1436
|
+
pool = await get_pool()
|
|
1437
|
+
if not pool:
|
|
1438
|
+
raise RuntimeError('Database not available')
|
|
1439
|
+
|
|
1440
|
+
tenant_id = str(uuid.uuid4())
|
|
1441
|
+
now = datetime.utcnow()
|
|
1442
|
+
|
|
1443
|
+
try:
|
|
1444
|
+
async with pool.acquire() as conn:
|
|
1445
|
+
await conn.execute(
|
|
1446
|
+
"""
|
|
1447
|
+
INSERT INTO tenants (id, realm_name, display_name, plan, created_at, updated_at)
|
|
1448
|
+
VALUES ($1, $2, $3, $4, $5, $6)
|
|
1449
|
+
""",
|
|
1450
|
+
tenant_id,
|
|
1451
|
+
realm_name,
|
|
1452
|
+
display_name,
|
|
1453
|
+
plan,
|
|
1454
|
+
now,
|
|
1455
|
+
now,
|
|
1456
|
+
)
|
|
1457
|
+
|
|
1458
|
+
return {
|
|
1459
|
+
'id': tenant_id,
|
|
1460
|
+
'realm_name': realm_name,
|
|
1461
|
+
'display_name': display_name,
|
|
1462
|
+
'plan': plan,
|
|
1463
|
+
'stripe_customer_id': None,
|
|
1464
|
+
'stripe_subscription_id': None,
|
|
1465
|
+
'created_at': now.isoformat(),
|
|
1466
|
+
'updated_at': now.isoformat(),
|
|
1467
|
+
}
|
|
1468
|
+
except Exception as e:
|
|
1469
|
+
logger.error(f'Failed to create tenant: {e}')
|
|
1470
|
+
raise
|
|
1471
|
+
|
|
1472
|
+
|
|
1473
|
+
async def get_tenant_by_realm(realm_name: str) -> Optional[dict]:
|
|
1474
|
+
"""Get a tenant by realm name.
|
|
1475
|
+
|
|
1476
|
+
Args:
|
|
1477
|
+
realm_name: The realm identifier (e.g., "acme.codetether.run")
|
|
1478
|
+
|
|
1479
|
+
Returns:
|
|
1480
|
+
Tenant dict or None if not found
|
|
1481
|
+
"""
|
|
1482
|
+
pool = await get_pool()
|
|
1483
|
+
if not pool:
|
|
1484
|
+
return None
|
|
1485
|
+
|
|
1486
|
+
try:
|
|
1487
|
+
async with pool.acquire() as conn:
|
|
1488
|
+
row = await conn.fetchrow(
|
|
1489
|
+
'SELECT * FROM tenants WHERE realm_name = $1', realm_name
|
|
1490
|
+
)
|
|
1491
|
+
if row:
|
|
1492
|
+
return _row_to_tenant(row)
|
|
1493
|
+
return None
|
|
1494
|
+
except Exception as e:
|
|
1495
|
+
logger.error(f'Failed to get tenant by realm: {e}')
|
|
1496
|
+
return None
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
async def get_tenant_by_id(tenant_id: str) -> Optional[dict]:
|
|
1500
|
+
"""Get a tenant by ID.
|
|
1501
|
+
|
|
1502
|
+
Args:
|
|
1503
|
+
tenant_id: The tenant UUID
|
|
1504
|
+
|
|
1505
|
+
Returns:
|
|
1506
|
+
Tenant dict or None if not found
|
|
1507
|
+
"""
|
|
1508
|
+
pool = await get_pool()
|
|
1509
|
+
if not pool:
|
|
1510
|
+
return None
|
|
1511
|
+
|
|
1512
|
+
try:
|
|
1513
|
+
async with pool.acquire() as conn:
|
|
1514
|
+
row = await conn.fetchrow(
|
|
1515
|
+
'SELECT * FROM tenants WHERE id = $1', tenant_id
|
|
1516
|
+
)
|
|
1517
|
+
if row:
|
|
1518
|
+
return _row_to_tenant(row)
|
|
1519
|
+
return None
|
|
1520
|
+
except Exception as e:
|
|
1521
|
+
logger.error(f'Failed to get tenant by id: {e}')
|
|
1522
|
+
return None
|
|
1523
|
+
|
|
1524
|
+
|
|
1525
|
+
async def list_tenants(limit: int = 100, offset: int = 0) -> List[dict]:
|
|
1526
|
+
"""List all tenants with pagination.
|
|
1527
|
+
|
|
1528
|
+
Args:
|
|
1529
|
+
limit: Maximum number of tenants to return
|
|
1530
|
+
offset: Number of tenants to skip
|
|
1531
|
+
|
|
1532
|
+
Returns:
|
|
1533
|
+
List of tenant dicts
|
|
1534
|
+
"""
|
|
1535
|
+
pool = await get_pool()
|
|
1536
|
+
if not pool:
|
|
1537
|
+
return []
|
|
1538
|
+
|
|
1539
|
+
try:
|
|
1540
|
+
async with pool.acquire() as conn:
|
|
1541
|
+
rows = await conn.fetch(
|
|
1542
|
+
"""
|
|
1543
|
+
SELECT * FROM tenants
|
|
1544
|
+
ORDER BY created_at DESC
|
|
1545
|
+
LIMIT $1 OFFSET $2
|
|
1546
|
+
""",
|
|
1547
|
+
limit,
|
|
1548
|
+
offset,
|
|
1549
|
+
)
|
|
1550
|
+
return [_row_to_tenant(row) for row in rows]
|
|
1551
|
+
except Exception as e:
|
|
1552
|
+
logger.error(f'Failed to list tenants: {e}')
|
|
1553
|
+
return []
|
|
1554
|
+
|
|
1555
|
+
|
|
1556
|
+
async def update_tenant(tenant_id: str, **kwargs) -> dict:
|
|
1557
|
+
"""Update tenant fields.
|
|
1558
|
+
|
|
1559
|
+
Args:
|
|
1560
|
+
tenant_id: The tenant UUID
|
|
1561
|
+
**kwargs: Fields to update (realm_name, display_name, plan)
|
|
1562
|
+
|
|
1563
|
+
Returns:
|
|
1564
|
+
The updated tenant dict
|
|
1565
|
+
|
|
1566
|
+
Raises:
|
|
1567
|
+
ValueError: If tenant not found
|
|
1568
|
+
"""
|
|
1569
|
+
pool = await get_pool()
|
|
1570
|
+
if not pool:
|
|
1571
|
+
raise RuntimeError('Database not available')
|
|
1572
|
+
|
|
1573
|
+
allowed_fields = {'realm_name', 'display_name', 'plan'}
|
|
1574
|
+
updates = []
|
|
1575
|
+
params = [tenant_id]
|
|
1576
|
+
param_idx = 2
|
|
1577
|
+
|
|
1578
|
+
for field, value in kwargs.items():
|
|
1579
|
+
if field in allowed_fields:
|
|
1580
|
+
updates.append(f'{field} = ${param_idx}')
|
|
1581
|
+
params.append(value)
|
|
1582
|
+
param_idx += 1
|
|
1583
|
+
|
|
1584
|
+
if not updates:
|
|
1585
|
+
# No valid fields to update, just return current tenant
|
|
1586
|
+
tenant = await get_tenant_by_id(tenant_id)
|
|
1587
|
+
if not tenant:
|
|
1588
|
+
raise ValueError(f'Tenant {tenant_id} not found')
|
|
1589
|
+
return tenant
|
|
1590
|
+
|
|
1591
|
+
updates.append('updated_at = NOW()')
|
|
1592
|
+
|
|
1593
|
+
try:
|
|
1594
|
+
async with pool.acquire() as conn:
|
|
1595
|
+
result = await conn.execute(
|
|
1596
|
+
f'UPDATE tenants SET {", ".join(updates)} WHERE id = $1',
|
|
1597
|
+
*params,
|
|
1598
|
+
)
|
|
1599
|
+
if 'UPDATE 0' in result:
|
|
1600
|
+
raise ValueError(f'Tenant {tenant_id} not found')
|
|
1601
|
+
|
|
1602
|
+
tenant = await get_tenant_by_id(tenant_id)
|
|
1603
|
+
if not tenant:
|
|
1604
|
+
raise ValueError(f'Tenant {tenant_id} not found')
|
|
1605
|
+
return tenant
|
|
1606
|
+
except ValueError:
|
|
1607
|
+
raise
|
|
1608
|
+
except Exception as e:
|
|
1609
|
+
logger.error(f'Failed to update tenant: {e}')
|
|
1610
|
+
raise
|
|
1611
|
+
|
|
1612
|
+
|
|
1613
|
+
async def update_tenant_stripe(
|
|
1614
|
+
tenant_id: str, customer_id: str, subscription_id: str
|
|
1615
|
+
) -> dict:
|
|
1616
|
+
"""Update tenant Stripe billing information.
|
|
1617
|
+
|
|
1618
|
+
Args:
|
|
1619
|
+
tenant_id: The tenant UUID
|
|
1620
|
+
customer_id: Stripe customer ID
|
|
1621
|
+
subscription_id: Stripe subscription ID
|
|
1622
|
+
|
|
1623
|
+
Returns:
|
|
1624
|
+
The updated tenant dict
|
|
1625
|
+
|
|
1626
|
+
Raises:
|
|
1627
|
+
ValueError: If tenant not found
|
|
1628
|
+
"""
|
|
1629
|
+
pool = await get_pool()
|
|
1630
|
+
if not pool:
|
|
1631
|
+
raise RuntimeError('Database not available')
|
|
1632
|
+
|
|
1633
|
+
try:
|
|
1634
|
+
async with pool.acquire() as conn:
|
|
1635
|
+
result = await conn.execute(
|
|
1636
|
+
"""
|
|
1637
|
+
UPDATE tenants
|
|
1638
|
+
SET stripe_customer_id = $2,
|
|
1639
|
+
stripe_subscription_id = $3,
|
|
1640
|
+
updated_at = NOW()
|
|
1641
|
+
WHERE id = $1
|
|
1642
|
+
""",
|
|
1643
|
+
tenant_id,
|
|
1644
|
+
customer_id,
|
|
1645
|
+
subscription_id,
|
|
1646
|
+
)
|
|
1647
|
+
if 'UPDATE 0' in result:
|
|
1648
|
+
raise ValueError(f'Tenant {tenant_id} not found')
|
|
1649
|
+
|
|
1650
|
+
tenant = await get_tenant_by_id(tenant_id)
|
|
1651
|
+
if not tenant:
|
|
1652
|
+
raise ValueError(f'Tenant {tenant_id} not found')
|
|
1653
|
+
return tenant
|
|
1654
|
+
except ValueError:
|
|
1655
|
+
raise
|
|
1656
|
+
except Exception as e:
|
|
1657
|
+
logger.error(f'Failed to update tenant Stripe info: {e}')
|
|
1658
|
+
raise
|
|
1659
|
+
|
|
1660
|
+
|
|
1661
|
+
def _row_to_tenant(row) -> dict:
|
|
1662
|
+
"""Convert a database row to a tenant dict."""
|
|
1663
|
+
return {
|
|
1664
|
+
'id': row['id'],
|
|
1665
|
+
'realm_name': row['realm_name'],
|
|
1666
|
+
'display_name': row['display_name'],
|
|
1667
|
+
'plan': row['plan'],
|
|
1668
|
+
'stripe_customer_id': row['stripe_customer_id'],
|
|
1669
|
+
'stripe_subscription_id': row['stripe_subscription_id'],
|
|
1670
|
+
'created_at': row['created_at'].isoformat()
|
|
1671
|
+
if row['created_at']
|
|
1672
|
+
else None,
|
|
1673
|
+
'updated_at': row['updated_at'].isoformat()
|
|
1674
|
+
if row['updated_at']
|
|
1675
|
+
else None,
|
|
1676
|
+
}
|
|
1677
|
+
|
|
1678
|
+
|
|
1679
|
+
# ========================================
|
|
1680
|
+
# Row-Level Security (RLS) Support
|
|
1681
|
+
# ========================================
|
|
1682
|
+
|
|
1683
|
+
|
|
1684
|
+
async def set_tenant_context(conn, tenant_id: str) -> None:
|
|
1685
|
+
"""Set the tenant context for the current database connection.
|
|
1686
|
+
|
|
1687
|
+
This sets the PostgreSQL session variable 'app.current_tenant_id' which
|
|
1688
|
+
is used by RLS policies to filter rows by tenant.
|
|
1689
|
+
|
|
1690
|
+
Args:
|
|
1691
|
+
conn: asyncpg connection object
|
|
1692
|
+
tenant_id: The tenant UUID to set as context
|
|
1693
|
+
|
|
1694
|
+
Example:
|
|
1695
|
+
async with pool.acquire() as conn:
|
|
1696
|
+
await set_tenant_context(conn, tenant_id)
|
|
1697
|
+
results = await conn.fetch("SELECT * FROM workers")
|
|
1698
|
+
await clear_tenant_context(conn)
|
|
1699
|
+
"""
|
|
1700
|
+
if not tenant_id:
|
|
1701
|
+
logger.warning('set_tenant_context called with empty tenant_id')
|
|
1702
|
+
return
|
|
1703
|
+
|
|
1704
|
+
if not RLS_ENABLED:
|
|
1705
|
+
logger.debug(f'RLS disabled, skipping tenant context: {tenant_id}')
|
|
1706
|
+
return
|
|
1707
|
+
|
|
1708
|
+
try:
|
|
1709
|
+
await conn.execute(
|
|
1710
|
+
"SELECT set_config('app.current_tenant_id', $1, false)", tenant_id
|
|
1711
|
+
)
|
|
1712
|
+
logger.debug(f'Set tenant context: {tenant_id}')
|
|
1713
|
+
except Exception as e:
|
|
1714
|
+
logger.error(f'Failed to set tenant context: {e}')
|
|
1715
|
+
raise
|
|
1716
|
+
|
|
1717
|
+
|
|
1718
|
+
async def clear_tenant_context(conn) -> None:
|
|
1719
|
+
"""Clear the tenant context for the current database connection.
|
|
1720
|
+
|
|
1721
|
+
This resets the PostgreSQL session variable 'app.current_tenant_id' to NULL,
|
|
1722
|
+
which allows access to all rows when RLS policies check for NULL context.
|
|
1723
|
+
|
|
1724
|
+
Args:
|
|
1725
|
+
conn: asyncpg connection object
|
|
1726
|
+
"""
|
|
1727
|
+
if not RLS_ENABLED:
|
|
1728
|
+
return
|
|
1729
|
+
|
|
1730
|
+
try:
|
|
1731
|
+
await conn.execute('RESET app.current_tenant_id')
|
|
1732
|
+
logger.debug('Cleared tenant context')
|
|
1733
|
+
except Exception as e:
|
|
1734
|
+
logger.warning(f'Failed to clear tenant context: {e}')
|
|
1735
|
+
|
|
1736
|
+
|
|
1737
|
+
async def get_tenant_context(conn) -> Optional[str]:
|
|
1738
|
+
"""Get the current tenant context from the database connection.
|
|
1739
|
+
|
|
1740
|
+
Args:
|
|
1741
|
+
conn: asyncpg connection object
|
|
1742
|
+
|
|
1743
|
+
Returns:
|
|
1744
|
+
The current tenant ID or None if not set
|
|
1745
|
+
"""
|
|
1746
|
+
if not RLS_ENABLED:
|
|
1747
|
+
return None
|
|
1748
|
+
|
|
1749
|
+
try:
|
|
1750
|
+
result = await conn.fetchval(
|
|
1751
|
+
"SELECT current_setting('app.current_tenant_id', true)"
|
|
1752
|
+
)
|
|
1753
|
+
return result
|
|
1754
|
+
except Exception as e:
|
|
1755
|
+
logger.debug(f'Failed to get tenant context: {e}')
|
|
1756
|
+
return None
|
|
1757
|
+
|
|
1758
|
+
|
|
1759
|
+
@asynccontextmanager
|
|
1760
|
+
async def tenant_scope(tenant_id: str):
|
|
1761
|
+
"""Context manager for tenant-scoped database operations.
|
|
1762
|
+
|
|
1763
|
+
Acquires a connection from the pool, sets the tenant context,
|
|
1764
|
+
yields the connection, and ensures the context is cleared afterward.
|
|
1765
|
+
|
|
1766
|
+
This is the recommended way to perform tenant-scoped operations
|
|
1767
|
+
when RLS is enabled.
|
|
1768
|
+
|
|
1769
|
+
Args:
|
|
1770
|
+
tenant_id: The tenant UUID to scope operations to
|
|
1771
|
+
|
|
1772
|
+
Yields:
|
|
1773
|
+
asyncpg connection with tenant context set
|
|
1774
|
+
|
|
1775
|
+
Example:
|
|
1776
|
+
async with tenant_scope("tenant-uuid") as conn:
|
|
1777
|
+
results = await conn.fetch("SELECT * FROM workers")
|
|
1778
|
+
|
|
1779
|
+
Raises:
|
|
1780
|
+
RuntimeError: If database pool is not available
|
|
1781
|
+
"""
|
|
1782
|
+
pool = await get_pool()
|
|
1783
|
+
if not pool:
|
|
1784
|
+
raise RuntimeError('Database pool not available')
|
|
1785
|
+
|
|
1786
|
+
conn = await pool.acquire()
|
|
1787
|
+
try:
|
|
1788
|
+
await set_tenant_context(conn, tenant_id)
|
|
1789
|
+
yield conn
|
|
1790
|
+
finally:
|
|
1791
|
+
try:
|
|
1792
|
+
await clear_tenant_context(conn)
|
|
1793
|
+
finally:
|
|
1794
|
+
await pool.release(conn)
|
|
1795
|
+
|
|
1796
|
+
|
|
1797
|
+
@asynccontextmanager
|
|
1798
|
+
async def admin_scope():
|
|
1799
|
+
"""Context manager for admin operations that bypass RLS.
|
|
1800
|
+
|
|
1801
|
+
This acquires a connection without setting tenant context, which
|
|
1802
|
+
allows access to all rows when RLS policies allow NULL context.
|
|
1803
|
+
|
|
1804
|
+
WARNING: Use sparingly and only for legitimate administrative operations
|
|
1805
|
+
like migrations, auditing, or cross-tenant reporting.
|
|
1806
|
+
|
|
1807
|
+
Yields:
|
|
1808
|
+
asyncpg connection with admin-level access
|
|
1809
|
+
|
|
1810
|
+
Example:
|
|
1811
|
+
async with admin_scope() as conn:
|
|
1812
|
+
# Can access all tenants' data
|
|
1813
|
+
results = await conn.fetch("SELECT COUNT(*) FROM workers")
|
|
1814
|
+
"""
|
|
1815
|
+
pool = await get_pool()
|
|
1816
|
+
if not pool:
|
|
1817
|
+
raise RuntimeError('Database pool not available')
|
|
1818
|
+
|
|
1819
|
+
conn = await pool.acquire()
|
|
1820
|
+
try:
|
|
1821
|
+
# Clear any existing tenant context to use admin bypass
|
|
1822
|
+
if RLS_ENABLED:
|
|
1823
|
+
await conn.execute('RESET app.current_tenant_id')
|
|
1824
|
+
yield conn
|
|
1825
|
+
finally:
|
|
1826
|
+
await pool.release(conn)
|
|
1827
|
+
|
|
1828
|
+
|
|
1829
|
+
async def db_execute_as_tenant(tenant_id: str, query: str, *args) -> Any:
|
|
1830
|
+
"""Execute a query with tenant context.
|
|
1831
|
+
|
|
1832
|
+
Convenience function for executing a single query within tenant scope.
|
|
1833
|
+
|
|
1834
|
+
Args:
|
|
1835
|
+
tenant_id: The tenant UUID
|
|
1836
|
+
query: SQL query to execute
|
|
1837
|
+
*args: Query parameters
|
|
1838
|
+
|
|
1839
|
+
Returns:
|
|
1840
|
+
Query result
|
|
1841
|
+
"""
|
|
1842
|
+
async with tenant_scope(tenant_id) as conn:
|
|
1843
|
+
return await conn.execute(query, *args)
|
|
1844
|
+
|
|
1845
|
+
|
|
1846
|
+
async def db_fetch_as_tenant(tenant_id: str, query: str, *args) -> List[Any]:
|
|
1847
|
+
"""Fetch rows with tenant context.
|
|
1848
|
+
|
|
1849
|
+
Convenience function for fetching rows within tenant scope.
|
|
1850
|
+
|
|
1851
|
+
Args:
|
|
1852
|
+
tenant_id: The tenant UUID
|
|
1853
|
+
query: SQL query to execute
|
|
1854
|
+
*args: Query parameters
|
|
1855
|
+
|
|
1856
|
+
Returns:
|
|
1857
|
+
List of rows
|
|
1858
|
+
"""
|
|
1859
|
+
async with tenant_scope(tenant_id) as conn:
|
|
1860
|
+
return await conn.fetch(query, *args)
|
|
1861
|
+
|
|
1862
|
+
|
|
1863
|
+
async def db_fetchrow_as_tenant(
|
|
1864
|
+
tenant_id: str, query: str, *args
|
|
1865
|
+
) -> Optional[Any]:
|
|
1866
|
+
"""Fetch a single row with tenant context.
|
|
1867
|
+
|
|
1868
|
+
Args:
|
|
1869
|
+
tenant_id: The tenant UUID
|
|
1870
|
+
query: SQL query to execute
|
|
1871
|
+
*args: Query parameters
|
|
1872
|
+
|
|
1873
|
+
Returns:
|
|
1874
|
+
Single row or None
|
|
1875
|
+
"""
|
|
1876
|
+
async with tenant_scope(tenant_id) as conn:
|
|
1877
|
+
return await conn.fetchrow(query, *args)
|
|
1878
|
+
|
|
1879
|
+
|
|
1880
|
+
async def db_fetchval_as_tenant(
|
|
1881
|
+
tenant_id: str, query: str, *args
|
|
1882
|
+
) -> Optional[Any]:
|
|
1883
|
+
"""Fetch a single value with tenant context.
|
|
1884
|
+
|
|
1885
|
+
Args:
|
|
1886
|
+
tenant_id: The tenant UUID
|
|
1887
|
+
query: SQL query to execute
|
|
1888
|
+
*args: Query parameters
|
|
1889
|
+
|
|
1890
|
+
Returns:
|
|
1891
|
+
Single value or None
|
|
1892
|
+
"""
|
|
1893
|
+
async with tenant_scope(tenant_id) as conn:
|
|
1894
|
+
return await conn.fetchval(query, *args)
|
|
1895
|
+
|
|
1896
|
+
|
|
1897
|
+
# ========================================
|
|
1898
|
+
# RLS Migration Support
|
|
1899
|
+
# ========================================
|
|
1900
|
+
|
|
1901
|
+
|
|
1902
|
+
async def db_run_migrations(
|
|
1903
|
+
migrations_dir: Optional[str] = None,
|
|
1904
|
+
) -> Dict[str, Any]:
|
|
1905
|
+
"""Run SQL migration files from the migrations directory.
|
|
1906
|
+
|
|
1907
|
+
Executes all .sql files in the migrations directory that haven't been
|
|
1908
|
+
applied yet, tracking them in the schema_migrations table.
|
|
1909
|
+
|
|
1910
|
+
Args:
|
|
1911
|
+
migrations_dir: Path to migrations directory (default: a2a_server/migrations)
|
|
1912
|
+
|
|
1913
|
+
Returns:
|
|
1914
|
+
Dict with migration results:
|
|
1915
|
+
- applied: List of newly applied migrations
|
|
1916
|
+
- skipped: List of already applied migrations
|
|
1917
|
+
- failed: List of failed migrations with errors
|
|
1918
|
+
"""
|
|
1919
|
+
from pathlib import Path
|
|
1920
|
+
|
|
1921
|
+
if migrations_dir is None:
|
|
1922
|
+
migrations_path = Path(__file__).parent / 'migrations'
|
|
1923
|
+
else:
|
|
1924
|
+
migrations_path = Path(migrations_dir)
|
|
1925
|
+
|
|
1926
|
+
if not migrations_path.exists():
|
|
1927
|
+
logger.warning(f'Migrations directory not found: {migrations_path}')
|
|
1928
|
+
return {'applied': [], 'skipped': [], 'failed': []}
|
|
1929
|
+
|
|
1930
|
+
pool = await get_pool()
|
|
1931
|
+
if not pool:
|
|
1932
|
+
raise RuntimeError('Database pool not available')
|
|
1933
|
+
|
|
1934
|
+
results: Dict[str, List[Any]] = {'applied': [], 'skipped': [], 'failed': []}
|
|
1935
|
+
|
|
1936
|
+
async with pool.acquire() as conn:
|
|
1937
|
+
# Ensure schema_migrations table exists
|
|
1938
|
+
await conn.execute("""
|
|
1939
|
+
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
1940
|
+
id SERIAL PRIMARY KEY,
|
|
1941
|
+
migration_name TEXT NOT NULL UNIQUE,
|
|
1942
|
+
applied_at TIMESTAMPTZ DEFAULT NOW(),
|
|
1943
|
+
checksum TEXT
|
|
1944
|
+
)
|
|
1945
|
+
""")
|
|
1946
|
+
|
|
1947
|
+
# Get already applied migrations
|
|
1948
|
+
applied = await conn.fetch(
|
|
1949
|
+
'SELECT migration_name FROM schema_migrations'
|
|
1950
|
+
)
|
|
1951
|
+
applied_set = {row['migration_name'] for row in applied}
|
|
1952
|
+
|
|
1953
|
+
# Get all .sql files sorted by name
|
|
1954
|
+
migration_files = sorted(migrations_path.glob('*.sql'))
|
|
1955
|
+
|
|
1956
|
+
for migration_file in migration_files:
|
|
1957
|
+
migration_name = migration_file.stem
|
|
1958
|
+
|
|
1959
|
+
if migration_name in applied_set:
|
|
1960
|
+
results['skipped'].append(migration_name)
|
|
1961
|
+
logger.info(
|
|
1962
|
+
f'Skipping already applied migration: {migration_name}'
|
|
1963
|
+
)
|
|
1964
|
+
continue
|
|
1965
|
+
|
|
1966
|
+
try:
|
|
1967
|
+
logger.info(f'Applying migration: {migration_name}')
|
|
1968
|
+
|
|
1969
|
+
# Read and execute migration
|
|
1970
|
+
sql_content = migration_file.read_text()
|
|
1971
|
+
|
|
1972
|
+
# Execute in a transaction
|
|
1973
|
+
async with conn.transaction():
|
|
1974
|
+
await conn.execute(sql_content)
|
|
1975
|
+
|
|
1976
|
+
# Record migration
|
|
1977
|
+
await conn.execute(
|
|
1978
|
+
"""
|
|
1979
|
+
INSERT INTO schema_migrations (migration_name, checksum)
|
|
1980
|
+
VALUES ($1, $2)
|
|
1981
|
+
ON CONFLICT (migration_name) DO NOTHING
|
|
1982
|
+
""",
|
|
1983
|
+
migration_name,
|
|
1984
|
+
str(hash(sql_content)),
|
|
1985
|
+
)
|
|
1986
|
+
|
|
1987
|
+
results['applied'].append(migration_name)
|
|
1988
|
+
logger.info(f'Successfully applied migration: {migration_name}')
|
|
1989
|
+
|
|
1990
|
+
except Exception as e:
|
|
1991
|
+
logger.error(f'Failed to apply migration {migration_name}: {e}')
|
|
1992
|
+
results['failed'].append(
|
|
1993
|
+
{'name': migration_name, 'error': str(e)}
|
|
1994
|
+
)
|
|
1995
|
+
|
|
1996
|
+
return results
|
|
1997
|
+
|
|
1998
|
+
|
|
1999
|
+
async def db_enable_rls() -> Dict[str, Any]:
|
|
2000
|
+
"""Enable RLS by running the enable_rls.sql migration.
|
|
2001
|
+
|
|
2002
|
+
This enables Row-Level Security on all tenant-scoped tables:
|
|
2003
|
+
- workers
|
|
2004
|
+
- codebases
|
|
2005
|
+
- tasks
|
|
2006
|
+
- sessions
|
|
2007
|
+
|
|
2008
|
+
Returns:
|
|
2009
|
+
Dict with migration result
|
|
2010
|
+
"""
|
|
2011
|
+
from pathlib import Path
|
|
2012
|
+
|
|
2013
|
+
migrations_path = Path(__file__).parent / 'migrations'
|
|
2014
|
+
enable_rls_file = migrations_path / 'enable_rls.sql'
|
|
2015
|
+
|
|
2016
|
+
if not enable_rls_file.exists():
|
|
2017
|
+
return {
|
|
2018
|
+
'status': 'error',
|
|
2019
|
+
'message': f'RLS migration not found: {enable_rls_file}',
|
|
2020
|
+
}
|
|
2021
|
+
|
|
2022
|
+
pool = await get_pool()
|
|
2023
|
+
if not pool:
|
|
2024
|
+
return {'status': 'error', 'message': 'Database pool not available'}
|
|
2025
|
+
|
|
2026
|
+
try:
|
|
2027
|
+
async with pool.acquire() as conn:
|
|
2028
|
+
sql_content = enable_rls_file.read_text()
|
|
2029
|
+
await conn.execute(sql_content)
|
|
2030
|
+
|
|
2031
|
+
logger.info('RLS enabled successfully')
|
|
2032
|
+
return {
|
|
2033
|
+
'status': 'success',
|
|
2034
|
+
'message': 'RLS enabled on all tenant-scoped tables',
|
|
2035
|
+
}
|
|
2036
|
+
except Exception as e:
|
|
2037
|
+
logger.error(f'Failed to enable RLS: {e}')
|
|
2038
|
+
return {'status': 'error', 'message': str(e)}
|
|
2039
|
+
|
|
2040
|
+
|
|
2041
|
+
async def db_disable_rls() -> Dict[str, Any]:
|
|
2042
|
+
"""Disable RLS by running the disable_rls.sql migration.
|
|
2043
|
+
|
|
2044
|
+
This disables Row-Level Security and removes all policies.
|
|
2045
|
+
|
|
2046
|
+
Returns:
|
|
2047
|
+
Dict with migration result
|
|
2048
|
+
"""
|
|
2049
|
+
from pathlib import Path
|
|
2050
|
+
|
|
2051
|
+
migrations_path = Path(__file__).parent / 'migrations'
|
|
2052
|
+
disable_rls_file = migrations_path / 'disable_rls.sql'
|
|
2053
|
+
|
|
2054
|
+
if not disable_rls_file.exists():
|
|
2055
|
+
return {
|
|
2056
|
+
'status': 'error',
|
|
2057
|
+
'message': f'RLS rollback not found: {disable_rls_file}',
|
|
2058
|
+
}
|
|
2059
|
+
|
|
2060
|
+
pool = await get_pool()
|
|
2061
|
+
if not pool:
|
|
2062
|
+
return {'status': 'error', 'message': 'Database pool not available'}
|
|
2063
|
+
|
|
2064
|
+
try:
|
|
2065
|
+
async with pool.acquire() as conn:
|
|
2066
|
+
sql_content = disable_rls_file.read_text()
|
|
2067
|
+
await conn.execute(sql_content)
|
|
2068
|
+
|
|
2069
|
+
logger.info('RLS disabled successfully')
|
|
2070
|
+
return {
|
|
2071
|
+
'status': 'success',
|
|
2072
|
+
'message': 'RLS disabled on all tenant-scoped tables',
|
|
2073
|
+
}
|
|
2074
|
+
except Exception as e:
|
|
2075
|
+
logger.error(f'Failed to disable RLS: {e}')
|
|
2076
|
+
return {'status': 'error', 'message': str(e)}
|
|
2077
|
+
|
|
2078
|
+
|
|
2079
|
+
async def get_rls_status() -> Dict[str, Any]:
|
|
2080
|
+
"""Get the current RLS status for all tenant-scoped tables.
|
|
2081
|
+
|
|
2082
|
+
Returns:
|
|
2083
|
+
Dict with RLS status for each table and overall configuration
|
|
2084
|
+
"""
|
|
2085
|
+
pool = await get_pool()
|
|
2086
|
+
if not pool:
|
|
2087
|
+
return {'enabled': False, 'database_available': False, 'tables': {}}
|
|
2088
|
+
|
|
2089
|
+
try:
|
|
2090
|
+
async with pool.acquire() as conn:
|
|
2091
|
+
# Check RLS status on each table
|
|
2092
|
+
rows = await conn.fetch("""
|
|
2093
|
+
SELECT
|
|
2094
|
+
schemaname,
|
|
2095
|
+
tablename,
|
|
2096
|
+
rowsecurity as rls_enabled,
|
|
2097
|
+
forcerowsecurity as rls_forced
|
|
2098
|
+
FROM pg_tables
|
|
2099
|
+
WHERE schemaname = 'public'
|
|
2100
|
+
AND tablename IN ('workers', 'codebases', 'tasks', 'sessions')
|
|
2101
|
+
""")
|
|
2102
|
+
|
|
2103
|
+
tables = {}
|
|
2104
|
+
all_enabled = True
|
|
2105
|
+
|
|
2106
|
+
for row in rows:
|
|
2107
|
+
tables[row['tablename']] = {
|
|
2108
|
+
'rls_enabled': row['rls_enabled'],
|
|
2109
|
+
'rls_forced': row['rls_forced'],
|
|
2110
|
+
}
|
|
2111
|
+
if not row['rls_enabled']:
|
|
2112
|
+
all_enabled = False
|
|
2113
|
+
|
|
2114
|
+
# Check for policies
|
|
2115
|
+
policies = await conn.fetch("""
|
|
2116
|
+
SELECT tablename, policyname
|
|
2117
|
+
FROM pg_policies
|
|
2118
|
+
WHERE schemaname = 'public'
|
|
2119
|
+
AND tablename IN ('workers', 'codebases', 'tasks', 'sessions')
|
|
2120
|
+
""")
|
|
2121
|
+
|
|
2122
|
+
policy_count: Dict[str, int] = {}
|
|
2123
|
+
for policy in policies:
|
|
2124
|
+
table = policy['tablename']
|
|
2125
|
+
if table not in policy_count:
|
|
2126
|
+
policy_count[table] = 0
|
|
2127
|
+
policy_count[table] += 1
|
|
2128
|
+
|
|
2129
|
+
for table in tables:
|
|
2130
|
+
tables[table]['policy_count'] = policy_count.get(table, 0)
|
|
2131
|
+
|
|
2132
|
+
return {
|
|
2133
|
+
'enabled': all_enabled and len(tables) == 4,
|
|
2134
|
+
'database_available': True,
|
|
2135
|
+
'rls_env_enabled': RLS_ENABLED,
|
|
2136
|
+
'strict_mode': RLS_STRICT_MODE,
|
|
2137
|
+
'tables': tables,
|
|
2138
|
+
}
|
|
2139
|
+
|
|
2140
|
+
except Exception as e:
|
|
2141
|
+
logger.error(f'Failed to get RLS status: {e}')
|
|
2142
|
+
return {
|
|
2143
|
+
'enabled': False,
|
|
2144
|
+
'database_available': True,
|
|
2145
|
+
'error': str(e),
|
|
2146
|
+
'tables': {},
|
|
2147
|
+
}
|
|
2148
|
+
|
|
2149
|
+
|
|
2150
|
+
def init_rls_config() -> None:
|
|
2151
|
+
"""Initialize RLS configuration from environment.
|
|
2152
|
+
|
|
2153
|
+
Call this at application startup to configure RLS settings.
|
|
2154
|
+
Updates the module-level RLS_ENABLED and RLS_STRICT_MODE variables.
|
|
2155
|
+
"""
|
|
2156
|
+
global RLS_ENABLED, RLS_STRICT_MODE
|
|
2157
|
+
|
|
2158
|
+
RLS_ENABLED = os.environ.get('RLS_ENABLED', 'false').lower() == 'true'
|
|
2159
|
+
RLS_STRICT_MODE = (
|
|
2160
|
+
os.environ.get('RLS_STRICT_MODE', 'false').lower() == 'true'
|
|
2161
|
+
)
|
|
2162
|
+
|
|
2163
|
+
logger.info(
|
|
2164
|
+
f'RLS Configuration: enabled={RLS_ENABLED}, strict={RLS_STRICT_MODE}'
|
|
2165
|
+
)
|