dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,802 @@
1
+ """
2
+ Abstract base class for persistent (database-backed) context-aware Genie cache implementations.
3
+
4
+ This module provides the foundational abstract base class for database-backed
5
+ cache implementations. It adds:
6
+ - Connection pooling management
7
+ - Transaction handling with retry logic
8
+ - Prompt history storage and retrieval
9
+ - Database error handling with exponential backoff
10
+
11
+ Subclasses must implement database-specific methods:
12
+ - _create_table_if_not_exists(): Create database schema
13
+ - _get_pool(): Get database connection pool
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from abc import abstractmethod
19
+ from typing import Any, Callable, TypeVar
20
+
21
+ from loguru import logger
22
+
23
+ from dao_ai.config import DatabaseModel
24
+ from dao_ai.genie.cache.context_aware.base import (
25
+ ContextAwareGenieService,
26
+ get_conversation_history,
27
+ )
28
+
29
+ # Type variable for return types
30
+ T = TypeVar("T")
31
+
32
+ # Type alias for database row (dict due to row_factory=dict_row)
33
+ DbRow = dict[str, Any]
34
+
35
+
36
+ class PersistentContextAwareGenieCacheService(ContextAwareGenieService):
37
+ """
38
+ Abstract base class for database-backed context-aware Genie cache implementations.
39
+
40
+ This class extends ContextAwareGenieService with database-specific functionality:
41
+ - Connection pool management
42
+ - Prompt history tracking for conversation context
43
+ - Retry logic for transient database failures
44
+ - Schema creation and management
45
+
46
+ Subclasses must implement:
47
+ - _get_pool(): Return the database connection pool
48
+ - _create_table_if_not_exists(): Create required database tables
49
+ - Database-specific _find_similar() and _store_entry() implementations
50
+
51
+ Thread Safety:
52
+ Uses connection pooling for thread-safe database access.
53
+ All database operations use connection context managers.
54
+ """
55
+
56
+ # Additional attributes for persistent implementations
57
+ _pool: Any # ConnectionPool
58
+
59
+ @property
60
+ @abstractmethod
61
+ def database(self) -> DatabaseModel:
62
+ """The database used for storing cache entries."""
63
+ pass
64
+
65
+ @property
66
+ @abstractmethod
67
+ def table_name(self) -> str:
68
+ """Name of the cache table."""
69
+ pass
70
+
71
+ @property
72
+ @abstractmethod
73
+ def prompt_history_table(self) -> str:
74
+ """Name of the prompt history table."""
75
+ pass
76
+
77
+ @property
78
+ @abstractmethod
79
+ def context_window_size(self) -> int:
80
+ """Number of previous prompts to include in context."""
81
+ pass
82
+
83
+ @property
84
+ @abstractmethod
85
+ def max_context_tokens(self) -> int:
86
+ """Maximum tokens for context string."""
87
+ pass
88
+
89
+ @property
90
+ @abstractmethod
91
+ def context_similarity_threshold(self) -> float:
92
+ """Minimum similarity for context matching."""
93
+ pass
94
+
95
+ @property
96
+ @abstractmethod
97
+ def question_weight(self) -> float:
98
+ """Weight for question similarity in combined score."""
99
+ pass
100
+
101
+ @property
102
+ @abstractmethod
103
+ def context_weight(self) -> float:
104
+ """Weight for context similarity in combined score."""
105
+ pass
106
+
107
+ @property
108
+ @abstractmethod
109
+ def max_prompt_history_length(self) -> int:
110
+ """Maximum number of prompts to keep per conversation."""
111
+ pass
112
+
113
+ @property
114
+ @abstractmethod
115
+ def time_to_live_seconds(self) -> int | None:
116
+ """TTL in seconds (None or negative = never expires)."""
117
+ pass
118
+
119
+ @abstractmethod
120
+ def _create_table_if_not_exists(self) -> None:
121
+ """
122
+ Create the cache and prompt history tables if they don't exist.
123
+
124
+ This method should handle:
125
+ - Creating the cache table with vector columns
126
+ - Creating indexes for efficient similarity search
127
+ - Creating the prompt history table
128
+ - Handling schema migrations if needed
129
+ """
130
+ pass
131
+
132
+ def _execute_with_retry(
133
+ self,
134
+ operation: Callable[[], T],
135
+ max_attempts: int = 3,
136
+ base_delay: float = 1.0,
137
+ max_delay: float = 10.0,
138
+ ) -> T:
139
+ """
140
+ Execute a database operation with exponential backoff retry.
141
+
142
+ Args:
143
+ operation: The database operation to execute
144
+ max_attempts: Maximum number of retry attempts
145
+ base_delay: Initial delay between retries (seconds)
146
+ max_delay: Maximum delay between retries (seconds)
147
+
148
+ Returns:
149
+ The result of the operation
150
+
151
+ Raises:
152
+ The last exception if all retries fail
153
+ """
154
+ import time
155
+
156
+ last_exception: Exception | None = None
157
+ delay = base_delay
158
+
159
+ for attempt in range(max_attempts):
160
+ try:
161
+ return operation()
162
+ except Exception as e:
163
+ last_exception = e
164
+ error_str = str(e).lower()
165
+
166
+ # Check if this is a retryable error
167
+ retryable_errors = [
168
+ "connection",
169
+ "timeout",
170
+ "temporarily unavailable",
171
+ "too many connections",
172
+ "connection refused",
173
+ "operational error",
174
+ ]
175
+ is_retryable = any(err in error_str for err in retryable_errors)
176
+
177
+ if not is_retryable or attempt == max_attempts - 1:
178
+ raise
179
+
180
+ logger.warning(
181
+ f"Database operation failed (attempt {attempt + 1}/{max_attempts}), retrying",
182
+ layer=self.name,
183
+ error=str(e),
184
+ delay=delay,
185
+ )
186
+
187
+ time.sleep(delay)
188
+ delay = min(delay * 2, max_delay)
189
+
190
+ # Should not reach here, but just in case
191
+ if last_exception:
192
+ raise last_exception
193
+ raise RuntimeError("Unexpected state in retry logic")
194
+
195
+ def _index_exists(self, cur: Any, index_name: str) -> bool:
196
+ """
197
+ Check if an index already exists in the database.
198
+
199
+ Args:
200
+ cur: Database cursor to execute SQL statements
201
+ index_name: Name of the index to check
202
+
203
+ Returns:
204
+ True if the index exists, False otherwise
205
+ """
206
+ cur.execute(
207
+ "SELECT 1 FROM pg_indexes WHERE indexname = %s",
208
+ (index_name,),
209
+ )
210
+ return cur.fetchone() is not None
211
+
212
+ def _store_user_prompt(
213
+ self,
214
+ prompt: str,
215
+ conversation_id: str,
216
+ cache_hit: bool = False,
217
+ ) -> bool:
218
+ """
219
+ Store user prompt in local conversation history.
220
+
221
+ This is called after embeddings are generated to ensure the current prompt
222
+ is not included in its own context.
223
+
224
+ Prompt history is non-critical; failures are logged but don't crash the request.
225
+
226
+ Args:
227
+ prompt: The user's question/prompt
228
+ conversation_id: The conversation ID
229
+ cache_hit: Whether this prompt resulted in a cache hit
230
+
231
+ Returns:
232
+ True if prompt was stored successfully, False otherwise
233
+ """
234
+ prompt_table_name = self.prompt_history_table
235
+ insert_sql: str = f"""
236
+ INSERT INTO {prompt_table_name}
237
+ (genie_space_id, conversation_id, prompt, cache_hit)
238
+ VALUES (%s, %s, %s, %s)
239
+ """
240
+
241
+ logger.debug(
242
+ "Inserting prompt into history",
243
+ layer=self.name,
244
+ table=prompt_table_name,
245
+ space_id=self.space_id,
246
+ conversation_id=conversation_id,
247
+ prompt_preview=prompt[:80] if len(prompt) > 80 else prompt,
248
+ prompt_length=len(prompt),
249
+ cache_hit=cache_hit,
250
+ )
251
+
252
+ try:
253
+ with self._pool.connection() as conn:
254
+ with conn.cursor() as cur:
255
+ cur.execute(
256
+ insert_sql, (self.space_id, conversation_id, prompt, cache_hit)
257
+ )
258
+
259
+ logger.info(
260
+ "Stored user prompt in history",
261
+ layer=self.name,
262
+ table=prompt_table_name,
263
+ conversation_id=conversation_id,
264
+ prompt_preview=prompt[:50],
265
+ cache_hit=cache_hit,
266
+ )
267
+
268
+ # Enforce max_prompt_history_length per conversation
269
+ self._enforce_prompt_history_limit(conversation_id)
270
+
271
+ return True
272
+ except Exception as e:
273
+ logger.warning(
274
+ f"Failed to store prompt in history (non-critical): {e}",
275
+ layer=self.name,
276
+ table=prompt_table_name,
277
+ conversation_id=conversation_id,
278
+ )
279
+ return False
280
+
281
+ def _enforce_prompt_history_limit(self, conversation_id: str) -> int:
282
+ """
283
+ Delete oldest prompts if conversation exceeds max_prompt_history_length.
284
+
285
+ This is called after inserting a new prompt to keep history bounded.
286
+ Uses a single DELETE with subquery for efficiency.
287
+
288
+ Args:
289
+ conversation_id: The conversation ID to enforce limit for
290
+
291
+ Returns:
292
+ Number of prompts deleted (0 if within limit)
293
+ """
294
+ max_length = self.max_prompt_history_length
295
+ prompt_table_name = self.prompt_history_table
296
+
297
+ # Delete prompts beyond the limit, keeping the most recent ones
298
+ delete_sql: str = f"""
299
+ DELETE FROM {prompt_table_name}
300
+ WHERE genie_space_id = %s
301
+ AND conversation_id = %s
302
+ AND created_at < (
303
+ SELECT created_at FROM {prompt_table_name}
304
+ WHERE genie_space_id = %s
305
+ AND conversation_id = %s
306
+ ORDER BY created_at DESC
307
+ LIMIT 1 OFFSET %s
308
+ )
309
+ """
310
+
311
+ try:
312
+ with self._pool.connection() as conn:
313
+ with conn.cursor() as cur:
314
+ cur.execute(
315
+ delete_sql,
316
+ (
317
+ self.space_id,
318
+ conversation_id,
319
+ self.space_id,
320
+ conversation_id,
321
+ max_length - 1,
322
+ ),
323
+ )
324
+ deleted = cur.rowcount if isinstance(cur.rowcount, int) else 0
325
+
326
+ if deleted > 0:
327
+ logger.debug(
328
+ "Enforced prompt history limit",
329
+ layer=self.name,
330
+ table=prompt_table_name,
331
+ conversation_id=conversation_id,
332
+ max_length=max_length,
333
+ deleted=deleted,
334
+ )
335
+ return deleted
336
+ except Exception as e:
337
+ logger.debug(
338
+ f"Failed to enforce prompt history limit (non-critical): {e}",
339
+ layer=self.name,
340
+ conversation_id=conversation_id,
341
+ )
342
+ return 0
343
+
344
+ def _get_local_prompt_history(
345
+ self,
346
+ conversation_id: str,
347
+ max_prompts: int | None = None,
348
+ ) -> list[str]:
349
+ """
350
+ Retrieve recent user prompts from local storage.
351
+
352
+ Uses SQL LIMIT for efficiency - only retrieves exactly the number
353
+ of prompts needed for the context window, not all prompts.
354
+
355
+ Args:
356
+ conversation_id: The conversation ID to retrieve prompts for
357
+ max_prompts: Maximum number of prompts to retrieve
358
+
359
+ Returns:
360
+ List of prompt strings in chronological order (oldest to newest)
361
+ """
362
+ if max_prompts is None:
363
+ max_prompts = self.context_window_size
364
+
365
+ prompt_table_name = self.prompt_history_table
366
+ query_sql: str = f"""
367
+ SELECT prompt
368
+ FROM {prompt_table_name}
369
+ WHERE genie_space_id = %s
370
+ AND conversation_id = %s
371
+ ORDER BY created_at DESC
372
+ LIMIT %s
373
+ """
374
+
375
+ logger.debug(
376
+ "Querying prompt history",
377
+ layer=self.name,
378
+ table=prompt_table_name,
379
+ space_id=self.space_id,
380
+ conversation_id=conversation_id,
381
+ max_prompts=max_prompts,
382
+ )
383
+
384
+ with self._pool.connection() as conn:
385
+ with conn.cursor() as cur:
386
+ # LIMIT ensures we only fetch exactly what's needed
387
+ cur.execute(query_sql, (self.space_id, conversation_id, max_prompts))
388
+ rows: list[DbRow] = cur.fetchall()
389
+ # Reverse to get chronological order (oldest to newest)
390
+ prompts = [row["prompt"] for row in reversed(rows)]
391
+
392
+ logger.info(
393
+ "Retrieved prompt history from database",
394
+ layer=self.name,
395
+ table=prompt_table_name,
396
+ conversation_id=conversation_id,
397
+ requested=max_prompts,
398
+ returned=len(prompts),
399
+ prompts_preview=[
400
+ p[:40] + "..." if len(p) > 40 else p for p in prompts
401
+ ],
402
+ )
403
+
404
+ return prompts
405
+
406
+ def _update_prompt_cache_hit(
407
+ self,
408
+ conversation_id: str,
409
+ prompt: str,
410
+ cache_hit: bool,
411
+ cache_entry_id: int | None = None,
412
+ ) -> bool:
413
+ """
414
+ Update the cache_hit flag and cache_entry_id for a previously stored prompt.
415
+
416
+ This is called after determining whether the prompt resulted in a cache hit.
417
+ Updates the most recent prompt matching the given text.
418
+
419
+ Args:
420
+ conversation_id: The conversation ID
421
+ prompt: The prompt text to update
422
+ cache_hit: The cache hit status to set
423
+ cache_entry_id: The ID of the cache entry that served this hit (for traceability)
424
+
425
+ Returns:
426
+ True if update was successful, False otherwise
427
+ """
428
+ prompt_table_name = self.prompt_history_table
429
+ update_sql: str = f"""
430
+ UPDATE {prompt_table_name}
431
+ SET cache_hit = %s, cache_entry_id = %s
432
+ WHERE genie_space_id = %s
433
+ AND conversation_id = %s
434
+ AND prompt = %s
435
+ AND created_at = (
436
+ SELECT MAX(created_at)
437
+ FROM {prompt_table_name}
438
+ WHERE genie_space_id = %s
439
+ AND conversation_id = %s
440
+ AND prompt = %s
441
+ )
442
+ """
443
+
444
+ logger.debug(
445
+ "Updating prompt cache_hit flag and cache_entry_id",
446
+ layer=self.name,
447
+ table=prompt_table_name,
448
+ space_id=self.space_id,
449
+ conversation_id=conversation_id,
450
+ prompt_preview=prompt[:50],
451
+ new_cache_hit=cache_hit,
452
+ cache_entry_id=cache_entry_id,
453
+ )
454
+
455
+ try:
456
+ with self._pool.connection() as conn:
457
+ with conn.cursor() as cur:
458
+ cur.execute(
459
+ update_sql,
460
+ (
461
+ cache_hit,
462
+ cache_entry_id,
463
+ self.space_id,
464
+ conversation_id,
465
+ prompt,
466
+ self.space_id,
467
+ conversation_id,
468
+ prompt,
469
+ ),
470
+ )
471
+ # Handle rowcount safely (may be Mock in tests or None)
472
+ updated_rows = getattr(cur, "rowcount", 0)
473
+ if not isinstance(updated_rows, int):
474
+ updated_rows = 0
475
+
476
+ if updated_rows > 0:
477
+ logger.info(
478
+ "Updated prompt cache_hit flag and cache_entry_id in history",
479
+ layer=self.name,
480
+ table=prompt_table_name,
481
+ conversation_id=conversation_id,
482
+ prompt_preview=prompt[:50],
483
+ cache_hit=cache_hit,
484
+ cache_entry_id=cache_entry_id,
485
+ rows_updated=updated_rows,
486
+ )
487
+ return True
488
+ else:
489
+ logger.debug(
490
+ "No prompt found to update cache_hit flag (may be expected)",
491
+ layer=self.name,
492
+ table=prompt_table_name,
493
+ conversation_id=conversation_id,
494
+ prompt_preview=prompt[:50],
495
+ )
496
+ return False
497
+ except Exception as e:
498
+ logger.warning(
499
+ f"Failed to update prompt cache_hit flag and cache_entry_id (non-critical): {e}",
500
+ layer=self.name,
501
+ table=prompt_table_name,
502
+ conversation_id=conversation_id,
503
+ cache_entry_id=cache_entry_id,
504
+ )
505
+ return False
506
+
507
+ def _embed_question(
508
+ self, question: str, conversation_id: str | None = None
509
+ ) -> tuple[list[float], list[float], str]:
510
+ """
511
+ Generate dual embeddings using local prompt history for context.
512
+
513
+ This method retrieves conversation history from local storage first,
514
+ falling back to Genie API if local history is empty.
515
+
516
+ Args:
517
+ question: The question to embed
518
+ conversation_id: Optional conversation ID for retrieving context
519
+
520
+ Returns:
521
+ Tuple of (question_embedding, context_embedding, conversation_context_string)
522
+ """
523
+ conversation_context = ""
524
+
525
+ # If conversation context is enabled and available
526
+ if conversation_id is not None and self.context_window_size > 0:
527
+ try:
528
+ # Try local prompt history first (FASTER, includes cache hits)
529
+ recent_prompts = self._get_local_prompt_history(
530
+ conversation_id=conversation_id,
531
+ max_prompts=self.context_window_size,
532
+ )
533
+
534
+ logger.trace(
535
+ "Retrieved local prompt history",
536
+ layer=self.name,
537
+ prompts_count=len(recent_prompts),
538
+ conversation_id=conversation_id,
539
+ )
540
+
541
+ # Fallback to Genie API if local history empty and API available
542
+ if not recent_prompts and self.workspace_client is not None:
543
+ logger.debug(
544
+ "Local prompt history empty, falling back to Genie API",
545
+ layer=self.name,
546
+ conversation_id=conversation_id,
547
+ )
548
+
549
+ conversation_messages = get_conversation_history(
550
+ workspace_client=self.workspace_client,
551
+ space_id=self.space_id,
552
+ conversation_id=conversation_id,
553
+ max_messages=self.context_window_size * 2,
554
+ )
555
+
556
+ if conversation_messages:
557
+ recent_messages = (
558
+ conversation_messages[-self.context_window_size :]
559
+ if len(conversation_messages) > self.context_window_size
560
+ else conversation_messages
561
+ )
562
+ recent_prompts = [
563
+ msg.content for msg in recent_messages if msg.content
564
+ ]
565
+
566
+ # Build context string from prompts
567
+ if recent_prompts:
568
+ context_parts: list[str] = []
569
+ for prompt in recent_prompts:
570
+ content: str = prompt
571
+ if len(content) > 500:
572
+ content = content[:500] + "..."
573
+ context_parts.append(f"Previous: {content}")
574
+
575
+ conversation_context = "\n".join(context_parts)
576
+
577
+ # Truncate if too long
578
+ estimated_tokens = len(conversation_context) / 4
579
+ if estimated_tokens > self.max_context_tokens:
580
+ target_chars = self.max_context_tokens * 4
581
+ conversation_context = (
582
+ conversation_context[:target_chars] + "..."
583
+ )
584
+
585
+ logger.trace(
586
+ "Using conversation context",
587
+ layer=self.name,
588
+ prompts_count=len(recent_prompts),
589
+ window_size=self.context_window_size,
590
+ source="local_db",
591
+ )
592
+ except Exception as e:
593
+ logger.warning(
594
+ "Failed to build conversation context, using question only",
595
+ layer=self.name,
596
+ error=str(e),
597
+ )
598
+ conversation_context = ""
599
+
600
+ return self._generate_dual_embeddings(question, conversation_context)
601
+
602
+ def get_prompt_history(
603
+ self,
604
+ conversation_id: str,
605
+ max_prompts: int | None = None,
606
+ include_cache_hits: bool = True,
607
+ ) -> list[dict[str, Any]]:
608
+ """
609
+ Retrieve prompt history for a conversation with metadata.
610
+
611
+ Public utility method for inspecting conversation history.
612
+
613
+ Args:
614
+ conversation_id: The conversation ID to retrieve
615
+ max_prompts: Maximum number of prompts (None = all prompts)
616
+ include_cache_hits: Whether to include prompts that hit cache
617
+
618
+ Returns:
619
+ List of prompt records with metadata (prompt, cache_hit, created_at)
620
+ """
621
+ self._setup()
622
+
623
+ prompt_table_name = self.prompt_history_table
624
+
625
+ cache_filter = "" if include_cache_hits else "AND cache_hit = false"
626
+ limit_clause = f"LIMIT {max_prompts}" if max_prompts else ""
627
+
628
+ query_sql: str = f"""
629
+ SELECT prompt, cache_hit, created_at
630
+ FROM {prompt_table_name}
631
+ WHERE genie_space_id = %s
632
+ AND conversation_id = %s
633
+ {cache_filter}
634
+ ORDER BY created_at ASC
635
+ {limit_clause}
636
+ """
637
+
638
+ with self._pool.connection() as conn:
639
+ with conn.cursor() as cur:
640
+ cur.execute(query_sql, (self.space_id, conversation_id))
641
+ rows: list[DbRow] = cur.fetchall()
642
+
643
+ return [
644
+ {
645
+ "prompt": row["prompt"],
646
+ "cache_hit": row["cache_hit"],
647
+ "created_at": row["created_at"],
648
+ }
649
+ for row in rows
650
+ ]
651
+
652
+ def export_prompt_history(
653
+ self,
654
+ conversation_id: str,
655
+ output_format: str = "text",
656
+ ) -> str:
657
+ """
658
+ Export prompt history for a conversation in various formats.
659
+
660
+ Args:
661
+ conversation_id: The conversation ID to export
662
+ output_format: Format for export ("text", "json", "markdown")
663
+
664
+ Returns:
665
+ Formatted prompt history string
666
+ """
667
+ self._setup()
668
+
669
+ history = self.get_prompt_history(conversation_id)
670
+
671
+ if not history:
672
+ return "No prompt history found."
673
+
674
+ if output_format == "json":
675
+ import json
676
+
677
+ return json.dumps(history, indent=2, default=str)
678
+
679
+ elif output_format == "markdown":
680
+ lines = ["# Conversation History", ""]
681
+ for i, entry in enumerate(history, 1):
682
+ cache_mark = "HIT" if entry["cache_hit"] else "MISS"
683
+ lines.append(f"## Prompt {i} [{cache_mark}]")
684
+ lines.append(f"**Prompt**: {entry['prompt']}")
685
+ lines.append(f"**Cache Hit**: {entry['cache_hit']}")
686
+ lines.append(f"**Timestamp**: {entry['created_at']}")
687
+ lines.append("")
688
+ return "\n".join(lines)
689
+
690
+ else: # text format
691
+ lines = [f"Conversation: {conversation_id}", ""]
692
+ for i, entry in enumerate(history, 1):
693
+ cache_mark = "[CACHE HIT]" if entry["cache_hit"] else "[GENIE]"
694
+ lines.append(f"{i}. {cache_mark} {entry['prompt']}")
695
+ lines.append(f" Timestamp: {entry['created_at']}")
696
+ return "\n".join(lines)
697
+
698
+ def clear_prompt_history(self, conversation_id: str | None = None) -> int:
699
+ """
700
+ Clear prompt history for a conversation or entire space.
701
+
702
+ Args:
703
+ conversation_id: Specific conversation to clear (None = clear all for space)
704
+
705
+ Returns:
706
+ Number of prompts deleted
707
+ """
708
+ self._setup()
709
+
710
+ prompt_table_name = self.prompt_history_table
711
+
712
+ if conversation_id:
713
+ delete_sql: str = f"""
714
+ DELETE FROM {prompt_table_name}
715
+ WHERE genie_space_id = %s AND conversation_id = %s
716
+ """
717
+ params = (self.space_id, conversation_id)
718
+ else:
719
+ delete_sql = f"""
720
+ DELETE FROM {prompt_table_name}
721
+ WHERE genie_space_id = %s
722
+ """
723
+ params = (self.space_id,)
724
+
725
+ with self._pool.connection() as conn:
726
+ with conn.cursor() as cur:
727
+ cur.execute(delete_sql, params)
728
+ deleted: int = cur.rowcount
729
+
730
+ logger.info(
731
+ "Cleared prompt history",
732
+ layer=self.name,
733
+ conversation_id=conversation_id or "all",
734
+ deleted_count=deleted,
735
+ )
736
+
737
+ return deleted
738
+
739
+ def drop_tables(self) -> dict[str, bool]:
740
+ """
741
+ Drop both cache and prompt history tables.
742
+
743
+ This is useful for test cleanup to avoid accumulating test tables.
744
+
745
+ Returns:
746
+ Dict with 'cache' and 'prompt_history' keys indicating success
747
+ """
748
+ self._setup()
749
+
750
+ results: dict[str, bool] = {"cache": False, "prompt_history": False}
751
+
752
+ with self._pool.connection() as conn:
753
+ with conn.cursor() as cur:
754
+ # Drop cache table
755
+ try:
756
+ cur.execute(f"DROP TABLE IF EXISTS {self.table_name} CASCADE")
757
+ results["cache"] = True
758
+ logger.info(
759
+ "Dropped cache table",
760
+ layer=self.name,
761
+ table_name=self.table_name,
762
+ )
763
+ except Exception as e:
764
+ logger.warning(
765
+ f"Failed to drop cache table: {e}",
766
+ layer=self.name,
767
+ table_name=self.table_name,
768
+ )
769
+
770
+ # Drop prompt history table
771
+ try:
772
+ cur.execute(
773
+ f"DROP TABLE IF EXISTS {self.prompt_history_table} CASCADE"
774
+ )
775
+ results["prompt_history"] = True
776
+ logger.info(
777
+ "Dropped prompt history table",
778
+ layer=self.name,
779
+ table_name=self.prompt_history_table,
780
+ )
781
+ except Exception as e:
782
+ logger.warning(
783
+ f"Failed to drop prompt history table: {e}",
784
+ layer=self.name,
785
+ table_name=self.prompt_history_table,
786
+ )
787
+
788
+ return results
789
+
790
+ @property
791
+ def size(self) -> int:
792
+ """Current number of entries in the cache for this Genie space."""
793
+ self._setup()
794
+ count_sql: str = (
795
+ f"SELECT COUNT(*) as count FROM {self.table_name} WHERE genie_space_id = %s"
796
+ )
797
+
798
+ with self._pool.connection() as conn:
799
+ with conn.cursor() as cur:
800
+ cur.execute(count_sql, (self.space_id,))
801
+ row: DbRow | None = cur.fetchone()
802
+ return row.get("count", 0) if row else 0