hindsight-api 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hindsight_api/__init__.py +38 -0
- hindsight_api/api/__init__.py +105 -0
- hindsight_api/api/http.py +1872 -0
- hindsight_api/api/mcp.py +157 -0
- hindsight_api/engine/__init__.py +47 -0
- hindsight_api/engine/cross_encoder.py +97 -0
- hindsight_api/engine/db_utils.py +93 -0
- hindsight_api/engine/embeddings.py +113 -0
- hindsight_api/engine/entity_resolver.py +575 -0
- hindsight_api/engine/llm_wrapper.py +269 -0
- hindsight_api/engine/memory_engine.py +3095 -0
- hindsight_api/engine/query_analyzer.py +519 -0
- hindsight_api/engine/response_models.py +222 -0
- hindsight_api/engine/retain/__init__.py +50 -0
- hindsight_api/engine/retain/bank_utils.py +423 -0
- hindsight_api/engine/retain/chunk_storage.py +82 -0
- hindsight_api/engine/retain/deduplication.py +104 -0
- hindsight_api/engine/retain/embedding_processing.py +62 -0
- hindsight_api/engine/retain/embedding_utils.py +54 -0
- hindsight_api/engine/retain/entity_processing.py +90 -0
- hindsight_api/engine/retain/fact_extraction.py +1027 -0
- hindsight_api/engine/retain/fact_storage.py +176 -0
- hindsight_api/engine/retain/link_creation.py +121 -0
- hindsight_api/engine/retain/link_utils.py +651 -0
- hindsight_api/engine/retain/orchestrator.py +405 -0
- hindsight_api/engine/retain/types.py +206 -0
- hindsight_api/engine/search/__init__.py +15 -0
- hindsight_api/engine/search/fusion.py +122 -0
- hindsight_api/engine/search/observation_utils.py +132 -0
- hindsight_api/engine/search/reranking.py +103 -0
- hindsight_api/engine/search/retrieval.py +503 -0
- hindsight_api/engine/search/scoring.py +161 -0
- hindsight_api/engine/search/temporal_extraction.py +64 -0
- hindsight_api/engine/search/think_utils.py +255 -0
- hindsight_api/engine/search/trace.py +215 -0
- hindsight_api/engine/search/tracer.py +447 -0
- hindsight_api/engine/search/types.py +160 -0
- hindsight_api/engine/task_backend.py +223 -0
- hindsight_api/engine/utils.py +203 -0
- hindsight_api/metrics.py +227 -0
- hindsight_api/migrations.py +163 -0
- hindsight_api/models.py +309 -0
- hindsight_api/pg0.py +425 -0
- hindsight_api/web/__init__.py +12 -0
- hindsight_api/web/server.py +143 -0
- hindsight_api-0.0.13.dist-info/METADATA +41 -0
- hindsight_api-0.0.13.dist-info/RECORD +48 -0
- hindsight_api-0.0.13.dist-info/WHEEL +4 -0
hindsight_api/api/mcp.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Hindsight MCP Server implementation using FastMCP."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
|
|
7
|
+
from fastmcp import FastMCP
|
|
8
|
+
from hindsight_api import MemoryEngine
|
|
9
|
+
|
|
10
|
+
# Configure logging from HINDSIGHT_API_LOG_LEVEL environment variable
|
|
11
|
+
_log_level_str = os.environ.get("HINDSIGHT_API_LOG_LEVEL", "info").lower()
|
|
12
|
+
_log_level_map = {"critical": logging.CRITICAL, "error": logging.ERROR, "warning": logging.WARNING,
|
|
13
|
+
"info": logging.INFO, "debug": logging.DEBUG, "trace": logging.DEBUG}
|
|
14
|
+
logging.basicConfig(
|
|
15
|
+
level=_log_level_map.get(_log_level_str, logging.INFO),
|
|
16
|
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s"
|
|
17
|
+
)
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_mcp_server(memory: MemoryEngine) -> FastMCP:
|
|
22
|
+
"""
|
|
23
|
+
Create and configure the Hindsight MCP server.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
memory: MemoryEngine instance (required)
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Configured FastMCP server instance
|
|
30
|
+
"""
|
|
31
|
+
# Create FastMCP server
|
|
32
|
+
mcp = FastMCP("hindsight-mcp-server")
|
|
33
|
+
|
|
34
|
+
@mcp.tool()
|
|
35
|
+
async def hindsight_put(bank_id: str, content: str, context: str, explanation: str = "") -> str:
|
|
36
|
+
"""
|
|
37
|
+
**CRITICAL: Store important user information to long-term memory.**
|
|
38
|
+
|
|
39
|
+
**⚠️ PER-USER TOOL - REQUIRES USER IDENTIFICATION:**
|
|
40
|
+
- This tool is STRICTLY per-user. Each user MUST have a unique `bank_id`.
|
|
41
|
+
- ONLY use this tool if you have a valid user identifier (user ID, email, session ID, etc.) to map to `bank_id`.
|
|
42
|
+
- DO NOT use this tool if you cannot identify the specific user.
|
|
43
|
+
- DO NOT share memories between different users - each user's memories are isolated by their `bank_id`.
|
|
44
|
+
- If you don't have a user identifier, DO NOT use this tool at all.
|
|
45
|
+
|
|
46
|
+
Use this tool PROACTIVELY whenever the user shares:
|
|
47
|
+
- Personal facts, preferences, or interests (e.g., "I love hiking", "I'm a vegetarian")
|
|
48
|
+
- Important events or milestones (e.g., "I got promoted", "My birthday is June 15")
|
|
49
|
+
- User history, experiences, or background (e.g., "I used to work at Google", "I studied CS at MIT")
|
|
50
|
+
- Decisions, opinions, or stated preferences (e.g., "I prefer Python over JavaScript")
|
|
51
|
+
- Goals, plans, or future intentions (e.g., "I'm planning to visit Japan next year")
|
|
52
|
+
- Relationships or people mentioned (e.g., "My manager Sarah", "My wife Alice")
|
|
53
|
+
- Work context, projects, or responsibilities
|
|
54
|
+
- Any other information the user would want remembered for future conversations
|
|
55
|
+
|
|
56
|
+
**When to use**: Immediately after user shares personal information. Don't ask permission - just store it naturally.
|
|
57
|
+
|
|
58
|
+
**Context guidelines**: Use descriptive contexts like "personal_preferences", "work_history", "family", "hobbies",
|
|
59
|
+
"career_goals", "project_details", etc. This helps organize and retrieve related memories later.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
bank_id: **REQUIRED** - The unique, persistent identifier for this specific user (e.g., user_id, email, session_id).
|
|
63
|
+
This MUST be consistent across all interactions with the same user.
|
|
64
|
+
Example: "user_12345", "alice@example.com", "session_abc123"
|
|
65
|
+
content: The fact/memory to store (be specific and include relevant details)
|
|
66
|
+
context: Categorize the memory (e.g., 'personal_preferences', 'work_history', 'hobbies', 'family')
|
|
67
|
+
explanation: Optional explanation for why this memory is being stored
|
|
68
|
+
"""
|
|
69
|
+
try:
|
|
70
|
+
# Log explanation if provided
|
|
71
|
+
if explanation:
|
|
72
|
+
pass # Explanation provided
|
|
73
|
+
|
|
74
|
+
# Store memory using put_batch_async
|
|
75
|
+
await memory.put_batch_async(
|
|
76
|
+
bank_id=bank_id,
|
|
77
|
+
contents=[{"content": content, "context": context}]
|
|
78
|
+
)
|
|
79
|
+
return f"Fact stored successfully"
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.error(f"Error storing fact: {e}", exc_info=True)
|
|
82
|
+
return f"Error: {str(e)}"
|
|
83
|
+
|
|
84
|
+
@mcp.tool()
|
|
85
|
+
async def hindsight_search(bank_id: str, query: str, max_tokens: int = 4096, explanation: str = "") -> str:
|
|
86
|
+
"""
|
|
87
|
+
**CRITICAL: Search user's memory to provide personalized, context-aware responses.**
|
|
88
|
+
|
|
89
|
+
**⚠️ PER-USER TOOL - REQUIRES USER IDENTIFICATION:**
|
|
90
|
+
- This tool is STRICTLY per-user. Each user MUST have a unique `bank_id`.
|
|
91
|
+
- ONLY use this tool if you have a valid user identifier (user ID, email, session ID, etc.) to map to `bank_id`.
|
|
92
|
+
- DO NOT use this tool if you cannot identify the specific user.
|
|
93
|
+
- DO NOT search across multiple users - each user's memories are isolated by their `bank_id`.
|
|
94
|
+
- If you don't have a user identifier, DO NOT use this tool at all.
|
|
95
|
+
|
|
96
|
+
Use this tool PROACTIVELY at the start of conversations or when making recommendations to:
|
|
97
|
+
- Check user's preferences before making suggestions (e.g., "what foods does the user like?")
|
|
98
|
+
- Recall user's history to provide continuity (e.g., "what projects has the user worked on?")
|
|
99
|
+
- Remember user's goals and context (e.g., "what is the user trying to accomplish?")
|
|
100
|
+
- Avoid repeating information or asking questions you should already know
|
|
101
|
+
- Personalize responses based on user's background, interests, and past interactions
|
|
102
|
+
- Reference past conversations or events the user mentioned
|
|
103
|
+
|
|
104
|
+
**When to use**:
|
|
105
|
+
- Start of conversation: Search for relevant context about the user
|
|
106
|
+
- Before recommendations: Check user preferences and past experiences
|
|
107
|
+
- When user asks about something they may have mentioned before
|
|
108
|
+
- To provide continuity across conversations
|
|
109
|
+
|
|
110
|
+
**Search tips**: Use natural language queries like "user's programming language preferences",
|
|
111
|
+
"user's work experience", "user's dietary restrictions", "what does the user know about X?"
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
bank_id: **REQUIRED** - The unique, persistent identifier for this specific user (e.g., user_id, email, session_id).
|
|
115
|
+
This MUST be consistent across all interactions with the same user.
|
|
116
|
+
Example: "user_12345", "alice@example.com", "session_abc123"
|
|
117
|
+
query: Natural language search query to find relevant memories
|
|
118
|
+
max_tokens: Maximum tokens for search context (default: 4096)
|
|
119
|
+
explanation: Optional explanation for why this search is being performed
|
|
120
|
+
"""
|
|
121
|
+
try:
|
|
122
|
+
# Log all parameters for debugging
|
|
123
|
+
logger.info(f"hindsight_search called with: query={query!r}, max_tokens={max_tokens}, explanation={explanation!r}")
|
|
124
|
+
|
|
125
|
+
# Log explanation if provided
|
|
126
|
+
if explanation:
|
|
127
|
+
pass # Explanation provided
|
|
128
|
+
|
|
129
|
+
# Search using recall_async
|
|
130
|
+
from hindsight_api.engine.memory_engine import Budget
|
|
131
|
+
search_result = await memory.recall_async(
|
|
132
|
+
bank_id=bank_id,
|
|
133
|
+
query=query,
|
|
134
|
+
fact_type=["world", "bank", "opinion"], # Search all fact types
|
|
135
|
+
max_tokens=max_tokens,
|
|
136
|
+
budget=Budget.LOW
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Convert results to dict format
|
|
140
|
+
results = [
|
|
141
|
+
{
|
|
142
|
+
"id": fact.id,
|
|
143
|
+
"text": fact.text,
|
|
144
|
+
"type": fact.fact_type,
|
|
145
|
+
"context": fact.context,
|
|
146
|
+
"event_date": fact.event_date, # Already a string from the database
|
|
147
|
+
"document_id": fact.document_id
|
|
148
|
+
}
|
|
149
|
+
for fact in search_result.results
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
return json.dumps({"results": results}, indent=2)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
logger.error(f"Error searching: {e}", exc_info=True)
|
|
155
|
+
return json.dumps({"error": str(e), "results": []})
|
|
156
|
+
|
|
157
|
+
return mcp
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory Engine - Core implementation of the memory system.
|
|
3
|
+
|
|
4
|
+
This package contains all the implementation details of the memory engine:
|
|
5
|
+
- MemoryEngine: Main class for memory operations
|
|
6
|
+
- Utility modules: embedding_utils, link_utils, think_utils, bank_utils
|
|
7
|
+
- Supporting modules: embeddings, cross_encoder, entity_resolver, etc.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from .memory_engine import MemoryEngine
|
|
11
|
+
from .db_utils import acquire_with_retry
|
|
12
|
+
from .embeddings import Embeddings, SentenceTransformersEmbeddings
|
|
13
|
+
from .search.trace import (
|
|
14
|
+
SearchTrace,
|
|
15
|
+
QueryInfo,
|
|
16
|
+
EntryPoint,
|
|
17
|
+
NodeVisit,
|
|
18
|
+
WeightComponents,
|
|
19
|
+
LinkInfo,
|
|
20
|
+
PruningDecision,
|
|
21
|
+
SearchSummary,
|
|
22
|
+
SearchPhaseMetrics,
|
|
23
|
+
)
|
|
24
|
+
from .search.tracer import SearchTracer
|
|
25
|
+
from .llm_wrapper import LLMConfig
|
|
26
|
+
from .response_models import RecallResult, ReflectResult, MemoryFact
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"MemoryEngine",
|
|
30
|
+
"acquire_with_retry",
|
|
31
|
+
"Embeddings",
|
|
32
|
+
"SentenceTransformersEmbeddings",
|
|
33
|
+
"SearchTrace",
|
|
34
|
+
"SearchTracer",
|
|
35
|
+
"QueryInfo",
|
|
36
|
+
"EntryPoint",
|
|
37
|
+
"NodeVisit",
|
|
38
|
+
"WeightComponents",
|
|
39
|
+
"LinkInfo",
|
|
40
|
+
"PruningDecision",
|
|
41
|
+
"SearchSummary",
|
|
42
|
+
"SearchPhaseMetrics",
|
|
43
|
+
"LLMConfig",
|
|
44
|
+
"RecallResult",
|
|
45
|
+
"ReflectResult",
|
|
46
|
+
"MemoryFact",
|
|
47
|
+
]
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cross-encoder abstraction for reranking.
|
|
3
|
+
|
|
4
|
+
Provides an interface for reranking with different backends.
|
|
5
|
+
"""
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import List, Tuple
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CrossEncoderModel(ABC):
|
|
14
|
+
"""
|
|
15
|
+
Abstract base class for cross-encoder reranking.
|
|
16
|
+
|
|
17
|
+
Cross-encoders take query-document pairs and return relevance scores.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def load(self) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Load the cross-encoder model.
|
|
24
|
+
|
|
25
|
+
This should be called during initialization to load the model
|
|
26
|
+
and avoid cold start latency on first predict() call.
|
|
27
|
+
"""
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
32
|
+
"""
|
|
33
|
+
Score query-document pairs for relevance.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
pairs: List of (query, document) tuples to score
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
List of relevance scores (higher = more relevant)
|
|
40
|
+
"""
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SentenceTransformersCrossEncoder(CrossEncoderModel):
|
|
45
|
+
"""
|
|
46
|
+
Cross-encoder implementation using SentenceTransformers.
|
|
47
|
+
|
|
48
|
+
Call load() during initialization to load the model and avoid cold starts.
|
|
49
|
+
|
|
50
|
+
Default model is cross-encoder/ms-marco-MiniLM-L-6-v2:
|
|
51
|
+
- Fast inference (~80ms for 100 pairs on CPU)
|
|
52
|
+
- Small model (80MB)
|
|
53
|
+
- Trained for passage re-ranking
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
|
57
|
+
"""
|
|
58
|
+
Initialize SentenceTransformers cross-encoder.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
model_name: Name of the CrossEncoder model to use.
|
|
62
|
+
Default: cross-encoder/ms-marco-MiniLM-L-6-v2
|
|
63
|
+
"""
|
|
64
|
+
self.model_name = model_name
|
|
65
|
+
self._model = None
|
|
66
|
+
|
|
67
|
+
def load(self) -> None:
|
|
68
|
+
"""Load the cross-encoder model."""
|
|
69
|
+
if self._model is not None:
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
from sentence_transformers import CrossEncoder
|
|
74
|
+
except ImportError:
|
|
75
|
+
raise ImportError(
|
|
76
|
+
"sentence-transformers is required for SentenceTransformersCrossEncoder. "
|
|
77
|
+
"Install it with: pip install sentence-transformers"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
logger.info(f"Loading cross-encoder model: {self.model_name}...")
|
|
81
|
+
self._model = CrossEncoder(self.model_name)
|
|
82
|
+
logger.info("Cross-encoder model loaded")
|
|
83
|
+
|
|
84
|
+
def predict(self, pairs: List[Tuple[str, str]]) -> List[float]:
|
|
85
|
+
"""
|
|
86
|
+
Score query-document pairs for relevance.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
pairs: List of (query, document) tuples to score
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
List of relevance scores (raw logits from the model)
|
|
93
|
+
"""
|
|
94
|
+
if self._model is None:
|
|
95
|
+
self.load()
|
|
96
|
+
scores = self._model.predict(pairs, show_progress_bar=False)
|
|
97
|
+
return scores.tolist() if hasattr(scores, 'tolist') else list(scores)
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Database utility functions for connection management with retry logic.
|
|
3
|
+
"""
|
|
4
|
+
import asyncio
|
|
5
|
+
import logging
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
import asyncpg
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
# Default retry configuration for database operations
|
|
12
|
+
DEFAULT_MAX_RETRIES = 3
|
|
13
|
+
DEFAULT_BASE_DELAY = 0.5 # seconds
|
|
14
|
+
DEFAULT_MAX_DELAY = 5.0 # seconds
|
|
15
|
+
|
|
16
|
+
# Exceptions that indicate transient connection issues worth retrying
|
|
17
|
+
RETRYABLE_EXCEPTIONS = (
|
|
18
|
+
asyncpg.exceptions.InterfaceError,
|
|
19
|
+
asyncpg.exceptions.ConnectionDoesNotExistError,
|
|
20
|
+
asyncpg.exceptions.TooManyConnectionsError,
|
|
21
|
+
OSError,
|
|
22
|
+
ConnectionError,
|
|
23
|
+
asyncio.TimeoutError,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def retry_with_backoff(
|
|
28
|
+
func,
|
|
29
|
+
max_retries: int = DEFAULT_MAX_RETRIES,
|
|
30
|
+
base_delay: float = DEFAULT_BASE_DELAY,
|
|
31
|
+
max_delay: float = DEFAULT_MAX_DELAY,
|
|
32
|
+
retryable_exceptions: tuple = RETRYABLE_EXCEPTIONS,
|
|
33
|
+
):
|
|
34
|
+
"""
|
|
35
|
+
Execute an async function with exponential backoff retry.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
func: Async function to execute
|
|
39
|
+
max_retries: Maximum number of retry attempts
|
|
40
|
+
base_delay: Initial delay between retries (seconds)
|
|
41
|
+
max_delay: Maximum delay between retries (seconds)
|
|
42
|
+
retryable_exceptions: Tuple of exception types to retry on
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Result of the function
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
The last exception if all retries fail
|
|
49
|
+
"""
|
|
50
|
+
last_exception = None
|
|
51
|
+
for attempt in range(max_retries + 1):
|
|
52
|
+
try:
|
|
53
|
+
return await func()
|
|
54
|
+
except retryable_exceptions as e:
|
|
55
|
+
last_exception = e
|
|
56
|
+
if attempt < max_retries:
|
|
57
|
+
delay = min(base_delay * (2 ** attempt), max_delay)
|
|
58
|
+
logger.warning(
|
|
59
|
+
f"Database operation failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
|
|
60
|
+
f"Retrying in {delay:.1f}s..."
|
|
61
|
+
)
|
|
62
|
+
await asyncio.sleep(delay)
|
|
63
|
+
else:
|
|
64
|
+
logger.error(
|
|
65
|
+
f"Database operation failed after {max_retries + 1} attempts: {e}"
|
|
66
|
+
)
|
|
67
|
+
raise last_exception
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@asynccontextmanager
|
|
71
|
+
async def acquire_with_retry(pool: asyncpg.Pool, max_retries: int = DEFAULT_MAX_RETRIES):
|
|
72
|
+
"""
|
|
73
|
+
Async context manager to acquire a connection with retry logic.
|
|
74
|
+
|
|
75
|
+
Usage:
|
|
76
|
+
async with acquire_with_retry(pool) as conn:
|
|
77
|
+
await conn.execute(...)
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
pool: The asyncpg connection pool
|
|
81
|
+
max_retries: Maximum number of retry attempts
|
|
82
|
+
|
|
83
|
+
Yields:
|
|
84
|
+
An asyncpg connection
|
|
85
|
+
"""
|
|
86
|
+
async def acquire():
|
|
87
|
+
return await pool.acquire()
|
|
88
|
+
|
|
89
|
+
conn = await retry_with_backoff(acquire, max_retries=max_retries)
|
|
90
|
+
try:
|
|
91
|
+
yield conn
|
|
92
|
+
finally:
|
|
93
|
+
await pool.release(conn)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embeddings abstraction for the memory system.
|
|
3
|
+
|
|
4
|
+
Provides an interface for generating embeddings with different backends.
|
|
5
|
+
|
|
6
|
+
IMPORTANT: All embeddings must produce 384-dimensional vectors to match
|
|
7
|
+
the database schema (pgvector column defined as vector(384)).
|
|
8
|
+
"""
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from typing import List
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
# Fixed embedding dimension required by database schema
|
|
16
|
+
EMBEDDING_DIMENSION = 384
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Embeddings(ABC):
|
|
20
|
+
"""
|
|
21
|
+
Abstract base class for embedding generation.
|
|
22
|
+
|
|
23
|
+
All implementations MUST generate 384-dimensional embeddings to match
|
|
24
|
+
the database schema.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def load(self) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Load the embedding model.
|
|
31
|
+
|
|
32
|
+
This should be called during initialization to load the model
|
|
33
|
+
and avoid cold start latency on first encode() call.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
def encode(self, texts: List[str]) -> List[List[float]]:
|
|
39
|
+
"""
|
|
40
|
+
Generate 384-dimensional embeddings for a list of texts.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
texts: List of text strings to encode
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
List of 384-dimensional embedding vectors (each is a list of floats)
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SentenceTransformersEmbeddings(Embeddings):
|
|
52
|
+
"""
|
|
53
|
+
Embeddings implementation using SentenceTransformers.
|
|
54
|
+
|
|
55
|
+
Call load() during initialization to load the model and avoid cold starts.
|
|
56
|
+
|
|
57
|
+
Default model is BAAI/bge-small-en-v1.5 which produces 384-dimensional
|
|
58
|
+
embeddings matching the database schema.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"):
|
|
62
|
+
"""
|
|
63
|
+
Initialize SentenceTransformers embeddings.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
model_name: Name of the SentenceTransformer model to use.
|
|
67
|
+
Must produce 384-dimensional embeddings.
|
|
68
|
+
Default: BAAI/bge-small-en-v1.5
|
|
69
|
+
"""
|
|
70
|
+
self.model_name = model_name
|
|
71
|
+
self._model = None
|
|
72
|
+
|
|
73
|
+
def load(self) -> None:
|
|
74
|
+
"""Load the embedding model."""
|
|
75
|
+
if self._model is not None:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
from sentence_transformers import SentenceTransformer
|
|
80
|
+
except ImportError:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
"sentence-transformers is required for SentenceTransformersEmbeddings. "
|
|
83
|
+
"Install it with: pip install sentence-transformers"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
logger.info(f"Loading embedding model: {self.model_name}...")
|
|
87
|
+
self._model = SentenceTransformer(self.model_name)
|
|
88
|
+
|
|
89
|
+
# Validate dimension matches database schema
|
|
90
|
+
model_dim = self._model.get_sentence_embedding_dimension()
|
|
91
|
+
if model_dim != EMBEDDING_DIMENSION:
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"Model {self.model_name} produces {model_dim}-dimensional embeddings, "
|
|
94
|
+
f"but database schema requires {EMBEDDING_DIMENSION} dimensions. "
|
|
95
|
+
f"Use a model that produces {EMBEDDING_DIMENSION}-dimensional embeddings."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
logger.info(f"Model loaded (embedding dim: {model_dim})")
|
|
99
|
+
|
|
100
|
+
def encode(self, texts: List[str]) -> List[List[float]]:
|
|
101
|
+
"""
|
|
102
|
+
Generate 384-dimensional embeddings for a list of texts.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
texts: List of text strings to encode
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
List of 384-dimensional embedding vectors
|
|
109
|
+
"""
|
|
110
|
+
if self._model is None:
|
|
111
|
+
self.load()
|
|
112
|
+
embeddings = self._model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
113
|
+
return [emb.tolist() for emb in embeddings]
|