loom-agent 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of loom-agent might be problematic. Click here for more details.
- loom/__init__.py +77 -0
- loom/agent.py +217 -0
- loom/agents/__init__.py +10 -0
- loom/agents/refs.py +28 -0
- loom/agents/registry.py +50 -0
- loom/builtin/compression/__init__.py +4 -0
- loom/builtin/compression/structured.py +79 -0
- loom/builtin/embeddings/__init__.py +9 -0
- loom/builtin/embeddings/openai_embedding.py +135 -0
- loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
- loom/builtin/llms/__init__.py +8 -0
- loom/builtin/llms/mock.py +34 -0
- loom/builtin/llms/openai.py +168 -0
- loom/builtin/llms/rule.py +102 -0
- loom/builtin/memory/__init__.py +5 -0
- loom/builtin/memory/in_memory.py +21 -0
- loom/builtin/memory/persistent_memory.py +278 -0
- loom/builtin/retriever/__init__.py +9 -0
- loom/builtin/retriever/chroma_store.py +265 -0
- loom/builtin/retriever/in_memory.py +106 -0
- loom/builtin/retriever/milvus_store.py +307 -0
- loom/builtin/retriever/pinecone_store.py +237 -0
- loom/builtin/retriever/qdrant_store.py +274 -0
- loom/builtin/retriever/vector_store.py +128 -0
- loom/builtin/retriever/vector_store_config.py +217 -0
- loom/builtin/tools/__init__.py +32 -0
- loom/builtin/tools/calculator.py +49 -0
- loom/builtin/tools/document_search.py +111 -0
- loom/builtin/tools/glob.py +27 -0
- loom/builtin/tools/grep.py +56 -0
- loom/builtin/tools/http_request.py +86 -0
- loom/builtin/tools/python_repl.py +73 -0
- loom/builtin/tools/read_file.py +32 -0
- loom/builtin/tools/task.py +158 -0
- loom/builtin/tools/web_search.py +64 -0
- loom/builtin/tools/write_file.py +31 -0
- loom/callbacks/base.py +9 -0
- loom/callbacks/logging.py +12 -0
- loom/callbacks/metrics.py +27 -0
- loom/callbacks/observability.py +248 -0
- loom/components/agent.py +107 -0
- loom/core/agent_executor.py +450 -0
- loom/core/circuit_breaker.py +178 -0
- loom/core/compression_manager.py +329 -0
- loom/core/context_retriever.py +185 -0
- loom/core/error_classifier.py +193 -0
- loom/core/errors.py +66 -0
- loom/core/message_queue.py +167 -0
- loom/core/permission_store.py +62 -0
- loom/core/permissions.py +69 -0
- loom/core/scheduler.py +125 -0
- loom/core/steering_control.py +47 -0
- loom/core/structured_logger.py +279 -0
- loom/core/subagent_pool.py +232 -0
- loom/core/system_prompt.py +141 -0
- loom/core/system_reminders.py +283 -0
- loom/core/tool_pipeline.py +113 -0
- loom/core/types.py +269 -0
- loom/interfaces/compressor.py +59 -0
- loom/interfaces/embedding.py +51 -0
- loom/interfaces/llm.py +33 -0
- loom/interfaces/memory.py +29 -0
- loom/interfaces/retriever.py +179 -0
- loom/interfaces/tool.py +27 -0
- loom/interfaces/vector_store.py +80 -0
- loom/llm/__init__.py +14 -0
- loom/llm/config.py +228 -0
- loom/llm/factory.py +111 -0
- loom/llm/model_health.py +235 -0
- loom/llm/model_pool_advanced.py +305 -0
- loom/llm/pool.py +170 -0
- loom/llm/registry.py +201 -0
- loom/mcp/__init__.py +4 -0
- loom/mcp/client.py +86 -0
- loom/mcp/registry.py +58 -0
- loom/mcp/tool_adapter.py +48 -0
- loom/observability/__init__.py +5 -0
- loom/patterns/__init__.py +5 -0
- loom/patterns/multi_agent.py +123 -0
- loom/patterns/rag.py +262 -0
- loom/plugins/registry.py +55 -0
- loom/resilience/__init__.py +5 -0
- loom/tooling.py +72 -0
- loom/utils/agent_loader.py +218 -0
- loom/utils/token_counter.py +19 -0
- loom_agent-0.0.1.dist-info/METADATA +457 -0
- loom_agent-0.0.1.dist-info/RECORD +89 -0
- loom_agent-0.0.1.dist-info/WHEEL +4 -0
- loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""US6: Three-Tier Memory System
|
|
2
|
+
|
|
3
|
+
Provides a practical memory system with automatic persistence for agent conversations.
|
|
4
|
+
|
|
5
|
+
Tiers:
|
|
6
|
+
1. Short-term: In-memory message array (current session)
|
|
7
|
+
2. Mid-term: Compression summaries with metadata (managed by CompressionManager)
|
|
8
|
+
3. Long-term: JSON file persistence for cross-session recall
|
|
9
|
+
|
|
10
|
+
Design goals:
|
|
11
|
+
- Simple API for developers
|
|
12
|
+
- Automatic backup and recovery
|
|
13
|
+
- Zero-config defaults with customization options
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import os
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import List, Optional
|
|
22
|
+
from datetime import datetime
|
|
23
|
+
import asyncio
|
|
24
|
+
|
|
25
|
+
from loom.core.types import Message
|
|
26
|
+
from loom.interfaces.memory import BaseMemory
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PersistentMemory(BaseMemory):
|
|
30
|
+
"""Three-tier memory with automatic persistence.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
# Simple usage - auto-creates .loom directory
|
|
34
|
+
memory = PersistentMemory()
|
|
35
|
+
|
|
36
|
+
# Custom persistence path
|
|
37
|
+
memory = PersistentMemory(persist_dir=".my_agent_memory")
|
|
38
|
+
|
|
39
|
+
# Disable persistence
|
|
40
|
+
memory = PersistentMemory(enable_persistence=False)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
persist_dir: str = ".loom",
|
|
46
|
+
session_id: Optional[str] = None,
|
|
47
|
+
enable_persistence: bool = True,
|
|
48
|
+
auto_backup: bool = True,
|
|
49
|
+
max_backup_files: int = 5,
|
|
50
|
+
):
|
|
51
|
+
"""Initialize persistent memory.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
persist_dir: Directory for persisting memory (default: .loom)
|
|
55
|
+
session_id: Session identifier (default: auto-generated timestamp)
|
|
56
|
+
enable_persistence: Enable file persistence (default: True)
|
|
57
|
+
auto_backup: Create backup before overwriting (default: True)
|
|
58
|
+
max_backup_files: Maximum backup files to keep (default: 5)
|
|
59
|
+
"""
|
|
60
|
+
self.persist_dir = Path(persist_dir)
|
|
61
|
+
self.session_id = session_id or datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
62
|
+
self.enable_persistence = enable_persistence
|
|
63
|
+
self.auto_backup = auto_backup
|
|
64
|
+
self.max_backup_files = max_backup_files
|
|
65
|
+
|
|
66
|
+
# Tier 1: Short-term (in-memory)
|
|
67
|
+
self._messages: List[Message] = []
|
|
68
|
+
|
|
69
|
+
# Tier 2: Mid-term (compression metadata - managed externally)
|
|
70
|
+
self._compression_metadata: List[dict] = []
|
|
71
|
+
|
|
72
|
+
# Setup persistence
|
|
73
|
+
if self.enable_persistence:
|
|
74
|
+
self._ensure_persist_dir()
|
|
75
|
+
self._load_from_disk()
|
|
76
|
+
|
|
77
|
+
self._lock = asyncio.Lock()
|
|
78
|
+
|
|
79
|
+
def _ensure_persist_dir(self) -> None:
|
|
80
|
+
"""Create persistence directory if it doesn't exist."""
|
|
81
|
+
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
|
82
|
+
|
|
83
|
+
def _get_memory_file(self) -> Path:
|
|
84
|
+
"""Get path to memory file."""
|
|
85
|
+
return self.persist_dir / f"session_{self.session_id}.json"
|
|
86
|
+
|
|
87
|
+
def _get_backup_file(self, index: int) -> Path:
|
|
88
|
+
"""Get path to backup file."""
|
|
89
|
+
return self.persist_dir / f"session_{self.session_id}.backup{index}.json"
|
|
90
|
+
|
|
91
|
+
def _load_from_disk(self) -> None:
|
|
92
|
+
"""Load memory from disk if exists."""
|
|
93
|
+
memory_file = self._get_memory_file()
|
|
94
|
+
if not memory_file.exists():
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
with open(memory_file, 'r', encoding='utf-8') as f:
|
|
99
|
+
data = json.load(f)
|
|
100
|
+
|
|
101
|
+
# Load messages
|
|
102
|
+
self._messages = [
|
|
103
|
+
Message(**msg_data) for msg_data in data.get('messages', [])
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
# Load compression metadata
|
|
107
|
+
self._compression_metadata = data.get('compression_metadata', [])
|
|
108
|
+
|
|
109
|
+
except Exception as e:
|
|
110
|
+
# Try to recover from backup
|
|
111
|
+
if self._recover_from_backup():
|
|
112
|
+
return
|
|
113
|
+
# If recovery fails, start fresh
|
|
114
|
+
print(f"Warning: Failed to load memory from disk: {e}")
|
|
115
|
+
self._messages = []
|
|
116
|
+
self._compression_metadata = []
|
|
117
|
+
|
|
118
|
+
def _save_to_disk(self) -> None:
|
|
119
|
+
"""Save memory to disk with optional backup."""
|
|
120
|
+
if not self.enable_persistence:
|
|
121
|
+
return
|
|
122
|
+
|
|
123
|
+
memory_file = self._get_memory_file()
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
# Create backup if file exists
|
|
127
|
+
if self.auto_backup and memory_file.exists():
|
|
128
|
+
self._create_backup()
|
|
129
|
+
|
|
130
|
+
# Save current state
|
|
131
|
+
data = {
|
|
132
|
+
'session_id': self.session_id,
|
|
133
|
+
'timestamp': datetime.now().isoformat(),
|
|
134
|
+
'messages': [self._message_to_dict(m) for m in self._messages],
|
|
135
|
+
'compression_metadata': self._compression_metadata,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
with open(memory_file, 'w', encoding='utf-8') as f:
|
|
139
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
140
|
+
|
|
141
|
+
except Exception as e:
|
|
142
|
+
print(f"Warning: Failed to save memory to disk: {e}")
|
|
143
|
+
|
|
144
|
+
def _message_to_dict(self, message: Message) -> dict:
|
|
145
|
+
"""Convert Message to JSON-serializable dict."""
|
|
146
|
+
return {
|
|
147
|
+
'role': message.role,
|
|
148
|
+
'content': message.content,
|
|
149
|
+
'tool_call_id': message.tool_call_id,
|
|
150
|
+
'metadata': message.metadata,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
def _create_backup(self) -> None:
|
|
154
|
+
"""Create backup of current memory file."""
|
|
155
|
+
memory_file = self._get_memory_file()
|
|
156
|
+
if not memory_file.exists():
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
# Rotate existing backups
|
|
160
|
+
for i in range(self.max_backup_files - 1, 0, -1):
|
|
161
|
+
old_backup = self._get_backup_file(i)
|
|
162
|
+
new_backup = self._get_backup_file(i + 1)
|
|
163
|
+
if old_backup.exists():
|
|
164
|
+
old_backup.rename(new_backup)
|
|
165
|
+
|
|
166
|
+
# Create new backup
|
|
167
|
+
backup_file = self._get_backup_file(1)
|
|
168
|
+
memory_file.rename(backup_file)
|
|
169
|
+
|
|
170
|
+
# Clean up old backups
|
|
171
|
+
self._cleanup_old_backups()
|
|
172
|
+
|
|
173
|
+
def _cleanup_old_backups(self) -> None:
|
|
174
|
+
"""Remove backups exceeding max_backup_files."""
|
|
175
|
+
for i in range(self.max_backup_files + 1, self.max_backup_files + 10):
|
|
176
|
+
backup_file = self._get_backup_file(i)
|
|
177
|
+
if backup_file.exists():
|
|
178
|
+
backup_file.unlink()
|
|
179
|
+
|
|
180
|
+
def _recover_from_backup(self) -> bool:
|
|
181
|
+
"""Attempt to recover from most recent backup.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
True if recovery successful, False otherwise
|
|
185
|
+
"""
|
|
186
|
+
for i in range(1, self.max_backup_files + 1):
|
|
187
|
+
backup_file = self._get_backup_file(i)
|
|
188
|
+
if not backup_file.exists():
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
try:
|
|
192
|
+
with open(backup_file, 'r', encoding='utf-8') as f:
|
|
193
|
+
data = json.load(f)
|
|
194
|
+
|
|
195
|
+
self._messages = [
|
|
196
|
+
Message(**msg_data) for msg_data in data.get('messages', [])
|
|
197
|
+
]
|
|
198
|
+
self._compression_metadata = data.get('compression_metadata', [])
|
|
199
|
+
|
|
200
|
+
print(f"Successfully recovered from backup {i}")
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
except Exception as e:
|
|
204
|
+
print(f"Failed to recover from backup {i}: {e}")
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
return False
|
|
208
|
+
|
|
209
|
+
async def add_message(self, message: Message) -> None:
|
|
210
|
+
"""Add message to memory and persist."""
|
|
211
|
+
async with self._lock:
|
|
212
|
+
self._messages.append(message)
|
|
213
|
+
self._save_to_disk()
|
|
214
|
+
|
|
215
|
+
async def get_messages(self) -> List[Message]:
|
|
216
|
+
"""Get all messages from memory."""
|
|
217
|
+
async with self._lock:
|
|
218
|
+
return self._messages.copy()
|
|
219
|
+
|
|
220
|
+
async def clear(self) -> None:
|
|
221
|
+
"""Clear all messages from memory."""
|
|
222
|
+
async with self._lock:
|
|
223
|
+
self._messages.clear()
|
|
224
|
+
self._compression_metadata.clear()
|
|
225
|
+
self._save_to_disk()
|
|
226
|
+
|
|
227
|
+
async def set_messages(self, messages: List[Message]) -> None:
|
|
228
|
+
"""Replace all messages in memory.
|
|
229
|
+
|
|
230
|
+
Used by CompressionManager when compressing history.
|
|
231
|
+
"""
|
|
232
|
+
async with self._lock:
|
|
233
|
+
self._messages = messages.copy()
|
|
234
|
+
self._save_to_disk()
|
|
235
|
+
|
|
236
|
+
def add_compression_metadata(self, metadata: dict) -> None:
|
|
237
|
+
"""Add compression metadata (Tier 2).
|
|
238
|
+
|
|
239
|
+
Called by CompressionManager to track compression events.
|
|
240
|
+
"""
|
|
241
|
+
self._compression_metadata.append({
|
|
242
|
+
'timestamp': datetime.now().isoformat(),
|
|
243
|
+
**metadata
|
|
244
|
+
})
|
|
245
|
+
self._save_to_disk()
|
|
246
|
+
|
|
247
|
+
def get_compression_history(self) -> List[dict]:
|
|
248
|
+
"""Get compression history metadata."""
|
|
249
|
+
return self._compression_metadata.copy()
|
|
250
|
+
|
|
251
|
+
def get_persistence_info(self) -> dict:
|
|
252
|
+
"""Get information about persistence state.
|
|
253
|
+
|
|
254
|
+
Useful for debugging and monitoring.
|
|
255
|
+
"""
|
|
256
|
+
memory_file = self._get_memory_file()
|
|
257
|
+
|
|
258
|
+
backup_files = []
|
|
259
|
+
for i in range(1, self.max_backup_files + 1):
|
|
260
|
+
backup = self._get_backup_file(i)
|
|
261
|
+
if backup.exists():
|
|
262
|
+
backup_files.append({
|
|
263
|
+
'index': i,
|
|
264
|
+
'path': str(backup),
|
|
265
|
+
'size_bytes': backup.stat().st_size,
|
|
266
|
+
'modified': datetime.fromtimestamp(backup.stat().st_mtime).isoformat(),
|
|
267
|
+
})
|
|
268
|
+
|
|
269
|
+
return {
|
|
270
|
+
'enabled': self.enable_persistence,
|
|
271
|
+
'session_id': self.session_id,
|
|
272
|
+
'persist_dir': str(self.persist_dir),
|
|
273
|
+
'memory_file': str(memory_file),
|
|
274
|
+
'memory_file_exists': memory_file.exists(),
|
|
275
|
+
'message_count': len(self._messages),
|
|
276
|
+
'compression_event_count': len(self._compression_metadata),
|
|
277
|
+
'backups': backup_files,
|
|
278
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""内置检索器实现"""
|
|
2
|
+
|
|
3
|
+
from loom.builtin.retriever.in_memory import InMemoryRetriever
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from loom.builtin.retriever.vector_store import VectorStoreRetriever
|
|
7
|
+
__all__ = ["InMemoryRetriever", "VectorStoreRetriever"]
|
|
8
|
+
except ImportError:
|
|
9
|
+
__all__ = ["InMemoryRetriever"]
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""ChromaDB 向量存储适配器"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
import uuid
|
|
7
|
+
|
|
8
|
+
from loom.interfaces.retriever import Document
|
|
9
|
+
from loom.interfaces.vector_store import BaseVectorStore
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import chromadb
|
|
13
|
+
from chromadb.config import Settings
|
|
14
|
+
CHROMA_AVAILABLE = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
CHROMA_AVAILABLE = False
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ChromaVectorStore(BaseVectorStore):
|
|
20
|
+
"""
|
|
21
|
+
ChromaDB 向量存储适配器
|
|
22
|
+
|
|
23
|
+
特点:
|
|
24
|
+
- ✅ 开源嵌入式向量数据库
|
|
25
|
+
- ✅ 极简 API
|
|
26
|
+
- ✅ 支持本地持久化
|
|
27
|
+
- ✅ 自带 Embedding 功能(可选)
|
|
28
|
+
- ✅ 适合快速原型开发
|
|
29
|
+
|
|
30
|
+
示例:
|
|
31
|
+
from loom.builtin.retriever.chroma_store import ChromaVectorStore
|
|
32
|
+
from loom.builtin.retriever.vector_store_config import ChromaConfig
|
|
33
|
+
|
|
34
|
+
# 本地持久化模式
|
|
35
|
+
config = ChromaConfig.create_local(
|
|
36
|
+
persist_directory="./chroma_db",
|
|
37
|
+
collection_name="loom_docs"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# 远程服务模式
|
|
41
|
+
config = ChromaConfig.create_remote(
|
|
42
|
+
host="localhost",
|
|
43
|
+
port=8000,
|
|
44
|
+
collection_name="loom_docs"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
vector_store = ChromaVectorStore(config)
|
|
48
|
+
await vector_store.initialize()
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, config: Dict[str, Any] | Any):
|
|
52
|
+
"""
|
|
53
|
+
Parameters:
|
|
54
|
+
config: ChromaConfig 对象或配置字典
|
|
55
|
+
"""
|
|
56
|
+
if not CHROMA_AVAILABLE:
|
|
57
|
+
raise ImportError(
|
|
58
|
+
"ChromaDB is not installed. "
|
|
59
|
+
"Install with: pip install chromadb"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# 支持字典或 Pydantic 模型
|
|
63
|
+
if hasattr(config, "model_dump"):
|
|
64
|
+
self.config = config.model_dump()
|
|
65
|
+
else:
|
|
66
|
+
self.config = config
|
|
67
|
+
|
|
68
|
+
self.collection_name = self.config.get("collection_name", "loom_documents")
|
|
69
|
+
self.dimension = self.config.get("dimension", 1536)
|
|
70
|
+
self.client_type = self.config.get("client_type", "local")
|
|
71
|
+
self.persist_directory = self.config.get("persist_directory")
|
|
72
|
+
self.host = self.config.get("host")
|
|
73
|
+
self.port = self.config.get("port", 8000)
|
|
74
|
+
|
|
75
|
+
self.client: Optional[Any] = None
|
|
76
|
+
self.collection: Optional[Any] = None
|
|
77
|
+
self._initialized = False
|
|
78
|
+
|
|
79
|
+
async def initialize(self) -> None:
|
|
80
|
+
"""初始化 ChromaDB 连接和集合"""
|
|
81
|
+
if self._initialized:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
# 初始化客户端
|
|
85
|
+
if self.client_type == "local":
|
|
86
|
+
# 本地持久化模式
|
|
87
|
+
if self.persist_directory:
|
|
88
|
+
self.client = chromadb.PersistentClient(
|
|
89
|
+
path=self.persist_directory
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
# 内存模式(不持久化)
|
|
93
|
+
self.client = chromadb.Client()
|
|
94
|
+
else:
|
|
95
|
+
# 远程 HTTP 模式
|
|
96
|
+
self.client = chromadb.HttpClient(
|
|
97
|
+
host=self.host,
|
|
98
|
+
port=self.port
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# 获取或创建集合
|
|
102
|
+
try:
|
|
103
|
+
self.collection = self.client.get_collection(
|
|
104
|
+
name=self.collection_name
|
|
105
|
+
)
|
|
106
|
+
except Exception:
|
|
107
|
+
# 集合不存在,创建新集合
|
|
108
|
+
self.collection = self.client.create_collection(
|
|
109
|
+
name=self.collection_name,
|
|
110
|
+
metadata={"dimension": self.dimension}
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self._initialized = True
|
|
114
|
+
|
|
115
|
+
async def add_vectors(
|
|
116
|
+
self,
|
|
117
|
+
vectors: List[List[float]],
|
|
118
|
+
documents: List[Document]
|
|
119
|
+
) -> None:
|
|
120
|
+
"""
|
|
121
|
+
添加向量到 ChromaDB
|
|
122
|
+
|
|
123
|
+
Parameters:
|
|
124
|
+
vectors: 向量列表
|
|
125
|
+
documents: 对应的文档列表
|
|
126
|
+
"""
|
|
127
|
+
if not self._initialized:
|
|
128
|
+
await self.initialize()
|
|
129
|
+
|
|
130
|
+
# 构建 ChromaDB 数据格式
|
|
131
|
+
ids = []
|
|
132
|
+
embeddings = []
|
|
133
|
+
documents_text = []
|
|
134
|
+
metadatas = []
|
|
135
|
+
|
|
136
|
+
for i, (vector, doc) in enumerate(zip(vectors, documents)):
|
|
137
|
+
# 生成或使用文档 ID
|
|
138
|
+
doc_id = doc.doc_id or str(uuid.uuid4())
|
|
139
|
+
|
|
140
|
+
ids.append(doc_id)
|
|
141
|
+
embeddings.append(vector)
|
|
142
|
+
documents_text.append(doc.content)
|
|
143
|
+
|
|
144
|
+
# 构建元数据
|
|
145
|
+
metadata = doc.metadata or {}
|
|
146
|
+
if doc.score is not None:
|
|
147
|
+
metadata["score"] = doc.score
|
|
148
|
+
metadatas.append(metadata)
|
|
149
|
+
|
|
150
|
+
# 批量添加(ChromaDB 自动处理批量)
|
|
151
|
+
self.collection.add(
|
|
152
|
+
ids=ids,
|
|
153
|
+
embeddings=embeddings,
|
|
154
|
+
documents=documents_text,
|
|
155
|
+
metadatas=metadatas
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
async def search(
|
|
159
|
+
self,
|
|
160
|
+
query_vector: List[float],
|
|
161
|
+
top_k: int = 5,
|
|
162
|
+
filters: Optional[Dict[str, Any]] = None
|
|
163
|
+
) -> List[Tuple[Document, float]]:
|
|
164
|
+
"""
|
|
165
|
+
搜索相似向量
|
|
166
|
+
|
|
167
|
+
Parameters:
|
|
168
|
+
query_vector: 查询向量
|
|
169
|
+
top_k: 返回结果数量
|
|
170
|
+
filters: 元数据过滤条件
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
[(Document, score), ...] 列表
|
|
174
|
+
"""
|
|
175
|
+
if not self._initialized:
|
|
176
|
+
await self.initialize()
|
|
177
|
+
|
|
178
|
+
# 构建 where 过滤条件(ChromaDB 格式)
|
|
179
|
+
where = None
|
|
180
|
+
if filters:
|
|
181
|
+
where = self._build_chroma_filter(filters)
|
|
182
|
+
|
|
183
|
+
# 执行查询
|
|
184
|
+
results = self.collection.query(
|
|
185
|
+
query_embeddings=[query_vector],
|
|
186
|
+
n_results=top_k,
|
|
187
|
+
where=where,
|
|
188
|
+
include=["documents", "metadatas", "distances"]
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# 转换结果
|
|
192
|
+
documents_with_scores = []
|
|
193
|
+
|
|
194
|
+
if results["ids"]:
|
|
195
|
+
for i, doc_id in enumerate(results["ids"][0]):
|
|
196
|
+
content = results["documents"][0][i]
|
|
197
|
+
metadata = results["metadatas"][0][i] or {}
|
|
198
|
+
distance = results["distances"][0][i]
|
|
199
|
+
|
|
200
|
+
# 转换距离到相似度分数(ChromaDB 返回的是距离)
|
|
201
|
+
# 对于余弦距离,相似度 = 1 - distance
|
|
202
|
+
score = 1.0 - distance
|
|
203
|
+
|
|
204
|
+
doc = Document(
|
|
205
|
+
content=content,
|
|
206
|
+
metadata=metadata,
|
|
207
|
+
score=score,
|
|
208
|
+
doc_id=doc_id
|
|
209
|
+
)
|
|
210
|
+
documents_with_scores.append((doc, score))
|
|
211
|
+
|
|
212
|
+
return documents_with_scores
|
|
213
|
+
|
|
214
|
+
async def delete(self, doc_ids: List[str]) -> None:
|
|
215
|
+
"""
|
|
216
|
+
删除文档
|
|
217
|
+
|
|
218
|
+
Parameters:
|
|
219
|
+
doc_ids: 文档 ID 列表
|
|
220
|
+
"""
|
|
221
|
+
if not self._initialized:
|
|
222
|
+
await self.initialize()
|
|
223
|
+
|
|
224
|
+
self.collection.delete(ids=doc_ids)
|
|
225
|
+
|
|
226
|
+
async def clear(self) -> None:
|
|
227
|
+
"""清空集合"""
|
|
228
|
+
if not self._initialized:
|
|
229
|
+
await self.initialize()
|
|
230
|
+
|
|
231
|
+
# ChromaDB: 删除并重建集合
|
|
232
|
+
self.client.delete_collection(name=self.collection_name)
|
|
233
|
+
self.collection = self.client.create_collection(
|
|
234
|
+
name=self.collection_name,
|
|
235
|
+
metadata={"dimension": self.dimension}
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
def _build_chroma_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
|
|
239
|
+
"""
|
|
240
|
+
构建 ChromaDB where 过滤器
|
|
241
|
+
|
|
242
|
+
ChromaDB 过滤语法:
|
|
243
|
+
{
|
|
244
|
+
"field": "value",
|
|
245
|
+
"numeric_field": {"$gte": 10}
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
支持的操作符: $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin
|
|
249
|
+
"""
|
|
250
|
+
where = {}
|
|
251
|
+
|
|
252
|
+
for key, value in filters.items():
|
|
253
|
+
if isinstance(value, dict):
|
|
254
|
+
# 复杂查询(例如: {"price": {"$gte": 100}})
|
|
255
|
+
where[key] = value
|
|
256
|
+
else:
|
|
257
|
+
# 简单相等查询
|
|
258
|
+
where[key] = value
|
|
259
|
+
|
|
260
|
+
return where
|
|
261
|
+
|
|
262
|
+
async def close(self) -> None:
|
|
263
|
+
"""关闭连接"""
|
|
264
|
+
# ChromaDB 客户端自动管理连接
|
|
265
|
+
self._initialized = False
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""简单的内存检索器 - 无需外部依赖"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from loom.interfaces.retriever import BaseRetriever, Document
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InMemoryRetriever(BaseRetriever):
|
|
11
|
+
"""
|
|
12
|
+
简单的内存检索器 - 基于关键词匹配
|
|
13
|
+
|
|
14
|
+
特点:
|
|
15
|
+
- 无需外部依赖
|
|
16
|
+
- 适用于开发/测试
|
|
17
|
+
- 基于简单的关键词匹配 (非向量检索)
|
|
18
|
+
|
|
19
|
+
示例:
|
|
20
|
+
retriever = InMemoryRetriever()
|
|
21
|
+
await retriever.add_texts([
|
|
22
|
+
"Python is a programming language",
|
|
23
|
+
"JavaScript is used for web development"
|
|
24
|
+
])
|
|
25
|
+
|
|
26
|
+
docs = await retriever.retrieve("programming")
|
|
27
|
+
# 返回包含 "programming" 的文档
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self.documents: List[Document] = []
|
|
32
|
+
|
|
33
|
+
async def retrieve(
|
|
34
|
+
self,
|
|
35
|
+
query: str,
|
|
36
|
+
top_k: int = 5,
|
|
37
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
38
|
+
) -> List[Document]:
|
|
39
|
+
"""
|
|
40
|
+
基于关键词匹配检索文档
|
|
41
|
+
|
|
42
|
+
简单算法:
|
|
43
|
+
1. 将查询分词
|
|
44
|
+
2. 计算每个文档包含的关键词数量
|
|
45
|
+
3. 按匹配度排序
|
|
46
|
+
"""
|
|
47
|
+
if not self.documents:
|
|
48
|
+
return []
|
|
49
|
+
|
|
50
|
+
# 简单分词 (按空格)
|
|
51
|
+
query_terms = set(query.lower().split())
|
|
52
|
+
|
|
53
|
+
# 计算每个文档的匹配分数
|
|
54
|
+
scored_docs = []
|
|
55
|
+
for doc in self.documents:
|
|
56
|
+
# 应用过滤器
|
|
57
|
+
if filters and not self._match_filters(doc, filters):
|
|
58
|
+
continue
|
|
59
|
+
|
|
60
|
+
# 计算匹配分数
|
|
61
|
+
doc_terms = set(doc.content.lower().split())
|
|
62
|
+
matches = query_terms.intersection(doc_terms)
|
|
63
|
+
score = len(matches) / len(query_terms) if query_terms else 0.0
|
|
64
|
+
|
|
65
|
+
if score > 0:
|
|
66
|
+
# 创建副本并设置分数
|
|
67
|
+
doc_with_score = Document(
|
|
68
|
+
content=doc.content,
|
|
69
|
+
metadata=doc.metadata,
|
|
70
|
+
score=score,
|
|
71
|
+
doc_id=doc.doc_id
|
|
72
|
+
)
|
|
73
|
+
scored_docs.append(doc_with_score)
|
|
74
|
+
|
|
75
|
+
# 按分数排序
|
|
76
|
+
scored_docs.sort(key=lambda d: d.score or 0, reverse=True)
|
|
77
|
+
|
|
78
|
+
return scored_docs[:top_k]
|
|
79
|
+
|
|
80
|
+
async def add_documents(self, documents: List[Document]) -> None:
|
|
81
|
+
"""添加文档到内存"""
|
|
82
|
+
for doc in documents:
|
|
83
|
+
# 分配 ID (如果没有)
|
|
84
|
+
if doc.doc_id is None:
|
|
85
|
+
doc.doc_id = str(len(self.documents))
|
|
86
|
+
|
|
87
|
+
self.documents.append(doc)
|
|
88
|
+
|
|
89
|
+
def _match_filters(self, doc: Document, filters: Dict[str, Any]) -> bool:
|
|
90
|
+
"""检查文档是否匹配过滤条件"""
|
|
91
|
+
if not doc.metadata:
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
for key, value in filters.items():
|
|
95
|
+
if doc.metadata.get(key) != value:
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
def clear(self) -> None:
|
|
101
|
+
"""清空所有文档"""
|
|
102
|
+
self.documents.clear()
|
|
103
|
+
|
|
104
|
+
def __len__(self) -> int:
|
|
105
|
+
"""返回文档数量"""
|
|
106
|
+
return len(self.documents)
|