slide-narrator 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of slide-narrator might be problematic. Click here for more details.
- narrator/__init__.py +18 -0
- narrator/database/__init__.py +8 -0
- narrator/database/cli.py +66 -0
- narrator/database/migrations/__init__.py +6 -0
- narrator/database/models.py +69 -0
- narrator/database/storage_backend.py +580 -0
- narrator/database/thread_store.py +280 -0
- narrator/models/__init__.py +9 -0
- narrator/models/attachment.py +363 -0
- narrator/models/message.py +507 -0
- narrator/models/thread.py +469 -0
- narrator/storage/__init__.py +7 -0
- narrator/storage/file_store.py +535 -0
- narrator/utils/__init__.py +9 -0
- narrator/utils/logging.py +58 -0
- slide_narrator-0.2.1.dist-info/METADATA +531 -0
- slide_narrator-0.2.1.dist-info/RECORD +20 -0
- slide_narrator-0.2.1.dist-info/WHEEL +4 -0
- slide_narrator-0.2.1.dist-info/entry_points.txt +2 -0
- slide_narrator-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,580 @@
|
|
|
1
|
+
"""Storage backend implementations for ThreadStore."""
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import List, Optional, Dict, Any
|
|
4
|
+
from datetime import datetime, UTC
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import tempfile
|
|
9
|
+
import asyncio
|
|
10
|
+
from sqlalchemy import create_engine, select, cast, String, text, bindparam
|
|
11
|
+
from sqlalchemy.orm import sessionmaker, selectinload
|
|
12
|
+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
13
|
+
# Direct imports
|
|
14
|
+
from ..models.thread import Thread
|
|
15
|
+
from ..models.message import Message
|
|
16
|
+
from ..models.attachment import Attachment
|
|
17
|
+
from ..storage.file_store import FileStore
|
|
18
|
+
from ..utils.logging import get_logger
|
|
19
|
+
from .models import Base, ThreadRecord, MessageRecord
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
class StorageBackend(ABC):
|
|
24
|
+
"""Abstract base class for thread storage backends."""
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
async def initialize(self) -> None:
|
|
28
|
+
"""Initialize the storage backend."""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def save(self, thread: Thread) -> Thread:
|
|
33
|
+
"""Save a thread to storage."""
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
async def get(self, thread_id: str) -> Optional[Thread]:
|
|
38
|
+
"""Get a thread by ID."""
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
async def delete(self, thread_id: str) -> bool:
|
|
43
|
+
"""Delete a thread by ID."""
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
|
|
48
|
+
"""List threads with pagination."""
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
|
|
53
|
+
"""Find threads by matching attributes."""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
|
|
58
|
+
"""Find threads by platform name and properties in the platforms structure."""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
@abstractmethod
|
|
62
|
+
async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
|
|
63
|
+
"""List recent threads."""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
@abstractmethod
|
|
67
|
+
async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
|
|
68
|
+
"""
|
|
69
|
+
Find messages that have a specific attribute at a given JSON path.
|
|
70
|
+
Uses efficient SQL JSON path queries for PostgreSQL and falls back to
|
|
71
|
+
SQLite JSON functions when needed.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
|
|
75
|
+
value: The value to search for
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
List of messages matching the criteria (possibly empty)
|
|
79
|
+
"""
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
class MemoryBackend(StorageBackend):
|
|
83
|
+
"""In-memory storage backend using a dictionary."""
|
|
84
|
+
|
|
85
|
+
def __init__(self):
|
|
86
|
+
self._threads: Dict[str, Thread] = {}
|
|
87
|
+
|
|
88
|
+
async def initialize(self) -> None:
|
|
89
|
+
pass # No initialization needed for memory backend
|
|
90
|
+
|
|
91
|
+
async def save(self, thread: Thread) -> Thread:
|
|
92
|
+
self._threads[thread.id] = thread
|
|
93
|
+
return thread
|
|
94
|
+
|
|
95
|
+
async def get(self, thread_id: str) -> Optional[Thread]:
|
|
96
|
+
return self._threads.get(thread_id)
|
|
97
|
+
|
|
98
|
+
async def delete(self, thread_id: str) -> bool:
|
|
99
|
+
if thread_id in self._threads:
|
|
100
|
+
del self._threads[thread_id]
|
|
101
|
+
return True
|
|
102
|
+
return False
|
|
103
|
+
|
|
104
|
+
async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
|
|
105
|
+
threads = sorted(
|
|
106
|
+
self._threads.values(),
|
|
107
|
+
key=lambda t: t.updated_at if hasattr(t, 'updated_at') else t.created_at,
|
|
108
|
+
reverse=True
|
|
109
|
+
)
|
|
110
|
+
return threads[offset:offset + limit]
|
|
111
|
+
|
|
112
|
+
async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
|
|
113
|
+
matching_threads = []
|
|
114
|
+
for thread in self._threads.values():
|
|
115
|
+
if all(
|
|
116
|
+
thread.attributes.get(k) == v
|
|
117
|
+
for k, v in attributes.items()
|
|
118
|
+
):
|
|
119
|
+
matching_threads.append(thread)
|
|
120
|
+
return matching_threads
|
|
121
|
+
|
|
122
|
+
async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
|
|
123
|
+
matching_threads = []
|
|
124
|
+
for thread in self._threads.values():
|
|
125
|
+
platforms = getattr(thread, 'platforms', {})
|
|
126
|
+
if (
|
|
127
|
+
isinstance(platforms, dict) and
|
|
128
|
+
platform_name in platforms and
|
|
129
|
+
all(platforms[platform_name].get(k) == v for k, v in properties.items())
|
|
130
|
+
):
|
|
131
|
+
matching_threads.append(thread)
|
|
132
|
+
return matching_threads
|
|
133
|
+
|
|
134
|
+
async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
|
|
135
|
+
threads = list(self._threads.values())
|
|
136
|
+
threads.sort(key=lambda t: t.updated_at or t.created_at, reverse=True)
|
|
137
|
+
if limit is not None:
|
|
138
|
+
threads = threads[:limit]
|
|
139
|
+
return threads
|
|
140
|
+
|
|
141
|
+
async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
|
|
142
|
+
"""
|
|
143
|
+
Check if any messages exist with a specific attribute at a given JSON path.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
|
|
147
|
+
value: The value to search for
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
True if any messages match, False otherwise
|
|
151
|
+
"""
|
|
152
|
+
# Traverse all threads and messages
|
|
153
|
+
for thread in self._threads.values():
|
|
154
|
+
for message in thread.messages:
|
|
155
|
+
# Use the path to navigate to the target attribute
|
|
156
|
+
current = message.model_dump(mode="python")
|
|
157
|
+
|
|
158
|
+
# Navigate the nested structure
|
|
159
|
+
parts = path.split('.')
|
|
160
|
+
for part in parts:
|
|
161
|
+
if isinstance(current, dict) and part in current:
|
|
162
|
+
current = current[part]
|
|
163
|
+
else:
|
|
164
|
+
current = None
|
|
165
|
+
break
|
|
166
|
+
|
|
167
|
+
# Check if we found a match
|
|
168
|
+
if current == value:
|
|
169
|
+
return [self._create_message_record(message, thread.id, 0)]
|
|
170
|
+
|
|
171
|
+
return []
|
|
172
|
+
|
|
173
|
+
class SQLBackend(StorageBackend):
|
|
174
|
+
"""SQL storage backend supporting both SQLite and PostgreSQL with proper connection pooling."""
|
|
175
|
+
|
|
176
|
+
def __init__(self, database_url: Optional[str] = None):
|
|
177
|
+
if database_url is None:
|
|
178
|
+
# Create a temporary directory that persists until program exit
|
|
179
|
+
tmp_dir = Path(tempfile.gettempdir()) / "narrator_threads"
|
|
180
|
+
tmp_dir.mkdir(exist_ok=True)
|
|
181
|
+
database_url = f"sqlite+aiosqlite:///{tmp_dir}/threads.db"
|
|
182
|
+
elif database_url == ":memory:":
|
|
183
|
+
database_url = "sqlite+aiosqlite:///:memory:"
|
|
184
|
+
|
|
185
|
+
self.database_url = database_url
|
|
186
|
+
|
|
187
|
+
# Configure engine options with better defaults for connection pooling
|
|
188
|
+
engine_kwargs = {
|
|
189
|
+
'echo': os.environ.get("NARRATOR_DB_ECHO", "").lower() == "true"
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Add pool configuration if not using SQLite
|
|
193
|
+
if not self.database_url.startswith('sqlite'):
|
|
194
|
+
# Default connection pool settings if not specified
|
|
195
|
+
pool_size = int(os.environ.get("NARRATOR_DB_POOL_SIZE", "5"))
|
|
196
|
+
max_overflow = int(os.environ.get("NARRATOR_DB_MAX_OVERFLOW", "10"))
|
|
197
|
+
pool_timeout = int(os.environ.get("NARRATOR_DB_POOL_TIMEOUT", "30"))
|
|
198
|
+
pool_recycle = int(os.environ.get("NARRATOR_DB_POOL_RECYCLE", "300"))
|
|
199
|
+
|
|
200
|
+
engine_kwargs.update({
|
|
201
|
+
'pool_size': pool_size,
|
|
202
|
+
'max_overflow': max_overflow,
|
|
203
|
+
'pool_timeout': pool_timeout,
|
|
204
|
+
'pool_recycle': pool_recycle,
|
|
205
|
+
'pool_pre_ping': True # Check connection validity before using from pool
|
|
206
|
+
})
|
|
207
|
+
|
|
208
|
+
logger.info(f"Configuring database connection pool: size={pool_size}, "
|
|
209
|
+
f"max_overflow={max_overflow}, timeout={pool_timeout}, "
|
|
210
|
+
f"recycle={pool_recycle}")
|
|
211
|
+
|
|
212
|
+
self.engine = create_async_engine(self.database_url, **engine_kwargs)
|
|
213
|
+
# Create session_maker for database operations
|
|
214
|
+
self._session_maker = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
|
215
|
+
|
|
216
|
+
@property
|
|
217
|
+
def async_session(self):
|
|
218
|
+
"""
|
|
219
|
+
Returns the session factory for creating new database sessions.
|
|
220
|
+
|
|
221
|
+
Use _get_session() method instead which properly creates a session
|
|
222
|
+
for each database operation.
|
|
223
|
+
"""
|
|
224
|
+
return self._session_maker
|
|
225
|
+
|
|
226
|
+
async def initialize(self) -> None:
|
|
227
|
+
"""Initialize the database by creating tables if they don't exist."""
|
|
228
|
+
async with self.engine.begin() as conn:
|
|
229
|
+
await conn.run_sync(Base.metadata.create_all)
|
|
230
|
+
logger.info(f"Database initialized with tables: {Base.metadata.tables.keys()}")
|
|
231
|
+
|
|
232
|
+
def _create_message_from_record(self, msg_record: MessageRecord) -> Message:
|
|
233
|
+
"""Helper method to create a Message from a MessageRecord"""
|
|
234
|
+
message = Message(
|
|
235
|
+
id=msg_record.id,
|
|
236
|
+
role=msg_record.role,
|
|
237
|
+
sequence=msg_record.sequence,
|
|
238
|
+
turn=msg_record.turn,
|
|
239
|
+
content=msg_record.content,
|
|
240
|
+
name=msg_record.name,
|
|
241
|
+
tool_call_id=msg_record.tool_call_id,
|
|
242
|
+
tool_calls=msg_record.tool_calls,
|
|
243
|
+
attributes=msg_record.attributes,
|
|
244
|
+
timestamp=msg_record.timestamp,
|
|
245
|
+
source=msg_record.source,
|
|
246
|
+
platforms=msg_record.platforms or {},
|
|
247
|
+
metrics=msg_record.metrics,
|
|
248
|
+
reactions=msg_record.reactions or {}
|
|
249
|
+
)
|
|
250
|
+
if msg_record.attachments:
|
|
251
|
+
message.attachments = [Attachment(**a) for a in msg_record.attachments]
|
|
252
|
+
return message
|
|
253
|
+
|
|
254
|
+
def _create_thread_from_record(self, record: ThreadRecord) -> Thread:
|
|
255
|
+
"""Helper method to create a Thread from a ThreadRecord"""
|
|
256
|
+
thread = Thread(
|
|
257
|
+
id=record.id,
|
|
258
|
+
title=record.title,
|
|
259
|
+
attributes=record.attributes,
|
|
260
|
+
platforms=record.platforms or {},
|
|
261
|
+
created_at=record.created_at,
|
|
262
|
+
updated_at=record.updated_at,
|
|
263
|
+
messages=[]
|
|
264
|
+
)
|
|
265
|
+
# Sort messages: system messages first, then others by sequence
|
|
266
|
+
sorted_messages = sorted(record.messages,
|
|
267
|
+
key=lambda m: (0 if m.role == "system" else 1, m.sequence or 0))
|
|
268
|
+
for msg_record in sorted_messages:
|
|
269
|
+
message = self._create_message_from_record(msg_record)
|
|
270
|
+
thread.messages.append(message)
|
|
271
|
+
return thread
|
|
272
|
+
|
|
273
|
+
def _create_message_record(self, message: Message, thread_id: str, sequence: int) -> MessageRecord:
|
|
274
|
+
"""Helper method to create a MessageRecord from a Message"""
|
|
275
|
+
return MessageRecord(
|
|
276
|
+
id=message.id,
|
|
277
|
+
thread_id=thread_id,
|
|
278
|
+
sequence=sequence,
|
|
279
|
+
turn=message.turn,
|
|
280
|
+
role=message.role,
|
|
281
|
+
content=message.content,
|
|
282
|
+
name=message.name,
|
|
283
|
+
tool_call_id=message.tool_call_id,
|
|
284
|
+
tool_calls=message.tool_calls,
|
|
285
|
+
attributes=message.attributes,
|
|
286
|
+
timestamp=message.timestamp,
|
|
287
|
+
source=message.source,
|
|
288
|
+
platforms=message.platforms,
|
|
289
|
+
attachments=[a.model_dump() for a in message.attachments] if message.attachments else None,
|
|
290
|
+
metrics=message.metrics,
|
|
291
|
+
reactions=message.reactions
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
async def _get_session(self) -> AsyncSession:
|
|
295
|
+
"""Create and return a new session for database operations."""
|
|
296
|
+
return self._session_maker()
|
|
297
|
+
|
|
298
|
+
async def _cleanup_failed_attachments(self, thread: Thread) -> None:
|
|
299
|
+
"""Helper to clean up attachment files if thread save fails"""
|
|
300
|
+
for message in thread.messages:
|
|
301
|
+
if message.attachments:
|
|
302
|
+
for attachment in message.attachments:
|
|
303
|
+
if hasattr(attachment, 'cleanup') and callable(attachment.cleanup):
|
|
304
|
+
await attachment.cleanup()
|
|
305
|
+
|
|
306
|
+
async def save(self, thread: Thread) -> Thread:
|
|
307
|
+
"""Save a thread and its messages to the database."""
|
|
308
|
+
session = await self._get_session()
|
|
309
|
+
|
|
310
|
+
# Create a FileStore instance for attachment storage
|
|
311
|
+
file_store = FileStore()
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
# Log the platforms data being saved
|
|
315
|
+
logger.info(f"SQLBackend.save: Attempting to save thread {thread.id}. Platforms data: {json.dumps(thread.platforms if thread.platforms is not None else {})}")
|
|
316
|
+
|
|
317
|
+
# First process and store all attachments
|
|
318
|
+
logger.info(f"Starting to process attachments for thread {thread.id}")
|
|
319
|
+
try:
|
|
320
|
+
for message in thread.messages:
|
|
321
|
+
if message.attachments:
|
|
322
|
+
logger.info(f"Processing {len(message.attachments)} attachments for message {message.id}")
|
|
323
|
+
for attachment in message.attachments:
|
|
324
|
+
logger.info(f"Processing attachment {attachment.filename} with status {attachment.status}")
|
|
325
|
+
await attachment.process_and_store(file_store)
|
|
326
|
+
logger.info(f"Finished processing attachment {attachment.filename}, new status: {attachment.status}")
|
|
327
|
+
except Exception as e:
|
|
328
|
+
# Handle attachment processing failures
|
|
329
|
+
logger.error(f"Failed to process attachment: {str(e)}")
|
|
330
|
+
await self._cleanup_failed_attachments(thread)
|
|
331
|
+
raise RuntimeError(f"Failed to save thread: {str(e)}") from e
|
|
332
|
+
|
|
333
|
+
async with session.begin():
|
|
334
|
+
# Get existing thread if it exists
|
|
335
|
+
stmt = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).where(ThreadRecord.id == thread.id)
|
|
336
|
+
result = await session.execute(stmt)
|
|
337
|
+
thread_record = result.scalar_one_or_none()
|
|
338
|
+
|
|
339
|
+
if thread_record:
|
|
340
|
+
# Update existing thread
|
|
341
|
+
thread_record.title = thread.title
|
|
342
|
+
thread_record.attributes = thread.attributes
|
|
343
|
+
thread_record.platforms = thread.platforms
|
|
344
|
+
thread_record.updated_at = datetime.now(UTC)
|
|
345
|
+
thread_record.messages = [] # Clear existing messages
|
|
346
|
+
else:
|
|
347
|
+
# Create new thread record
|
|
348
|
+
thread_record = ThreadRecord(
|
|
349
|
+
id=thread.id,
|
|
350
|
+
title=thread.title,
|
|
351
|
+
attributes=thread.attributes,
|
|
352
|
+
platforms=thread.platforms,
|
|
353
|
+
created_at=thread.created_at,
|
|
354
|
+
updated_at=thread.updated_at,
|
|
355
|
+
messages=[]
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
# Process messages in order
|
|
359
|
+
sequence = 1
|
|
360
|
+
|
|
361
|
+
# First handle system messages
|
|
362
|
+
for message in thread.messages:
|
|
363
|
+
if message.role == "system":
|
|
364
|
+
thread_record.messages.append(self._create_message_record(message, thread.id, 0))
|
|
365
|
+
|
|
366
|
+
# Then handle non-system messages
|
|
367
|
+
for message in thread.messages:
|
|
368
|
+
if message.role != "system":
|
|
369
|
+
thread_record.messages.append(self._create_message_record(message, thread.id, sequence))
|
|
370
|
+
sequence += 1
|
|
371
|
+
|
|
372
|
+
session.add(thread_record)
|
|
373
|
+
try:
|
|
374
|
+
await session.commit()
|
|
375
|
+
logger.info(f"Thread {thread.id} successfully committed to database.")
|
|
376
|
+
except Exception as e:
|
|
377
|
+
# Convert database errors to RuntimeError for consistent error handling
|
|
378
|
+
logger.error(f"Database error during commit: {str(e)}")
|
|
379
|
+
raise RuntimeError(f"Failed to save thread: Database error - {str(e)}") from e
|
|
380
|
+
return thread
|
|
381
|
+
|
|
382
|
+
except Exception as e:
|
|
383
|
+
# If this is not already a RuntimeError, wrap it
|
|
384
|
+
if not isinstance(e, RuntimeError):
|
|
385
|
+
raise RuntimeError(f"Failed to save thread: {str(e)}") from e
|
|
386
|
+
raise e
|
|
387
|
+
finally:
|
|
388
|
+
await session.close()
|
|
389
|
+
|
|
390
|
+
async def get(self, thread_id: str) -> Optional[Thread]:
|
|
391
|
+
"""Get a thread by ID."""
|
|
392
|
+
session = await self._get_session()
|
|
393
|
+
try:
|
|
394
|
+
stmt = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).where(ThreadRecord.id == thread_id)
|
|
395
|
+
result = await session.execute(stmt)
|
|
396
|
+
thread_record = result.scalar_one_or_none()
|
|
397
|
+
return self._create_thread_from_record(thread_record) if thread_record else None
|
|
398
|
+
finally:
|
|
399
|
+
await session.close()
|
|
400
|
+
|
|
401
|
+
async def delete(self, thread_id: str) -> bool:
|
|
402
|
+
"""Delete a thread by ID."""
|
|
403
|
+
session = await self._get_session()
|
|
404
|
+
try:
|
|
405
|
+
async with session.begin():
|
|
406
|
+
record = await session.get(ThreadRecord, thread_id)
|
|
407
|
+
if record:
|
|
408
|
+
await session.delete(record)
|
|
409
|
+
return True
|
|
410
|
+
return False
|
|
411
|
+
finally:
|
|
412
|
+
await session.close()
|
|
413
|
+
|
|
414
|
+
async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
|
|
415
|
+
"""List threads with pagination."""
|
|
416
|
+
session = await self._get_session()
|
|
417
|
+
try:
|
|
418
|
+
result = await session.execute(
|
|
419
|
+
select(ThreadRecord)
|
|
420
|
+
.options(selectinload(ThreadRecord.messages))
|
|
421
|
+
.order_by(ThreadRecord.updated_at.desc())
|
|
422
|
+
.limit(limit)
|
|
423
|
+
.offset(offset)
|
|
424
|
+
)
|
|
425
|
+
return [self._create_thread_from_record(record) for record in result.scalars().all()]
|
|
426
|
+
finally:
|
|
427
|
+
await session.close()
|
|
428
|
+
|
|
429
|
+
async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
|
|
430
|
+
"""Find threads by matching attributes."""
|
|
431
|
+
session = await self._get_session()
|
|
432
|
+
try:
|
|
433
|
+
query = select(ThreadRecord).options(selectinload(ThreadRecord.messages))
|
|
434
|
+
|
|
435
|
+
for key, value in attributes.items():
|
|
436
|
+
if self.database_url.startswith('sqlite'):
|
|
437
|
+
# Use SQLite json_extract
|
|
438
|
+
query = query.where(text(f"json_extract(attributes, '$.{key}') = :value").bindparams(value=str(value)))
|
|
439
|
+
else:
|
|
440
|
+
# Use PostgreSQL JSONB operators via text() for direct SQL control
|
|
441
|
+
logger.info(f"Searching for attribute[{key}] = {value} (type: {type(value)})")
|
|
442
|
+
|
|
443
|
+
# Handle different value types appropriately
|
|
444
|
+
if value is None:
|
|
445
|
+
# Check for null/None values
|
|
446
|
+
query = query.where(text(f"attributes->>'{key}' IS NULL"))
|
|
447
|
+
else:
|
|
448
|
+
# Convert value to string for text comparison
|
|
449
|
+
str_value = str(value)
|
|
450
|
+
if isinstance(value, bool):
|
|
451
|
+
# Convert boolean to lowercase string
|
|
452
|
+
str_value = str(value).lower()
|
|
453
|
+
|
|
454
|
+
# Use PostgreSQL's JSONB operators for direct string comparison
|
|
455
|
+
param_name = f"attr_{key}"
|
|
456
|
+
bp = bindparam(param_name, str_value)
|
|
457
|
+
query = query.where(
|
|
458
|
+
text(f"attributes->>'{key}' = :{param_name}").bindparams(bp)
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Log the final query for debugging
|
|
462
|
+
logger.info(f"Executing find_by_attributes query: {query}")
|
|
463
|
+
|
|
464
|
+
result = await session.execute(query)
|
|
465
|
+
threads = [self._create_thread_from_record(record) for record in result.scalars().all()]
|
|
466
|
+
logger.info(f"Found {len(threads)} matching threads")
|
|
467
|
+
return threads
|
|
468
|
+
except Exception as e:
|
|
469
|
+
logger.error(f"Error in find_by_attributes: {str(e)}")
|
|
470
|
+
raise
|
|
471
|
+
finally:
|
|
472
|
+
await session.close()
|
|
473
|
+
|
|
474
|
+
async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
|
|
475
|
+
"""Find threads by platform name and properties in the platforms structure."""
|
|
476
|
+
session = await self._get_session()
|
|
477
|
+
try:
|
|
478
|
+
query = select(ThreadRecord).options(selectinload(ThreadRecord.messages))
|
|
479
|
+
|
|
480
|
+
if self.database_url.startswith('sqlite'):
|
|
481
|
+
# Use SQLite json_extract for platform name
|
|
482
|
+
query = query.where(text(f"json_extract(platforms, '$.{platform_name}') IS NOT NULL"))
|
|
483
|
+
# Add property conditions
|
|
484
|
+
for key, value in properties.items():
|
|
485
|
+
# Convert value to string for text comparison
|
|
486
|
+
str_value = str(value)
|
|
487
|
+
param_name = f"value_{platform_name}_{key}" # Ensure unique param name
|
|
488
|
+
bp = bindparam(param_name, str_value)
|
|
489
|
+
query = query.where(
|
|
490
|
+
text(f"json_extract(platforms, '$.{platform_name}.{key}') = :{param_name}")
|
|
491
|
+
.bindparams(bp)
|
|
492
|
+
)
|
|
493
|
+
else:
|
|
494
|
+
# Use PostgreSQL JSONB operators for platform checks
|
|
495
|
+
query = query.where(text(f"platforms ? '{platform_name}'"))
|
|
496
|
+
|
|
497
|
+
# Add property conditions with text() for proper PostgreSQL JSONB syntax
|
|
498
|
+
for key, value in properties.items():
|
|
499
|
+
str_value = str(value)
|
|
500
|
+
param_name = f"value_{platform_name}_{key}"
|
|
501
|
+
bp = bindparam(param_name, str_value)
|
|
502
|
+
query = query.where(
|
|
503
|
+
text(f"platforms->'{platform_name}'->>'{key}' = :{param_name}")
|
|
504
|
+
.bindparams(bp)
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
result = await session.execute(query)
|
|
508
|
+
return [self._create_thread_from_record(record) for record in result.scalars().all()]
|
|
509
|
+
finally:
|
|
510
|
+
await session.close()
|
|
511
|
+
|
|
512
|
+
async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
|
|
513
|
+
"""List recent threads."""
|
|
514
|
+
session = await self._get_session()
|
|
515
|
+
try:
|
|
516
|
+
query = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).order_by(ThreadRecord.updated_at.desc())
|
|
517
|
+
if limit is not None:
|
|
518
|
+
query = query.limit(limit)
|
|
519
|
+
result = await session.execute(query)
|
|
520
|
+
return [self._create_thread_from_record(record) for record in result.scalars().all()]
|
|
521
|
+
finally:
|
|
522
|
+
await session.close()
|
|
523
|
+
|
|
524
|
+
async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
|
|
525
|
+
"""
|
|
526
|
+
Find messages that have a specific attribute at a given JSON path.
|
|
527
|
+
Uses efficient SQL JSON path queries for PostgreSQL and falls back to
|
|
528
|
+
SQLite JSON functions when needed.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
|
|
532
|
+
value: The value to search for
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
List of messages matching the criteria (possibly empty)
|
|
536
|
+
"""
|
|
537
|
+
session = await self._get_session()
|
|
538
|
+
try:
|
|
539
|
+
query = select(MessageRecord)
|
|
540
|
+
|
|
541
|
+
if self.database_url.startswith('sqlite'):
|
|
542
|
+
# Use SQLite json_extract
|
|
543
|
+
json_path = '$.' + path.replace('.', '.')
|
|
544
|
+
query = query.where(text(f"json_extract(source, '{json_path}') = :value").bindparams(value=str(value)))
|
|
545
|
+
else:
|
|
546
|
+
# Use PostgreSQL JSONB operators
|
|
547
|
+
# Convert dot notation to PostgreSQL JSON path
|
|
548
|
+
path_parts = path.split('.')
|
|
549
|
+
json_path = '->'.join([f"'{part}'" for part in path_parts[:-1]]) + f"->>'{path_parts[-1]}'"
|
|
550
|
+
query = query.where(text(f"source{json_path} = :value").bindparams(value=str(value)))
|
|
551
|
+
|
|
552
|
+
result = await session.execute(query)
|
|
553
|
+
return result.scalars().all()
|
|
554
|
+
finally:
|
|
555
|
+
await session.close()
|
|
556
|
+
|
|
557
|
+
async def get_thread_by_message_id(self, message_id: str) -> Optional[Thread]:
|
|
558
|
+
"""
|
|
559
|
+
Find a thread containing a specific message ID.
|
|
560
|
+
|
|
561
|
+
Args:
|
|
562
|
+
message_id: The ID of the message to find
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
The Thread containing the message, or None if not found
|
|
566
|
+
"""
|
|
567
|
+
session = await self._get_session()
|
|
568
|
+
try:
|
|
569
|
+
# Query for the message and join with thread
|
|
570
|
+
stmt = (
|
|
571
|
+
select(ThreadRecord)
|
|
572
|
+
.options(selectinload(ThreadRecord.messages))
|
|
573
|
+
.join(MessageRecord)
|
|
574
|
+
.where(MessageRecord.id == message_id)
|
|
575
|
+
)
|
|
576
|
+
result = await session.execute(stmt)
|
|
577
|
+
thread_record = result.scalar_one_or_none()
|
|
578
|
+
return self._create_thread_from_record(thread_record) if thread_record else None
|
|
579
|
+
finally:
|
|
580
|
+
await session.close()
|