mem1 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mem1/prompts.py CHANGED
@@ -207,6 +207,28 @@ ASSISTANT_SUMMARY_PROMPT = """你是对话摘要专家。将助手的长回复
207
207
  直接输出摘要内容,开头标注 [摘要]"""
208
208
 
209
209
 
210
+ # ============ 渐进式检索判断提示词 ============
211
+
212
+ CONTEXT_SUFFICIENT_PROMPT = """判断当前信息是否足够回答用户问题。
213
+
214
+ ## 用户问题
215
+ {query}
216
+
217
+ ## 用户画像
218
+ {profile}
219
+
220
+ ## 已检索的对话记录(最近 {days} 天)
221
+ {conversations}
222
+
223
+ ## 判断标准
224
+ - 如果画像或对话中包含回答问题所需的信息,输出 `true`
225
+ - 如果问题涉及的时间、事件、数据在已有信息中找不到,输出 `false`
226
+ - 如果是通用问题(不依赖历史记录),输出 `true`
227
+
228
+ ## 输出
229
+ 只输出:`true`(信息足够)或 `false`(需要检索更早的记录)"""
230
+
231
+
210
232
  # ============ 图片搜索提示词(通用) ============
211
233
 
212
234
  IMAGE_SEARCH_PROMPT = """根据用户查询,从图片列表中找出匹配的图片。
mem1/storage.py ADDED
@@ -0,0 +1,399 @@
1
+ """可插拔存储层抽象
2
+
3
+ 设计目标:
4
+ - 将存储操作从 Mem1Memory 中解耦
5
+ - 支持 ES/SQLite/MySQL 等多种后端
6
+ - 保持接口简洁,只抽象必要操作
7
+
8
+ 使用方式:
9
+ from mem1.storage import ESStorage
10
+ storage = ESStorage(config.es)
11
+
12
+ # 或未来实现
13
+ from mem1.storage import SQLiteStorage
14
+ storage = SQLiteStorage(db_path="mem1.db")
15
+ """
16
+ from abc import ABC, abstractmethod
17
+ from datetime import datetime
18
+ from typing import List, Dict, Any, Optional
19
+
20
+
21
+ class StorageBackend(ABC):
22
+ """存储后端抽象基类
23
+
24
+ 所有存储实现需要实现以下方法:
25
+ - 对话记录:save_conversation, get_conversations, delete_conversations
26
+ - 用户画像:get_profile, save_profile, delete_profile
27
+ - 用户状态:get_user_state, save_user_state, delete_user_state
28
+ - 聚合查询:get_user_list, get_topic_list
29
+ """
30
+
31
+ # ========== 对话记录 ==========
32
+
33
+ @abstractmethod
34
+ def save_conversation(self, conversation: Dict[str, Any]) -> str:
35
+ """保存对话记录
36
+
37
+ Args:
38
+ conversation: {
39
+ "user_id": str,
40
+ "topic_id": str,
41
+ "timestamp": str, # 格式: '%Y-%m-%d %H:%M:%S'
42
+ "messages": List[Dict],
43
+ "metadata": Dict,
44
+ "images": List[Dict] (可选)
45
+ }
46
+
47
+ Returns:
48
+ 记录ID
49
+ """
50
+ pass
51
+
52
+ @abstractmethod
53
+ def get_conversations(
54
+ self,
55
+ user_id: str,
56
+ topic_id: Optional[str] = None,
57
+ start_time: Optional[datetime] = None,
58
+ end_time: Optional[datetime] = None,
59
+ metadata_filter: Optional[Dict[str, Any]] = None,
60
+ limit: int = 1000
61
+ ) -> List[Dict[str, Any]]:
62
+ """查询对话记录
63
+
64
+ Args:
65
+ user_id: 用户ID
66
+ topic_id: 话题ID,None 表示所有话题
67
+ start_time: 起始时间
68
+ end_time: 结束时间
69
+ metadata_filter: 元数据过滤
70
+ limit: 最大返回数量
71
+
72
+ Returns:
73
+ 对话记录列表,按时间升序
74
+ """
75
+ pass
76
+
77
+ @abstractmethod
78
+ def delete_conversations(
79
+ self,
80
+ user_id: str,
81
+ topic_id: Optional[str] = None
82
+ ) -> int:
83
+ """删除对话记录
84
+
85
+ Args:
86
+ user_id: 用户ID
87
+ topic_id: 话题ID,None 表示删除所有话题
88
+
89
+ Returns:
90
+ 删除的记录数
91
+ """
92
+ pass
93
+
94
+ # ========== 用户画像 ==========
95
+
96
+ @abstractmethod
97
+ def get_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
98
+ """获取用户画像
99
+
100
+ Returns:
101
+ {"content": str, "updated_at": str} 或 None
102
+ """
103
+ pass
104
+
105
+ @abstractmethod
106
+ def save_profile(self, user_id: str, content: str) -> None:
107
+ """保存用户画像"""
108
+ pass
109
+
110
+ @abstractmethod
111
+ def delete_profile(self, user_id: str) -> bool:
112
+ """删除用户画像"""
113
+ pass
114
+
115
+ # ========== 用户状态 ==========
116
+
117
+ @abstractmethod
118
+ def get_user_state(self, user_id: str) -> Optional[Dict[str, Any]]:
119
+ """获取用户状态
120
+
121
+ Returns:
122
+ {"rounds": int, "last_update": str} 或 None
123
+ """
124
+ pass
125
+
126
+ @abstractmethod
127
+ def save_user_state(self, user_id: str, rounds: int, last_update: Optional[str] = None) -> None:
128
+ """保存用户状态"""
129
+ pass
130
+
131
+ @abstractmethod
132
+ def delete_user_state(self, user_id: str) -> bool:
133
+ """删除用户状态"""
134
+ pass
135
+
136
+ # ========== 聚合查询 ==========
137
+
138
+ @abstractmethod
139
+ def get_user_list(self) -> List[str]:
140
+ """获取所有用户ID列表"""
141
+ pass
142
+
143
+ @abstractmethod
144
+ def get_topic_list(self, user_id: str) -> List[Dict[str, Any]]:
145
+ """获取用户的话题列表
146
+
147
+ Returns:
148
+ [{"topic_id": str, "conversation_count": int, "last_active": str}, ...]
149
+ """
150
+ pass
151
+
152
+ # ========== 初始化 ==========
153
+
154
+ @abstractmethod
155
+ def ensure_schema(self) -> None:
156
+ """确保存储结构存在(索引/表)"""
157
+ pass
158
+
159
+
160
+
161
+ class ESStorage(StorageBackend):
162
+ """Elasticsearch 存储后端"""
163
+
164
+ # 索引名常量
165
+ USER_STATE_INDEX = "mem1_user_state"
166
+ USER_PROFILE_INDEX = "mem1_user_profile"
167
+
168
+ def __init__(self, hosts: List[str], index_name: str):
169
+ """
170
+ Args:
171
+ hosts: ES 地址列表
172
+ index_name: 对话记录索引名
173
+ """
174
+ from elasticsearch import Elasticsearch
175
+ self.es = Elasticsearch(hosts)
176
+ self.index_name = index_name
177
+ self.ensure_schema()
178
+
179
+ def ensure_schema(self) -> None:
180
+ """确保所有索引存在"""
181
+ # 对话记录索引
182
+ if not self.es.indices.exists(index=self.index_name):
183
+ self.es.indices.create(
184
+ index=self.index_name,
185
+ body={
186
+ "mappings": {
187
+ "properties": {
188
+ "user_id": {"type": "keyword"},
189
+ "topic_id": {"type": "keyword"},
190
+ "timestamp": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"},
191
+ "messages": {"type": "nested"},
192
+ "metadata": {"type": "object"},
193
+ "images": {"type": "nested"}
194
+ }
195
+ }
196
+ }
197
+ )
198
+
199
+ # 用户状态索引
200
+ if not self.es.indices.exists(index=self.USER_STATE_INDEX):
201
+ self.es.indices.create(
202
+ index=self.USER_STATE_INDEX,
203
+ body={
204
+ "mappings": {
205
+ "properties": {
206
+ "user_id": {"type": "keyword"},
207
+ "rounds": {"type": "integer"},
208
+ "last_update": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"}
209
+ }
210
+ }
211
+ }
212
+ )
213
+
214
+ # 用户画像索引
215
+ if not self.es.indices.exists(index=self.USER_PROFILE_INDEX):
216
+ self.es.indices.create(
217
+ index=self.USER_PROFILE_INDEX,
218
+ body={
219
+ "mappings": {
220
+ "properties": {
221
+ "user_id": {"type": "keyword"},
222
+ "content": {"type": "text"},
223
+ "updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"}
224
+ }
225
+ }
226
+ }
227
+ )
228
+
229
+ # ========== 对话记录 ==========
230
+
231
+ def save_conversation(self, conversation: Dict[str, Any]) -> str:
232
+ response = self.es.index(
233
+ index=self.index_name,
234
+ document=conversation,
235
+ refresh=True
236
+ )
237
+ return response["_id"]
238
+
239
+ def get_conversations(
240
+ self,
241
+ user_id: str,
242
+ topic_id: Optional[str] = None,
243
+ start_time: Optional[datetime] = None,
244
+ end_time: Optional[datetime] = None,
245
+ metadata_filter: Optional[Dict[str, Any]] = None,
246
+ limit: int = 1000
247
+ ) -> List[Dict[str, Any]]:
248
+ query = {"bool": {"must": [{"term": {"user_id": user_id}}]}}
249
+
250
+ if topic_id:
251
+ query["bool"]["must"].append({"term": {"topic_id": topic_id}})
252
+
253
+ if start_time or end_time:
254
+ range_query = {}
255
+ if start_time:
256
+ range_query["gte"] = start_time.strftime('%Y-%m-%d %H:%M:%S')
257
+ if end_time:
258
+ range_query["lt"] = end_time.strftime('%Y-%m-%d %H:%M:%S')
259
+ query["bool"]["must"].append({"range": {"timestamp": range_query}})
260
+
261
+ if metadata_filter:
262
+ for k, v in metadata_filter.items():
263
+ query["bool"]["must"].append({"term": {f"metadata.{k}": v}})
264
+
265
+ response = self.es.search(
266
+ index=self.index_name,
267
+ query=query,
268
+ size=limit,
269
+ sort=[{"timestamp": {"order": "asc"}}]
270
+ )
271
+
272
+ return [hit["_source"] for hit in response["hits"]["hits"]]
273
+
274
+ def delete_conversations(self, user_id: str, topic_id: Optional[str] = None) -> int:
275
+ query = {"bool": {"must": [{"term": {"user_id": user_id}}]}}
276
+ if topic_id:
277
+ query["bool"]["must"].append({"term": {"topic_id": topic_id}})
278
+
279
+ try:
280
+ response = self.es.delete_by_query(
281
+ index=self.index_name,
282
+ query=query,
283
+ refresh=True
284
+ )
285
+ return response.get("deleted", 0)
286
+ except Exception:
287
+ return 0
288
+
289
+ # ========== 用户画像 ==========
290
+
291
+ def get_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
292
+ try:
293
+ response = self.es.get(index=self.USER_PROFILE_INDEX, id=user_id)
294
+ return response["_source"]
295
+ except Exception:
296
+ return None
297
+
298
+ def save_profile(self, user_id: str, content: str) -> None:
299
+ self.es.index(
300
+ index=self.USER_PROFILE_INDEX,
301
+ id=user_id,
302
+ document={
303
+ "user_id": user_id,
304
+ "content": content,
305
+ "updated_at": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
306
+ },
307
+ refresh=True
308
+ )
309
+
310
+ def delete_profile(self, user_id: str) -> bool:
311
+ try:
312
+ self.es.delete(index=self.USER_PROFILE_INDEX, id=user_id, refresh=True)
313
+ return True
314
+ except Exception:
315
+ return False
316
+
317
+ # ========== 用户状态 ==========
318
+
319
+ def get_user_state(self, user_id: str) -> Optional[Dict[str, Any]]:
320
+ try:
321
+ response = self.es.get(index=self.USER_STATE_INDEX, id=user_id)
322
+ return response["_source"]
323
+ except Exception:
324
+ return None
325
+
326
+ def save_user_state(self, user_id: str, rounds: int, last_update: Optional[str] = None) -> None:
327
+ doc = {"user_id": user_id, "rounds": rounds}
328
+ if last_update:
329
+ doc["last_update"] = last_update
330
+
331
+ self.es.index(
332
+ index=self.USER_STATE_INDEX,
333
+ id=user_id,
334
+ document=doc,
335
+ refresh=True
336
+ )
337
+
338
+ def delete_user_state(self, user_id: str) -> bool:
339
+ try:
340
+ self.es.delete(index=self.USER_STATE_INDEX, id=user_id, refresh=True)
341
+ return True
342
+ except Exception:
343
+ return False
344
+
345
+ # ========== 聚合查询 ==========
346
+
347
+ def get_user_list(self) -> List[str]:
348
+ response = self.es.search(
349
+ index=self.index_name,
350
+ body={
351
+ "size": 0,
352
+ "aggs": {"users": {"terms": {"field": "user_id", "size": 10000}}}
353
+ }
354
+ )
355
+ return [bucket["key"] for bucket in response["aggregations"]["users"]["buckets"]]
356
+
357
+ def get_topic_list(self, user_id: str) -> List[Dict[str, Any]]:
358
+ response = self.es.search(
359
+ index=self.index_name,
360
+ body={
361
+ "size": 0,
362
+ "query": {"term": {"user_id": user_id}},
363
+ "aggs": {
364
+ "topics": {
365
+ "terms": {"field": "topic_id", "size": 1000},
366
+ "aggs": {
367
+ "latest": {"max": {"field": "timestamp"}},
368
+ "count": {"value_count": {"field": "timestamp"}}
369
+ }
370
+ }
371
+ }
372
+ }
373
+ )
374
+
375
+ topics = []
376
+ for bucket in response["aggregations"]["topics"]["buckets"]:
377
+ topics.append({
378
+ "topic_id": bucket["key"],
379
+ "conversation_count": bucket["doc_count"],
380
+ "last_active": bucket["latest"]["value_as_string"] if bucket["latest"]["value"] else None
381
+ })
382
+ return topics
383
+
384
+ def get_conversations_with_images(self, user_id: str) -> List[Dict[str, Any]]:
385
+ """获取用户所有带图片的对话(用于图片索引)"""
386
+ response = self.es.search(
387
+ index=self.index_name,
388
+ query={
389
+ "bool": {
390
+ "must": [
391
+ {"term": {"user_id": user_id}},
392
+ {"exists": {"field": "images"}}
393
+ ]
394
+ }
395
+ },
396
+ size=1000,
397
+ sort=[{"timestamp": {"order": "asc"}}]
398
+ )
399
+ return [hit["_source"] for hit in response["hits"]["hits"]]