slide-narrator 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of slide-narrator might be problematic. Click here for more details.

@@ -0,0 +1,580 @@
1
+ """Storage backend implementations for ThreadStore."""
2
+ from abc import ABC, abstractmethod
3
+ from typing import List, Optional, Dict, Any
4
+ from datetime import datetime, UTC
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import tempfile
9
+ import asyncio
10
+ from sqlalchemy import create_engine, select, cast, String, text, bindparam
11
+ from sqlalchemy.orm import sessionmaker, selectinload
12
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
13
+ # Direct imports
14
+ from ..models.thread import Thread
15
+ from ..models.message import Message
16
+ from ..models.attachment import Attachment
17
+ from ..storage.file_store import FileStore
18
+ from ..utils.logging import get_logger
19
+ from .models import Base, ThreadRecord, MessageRecord
20
+
21
+ logger = get_logger(__name__)
22
+
23
+ class StorageBackend(ABC):
24
+ """Abstract base class for thread storage backends."""
25
+
26
+ @abstractmethod
27
+ async def initialize(self) -> None:
28
+ """Initialize the storage backend."""
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def save(self, thread: Thread) -> Thread:
33
+ """Save a thread to storage."""
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def get(self, thread_id: str) -> Optional[Thread]:
38
+ """Get a thread by ID."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ async def delete(self, thread_id: str) -> bool:
43
+ """Delete a thread by ID."""
44
+ pass
45
+
46
+ @abstractmethod
47
+ async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
48
+ """List threads with pagination."""
49
+ pass
50
+
51
+ @abstractmethod
52
+ async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
53
+ """Find threads by matching attributes."""
54
+ pass
55
+
56
+ @abstractmethod
57
+ async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
58
+ """Find threads by platform name and properties in the platforms structure."""
59
+ pass
60
+
61
+ @abstractmethod
62
+ async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
63
+ """List recent threads."""
64
+ pass
65
+
66
+ @abstractmethod
67
+ async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
68
+ """
69
+ Find messages that have a specific attribute at a given JSON path.
70
+ Uses efficient SQL JSON path queries for PostgreSQL and falls back to
71
+ SQLite JSON functions when needed.
72
+
73
+ Args:
74
+ path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
75
+ value: The value to search for
76
+
77
+ Returns:
78
+ List of messages matching the criteria (possibly empty)
79
+ """
80
+ pass
81
+
82
+ class MemoryBackend(StorageBackend):
83
+ """In-memory storage backend using a dictionary."""
84
+
85
+ def __init__(self):
86
+ self._threads: Dict[str, Thread] = {}
87
+
88
+ async def initialize(self) -> None:
89
+ pass # No initialization needed for memory backend
90
+
91
+ async def save(self, thread: Thread) -> Thread:
92
+ self._threads[thread.id] = thread
93
+ return thread
94
+
95
+ async def get(self, thread_id: str) -> Optional[Thread]:
96
+ return self._threads.get(thread_id)
97
+
98
+ async def delete(self, thread_id: str) -> bool:
99
+ if thread_id in self._threads:
100
+ del self._threads[thread_id]
101
+ return True
102
+ return False
103
+
104
+ async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
105
+ threads = sorted(
106
+ self._threads.values(),
107
+ key=lambda t: t.updated_at if hasattr(t, 'updated_at') else t.created_at,
108
+ reverse=True
109
+ )
110
+ return threads[offset:offset + limit]
111
+
112
+ async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
113
+ matching_threads = []
114
+ for thread in self._threads.values():
115
+ if all(
116
+ thread.attributes.get(k) == v
117
+ for k, v in attributes.items()
118
+ ):
119
+ matching_threads.append(thread)
120
+ return matching_threads
121
+
122
+ async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
123
+ matching_threads = []
124
+ for thread in self._threads.values():
125
+ platforms = getattr(thread, 'platforms', {})
126
+ if (
127
+ isinstance(platforms, dict) and
128
+ platform_name in platforms and
129
+ all(platforms[platform_name].get(k) == v for k, v in properties.items())
130
+ ):
131
+ matching_threads.append(thread)
132
+ return matching_threads
133
+
134
+ async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
135
+ threads = list(self._threads.values())
136
+ threads.sort(key=lambda t: t.updated_at or t.created_at, reverse=True)
137
+ if limit is not None:
138
+ threads = threads[:limit]
139
+ return threads
140
+
141
+ async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
142
+ """
143
+ Check if any messages exist with a specific attribute at a given JSON path.
144
+
145
+ Args:
146
+ path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
147
+ value: The value to search for
148
+
149
+ Returns:
150
+ True if any messages match, False otherwise
151
+ """
152
+ # Traverse all threads and messages
153
+ for thread in self._threads.values():
154
+ for message in thread.messages:
155
+ # Use the path to navigate to the target attribute
156
+ current = message.model_dump(mode="python")
157
+
158
+ # Navigate the nested structure
159
+ parts = path.split('.')
160
+ for part in parts:
161
+ if isinstance(current, dict) and part in current:
162
+ current = current[part]
163
+ else:
164
+ current = None
165
+ break
166
+
167
+ # Check if we found a match
168
+ if current == value:
169
+ return [self._create_message_record(message, thread.id, 0)]
170
+
171
+ return []
172
+
173
+ class SQLBackend(StorageBackend):
174
+ """SQL storage backend supporting both SQLite and PostgreSQL with proper connection pooling."""
175
+
176
+ def __init__(self, database_url: Optional[str] = None):
177
+ if database_url is None:
178
+ # Create a temporary directory that persists until program exit
179
+ tmp_dir = Path(tempfile.gettempdir()) / "narrator_threads"
180
+ tmp_dir.mkdir(exist_ok=True)
181
+ database_url = f"sqlite+aiosqlite:///{tmp_dir}/threads.db"
182
+ elif database_url == ":memory:":
183
+ database_url = "sqlite+aiosqlite:///:memory:"
184
+
185
+ self.database_url = database_url
186
+
187
+ # Configure engine options with better defaults for connection pooling
188
+ engine_kwargs = {
189
+ 'echo': os.environ.get("NARRATOR_DB_ECHO", "").lower() == "true"
190
+ }
191
+
192
+ # Add pool configuration if not using SQLite
193
+ if not self.database_url.startswith('sqlite'):
194
+ # Default connection pool settings if not specified
195
+ pool_size = int(os.environ.get("NARRATOR_DB_POOL_SIZE", "5"))
196
+ max_overflow = int(os.environ.get("NARRATOR_DB_MAX_OVERFLOW", "10"))
197
+ pool_timeout = int(os.environ.get("NARRATOR_DB_POOL_TIMEOUT", "30"))
198
+ pool_recycle = int(os.environ.get("NARRATOR_DB_POOL_RECYCLE", "300"))
199
+
200
+ engine_kwargs.update({
201
+ 'pool_size': pool_size,
202
+ 'max_overflow': max_overflow,
203
+ 'pool_timeout': pool_timeout,
204
+ 'pool_recycle': pool_recycle,
205
+ 'pool_pre_ping': True # Check connection validity before using from pool
206
+ })
207
+
208
+ logger.info(f"Configuring database connection pool: size={pool_size}, "
209
+ f"max_overflow={max_overflow}, timeout={pool_timeout}, "
210
+ f"recycle={pool_recycle}")
211
+
212
+ self.engine = create_async_engine(self.database_url, **engine_kwargs)
213
+ # Create session_maker for database operations
214
+ self._session_maker = sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
215
+
216
+ @property
217
+ def async_session(self):
218
+ """
219
+ Returns the session factory for creating new database sessions.
220
+
221
+ Use _get_session() method instead which properly creates a session
222
+ for each database operation.
223
+ """
224
+ return self._session_maker
225
+
226
+ async def initialize(self) -> None:
227
+ """Initialize the database by creating tables if they don't exist."""
228
+ async with self.engine.begin() as conn:
229
+ await conn.run_sync(Base.metadata.create_all)
230
+ logger.info(f"Database initialized with tables: {Base.metadata.tables.keys()}")
231
+
232
+ def _create_message_from_record(self, msg_record: MessageRecord) -> Message:
233
+ """Helper method to create a Message from a MessageRecord"""
234
+ message = Message(
235
+ id=msg_record.id,
236
+ role=msg_record.role,
237
+ sequence=msg_record.sequence,
238
+ turn=msg_record.turn,
239
+ content=msg_record.content,
240
+ name=msg_record.name,
241
+ tool_call_id=msg_record.tool_call_id,
242
+ tool_calls=msg_record.tool_calls,
243
+ attributes=msg_record.attributes,
244
+ timestamp=msg_record.timestamp,
245
+ source=msg_record.source,
246
+ platforms=msg_record.platforms or {},
247
+ metrics=msg_record.metrics,
248
+ reactions=msg_record.reactions or {}
249
+ )
250
+ if msg_record.attachments:
251
+ message.attachments = [Attachment(**a) for a in msg_record.attachments]
252
+ return message
253
+
254
+ def _create_thread_from_record(self, record: ThreadRecord) -> Thread:
255
+ """Helper method to create a Thread from a ThreadRecord"""
256
+ thread = Thread(
257
+ id=record.id,
258
+ title=record.title,
259
+ attributes=record.attributes,
260
+ platforms=record.platforms or {},
261
+ created_at=record.created_at,
262
+ updated_at=record.updated_at,
263
+ messages=[]
264
+ )
265
+ # Sort messages: system messages first, then others by sequence
266
+ sorted_messages = sorted(record.messages,
267
+ key=lambda m: (0 if m.role == "system" else 1, m.sequence or 0))
268
+ for msg_record in sorted_messages:
269
+ message = self._create_message_from_record(msg_record)
270
+ thread.messages.append(message)
271
+ return thread
272
+
273
+ def _create_message_record(self, message: Message, thread_id: str, sequence: int) -> MessageRecord:
274
+ """Helper method to create a MessageRecord from a Message"""
275
+ return MessageRecord(
276
+ id=message.id,
277
+ thread_id=thread_id,
278
+ sequence=sequence,
279
+ turn=message.turn,
280
+ role=message.role,
281
+ content=message.content,
282
+ name=message.name,
283
+ tool_call_id=message.tool_call_id,
284
+ tool_calls=message.tool_calls,
285
+ attributes=message.attributes,
286
+ timestamp=message.timestamp,
287
+ source=message.source,
288
+ platforms=message.platforms,
289
+ attachments=[a.model_dump() for a in message.attachments] if message.attachments else None,
290
+ metrics=message.metrics,
291
+ reactions=message.reactions
292
+ )
293
+
294
+ async def _get_session(self) -> AsyncSession:
295
+ """Create and return a new session for database operations."""
296
+ return self._session_maker()
297
+
298
+ async def _cleanup_failed_attachments(self, thread: Thread) -> None:
299
+ """Helper to clean up attachment files if thread save fails"""
300
+ for message in thread.messages:
301
+ if message.attachments:
302
+ for attachment in message.attachments:
303
+ if hasattr(attachment, 'cleanup') and callable(attachment.cleanup):
304
+ await attachment.cleanup()
305
+
306
+ async def save(self, thread: Thread) -> Thread:
307
+ """Save a thread and its messages to the database."""
308
+ session = await self._get_session()
309
+
310
+ # Create a FileStore instance for attachment storage
311
+ file_store = FileStore()
312
+
313
+ try:
314
+ # Log the platforms data being saved
315
+ logger.info(f"SQLBackend.save: Attempting to save thread {thread.id}. Platforms data: {json.dumps(thread.platforms if thread.platforms is not None else {})}")
316
+
317
+ # First process and store all attachments
318
+ logger.info(f"Starting to process attachments for thread {thread.id}")
319
+ try:
320
+ for message in thread.messages:
321
+ if message.attachments:
322
+ logger.info(f"Processing {len(message.attachments)} attachments for message {message.id}")
323
+ for attachment in message.attachments:
324
+ logger.info(f"Processing attachment {attachment.filename} with status {attachment.status}")
325
+ await attachment.process_and_store(file_store)
326
+ logger.info(f"Finished processing attachment {attachment.filename}, new status: {attachment.status}")
327
+ except Exception as e:
328
+ # Handle attachment processing failures
329
+ logger.error(f"Failed to process attachment: {str(e)}")
330
+ await self._cleanup_failed_attachments(thread)
331
+ raise RuntimeError(f"Failed to save thread: {str(e)}") from e
332
+
333
+ async with session.begin():
334
+ # Get existing thread if it exists
335
+ stmt = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).where(ThreadRecord.id == thread.id)
336
+ result = await session.execute(stmt)
337
+ thread_record = result.scalar_one_or_none()
338
+
339
+ if thread_record:
340
+ # Update existing thread
341
+ thread_record.title = thread.title
342
+ thread_record.attributes = thread.attributes
343
+ thread_record.platforms = thread.platforms
344
+ thread_record.updated_at = datetime.now(UTC)
345
+ thread_record.messages = [] # Clear existing messages
346
+ else:
347
+ # Create new thread record
348
+ thread_record = ThreadRecord(
349
+ id=thread.id,
350
+ title=thread.title,
351
+ attributes=thread.attributes,
352
+ platforms=thread.platforms,
353
+ created_at=thread.created_at,
354
+ updated_at=thread.updated_at,
355
+ messages=[]
356
+ )
357
+
358
+ # Process messages in order
359
+ sequence = 1
360
+
361
+ # First handle system messages
362
+ for message in thread.messages:
363
+ if message.role == "system":
364
+ thread_record.messages.append(self._create_message_record(message, thread.id, 0))
365
+
366
+ # Then handle non-system messages
367
+ for message in thread.messages:
368
+ if message.role != "system":
369
+ thread_record.messages.append(self._create_message_record(message, thread.id, sequence))
370
+ sequence += 1
371
+
372
+ session.add(thread_record)
373
+ try:
374
+ await session.commit()
375
+ logger.info(f"Thread {thread.id} successfully committed to database.")
376
+ except Exception as e:
377
+ # Convert database errors to RuntimeError for consistent error handling
378
+ logger.error(f"Database error during commit: {str(e)}")
379
+ raise RuntimeError(f"Failed to save thread: Database error - {str(e)}") from e
380
+ return thread
381
+
382
+ except Exception as e:
383
+ # If this is not already a RuntimeError, wrap it
384
+ if not isinstance(e, RuntimeError):
385
+ raise RuntimeError(f"Failed to save thread: {str(e)}") from e
386
+ raise e
387
+ finally:
388
+ await session.close()
389
+
390
+ async def get(self, thread_id: str) -> Optional[Thread]:
391
+ """Get a thread by ID."""
392
+ session = await self._get_session()
393
+ try:
394
+ stmt = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).where(ThreadRecord.id == thread_id)
395
+ result = await session.execute(stmt)
396
+ thread_record = result.scalar_one_or_none()
397
+ return self._create_thread_from_record(thread_record) if thread_record else None
398
+ finally:
399
+ await session.close()
400
+
401
+ async def delete(self, thread_id: str) -> bool:
402
+ """Delete a thread by ID."""
403
+ session = await self._get_session()
404
+ try:
405
+ async with session.begin():
406
+ record = await session.get(ThreadRecord, thread_id)
407
+ if record:
408
+ await session.delete(record)
409
+ return True
410
+ return False
411
+ finally:
412
+ await session.close()
413
+
414
+ async def list(self, limit: int = 100, offset: int = 0) -> List[Thread]:
415
+ """List threads with pagination."""
416
+ session = await self._get_session()
417
+ try:
418
+ result = await session.execute(
419
+ select(ThreadRecord)
420
+ .options(selectinload(ThreadRecord.messages))
421
+ .order_by(ThreadRecord.updated_at.desc())
422
+ .limit(limit)
423
+ .offset(offset)
424
+ )
425
+ return [self._create_thread_from_record(record) for record in result.scalars().all()]
426
+ finally:
427
+ await session.close()
428
+
429
+ async def find_by_attributes(self, attributes: Dict[str, Any]) -> List[Thread]:
430
+ """Find threads by matching attributes."""
431
+ session = await self._get_session()
432
+ try:
433
+ query = select(ThreadRecord).options(selectinload(ThreadRecord.messages))
434
+
435
+ for key, value in attributes.items():
436
+ if self.database_url.startswith('sqlite'):
437
+ # Use SQLite json_extract
438
+ query = query.where(text(f"json_extract(attributes, '$.{key}') = :value").bindparams(value=str(value)))
439
+ else:
440
+ # Use PostgreSQL JSONB operators via text() for direct SQL control
441
+ logger.info(f"Searching for attribute[{key}] = {value} (type: {type(value)})")
442
+
443
+ # Handle different value types appropriately
444
+ if value is None:
445
+ # Check for null/None values
446
+ query = query.where(text(f"attributes->>'{key}' IS NULL"))
447
+ else:
448
+ # Convert value to string for text comparison
449
+ str_value = str(value)
450
+ if isinstance(value, bool):
451
+ # Convert boolean to lowercase string
452
+ str_value = str(value).lower()
453
+
454
+ # Use PostgreSQL's JSONB operators for direct string comparison
455
+ param_name = f"attr_{key}"
456
+ bp = bindparam(param_name, str_value)
457
+ query = query.where(
458
+ text(f"attributes->>'{key}' = :{param_name}").bindparams(bp)
459
+ )
460
+
461
+ # Log the final query for debugging
462
+ logger.info(f"Executing find_by_attributes query: {query}")
463
+
464
+ result = await session.execute(query)
465
+ threads = [self._create_thread_from_record(record) for record in result.scalars().all()]
466
+ logger.info(f"Found {len(threads)} matching threads")
467
+ return threads
468
+ except Exception as e:
469
+ logger.error(f"Error in find_by_attributes: {str(e)}")
470
+ raise
471
+ finally:
472
+ await session.close()
473
+
474
+ async def find_by_platform(self, platform_name: str, properties: Dict[str, Any]) -> List[Thread]:
475
+ """Find threads by platform name and properties in the platforms structure."""
476
+ session = await self._get_session()
477
+ try:
478
+ query = select(ThreadRecord).options(selectinload(ThreadRecord.messages))
479
+
480
+ if self.database_url.startswith('sqlite'):
481
+ # Use SQLite json_extract for platform name
482
+ query = query.where(text(f"json_extract(platforms, '$.{platform_name}') IS NOT NULL"))
483
+ # Add property conditions
484
+ for key, value in properties.items():
485
+ # Convert value to string for text comparison
486
+ str_value = str(value)
487
+ param_name = f"value_{platform_name}_{key}" # Ensure unique param name
488
+ bp = bindparam(param_name, str_value)
489
+ query = query.where(
490
+ text(f"json_extract(platforms, '$.{platform_name}.{key}') = :{param_name}")
491
+ .bindparams(bp)
492
+ )
493
+ else:
494
+ # Use PostgreSQL JSONB operators for platform checks
495
+ query = query.where(text(f"platforms ? '{platform_name}'"))
496
+
497
+ # Add property conditions with text() for proper PostgreSQL JSONB syntax
498
+ for key, value in properties.items():
499
+ str_value = str(value)
500
+ param_name = f"value_{platform_name}_{key}"
501
+ bp = bindparam(param_name, str_value)
502
+ query = query.where(
503
+ text(f"platforms->'{platform_name}'->>'{key}' = :{param_name}")
504
+ .bindparams(bp)
505
+ )
506
+
507
+ result = await session.execute(query)
508
+ return [self._create_thread_from_record(record) for record in result.scalars().all()]
509
+ finally:
510
+ await session.close()
511
+
512
+ async def list_recent(self, limit: Optional[int] = None) -> List[Thread]:
513
+ """List recent threads."""
514
+ session = await self._get_session()
515
+ try:
516
+ query = select(ThreadRecord).options(selectinload(ThreadRecord.messages)).order_by(ThreadRecord.updated_at.desc())
517
+ if limit is not None:
518
+ query = query.limit(limit)
519
+ result = await session.execute(query)
520
+ return [self._create_thread_from_record(record) for record in result.scalars().all()]
521
+ finally:
522
+ await session.close()
523
+
524
+ async def find_messages_by_attribute(self, path: str, value: Any) -> List[MessageRecord]:
525
+ """
526
+ Find messages that have a specific attribute at a given JSON path.
527
+ Uses efficient SQL JSON path queries for PostgreSQL and falls back to
528
+ SQLite JSON functions when needed.
529
+
530
+ Args:
531
+ path: Dot-notation path to the attribute (e.g., "source.platform.attributes.ts")
532
+ value: The value to search for
533
+
534
+ Returns:
535
+ List of messages matching the criteria (possibly empty)
536
+ """
537
+ session = await self._get_session()
538
+ try:
539
+ query = select(MessageRecord)
540
+
541
+ if self.database_url.startswith('sqlite'):
542
+ # Use SQLite json_extract
543
+ json_path = '$.' + path.replace('.', '.')
544
+ query = query.where(text(f"json_extract(source, '{json_path}') = :value").bindparams(value=str(value)))
545
+ else:
546
+ # Use PostgreSQL JSONB operators
547
+ # Convert dot notation to PostgreSQL JSON path
548
+ path_parts = path.split('.')
549
+ json_path = '->'.join([f"'{part}'" for part in path_parts[:-1]]) + f"->>'{path_parts[-1]}'"
550
+ query = query.where(text(f"source{json_path} = :value").bindparams(value=str(value)))
551
+
552
+ result = await session.execute(query)
553
+ return result.scalars().all()
554
+ finally:
555
+ await session.close()
556
+
557
+ async def get_thread_by_message_id(self, message_id: str) -> Optional[Thread]:
558
+ """
559
+ Find a thread containing a specific message ID.
560
+
561
+ Args:
562
+ message_id: The ID of the message to find
563
+
564
+ Returns:
565
+ The Thread containing the message, or None if not found
566
+ """
567
+ session = await self._get_session()
568
+ try:
569
+ # Query for the message and join with thread
570
+ stmt = (
571
+ select(ThreadRecord)
572
+ .options(selectinload(ThreadRecord.messages))
573
+ .join(MessageRecord)
574
+ .where(MessageRecord.id == message_id)
575
+ )
576
+ result = await session.execute(stmt)
577
+ thread_record = result.scalar_one_or_none()
578
+ return self._create_thread_from_record(thread_record) if thread_record else None
579
+ finally:
580
+ await session.close()