vector-memory-mcp 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- main.py +339 -0
- src/__init__.py +45 -0
- src/embeddings.py +235 -0
- src/memory_store.py +512 -0
- src/models.py +202 -0
- src/security.py +277 -0
- vector_memory_mcp-1.0.0.dist-info/METADATA +586 -0
- vector_memory_mcp-1.0.0.dist-info/RECORD +12 -0
- vector_memory_mcp-1.0.0.dist-info/WHEEL +5 -0
- vector_memory_mcp-1.0.0.dist-info/entry_points.txt +2 -0
- vector_memory_mcp-1.0.0.dist-info/licenses/LICENSE +21 -0
- vector_memory_mcp-1.0.0.dist-info/top_level.txt +2 -0
main.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
#!/usr/bin/env -S uv run --script
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
# /// script
|
|
4
|
+
# dependencies = [
|
|
5
|
+
# "mcp>=0.3.0",
|
|
6
|
+
# "sqlite-vec>=0.1.6",
|
|
7
|
+
# "sentence-transformers>=2.2.2"
|
|
8
|
+
# ]
|
|
9
|
+
# requires-python = ">=3.8"
|
|
10
|
+
# ///
|
|
11
|
+
|
|
12
|
+
"""
|
|
13
|
+
Vector Memory MCP Server - Main Entry Point
|
|
14
|
+
===========================================
|
|
15
|
+
|
|
16
|
+
A secure, vector-based memory server using sqlite-vec for semantic search.
|
|
17
|
+
Stores and retrieves coding memories, experiences, and knowledge using
|
|
18
|
+
384-dimensional embeddings generated by sentence-transformers.
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
python main.py --working-dir /path/to/project
|
|
22
|
+
|
|
23
|
+
Memory files stored in: {working_dir}/memory/vector_memory.db
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import sys
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Dict, Any
|
|
29
|
+
|
|
30
|
+
# Add src to path for imports
|
|
31
|
+
sys.path.insert(0, str(Path(__file__).parent / "src"))
|
|
32
|
+
|
|
33
|
+
from mcp.server.fastmcp import FastMCP
|
|
34
|
+
|
|
35
|
+
# Import our modules
|
|
36
|
+
from src.models import Config
|
|
37
|
+
from src.security import validate_working_dir, SecurityError
|
|
38
|
+
from src.memory_store import VectorMemoryStore
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_working_dir() -> Path:
|
|
42
|
+
"""Get working directory from command line arguments"""
|
|
43
|
+
if len(sys.argv) >= 3 and sys.argv[1] == "--working-dir":
|
|
44
|
+
return validate_working_dir(sys.argv[2])
|
|
45
|
+
else:
|
|
46
|
+
# Default to current directory
|
|
47
|
+
return validate_working_dir(".")
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def create_server() -> FastMCP:
|
|
51
|
+
"""Create and configure the MCP server"""
|
|
52
|
+
|
|
53
|
+
# Initialize global memory store
|
|
54
|
+
try:
|
|
55
|
+
memory_dir = get_working_dir()
|
|
56
|
+
db_path = memory_dir / Config.DB_NAME
|
|
57
|
+
memory_store = VectorMemoryStore(db_path)
|
|
58
|
+
print(f"Memory database initialized: {db_path}", file=sys.stderr)
|
|
59
|
+
except Exception as e:
|
|
60
|
+
print(f"Failed to initialize memory store: {e}", file=sys.stderr)
|
|
61
|
+
sys.exit(1)
|
|
62
|
+
|
|
63
|
+
# Create FastMCP server
|
|
64
|
+
mcp = FastMCP(Config.SERVER_NAME)
|
|
65
|
+
|
|
66
|
+
# ===============================================================================
|
|
67
|
+
# MCP TOOLS IMPLEMENTATION
|
|
68
|
+
# ===============================================================================
|
|
69
|
+
|
|
70
|
+
@mcp.tool()
|
|
71
|
+
def store_memory(
|
|
72
|
+
content: str,
|
|
73
|
+
category: str = "other",
|
|
74
|
+
tags: list[str] = None
|
|
75
|
+
) -> dict[str, Any]:
|
|
76
|
+
"""
|
|
77
|
+
Store coding memory with vector embedding for semantic search.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
content: Memory content (max 10K chars)
|
|
81
|
+
category: code-solution, bug-fix, architecture, learning, tool-usage, debugging, performance, security, other
|
|
82
|
+
tags: Tags for organization (max 10)
|
|
83
|
+
"""
|
|
84
|
+
try:
|
|
85
|
+
if tags is None:
|
|
86
|
+
tags = []
|
|
87
|
+
|
|
88
|
+
result = memory_store.store_memory(content, category, tags)
|
|
89
|
+
return result
|
|
90
|
+
|
|
91
|
+
except SecurityError as e:
|
|
92
|
+
return {
|
|
93
|
+
"success": False,
|
|
94
|
+
"error": "Security validation failed",
|
|
95
|
+
"message": str(e)
|
|
96
|
+
}
|
|
97
|
+
except Exception as e:
|
|
98
|
+
return {
|
|
99
|
+
"success": False,
|
|
100
|
+
"error": "Storage failed",
|
|
101
|
+
"message": str(e)
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
@mcp.tool()
|
|
105
|
+
def search_memories(
|
|
106
|
+
query: str,
|
|
107
|
+
limit: int = 10,
|
|
108
|
+
category: str = None
|
|
109
|
+
) -> dict[str, Any]:
|
|
110
|
+
"""
|
|
111
|
+
Search memories using semantic similarity (vector search).
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
query: Search query
|
|
115
|
+
limit: Max results (1-50, default 10)
|
|
116
|
+
category: Optional category filter
|
|
117
|
+
"""
|
|
118
|
+
try:
|
|
119
|
+
search_results = memory_store.search_memories(query, limit, category)
|
|
120
|
+
|
|
121
|
+
if not search_results:
|
|
122
|
+
return {
|
|
123
|
+
"success": True,
|
|
124
|
+
"results": [],
|
|
125
|
+
"message": "No matching memories found. Try different keywords or broader terms."
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
# Convert SearchResult objects to dictionaries
|
|
129
|
+
results = [result.to_dict() for result in search_results]
|
|
130
|
+
|
|
131
|
+
return {
|
|
132
|
+
"success": True,
|
|
133
|
+
"query": query,
|
|
134
|
+
"results": results,
|
|
135
|
+
"count": len(results),
|
|
136
|
+
"message": f"Found {len(results)} relevant memories"
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
except SecurityError as e:
|
|
140
|
+
return {
|
|
141
|
+
"success": False,
|
|
142
|
+
"error": "Security validation failed",
|
|
143
|
+
"message": str(e)
|
|
144
|
+
}
|
|
145
|
+
except Exception as e:
|
|
146
|
+
return {
|
|
147
|
+
"success": False,
|
|
148
|
+
"error": "Search failed",
|
|
149
|
+
"message": str(e)
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
@mcp.tool()
|
|
153
|
+
def list_recent_memories(limit: int = 10) -> dict[str, Any]:
|
|
154
|
+
"""
|
|
155
|
+
List recent memories in chronological order.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
limit: Max results (1-50, default 10)
|
|
159
|
+
"""
|
|
160
|
+
try:
|
|
161
|
+
limit = min(max(1, limit), Config.MAX_MEMORIES_PER_SEARCH)
|
|
162
|
+
memories = memory_store.get_recent_memories(limit)
|
|
163
|
+
|
|
164
|
+
# Convert MemoryEntry objects to dictionaries
|
|
165
|
+
memory_dicts = [memory.to_dict() for memory in memories]
|
|
166
|
+
|
|
167
|
+
return {
|
|
168
|
+
"success": True,
|
|
169
|
+
"memories": memory_dicts,
|
|
170
|
+
"count": len(memory_dicts),
|
|
171
|
+
"message": f"Retrieved {len(memory_dicts)} recent memories"
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
return {
|
|
176
|
+
"success": False,
|
|
177
|
+
"error": "Failed to get recent memories",
|
|
178
|
+
"message": str(e)
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
@mcp.tool()
|
|
182
|
+
def get_memory_stats() -> dict[str, Any]:
|
|
183
|
+
"""Get database statistics (total memories, categories, usage, health)."""
|
|
184
|
+
try:
|
|
185
|
+
stats = memory_store.get_stats()
|
|
186
|
+
result = stats.to_dict()
|
|
187
|
+
result["success"] = True
|
|
188
|
+
return result
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
return {
|
|
192
|
+
"success": False,
|
|
193
|
+
"error": "Failed to get statistics",
|
|
194
|
+
"message": str(e)
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
@mcp.tool()
|
|
198
|
+
def clear_old_memories(
|
|
199
|
+
days_old: int = 30,
|
|
200
|
+
max_to_keep: int = 1000
|
|
201
|
+
) -> dict[str, Any]:
|
|
202
|
+
"""
|
|
203
|
+
Clear old memories to free space (keeps frequently accessed).
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
days_old: Min age in days (default 30)
|
|
207
|
+
max_to_keep: Max total memories (default 1000)
|
|
208
|
+
"""
|
|
209
|
+
try:
|
|
210
|
+
if days_old < 1:
|
|
211
|
+
return {
|
|
212
|
+
"success": False,
|
|
213
|
+
"error": "Invalid parameter",
|
|
214
|
+
"message": "days_old must be at least 1"
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
result = memory_store.clear_old_memories(days_old, max_to_keep)
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
except SecurityError as e:
|
|
221
|
+
return {
|
|
222
|
+
"success": False,
|
|
223
|
+
"error": "Security validation failed",
|
|
224
|
+
"message": str(e)
|
|
225
|
+
}
|
|
226
|
+
except Exception as e:
|
|
227
|
+
return {
|
|
228
|
+
"success": False,
|
|
229
|
+
"error": "Cleanup failed",
|
|
230
|
+
"message": str(e)
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
@mcp.tool()
|
|
234
|
+
def get_by_memory_id(memory_id: int) -> dict[str, Any]:
|
|
235
|
+
"""
|
|
236
|
+
Get specific memory by ID.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
memory_id: Memory ID to retrieve
|
|
240
|
+
"""
|
|
241
|
+
try:
|
|
242
|
+
if not isinstance(memory_id, int) or memory_id < 1:
|
|
243
|
+
return {
|
|
244
|
+
"success": False,
|
|
245
|
+
"error": "Invalid parameter",
|
|
246
|
+
"message": "memory_id must be a positive integer"
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
memory = memory_store.get_memory_by_id(memory_id)
|
|
250
|
+
|
|
251
|
+
if memory is None:
|
|
252
|
+
return {
|
|
253
|
+
"success": False,
|
|
254
|
+
"error": "Not found",
|
|
255
|
+
"message": f"Memory with ID {memory_id} not found"
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
return {
|
|
259
|
+
"success": True,
|
|
260
|
+
"memory": memory.to_dict(),
|
|
261
|
+
"message": "Memory retrieved successfully"
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
except Exception as e:
|
|
265
|
+
return {
|
|
266
|
+
"success": False,
|
|
267
|
+
"error": "Retrieval failed",
|
|
268
|
+
"message": str(e)
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
@mcp.tool()
|
|
272
|
+
def delete_by_memory_id(memory_id: int) -> dict[str, Any]:
|
|
273
|
+
"""
|
|
274
|
+
Delete memory by ID (permanent, cannot be undone).
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
memory_id: Memory ID to delete
|
|
278
|
+
"""
|
|
279
|
+
try:
|
|
280
|
+
if not isinstance(memory_id, int) or memory_id < 1:
|
|
281
|
+
return {
|
|
282
|
+
"success": False,
|
|
283
|
+
"error": "Invalid parameter",
|
|
284
|
+
"message": "memory_id must be a positive integer"
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
deleted = memory_store.delete_memory(memory_id)
|
|
288
|
+
|
|
289
|
+
if not deleted:
|
|
290
|
+
return {
|
|
291
|
+
"success": False,
|
|
292
|
+
"error": "Not found",
|
|
293
|
+
"message": f"Memory with ID {memory_id} not found"
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
return {
|
|
297
|
+
"success": True,
|
|
298
|
+
"memory_id": memory_id,
|
|
299
|
+
"message": "Memory deleted successfully from both metadata and vector tables"
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
except Exception as e:
|
|
303
|
+
return {
|
|
304
|
+
"success": False,
|
|
305
|
+
"error": "Deletion failed",
|
|
306
|
+
"message": str(e)
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
return mcp
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def main():
|
|
313
|
+
"""Main entry point"""
|
|
314
|
+
print(f"Starting {Config.SERVER_NAME} v{Config.SERVER_VERSION}", file=sys.stderr)
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
# Get working directory info
|
|
318
|
+
memory_dir = get_working_dir()
|
|
319
|
+
db_path = memory_dir / Config.DB_NAME
|
|
320
|
+
|
|
321
|
+
print(f"Working directory: {memory_dir.parent}", file=sys.stderr)
|
|
322
|
+
print(f"Memory database: {db_path}", file=sys.stderr)
|
|
323
|
+
print(f"Embedding model: {Config.EMBEDDING_MODEL}", file=sys.stderr)
|
|
324
|
+
print("=" * 50, file=sys.stderr)
|
|
325
|
+
|
|
326
|
+
# Create and run server
|
|
327
|
+
server = create_server()
|
|
328
|
+
print("Server ready for connections...", file=sys.stderr)
|
|
329
|
+
server.run()
|
|
330
|
+
|
|
331
|
+
except KeyboardInterrupt:
|
|
332
|
+
print("\nServer stopped by user", file=sys.stderr)
|
|
333
|
+
except Exception as e:
|
|
334
|
+
print(f"Server failed to start: {e}", file=sys.stderr)
|
|
335
|
+
sys.exit(1)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
if __name__ == "__main__":
|
|
339
|
+
main()
|
src/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Vector Memory MCP Server - Core Package
|
|
3
|
+
=======================================
|
|
4
|
+
|
|
5
|
+
This package provides vector-based memory capabilities for Claude Desktop
|
|
6
|
+
using sqlite-vec and sentence-transformers.
|
|
7
|
+
|
|
8
|
+
Modules:
|
|
9
|
+
models: Data models and type definitions
|
|
10
|
+
security: Security utilities and validation
|
|
11
|
+
embeddings: Sentence transformer wrapper (requires sentence-transformers)
|
|
12
|
+
memory_store: SQLite-vec operations and storage (requires sqlite-vec)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
__version__ = "1.0.0"
|
|
16
|
+
__author__ = "Vector Memory MCP Server"
|
|
17
|
+
|
|
18
|
+
# Core modules that don't require external dependencies
|
|
19
|
+
from .models import MemoryEntry, MemoryCategory, SearchResult, Config
|
|
20
|
+
from .security import SecurityError, validate_working_dir, sanitize_input
|
|
21
|
+
|
|
22
|
+
# Optional imports that require external dependencies
|
|
23
|
+
# These are imported lazily to avoid import errors when dependencies aren't available
|
|
24
|
+
|
|
25
|
+
def get_embedding_model():
|
|
26
|
+
"""Get embedding model (requires sentence-transformers)"""
|
|
27
|
+
from .embeddings import EmbeddingModel
|
|
28
|
+
return EmbeddingModel
|
|
29
|
+
|
|
30
|
+
def get_vector_memory_store():
|
|
31
|
+
"""Get vector memory store (requires sqlite-vec)"""
|
|
32
|
+
from .memory_store import VectorMemoryStore
|
|
33
|
+
return VectorMemoryStore
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"MemoryEntry",
|
|
37
|
+
"MemoryCategory",
|
|
38
|
+
"SearchResult",
|
|
39
|
+
"Config",
|
|
40
|
+
"SecurityError",
|
|
41
|
+
"validate_working_dir",
|
|
42
|
+
"sanitize_input",
|
|
43
|
+
"get_embedding_model",
|
|
44
|
+
"get_vector_memory_store"
|
|
45
|
+
]
|
src/embeddings.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embeddings Module
|
|
3
|
+
=================
|
|
4
|
+
|
|
5
|
+
Provides text embedding generation using sentence-transformers.
|
|
6
|
+
Handles model initialization, caching, and vector operations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
from typing import List, Optional
|
|
12
|
+
import numpy as np
|
|
13
|
+
from sentence_transformers import SentenceTransformer
|
|
14
|
+
|
|
15
|
+
from .models import Config
|
|
16
|
+
from .security import SecurityError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EmbeddingModel:
|
|
20
|
+
"""
|
|
21
|
+
Wrapper for sentence-transformers model with caching and validation.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, model_name: str = None, cache_dir: str = None):
|
|
25
|
+
"""
|
|
26
|
+
Initialize embedding model.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_name: Name of the sentence-transformers model
|
|
30
|
+
cache_dir: Directory to cache the model
|
|
31
|
+
"""
|
|
32
|
+
self.model_name = model_name or Config.EMBEDDING_MODEL
|
|
33
|
+
self.cache_dir = cache_dir
|
|
34
|
+
self.model: Optional[SentenceTransformer] = None
|
|
35
|
+
self._embedding_dim: Optional[int] = None
|
|
36
|
+
|
|
37
|
+
def _initialize_model(self) -> None:
|
|
38
|
+
"""Initialize the sentence transformer model."""
|
|
39
|
+
try:
|
|
40
|
+
# Set cache directory if provided
|
|
41
|
+
if self.cache_dir:
|
|
42
|
+
os.environ['SENTENCE_TRANSFORMERS_HOME'] = self.cache_dir
|
|
43
|
+
|
|
44
|
+
print(f"Loading embedding model: {self.model_name}", file=sys.stderr)
|
|
45
|
+
self.model = SentenceTransformer(self.model_name)
|
|
46
|
+
|
|
47
|
+
# Verify model dimensions
|
|
48
|
+
test_embedding = self.model.encode(["test"], normalize_embeddings=True)
|
|
49
|
+
self._embedding_dim = test_embedding.shape[1]
|
|
50
|
+
|
|
51
|
+
if self._embedding_dim != Config.EMBEDDING_DIM:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"Model dimension mismatch: expected {Config.EMBEDDING_DIM}, "
|
|
54
|
+
f"got {self._embedding_dim}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
print(f"Model loaded successfully. Dimensions: {self._embedding_dim}", file=sys.stderr)
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
raise RuntimeError(f"Failed to initialize embedding model: {e}")
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def embedding_dim(self) -> int:
|
|
64
|
+
"""Get embedding dimensions."""
|
|
65
|
+
if self._embedding_dim is None:
|
|
66
|
+
if self.model is None:
|
|
67
|
+
self._initialize_model()
|
|
68
|
+
return self._embedding_dim
|
|
69
|
+
return self._embedding_dim
|
|
70
|
+
|
|
71
|
+
def encode(self, texts: List[str], normalize: bool = True) -> np.ndarray:
|
|
72
|
+
"""
|
|
73
|
+
Generate embeddings for a list of texts.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
texts: List of text strings to encode
|
|
77
|
+
normalize: Whether to normalize embeddings to unit length
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
np.ndarray: Array of embeddings with shape (len(texts), embedding_dim)
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
SecurityError: If input validation fails
|
|
84
|
+
RuntimeError: If encoding fails
|
|
85
|
+
"""
|
|
86
|
+
if not isinstance(texts, list):
|
|
87
|
+
raise SecurityError("Input must be a list of strings")
|
|
88
|
+
|
|
89
|
+
if not texts:
|
|
90
|
+
raise SecurityError("Input list cannot be empty")
|
|
91
|
+
|
|
92
|
+
# Validate each text
|
|
93
|
+
for i, text in enumerate(texts):
|
|
94
|
+
if not isinstance(text, str):
|
|
95
|
+
raise SecurityError(f"Text at index {i} must be a string")
|
|
96
|
+
if not text.strip():
|
|
97
|
+
raise SecurityError(f"Text at index {i} cannot be empty")
|
|
98
|
+
|
|
99
|
+
# Initialize model if needed
|
|
100
|
+
if self.model is None:
|
|
101
|
+
self._initialize_model()
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
embeddings = self.model.encode(
|
|
105
|
+
texts,
|
|
106
|
+
normalize_embeddings=normalize,
|
|
107
|
+
convert_to_numpy=True
|
|
108
|
+
)
|
|
109
|
+
return embeddings
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise RuntimeError(f"Failed to generate embeddings: {e}")
|
|
113
|
+
|
|
114
|
+
def encode_single(self, text: str, normalize: bool = True) -> List[float]:
|
|
115
|
+
"""
|
|
116
|
+
Generate embedding for a single text.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
text: Text string to encode
|
|
120
|
+
normalize: Whether to normalize embedding to unit length
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
List[float]: Embedding vector as list of floats
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
SecurityError: If input validation fails
|
|
127
|
+
RuntimeError: If encoding fails
|
|
128
|
+
"""
|
|
129
|
+
embeddings = self.encode([text], normalize=normalize)
|
|
130
|
+
return embeddings[0].tolist()
|
|
131
|
+
|
|
132
|
+
def similarity(self, text1: str, text2: str) -> float:
|
|
133
|
+
"""
|
|
134
|
+
Calculate cosine similarity between two texts.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
text1: First text
|
|
138
|
+
text2: Second text
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
float: Cosine similarity score (0-1)
|
|
142
|
+
"""
|
|
143
|
+
embeddings = self.encode([text1, text2], normalize=True)
|
|
144
|
+
|
|
145
|
+
# Cosine similarity with normalized vectors is just dot product
|
|
146
|
+
similarity = np.dot(embeddings[0], embeddings[1])
|
|
147
|
+
return float(similarity)
|
|
148
|
+
|
|
149
|
+
def batch_similarity(self, query: str, texts: List[str]) -> List[float]:
|
|
150
|
+
"""
|
|
151
|
+
Calculate similarity between a query and multiple texts.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
query: Query text
|
|
155
|
+
texts: List of texts to compare against
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List[float]: Similarity scores for each text
|
|
159
|
+
"""
|
|
160
|
+
all_texts = [query] + texts
|
|
161
|
+
embeddings = self.encode(all_texts, normalize=True)
|
|
162
|
+
|
|
163
|
+
query_embedding = embeddings[0]
|
|
164
|
+
text_embeddings = embeddings[1:]
|
|
165
|
+
|
|
166
|
+
# Calculate dot products (cosine similarity with normalized vectors)
|
|
167
|
+
similarities = [
|
|
168
|
+
float(np.dot(query_embedding, text_emb))
|
|
169
|
+
for text_emb in text_embeddings
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
return similarities
|
|
173
|
+
|
|
174
|
+
def get_model_info(self) -> dict:
|
|
175
|
+
"""
|
|
176
|
+
Get information about the loaded model.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
dict: Model information
|
|
180
|
+
"""
|
|
181
|
+
if self.model is None:
|
|
182
|
+
self._initialize_model()
|
|
183
|
+
|
|
184
|
+
return {
|
|
185
|
+
"model_name": self.model_name,
|
|
186
|
+
"embedding_dimensions": self.embedding_dim,
|
|
187
|
+
"max_sequence_length": getattr(self.model, 'max_seq_length', 'Unknown'),
|
|
188
|
+
"device": str(self.model.device) if hasattr(self.model, 'device') else 'Unknown',
|
|
189
|
+
"cache_dir": self.cache_dir or 'Default'
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def validate_embedding(self, embedding: List[float]) -> bool:
|
|
193
|
+
"""
|
|
194
|
+
Validate that an embedding has the correct dimensions.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
embedding: Embedding vector to validate
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
bool: True if valid, False otherwise
|
|
201
|
+
"""
|
|
202
|
+
return (
|
|
203
|
+
isinstance(embedding, list) and
|
|
204
|
+
len(embedding) == self.embedding_dim and
|
|
205
|
+
all(isinstance(x, (int, float)) for x in embedding)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
# Global model instance for efficient reuse
|
|
210
|
+
_global_model: Optional[EmbeddingModel] = None
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def get_embedding_model(model_name: str = None, cache_dir: str = None) -> EmbeddingModel:
|
|
214
|
+
"""
|
|
215
|
+
Get global embedding model instance (singleton pattern).
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
model_name: Name of the model (only used on first call)
|
|
219
|
+
cache_dir: Cache directory (only used on first call)
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
EmbeddingModel: Global model instance
|
|
223
|
+
"""
|
|
224
|
+
global _global_model
|
|
225
|
+
|
|
226
|
+
if _global_model is None:
|
|
227
|
+
_global_model = EmbeddingModel(model_name, cache_dir)
|
|
228
|
+
|
|
229
|
+
return _global_model
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def reset_embedding_model() -> None:
|
|
233
|
+
"""Reset global model instance (useful for testing)."""
|
|
234
|
+
global _global_model
|
|
235
|
+
_global_model = None
|