loom-agent 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of loom-agent might be problematic. Click here for more details.

Files changed (89) hide show
  1. loom/__init__.py +77 -0
  2. loom/agent.py +217 -0
  3. loom/agents/__init__.py +10 -0
  4. loom/agents/refs.py +28 -0
  5. loom/agents/registry.py +50 -0
  6. loom/builtin/compression/__init__.py +4 -0
  7. loom/builtin/compression/structured.py +79 -0
  8. loom/builtin/embeddings/__init__.py +9 -0
  9. loom/builtin/embeddings/openai_embedding.py +135 -0
  10. loom/builtin/embeddings/sentence_transformers_embedding.py +145 -0
  11. loom/builtin/llms/__init__.py +8 -0
  12. loom/builtin/llms/mock.py +34 -0
  13. loom/builtin/llms/openai.py +168 -0
  14. loom/builtin/llms/rule.py +102 -0
  15. loom/builtin/memory/__init__.py +5 -0
  16. loom/builtin/memory/in_memory.py +21 -0
  17. loom/builtin/memory/persistent_memory.py +278 -0
  18. loom/builtin/retriever/__init__.py +9 -0
  19. loom/builtin/retriever/chroma_store.py +265 -0
  20. loom/builtin/retriever/in_memory.py +106 -0
  21. loom/builtin/retriever/milvus_store.py +307 -0
  22. loom/builtin/retriever/pinecone_store.py +237 -0
  23. loom/builtin/retriever/qdrant_store.py +274 -0
  24. loom/builtin/retriever/vector_store.py +128 -0
  25. loom/builtin/retriever/vector_store_config.py +217 -0
  26. loom/builtin/tools/__init__.py +32 -0
  27. loom/builtin/tools/calculator.py +49 -0
  28. loom/builtin/tools/document_search.py +111 -0
  29. loom/builtin/tools/glob.py +27 -0
  30. loom/builtin/tools/grep.py +56 -0
  31. loom/builtin/tools/http_request.py +86 -0
  32. loom/builtin/tools/python_repl.py +73 -0
  33. loom/builtin/tools/read_file.py +32 -0
  34. loom/builtin/tools/task.py +158 -0
  35. loom/builtin/tools/web_search.py +64 -0
  36. loom/builtin/tools/write_file.py +31 -0
  37. loom/callbacks/base.py +9 -0
  38. loom/callbacks/logging.py +12 -0
  39. loom/callbacks/metrics.py +27 -0
  40. loom/callbacks/observability.py +248 -0
  41. loom/components/agent.py +107 -0
  42. loom/core/agent_executor.py +450 -0
  43. loom/core/circuit_breaker.py +178 -0
  44. loom/core/compression_manager.py +329 -0
  45. loom/core/context_retriever.py +185 -0
  46. loom/core/error_classifier.py +193 -0
  47. loom/core/errors.py +66 -0
  48. loom/core/message_queue.py +167 -0
  49. loom/core/permission_store.py +62 -0
  50. loom/core/permissions.py +69 -0
  51. loom/core/scheduler.py +125 -0
  52. loom/core/steering_control.py +47 -0
  53. loom/core/structured_logger.py +279 -0
  54. loom/core/subagent_pool.py +232 -0
  55. loom/core/system_prompt.py +141 -0
  56. loom/core/system_reminders.py +283 -0
  57. loom/core/tool_pipeline.py +113 -0
  58. loom/core/types.py +269 -0
  59. loom/interfaces/compressor.py +59 -0
  60. loom/interfaces/embedding.py +51 -0
  61. loom/interfaces/llm.py +33 -0
  62. loom/interfaces/memory.py +29 -0
  63. loom/interfaces/retriever.py +179 -0
  64. loom/interfaces/tool.py +27 -0
  65. loom/interfaces/vector_store.py +80 -0
  66. loom/llm/__init__.py +14 -0
  67. loom/llm/config.py +228 -0
  68. loom/llm/factory.py +111 -0
  69. loom/llm/model_health.py +235 -0
  70. loom/llm/model_pool_advanced.py +305 -0
  71. loom/llm/pool.py +170 -0
  72. loom/llm/registry.py +201 -0
  73. loom/mcp/__init__.py +4 -0
  74. loom/mcp/client.py +86 -0
  75. loom/mcp/registry.py +58 -0
  76. loom/mcp/tool_adapter.py +48 -0
  77. loom/observability/__init__.py +5 -0
  78. loom/patterns/__init__.py +5 -0
  79. loom/patterns/multi_agent.py +123 -0
  80. loom/patterns/rag.py +262 -0
  81. loom/plugins/registry.py +55 -0
  82. loom/resilience/__init__.py +5 -0
  83. loom/tooling.py +72 -0
  84. loom/utils/agent_loader.py +218 -0
  85. loom/utils/token_counter.py +19 -0
  86. loom_agent-0.0.1.dist-info/METADATA +457 -0
  87. loom_agent-0.0.1.dist-info/RECORD +89 -0
  88. loom_agent-0.0.1.dist-info/WHEEL +4 -0
  89. loom_agent-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,278 @@
1
+ """US6: Three-Tier Memory System
2
+
3
+ Provides a practical memory system with automatic persistence for agent conversations.
4
+
5
+ Tiers:
6
+ 1. Short-term: In-memory message array (current session)
7
+ 2. Mid-term: Compression summaries with metadata (managed by CompressionManager)
8
+ 3. Long-term: JSON file persistence for cross-session recall
9
+
10
+ Design goals:
11
+ - Simple API for developers
12
+ - Automatic backup and recovery
13
+ - Zero-config defaults with customization options
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ from pathlib import Path
21
+ from typing import List, Optional
22
+ from datetime import datetime
23
+ import asyncio
24
+
25
+ from loom.core.types import Message
26
+ from loom.interfaces.memory import BaseMemory
27
+
28
+
29
+ class PersistentMemory(BaseMemory):
30
+ """Three-tier memory with automatic persistence.
31
+
32
+ Example:
33
+ # Simple usage - auto-creates .loom directory
34
+ memory = PersistentMemory()
35
+
36
+ # Custom persistence path
37
+ memory = PersistentMemory(persist_dir=".my_agent_memory")
38
+
39
+ # Disable persistence
40
+ memory = PersistentMemory(enable_persistence=False)
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ persist_dir: str = ".loom",
46
+ session_id: Optional[str] = None,
47
+ enable_persistence: bool = True,
48
+ auto_backup: bool = True,
49
+ max_backup_files: int = 5,
50
+ ):
51
+ """Initialize persistent memory.
52
+
53
+ Args:
54
+ persist_dir: Directory for persisting memory (default: .loom)
55
+ session_id: Session identifier (default: auto-generated timestamp)
56
+ enable_persistence: Enable file persistence (default: True)
57
+ auto_backup: Create backup before overwriting (default: True)
58
+ max_backup_files: Maximum backup files to keep (default: 5)
59
+ """
60
+ self.persist_dir = Path(persist_dir)
61
+ self.session_id = session_id or datetime.now().strftime("%Y%m%d_%H%M%S")
62
+ self.enable_persistence = enable_persistence
63
+ self.auto_backup = auto_backup
64
+ self.max_backup_files = max_backup_files
65
+
66
+ # Tier 1: Short-term (in-memory)
67
+ self._messages: List[Message] = []
68
+
69
+ # Tier 2: Mid-term (compression metadata - managed externally)
70
+ self._compression_metadata: List[dict] = []
71
+
72
+ # Setup persistence
73
+ if self.enable_persistence:
74
+ self._ensure_persist_dir()
75
+ self._load_from_disk()
76
+
77
+ self._lock = asyncio.Lock()
78
+
79
+ def _ensure_persist_dir(self) -> None:
80
+ """Create persistence directory if it doesn't exist."""
81
+ self.persist_dir.mkdir(parents=True, exist_ok=True)
82
+
83
+ def _get_memory_file(self) -> Path:
84
+ """Get path to memory file."""
85
+ return self.persist_dir / f"session_{self.session_id}.json"
86
+
87
+ def _get_backup_file(self, index: int) -> Path:
88
+ """Get path to backup file."""
89
+ return self.persist_dir / f"session_{self.session_id}.backup{index}.json"
90
+
91
+ def _load_from_disk(self) -> None:
92
+ """Load memory from disk if exists."""
93
+ memory_file = self._get_memory_file()
94
+ if not memory_file.exists():
95
+ return
96
+
97
+ try:
98
+ with open(memory_file, 'r', encoding='utf-8') as f:
99
+ data = json.load(f)
100
+
101
+ # Load messages
102
+ self._messages = [
103
+ Message(**msg_data) for msg_data in data.get('messages', [])
104
+ ]
105
+
106
+ # Load compression metadata
107
+ self._compression_metadata = data.get('compression_metadata', [])
108
+
109
+ except Exception as e:
110
+ # Try to recover from backup
111
+ if self._recover_from_backup():
112
+ return
113
+ # If recovery fails, start fresh
114
+ print(f"Warning: Failed to load memory from disk: {e}")
115
+ self._messages = []
116
+ self._compression_metadata = []
117
+
118
+ def _save_to_disk(self) -> None:
119
+ """Save memory to disk with optional backup."""
120
+ if not self.enable_persistence:
121
+ return
122
+
123
+ memory_file = self._get_memory_file()
124
+
125
+ try:
126
+ # Create backup if file exists
127
+ if self.auto_backup and memory_file.exists():
128
+ self._create_backup()
129
+
130
+ # Save current state
131
+ data = {
132
+ 'session_id': self.session_id,
133
+ 'timestamp': datetime.now().isoformat(),
134
+ 'messages': [self._message_to_dict(m) for m in self._messages],
135
+ 'compression_metadata': self._compression_metadata,
136
+ }
137
+
138
+ with open(memory_file, 'w', encoding='utf-8') as f:
139
+ json.dump(data, f, indent=2, ensure_ascii=False)
140
+
141
+ except Exception as e:
142
+ print(f"Warning: Failed to save memory to disk: {e}")
143
+
144
+ def _message_to_dict(self, message: Message) -> dict:
145
+ """Convert Message to JSON-serializable dict."""
146
+ return {
147
+ 'role': message.role,
148
+ 'content': message.content,
149
+ 'tool_call_id': message.tool_call_id,
150
+ 'metadata': message.metadata,
151
+ }
152
+
153
+ def _create_backup(self) -> None:
154
+ """Create backup of current memory file."""
155
+ memory_file = self._get_memory_file()
156
+ if not memory_file.exists():
157
+ return
158
+
159
+ # Rotate existing backups
160
+ for i in range(self.max_backup_files - 1, 0, -1):
161
+ old_backup = self._get_backup_file(i)
162
+ new_backup = self._get_backup_file(i + 1)
163
+ if old_backup.exists():
164
+ old_backup.rename(new_backup)
165
+
166
+ # Create new backup
167
+ backup_file = self._get_backup_file(1)
168
+ memory_file.rename(backup_file)
169
+
170
+ # Clean up old backups
171
+ self._cleanup_old_backups()
172
+
173
+ def _cleanup_old_backups(self) -> None:
174
+ """Remove backups exceeding max_backup_files."""
175
+ for i in range(self.max_backup_files + 1, self.max_backup_files + 10):
176
+ backup_file = self._get_backup_file(i)
177
+ if backup_file.exists():
178
+ backup_file.unlink()
179
+
180
+ def _recover_from_backup(self) -> bool:
181
+ """Attempt to recover from most recent backup.
182
+
183
+ Returns:
184
+ True if recovery successful, False otherwise
185
+ """
186
+ for i in range(1, self.max_backup_files + 1):
187
+ backup_file = self._get_backup_file(i)
188
+ if not backup_file.exists():
189
+ continue
190
+
191
+ try:
192
+ with open(backup_file, 'r', encoding='utf-8') as f:
193
+ data = json.load(f)
194
+
195
+ self._messages = [
196
+ Message(**msg_data) for msg_data in data.get('messages', [])
197
+ ]
198
+ self._compression_metadata = data.get('compression_metadata', [])
199
+
200
+ print(f"Successfully recovered from backup {i}")
201
+ return True
202
+
203
+ except Exception as e:
204
+ print(f"Failed to recover from backup {i}: {e}")
205
+ continue
206
+
207
+ return False
208
+
209
+ async def add_message(self, message: Message) -> None:
210
+ """Add message to memory and persist."""
211
+ async with self._lock:
212
+ self._messages.append(message)
213
+ self._save_to_disk()
214
+
215
+ async def get_messages(self) -> List[Message]:
216
+ """Get all messages from memory."""
217
+ async with self._lock:
218
+ return self._messages.copy()
219
+
220
+ async def clear(self) -> None:
221
+ """Clear all messages from memory."""
222
+ async with self._lock:
223
+ self._messages.clear()
224
+ self._compression_metadata.clear()
225
+ self._save_to_disk()
226
+
227
+ async def set_messages(self, messages: List[Message]) -> None:
228
+ """Replace all messages in memory.
229
+
230
+ Used by CompressionManager when compressing history.
231
+ """
232
+ async with self._lock:
233
+ self._messages = messages.copy()
234
+ self._save_to_disk()
235
+
236
+ def add_compression_metadata(self, metadata: dict) -> None:
237
+ """Add compression metadata (Tier 2).
238
+
239
+ Called by CompressionManager to track compression events.
240
+ """
241
+ self._compression_metadata.append({
242
+ 'timestamp': datetime.now().isoformat(),
243
+ **metadata
244
+ })
245
+ self._save_to_disk()
246
+
247
+ def get_compression_history(self) -> List[dict]:
248
+ """Get compression history metadata."""
249
+ return self._compression_metadata.copy()
250
+
251
+ def get_persistence_info(self) -> dict:
252
+ """Get information about persistence state.
253
+
254
+ Useful for debugging and monitoring.
255
+ """
256
+ memory_file = self._get_memory_file()
257
+
258
+ backup_files = []
259
+ for i in range(1, self.max_backup_files + 1):
260
+ backup = self._get_backup_file(i)
261
+ if backup.exists():
262
+ backup_files.append({
263
+ 'index': i,
264
+ 'path': str(backup),
265
+ 'size_bytes': backup.stat().st_size,
266
+ 'modified': datetime.fromtimestamp(backup.stat().st_mtime).isoformat(),
267
+ })
268
+
269
+ return {
270
+ 'enabled': self.enable_persistence,
271
+ 'session_id': self.session_id,
272
+ 'persist_dir': str(self.persist_dir),
273
+ 'memory_file': str(memory_file),
274
+ 'memory_file_exists': memory_file.exists(),
275
+ 'message_count': len(self._messages),
276
+ 'compression_event_count': len(self._compression_metadata),
277
+ 'backups': backup_files,
278
+ }
@@ -0,0 +1,9 @@
1
+ """内置检索器实现"""
2
+
3
+ from loom.builtin.retriever.in_memory import InMemoryRetriever
4
+
5
+ try:
6
+ from loom.builtin.retriever.vector_store import VectorStoreRetriever
7
+ __all__ = ["InMemoryRetriever", "VectorStoreRetriever"]
8
+ except ImportError:
9
+ __all__ = ["InMemoryRetriever"]
@@ -0,0 +1,265 @@
1
+ """ChromaDB 向量存储适配器"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+ import uuid
7
+
8
+ from loom.interfaces.retriever import Document
9
+ from loom.interfaces.vector_store import BaseVectorStore
10
+
11
+ try:
12
+ import chromadb
13
+ from chromadb.config import Settings
14
+ CHROMA_AVAILABLE = True
15
+ except ImportError:
16
+ CHROMA_AVAILABLE = False
17
+
18
+
19
+ class ChromaVectorStore(BaseVectorStore):
20
+ """
21
+ ChromaDB 向量存储适配器
22
+
23
+ 特点:
24
+ - ✅ 开源嵌入式向量数据库
25
+ - ✅ 极简 API
26
+ - ✅ 支持本地持久化
27
+ - ✅ 自带 Embedding 功能(可选)
28
+ - ✅ 适合快速原型开发
29
+
30
+ 示例:
31
+ from loom.builtin.retriever.chroma_store import ChromaVectorStore
32
+ from loom.builtin.retriever.vector_store_config import ChromaConfig
33
+
34
+ # 本地持久化模式
35
+ config = ChromaConfig.create_local(
36
+ persist_directory="./chroma_db",
37
+ collection_name="loom_docs"
38
+ )
39
+
40
+ # 远程服务模式
41
+ config = ChromaConfig.create_remote(
42
+ host="localhost",
43
+ port=8000,
44
+ collection_name="loom_docs"
45
+ )
46
+
47
+ vector_store = ChromaVectorStore(config)
48
+ await vector_store.initialize()
49
+ """
50
+
51
+ def __init__(self, config: Dict[str, Any] | Any):
52
+ """
53
+ Parameters:
54
+ config: ChromaConfig 对象或配置字典
55
+ """
56
+ if not CHROMA_AVAILABLE:
57
+ raise ImportError(
58
+ "ChromaDB is not installed. "
59
+ "Install with: pip install chromadb"
60
+ )
61
+
62
+ # 支持字典或 Pydantic 模型
63
+ if hasattr(config, "model_dump"):
64
+ self.config = config.model_dump()
65
+ else:
66
+ self.config = config
67
+
68
+ self.collection_name = self.config.get("collection_name", "loom_documents")
69
+ self.dimension = self.config.get("dimension", 1536)
70
+ self.client_type = self.config.get("client_type", "local")
71
+ self.persist_directory = self.config.get("persist_directory")
72
+ self.host = self.config.get("host")
73
+ self.port = self.config.get("port", 8000)
74
+
75
+ self.client: Optional[Any] = None
76
+ self.collection: Optional[Any] = None
77
+ self._initialized = False
78
+
79
+ async def initialize(self) -> None:
80
+ """初始化 ChromaDB 连接和集合"""
81
+ if self._initialized:
82
+ return
83
+
84
+ # 初始化客户端
85
+ if self.client_type == "local":
86
+ # 本地持久化模式
87
+ if self.persist_directory:
88
+ self.client = chromadb.PersistentClient(
89
+ path=self.persist_directory
90
+ )
91
+ else:
92
+ # 内存模式(不持久化)
93
+ self.client = chromadb.Client()
94
+ else:
95
+ # 远程 HTTP 模式
96
+ self.client = chromadb.HttpClient(
97
+ host=self.host,
98
+ port=self.port
99
+ )
100
+
101
+ # 获取或创建集合
102
+ try:
103
+ self.collection = self.client.get_collection(
104
+ name=self.collection_name
105
+ )
106
+ except Exception:
107
+ # 集合不存在,创建新集合
108
+ self.collection = self.client.create_collection(
109
+ name=self.collection_name,
110
+ metadata={"dimension": self.dimension}
111
+ )
112
+
113
+ self._initialized = True
114
+
115
+ async def add_vectors(
116
+ self,
117
+ vectors: List[List[float]],
118
+ documents: List[Document]
119
+ ) -> None:
120
+ """
121
+ 添加向量到 ChromaDB
122
+
123
+ Parameters:
124
+ vectors: 向量列表
125
+ documents: 对应的文档列表
126
+ """
127
+ if not self._initialized:
128
+ await self.initialize()
129
+
130
+ # 构建 ChromaDB 数据格式
131
+ ids = []
132
+ embeddings = []
133
+ documents_text = []
134
+ metadatas = []
135
+
136
+ for i, (vector, doc) in enumerate(zip(vectors, documents)):
137
+ # 生成或使用文档 ID
138
+ doc_id = doc.doc_id or str(uuid.uuid4())
139
+
140
+ ids.append(doc_id)
141
+ embeddings.append(vector)
142
+ documents_text.append(doc.content)
143
+
144
+ # 构建元数据
145
+ metadata = doc.metadata or {}
146
+ if doc.score is not None:
147
+ metadata["score"] = doc.score
148
+ metadatas.append(metadata)
149
+
150
+ # 批量添加(ChromaDB 自动处理批量)
151
+ self.collection.add(
152
+ ids=ids,
153
+ embeddings=embeddings,
154
+ documents=documents_text,
155
+ metadatas=metadatas
156
+ )
157
+
158
+ async def search(
159
+ self,
160
+ query_vector: List[float],
161
+ top_k: int = 5,
162
+ filters: Optional[Dict[str, Any]] = None
163
+ ) -> List[Tuple[Document, float]]:
164
+ """
165
+ 搜索相似向量
166
+
167
+ Parameters:
168
+ query_vector: 查询向量
169
+ top_k: 返回结果数量
170
+ filters: 元数据过滤条件
171
+
172
+ Returns:
173
+ [(Document, score), ...] 列表
174
+ """
175
+ if not self._initialized:
176
+ await self.initialize()
177
+
178
+ # 构建 where 过滤条件(ChromaDB 格式)
179
+ where = None
180
+ if filters:
181
+ where = self._build_chroma_filter(filters)
182
+
183
+ # 执行查询
184
+ results = self.collection.query(
185
+ query_embeddings=[query_vector],
186
+ n_results=top_k,
187
+ where=where,
188
+ include=["documents", "metadatas", "distances"]
189
+ )
190
+
191
+ # 转换结果
192
+ documents_with_scores = []
193
+
194
+ if results["ids"]:
195
+ for i, doc_id in enumerate(results["ids"][0]):
196
+ content = results["documents"][0][i]
197
+ metadata = results["metadatas"][0][i] or {}
198
+ distance = results["distances"][0][i]
199
+
200
+ # 转换距离到相似度分数(ChromaDB 返回的是距离)
201
+ # 对于余弦距离,相似度 = 1 - distance
202
+ score = 1.0 - distance
203
+
204
+ doc = Document(
205
+ content=content,
206
+ metadata=metadata,
207
+ score=score,
208
+ doc_id=doc_id
209
+ )
210
+ documents_with_scores.append((doc, score))
211
+
212
+ return documents_with_scores
213
+
214
+ async def delete(self, doc_ids: List[str]) -> None:
215
+ """
216
+ 删除文档
217
+
218
+ Parameters:
219
+ doc_ids: 文档 ID 列表
220
+ """
221
+ if not self._initialized:
222
+ await self.initialize()
223
+
224
+ self.collection.delete(ids=doc_ids)
225
+
226
+ async def clear(self) -> None:
227
+ """清空集合"""
228
+ if not self._initialized:
229
+ await self.initialize()
230
+
231
+ # ChromaDB: 删除并重建集合
232
+ self.client.delete_collection(name=self.collection_name)
233
+ self.collection = self.client.create_collection(
234
+ name=self.collection_name,
235
+ metadata={"dimension": self.dimension}
236
+ )
237
+
238
+ def _build_chroma_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
239
+ """
240
+ 构建 ChromaDB where 过滤器
241
+
242
+ ChromaDB 过滤语法:
243
+ {
244
+ "field": "value",
245
+ "numeric_field": {"$gte": 10}
246
+ }
247
+
248
+ 支持的操作符: $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin
249
+ """
250
+ where = {}
251
+
252
+ for key, value in filters.items():
253
+ if isinstance(value, dict):
254
+ # 复杂查询(例如: {"price": {"$gte": 100}})
255
+ where[key] = value
256
+ else:
257
+ # 简单相等查询
258
+ where[key] = value
259
+
260
+ return where
261
+
262
+ async def close(self) -> None:
263
+ """关闭连接"""
264
+ # ChromaDB 客户端自动管理连接
265
+ self._initialized = False
@@ -0,0 +1,106 @@
1
+ """简单的内存检索器 - 无需外部依赖"""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from loom.interfaces.retriever import BaseRetriever, Document
8
+
9
+
10
+ class InMemoryRetriever(BaseRetriever):
11
+ """
12
+ 简单的内存检索器 - 基于关键词匹配
13
+
14
+ 特点:
15
+ - 无需外部依赖
16
+ - 适用于开发/测试
17
+ - 基于简单的关键词匹配 (非向量检索)
18
+
19
+ 示例:
20
+ retriever = InMemoryRetriever()
21
+ await retriever.add_texts([
22
+ "Python is a programming language",
23
+ "JavaScript is used for web development"
24
+ ])
25
+
26
+ docs = await retriever.retrieve("programming")
27
+ # 返回包含 "programming" 的文档
28
+ """
29
+
30
+ def __init__(self):
31
+ self.documents: List[Document] = []
32
+
33
+ async def retrieve(
34
+ self,
35
+ query: str,
36
+ top_k: int = 5,
37
+ filters: Optional[Dict[str, Any]] = None,
38
+ ) -> List[Document]:
39
+ """
40
+ 基于关键词匹配检索文档
41
+
42
+ 简单算法:
43
+ 1. 将查询分词
44
+ 2. 计算每个文档包含的关键词数量
45
+ 3. 按匹配度排序
46
+ """
47
+ if not self.documents:
48
+ return []
49
+
50
+ # 简单分词 (按空格)
51
+ query_terms = set(query.lower().split())
52
+
53
+ # 计算每个文档的匹配分数
54
+ scored_docs = []
55
+ for doc in self.documents:
56
+ # 应用过滤器
57
+ if filters and not self._match_filters(doc, filters):
58
+ continue
59
+
60
+ # 计算匹配分数
61
+ doc_terms = set(doc.content.lower().split())
62
+ matches = query_terms.intersection(doc_terms)
63
+ score = len(matches) / len(query_terms) if query_terms else 0.0
64
+
65
+ if score > 0:
66
+ # 创建副本并设置分数
67
+ doc_with_score = Document(
68
+ content=doc.content,
69
+ metadata=doc.metadata,
70
+ score=score,
71
+ doc_id=doc.doc_id
72
+ )
73
+ scored_docs.append(doc_with_score)
74
+
75
+ # 按分数排序
76
+ scored_docs.sort(key=lambda d: d.score or 0, reverse=True)
77
+
78
+ return scored_docs[:top_k]
79
+
80
+ async def add_documents(self, documents: List[Document]) -> None:
81
+ """添加文档到内存"""
82
+ for doc in documents:
83
+ # 分配 ID (如果没有)
84
+ if doc.doc_id is None:
85
+ doc.doc_id = str(len(self.documents))
86
+
87
+ self.documents.append(doc)
88
+
89
+ def _match_filters(self, doc: Document, filters: Dict[str, Any]) -> bool:
90
+ """检查文档是否匹配过滤条件"""
91
+ if not doc.metadata:
92
+ return False
93
+
94
+ for key, value in filters.items():
95
+ if doc.metadata.get(key) != value:
96
+ return False
97
+
98
+ return True
99
+
100
+ def clear(self) -> None:
101
+ """清空所有文档"""
102
+ self.documents.clear()
103
+
104
+ def __len__(self) -> int:
105
+ """返回文档数量"""
106
+ return len(self.documents)