slide-narrator 5.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,624 @@
1
+ """Storage backend implementations for ThreadStore."""
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Optional, Dict, Any, Union
4
+ import re
5
+ from datetime import datetime, UTC
6
+ import json
7
+ import os
8
+ from pathlib import Path
9
+ import tempfile
10
+ import asyncio
11
+ from sqlalchemy import create_engine, select, cast, String, text, bindparam
12
+ from sqlalchemy.orm import sessionmaker, selectinload
13
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
14
+ # Direct imports
15
+ from ..models.thread import Thread
16
+ from ..models.message import Message
17
+ from ..models.attachment import Attachment
18
+ from ..storage.file_store import FileStore
19
+ from ..utils.logging import get_logger
20
+ from .models import Base, ThreadRecord, MessageRecord
21
+
22
+ logger = get_logger(__name__)
23
+
24
+ def _sanitize_key(component: str) -> str:
25
+ """Allow only alphanumeric and underscore for JSON path keys to avoid SQL injection."""
26
+ if not re.fullmatch(r"[A-Za-z0-9_]+", component):
27
+ raise ValueError(f"Invalid key component: {component}")
28
+ return component
29
+
30
+ class StorageBackend(ABC):
31
+ """Abstract base class for thread storage backends."""
32
+
33
+ @abstractmethod
34
+ async def initialize(self) -> None:
35
+ """Initialize the storage backend."""
36
+ pass
37
+
38
+ @abstractmethod
39
+ async def save(self, thread: Thread) -> Thread:
40
+ """Save a thread to storage."""
41
+ pass
42
+
43
+ @abstractmethod
44
+ async def get(self, thread_id: str) -> Optional[Thread]:
45
+ """Get a thread by ID."""
46
+ pass
47
+
48
+ @abstractmethod
49
+ async def delete(self, thread_id: str) -> bool:
50
+ """Delete a thread by ID."""
51
+ pass
52
+
53
+ @abstractmethod
54
+ async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
55
+ """List threads with pagination."""
56
+ pass
57
+
58
+ @abstractmethod
59
+ async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
60
+ """Find threads by matching attributes."""
61
+ pass
62
+
63
+ @abstractmethod
64
+ async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
65
+ """Find threads by platform name and properties in the platforms structure."""
66
+ pass
67
+
68
+ @abstractmethod
69
+ async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
70
+ """List recent threads."""
71
+ pass
72
+
73
+ @abstractmethod
74
+ async def find_messages_by_attribute(self, path: str, value: Any) -> Union[List[Message], List[MessageRecord]]:
75
+ """
76
+ Find messages that have a specific attribute at a given JSON path.
77
+ Uses efficient SQL JSON path queries for PostgreSQL and falls back to
78
+ SQLite JSON functions when needed.
79
+
80
+ Args:
81
+ path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
82
+ value: The value to search for
83
+
84
+ Returns:
85
+ List of messages matching the criteria (possibly empty)
86
+ """
87
+ pass
88
+
89
+ class MemoryBackend(StorageBackend):
90
+ """In-memory storage backend using a dictionary."""
91
+
92
+ def __init__(self):
93
+ self._threads: Dict[str, Thread] = {}
94
+
95
+ async def initialize(self) -> None:
96
+ pass # No initialization needed for memory backend
97
+
98
+ async def save(self, thread: Thread) -> Thread:
99
+ self._threads[thread.id] = thread
100
+ return thread
101
+
102
+ async def get(self, thread_id: str) -> Optional[Thread]:
103
+ return self._threads.get(thread_id)
104
+
105
+ async def delete(self, thread_id: str) -> bool:
106
+ if thread_id in self._threads:
107
+ del self._threads[thread_id]
108
+ return True
109
+ return False
110
+
111
+ async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
112
+ threads = sorted(
113
+ self._threads.values(),
114
+ key=lambda t: t.updated_at if hasattr(t, 'updated_at') else t.created_at,
115
+ reverse=True
116
+ )
117
+ return threads[offset:offset + limit]
118
+
119
+ async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
120
+ matching_threads = []
121
+ for thread in self._threads.values():
122
+ if all(
123
+ thread.attributes.get(k) == v
124
+ for k, v in attributes.items()
125
+ ):
126
+ matching_threads.append(thread)
127
+ return matching_threads
128
+
129
+ async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
130
+ matching_threads = []
131
+ for thread in self._threads.values():
132
+ platforms = getattr(thread, 'platforms', {})
133
+ if (
134
+ isinstance(platforms, dict) and
135
+ platform_name in platforms and
136
+ all(platforms[platform_name].get(k) == v for k, v in properties.items())
137
+ ):
138
+ matching_threads.append(thread)
139
+ return matching_threads
140
+
141
+ async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
142
+ threads = list(self._threads.values())
143
+ threads.sort(key=lambda t: t.updated_at or t.created_at, reverse=True)
144
+ if limit is not None:
145
+ threads = threads[:limit]
146
+ return threads
147
+
148
+ async def find_messages_by_attribute(self, path: str, value: Any) -> List[Message]:
149
+ """
150
+ Check if any messages exist with a specific attribute at a given JSON path.
151
+
152
+ Args:
153
+ path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
154
+ value: The value to search for
155
+
156
+ Returns:
157
+ List of messages matching the criteria (possibly empty)
158
+ """
159
+ matches: List[Message] = []
160
+ # Traverse all threads and messages
161
+ for thread in self._threads.values():
162
+ for message in thread.messages:
163
+ current: Any = message.model_dump(mode="python")
164
+ # Navigate the nested structure
165
+ parts = [p for p in path.split('.') if p]
166
+ for part in parts:
167
+ if isinstance(current, dict) and part in current:
168
+ current = current[part]
169
+ else:
170
+ current = None
171
+ break
172
+ if current == value:
173
+ matches.append(message)
174
+ return matches
175
+
176
+ class SQLBackend(StorageBackend):
177
+ """SQL storage backend supporting both SQLite and PostgreSQL with proper connection pooling."""
178
+
179
+ def __init__(self, database_url: Optional[str] = None):
180
+ if database_url is None:
181
+ # Create a temporary directory that persists until program exit
182
+ tmp_dir = Path(tempfile.gettempdir()) / "narrator_threads"
183
+ tmp_dir.mkdir(exist_ok=True)
184
+ database_url = f"sqlite+aiosqlite:///{tmp_dir}/threads.db"
185
+ elif database_url == ":memory:":
186
+ database_url = "sqlite+aiosqlite:///:memory:"
187
+
188
+ self.database_url = database_url
189
+
190
+ # Configure engine options with better defaults for connection pooling
191
+ engine_kwargs = {
192
+ 'echo': os.environ.get("NARRATOR_DB_ECHO", "").lower() == "true"
193
+ }
194
+
195
+ # Add pool configuration if not using SQLite
196
+ if not self.database_url.startswith('sqlite'):
197
+ # Default connection pool settings if not specified
198
+ pool_size = int(os.environ.get("NARRATOR_DB_POOL_SIZE", "5"))
199
+ max_overflow = int(os.environ.get("NARRATOR_DB_MAX_OVERFLOW", "10"))
200
+ pool_timeout = int(os.environ.get("NARRATOR_DB_POOL_TIMEOUT", "30"))
201
+ pool_recycle = int(os.environ.get("NARRATOR_DB_POOL_RECYCLE", "300"))
202
+
203
+ engine_kwargs.update({
204
+ 'pool_size': pool_size,
205
+ 'max_overflow': max_overflow,
206
+ 'pool_timeout': pool_timeout,
207
+ 'pool_recycle': pool_recycle,
208
+ 'pool_pre_ping': True # Check connection validity before using from pool
209
+ })
210
+
211
+ logger.info(f"Configuring database connection pool: size={pool_size}, "
212
+ f"max_overflow={max_overflow}, timeout={pool_timeout}, "
213
+ f"recycle={pool_recycle}")
214
+
215
+ self.engine = create_async_engine(self.database_url, **engine_kwargs)
216
+ # Create session_maker for database operations
217
+ self._session_maker = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
218
+
219
+ @property
220
+ def async_session(self):
221
+ """
222
+ Returns the session factory for creating new database sessions.
223
+
224
+ Use _get_session() method instead which properly creates a session
225
+ for each database operation.
226
+ """
227
+ return self._session_maker
228
+
229
+ async def initialize(self) -> None:
230
+ """Initialize the database by creating tables if they don't exist."""
231
+ async with self.engine.begin() as conn:
232
+ await conn.run_sync(Base.metadata.create_all)
233
+ logger.info(f"Database initialized with tables: {Base.metadata.tables.keys()}")
234
+
235
+ def _create_message_from_record(self, msg_record: MessageRecord) -> Message:
236
+ """Helper method to create a Message from a MessageRecord"""
237
+ message = Message(
238
+ id=msg_record.id,
239
+ role=msg_record.role,
240
+ sequence=msg_record.sequence,
241
+ turn=msg_record.turn,
242
+ content=msg_record.content,
243
+ reasoning_content=msg_record.reasoning_content,
244
+ name=msg_record.name,
245
+ tool_call_id=msg_record.tool_call_id,
246
+ tool_calls=msg_record.tool_calls,
247
+ attributes=msg_record.attributes,
248
+ timestamp=msg_record.timestamp,
249
+ source=msg_record.source,
250
+ platforms=msg_record.platforms or {},
251
+ metrics=msg_record.metrics,
252
+ reactions=msg_record.reactions or {}
253
+ )
254
+ if msg_record.attachments:
255
+ message.attachments = [Attachment(**a) for a in msg_record.attachments]
256
+ return message
257
+
258
+ def _create_thread_from_record(self, record: ThreadRecord) -> Thread:
259
+ """Helper method to create a Thread from a ThreadRecord"""
260
+ thread = Thread(
261
+ id=record.id,
262
+ title=record.title,
263
+ attributes=record.attributes,
264
+ platforms=record.platforms or {},
265
+ created_at=record.created_at,
266
+ updated_at=record.updated_at,
267
+ messages=[]
268
+ )
269
+ # Sort messages: system messages first, then others by sequence
270
+ sorted_messages = sorted(record.messages,
271
+ key=lambda m: (0 if m.role == "system" else 1, m.sequence or 0))
272
+ for msg_record in sorted_messages:
273
+ message = self._create_message_from_record(msg_record)
274
+ thread.messages.append(message)
275
+ return thread
276
+
277
+ def _create_message_record(self, message: Message, thread_id: str, sequence: int) -> MessageRecord:
278
+ """Helper method to create a MessageRecord from a Message"""
279
+ return MessageRecord(
280
+ id=message.id,
281
+ thread_id=thread_id,
282
+ sequence=sequence,
283
+ turn=message.turn,
284
+ role=message.role,
285
+ content=message.content,
286
+ reasoning_content=message.reasoning_content,
287
+ name=message.name,
288
+ tool_call_id=message.tool_call_id,
289
+ tool_calls=message.tool_calls,
290
+ attributes=message.attributes,
291
+ timestamp=message.timestamp,
292
+ source=message.source,
293
+ platforms=message.platforms,
294
+ attachments=[a.model_dump() for a in message.attachments] if message.attachments else None,
295
+ metrics=message.metrics,
296
+ reactions=message.reactions
297
+ )
298
+
299
+ async def _get_session(self) -> AsyncSession:
300
+ """Create and return a new session for database operations."""
301
+ return self._session_maker()
302
+
303
+ async def _cleanup_failed_attachments(self, thread: Thread) -> None:
304
+ """Helper to clean up attachment files if thread save fails"""
305
+ for message in thread.messages:
306
+ if message.attachments:
307
+ for attachment in message.attachments:
308
+ if hasattr(attachment, 'cleanup') and callable(attachment.cleanup):
309
+ await attachment.cleanup()
310
+
311
+ async def save(self, thread: Thread) -> Thread:
312
+ """Save a thread and its messages to the database."""
313
+ session = await self._get_session()
314
+
315
+ # Create a FileStore instance for attachment storage
316
+ file_store = FileStore()
317
+
318
+ try:
319
+ # Log the platforms data being saved
320
+ logger.info(f"SQLBackend.save: Attempting to save thread {thread.id}. Platforms data: {json.dumps(thread.platforms if thread.platforms is not None else {})}")
321
+
322
+ # First process and store all attachments
323
+ logger.info(f"Starting to process attachments for thread {thread.id}")
324
+ try:
325
+ for message in thread.messages:
326
+ if message.attachments:
327
+ logger.info(f"Processing {len(message.attachments)} attachments for message {message.id}")
328
+ for attachment in message.attachments:
329
+ logger.info(f"Processing attachment {attachment.filename} with status {attachment.status}")
330
+ await attachment.process_and_store(file_store)
331
+ logger.info(f"Finished processing attachment {attachment.filename}, new status: {attachment.status}")
332
+ except Exception as e:
333
+ # Handle attachment processing failures
334
+ logger.error(f"Failed to process attachment: {str(e)}")
335
+ await self._cleanup_failed_attachments(thread)
336
+ raise RuntimeError(f"Failed to save thread: {str(e)}") from e
337
+
338
+ async with session.begin():
339
+ # Get existing thread if it exists
340
+ stmt = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).where(ThreadRecord.id == thread.id)
341
+ result = await session.execute(stmt)
342
+ thread_record = result.scalar_one_or_none()
343
+
344
+ if thread_record:
345
+ # Update existing thread
346
+ thread_record.title = thread.title
347
+ thread_record.attributes = thread.attributes
348
+ thread_record.platforms = thread.platforms
349
+ thread_record.updated_at = datetime.now(UTC)
350
+ thread_record.messages = [] # Clear existing messages
351
+ else:
352
+ # Create new thread record
353
+ thread_record = ThreadRecord(
354
+ id=thread.id,
355
+ title=thread.title,
356
+ attributes=thread.attributes,
357
+ platforms=thread.platforms,
358
+ created_at=thread.created_at,
359
+ updated_at=thread.updated_at,
360
+ messages=[]
361
+ )
362
+
363
+ # Process messages in order
364
+ sequence = 1
365
+
366
+ # First handle system messages
367
+ for message in thread.messages:
368
+ if message.role == "system":
369
+ thread_record.messages.append(self._create_message_record(message, thread.id, 0))
370
+
371
+ # Then handle non-system messages
372
+ for message in thread.messages:
373
+ if message.role != "system":
374
+ thread_record.messages.append(self._create_message_record(message, thread.id, sequence))
375
+ sequence += 1
376
+
377
+ session.add(thread_record)
378
+ try:
379
+ await session.commit()
380
+ logger.info(f"Thread {thread.id} successfully committed to database.")
381
+ except Exception as e:
382
+ # Convert database errors to RuntimeError for consistent error handling
383
+ logger.error(f"Database error during commit: {str(e)}")
384
+ raise RuntimeError(f"Failed to save thread: Database error - {str(e)}") from e
385
+ return thread
386
+
387
+ except Exception as e:
388
+ # If this is not already a RuntimeError, wrap it
389
+ if not isinstance(e, RuntimeError):
390
+ raise RuntimeError(f"Failed to save thread: {str(e)}") from e
391
+ raise e
392
+ finally:
393
+ await session.close()
394
+
395
+ async def get(self, thread_id: str) -> Optional[Thread]:
396
+ """Get a thread by ID."""
397
+ session = await self._get_session()
398
+ try:
399
+ stmt = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).where(ThreadRecord.id == thread_id)
400
+ result = await session.execute(stmt)
401
+ thread_record = result.scalar_one_or_none()
402
+ return self._create_thread_from_record(thread_record) if thread_record else None
403
+ finally:
404
+ await session.close()
405
+
406
+ async def delete(self, thread_id: str) -> bool:
407
+ """Delete a thread by ID."""
408
+ session = await self._get_session()
409
+ try:
410
+ async with session.begin():
411
+ record = await session.get(ThreadRecord, thread_id)
412
+ if record:
413
+ await session.delete(record)
414
+ return True
415
+ return False
416
+ finally:
417
+ await session.close()
418
+
419
+ async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
420
+ """List threads with pagination."""
421
+ session = await self._get_session()
422
+ try:
423
+ result = await session.execute(
424
+ select(ThreadRecord)
425
+ .options(selectinload(ThreadRecord.messages))
426
+ .order_by(ThreadRecord.updated_at.desc())
427
+ .limit(limit)
428
+ .offset(offset)
429
+ )
430
+ return [self._create_thread_from_record(record) for record in result.scalars().all()]
431
+ finally:
432
+ await session.close()
433
+
434
+ async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
435
+ """Find threads by matching attributes."""
436
+ session = await self._get_session()
437
+ try:
438
+ query = select(ThreadRecord).options(selectinload(ThreadRecord.messages))
439
+
440
+ for key, value in attributes.items():
441
+ if self.database_url.startswith('sqlite'):
442
+ # Use SQLite json_extract
443
+ safe_key = _sanitize_key(key)
444
+ if value is None:
445
+ query = query.where(text(f"json_extract(attributes, '$.{safe_key}') IS NULL"))
446
+ elif isinstance(value, bool):
447
+ # SQLite stores booleans as 1/0
448
+ num_val = 1 if value else 0
449
+ query = query.where(text(f"json_extract(attributes, '$.{safe_key}') = {num_val}"))
450
+ else:
451
+ query = query.where(
452
+ text(f"json_extract(attributes, '$.{safe_key}') = :value").bindparams(value=str(value))
453
+ )
454
+ else:
455
+ # Use PostgreSQL JSONB operators via text() for direct SQL control
456
+ logger.info(f"Searching for attribute[{key}] = {value} (type: {type(value)})")
457
+
458
+ # Handle different value types appropriately
459
+ if value is None:
460
+ # Check for null/None values
461
+ safe_key = _sanitize_key(key)
462
+ query = query.where(text(f"attributes->>'{safe_key}' IS NULL"))
463
+ else:
464
+ # Convert value to string for text comparison
465
+ str_value = str(value)
466
+ if isinstance(value, bool):
467
+ # Convert boolean to lowercase string
468
+ str_value = str(value).lower()
469
+
470
+ # Use PostgreSQL's JSONB operators for direct string comparison
471
+ safe_key = _sanitize_key(key)
472
+ param_name = f"attr_{safe_key}"
473
+ bp = bindparam(param_name, str_value)
474
+ query = query.where(
475
+ text(f"attributes->>'{safe_key}' = :{param_name}").bindparams(bp)
476
+ )
477
+
478
+ # Log the final query for debugging
479
+ logger.info(f"Executing find_by_attributes query: {query}")
480
+
481
+ result = await session.execute(query)
482
+ threads = [self._create_thread_from_record(record) for record in result.scalars().all()]
483
+ logger.info(f"Found {len(threads)} matching threads")
484
+ return threads
485
+ except Exception as e:
486
+ logger.error(f"Error in find_by_attributes: {str(e)}")
487
+ raise
488
+ finally:
489
+ await session.close()
490
+
491
+ async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
492
+ """Find threads by platform name and properties in the platforms structure."""
493
+ session = await self._get_session()
494
+ try:
495
+ query = select(ThreadRecord).options(selectinload(ThreadRecord.messages))
496
+
497
+ if self.database_url.startswith('sqlite'):
498
+ # Use SQLite json_extract for platform name
499
+ safe_platform = _sanitize_key(platform_name)
500
+ query = query.where(text(f"json_extract(platforms, '$.{safe_platform}') IS NOT NULL"))
501
+ # Add property conditions
502
+ for key, value in properties.items():
503
+ # Convert value to string for text comparison
504
+ safe_key = _sanitize_key(key)
505
+ if value is None:
506
+ query = query.where(text(f"json_extract(platforms, '$.{safe_platform}.{safe_key}') IS NULL"))
507
+ elif isinstance(value, bool):
508
+ num_val = 1 if value else 0
509
+ query = query.where(text(f"json_extract(platforms, '$.{safe_platform}.{safe_key}') = {num_val}"))
510
+ else:
511
+ str_value = str(value)
512
+ param_name = f"value_{safe_platform}_{safe_key}" # Ensure unique param name
513
+ bp = bindparam(param_name, str_value)
514
+ query = query.where(
515
+ text(f"json_extract(platforms, '$.{safe_platform}.{safe_key}') = :{param_name}")
516
+ .bindparams(bp)
517
+ )
518
+ else:
519
+ # Use PostgreSQL JSONB operators for platform checks
520
+ safe_platform = _sanitize_key(platform_name)
521
+ query = query.where(text(f"platforms ? '{safe_platform}'"))
522
+
523
+ # Add property conditions with text() for proper PostgreSQL JSONB syntax
524
+ for key, value in properties.items():
525
+ str_value = str(value)
526
+ safe_key = _sanitize_key(key)
527
+ param_name = f"value_{safe_platform}_{safe_key}"
528
+ bp = bindparam(param_name, str_value)
529
+ query = query.where(
530
+ text(f"platforms->'{safe_platform}'->>'{safe_key}' = :{param_name}")
531
+ .bindparams(bp)
532
+ )
533
+
534
+ result = await session.execute(query)
535
+ return [self._create_thread_from_record(record) for record in result.scalars().all()]
536
+ finally:
537
+ await session.close()
538
+
539
+ async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
540
+ """List recent threads."""
541
+ session = await self._get_session()
542
+ try:
543
+ query = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).order_by(ThreadRecord.updated_at.desc())
544
+ if limit is not None:
545
+ query = query.limit(limit)
546
+ result = await session.execute(query)
547
+ return [self._create_thread_from_record(record) for record in result.scalars().all()]
548
+ finally:
549
+ await session.close()
550
+
551
+ async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
552
+ """
553
+ Find messages that have a specific attribute at a given JSON path.
554
+ Uses efficient SQL JSON path queries for PostgreSQL and falls back to
555
+ SQLite JSON functions when needed.
556
+
557
+ Args:
558
+ path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
559
+ value: The value to search for
560
+
561
+ Returns:
562
+ List of messages matching the criteria (possibly empty)
563
+ """
564
+ session = await self._get_session()
565
+ try:
566
+ query = select(MessageRecord)
567
+
568
+ # Normalize and sanitize path parts
569
+ parts = [p for p in path.split('.') if p]
570
+ parts = [_sanitize_key(p) for p in parts]
571
+ if not parts:
572
+ return []
573
+ # Support paths prefixed with 'source.' by stripping the leading component
574
+ if parts and parts[0] == 'source':
575
+ parts = parts[1:]
576
+ if not parts:
577
+ return []
578
+ if self.database_url.startswith('sqlite'):
579
+ # Use SQLite json_extract with a proper JSON path: $.a.b.c (safe due to sanitized parts)
580
+ json_path = '$.' + '.'.join(parts)
581
+ query = query.where(
582
+ text(f"json_extract(source, '{json_path}') = :value").bindparams(value=str(value))
583
+ )
584
+ else:
585
+ # Use PostgreSQL JSONB operators: source->'a'->'b'->>'c' (last part text)
586
+ if len(parts) == 1:
587
+ pg_expr = f"source->>'{parts[0]}'"
588
+ else:
589
+ head = parts[:-1]
590
+ tail = parts[-1]
591
+ pg_expr = "source" + ''.join([f"->'{h}'" for h in head]) + f"->>'{tail}'"
592
+ query = query.where(
593
+ text(f"{pg_expr} = :value").bindparams(value=str(value))
594
+ )
595
+
596
+ result = await session.execute(query)
597
+ return result.scalars().all()
598
+ finally:
599
+ await session.close()
600
+
601
+ async def get_thread_by_message_id(self, message_id: str) -> Optional[Thread]:
602
+ """
603
+ Find a thread containing a specific message ID.
604
+
605
+ Args:
606
+ message_id: The ID of the message to find
607
+
608
+ Returns:
609
+ The Thread containing the message, or None if not found
610
+ """
611
+ session = await self._get_session()
612
+ try:
613
+ # Query for the message and join with thread
614
+ stmt = (
615
+ select(ThreadRecord)
616
+ .options(selectinload(ThreadRecord.messages))
617
+ .join(MessageRecord)
618
+ .where(MessageRecord.id == message_id)
619
+ )
620
+ result = await session.execute(stmt)
621
+ thread_record = result.scalar_one_or_none()
622
+ return self._create_thread_from_record(thread_record) if thread_record else None
623
+ finally:
624
+ await session.close()