mem1 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mem1/__init__.py +5 -4
- mem1/config.py +14 -5
- mem1/llm.py +54 -5
- mem1/{memory_es.py → memory.py} +225 -382
- mem1/prompts.py +22 -0
- mem1/storage.py +399 -0
- mem1-0.0.8.dist-info/METADATA +290 -0
- mem1-0.0.8.dist-info/RECORD +12 -0
- mem1-0.0.6.dist-info/METADATA +0 -191
- mem1-0.0.6.dist-info/RECORD +0 -11
- {mem1-0.0.6.dist-info → mem1-0.0.8.dist-info}/WHEEL +0 -0
mem1/prompts.py
CHANGED
|
@@ -207,6 +207,28 @@ ASSISTANT_SUMMARY_PROMPT = """你是对话摘要专家。将助手的长回复
|
|
|
207
207
|
直接输出摘要内容,开头标注 [摘要]"""
|
|
208
208
|
|
|
209
209
|
|
|
210
|
+
# ============ 渐进式检索判断提示词 ============
|
|
211
|
+
|
|
212
|
+
CONTEXT_SUFFICIENT_PROMPT = """判断当前信息是否足够回答用户问题。
|
|
213
|
+
|
|
214
|
+
## 用户问题
|
|
215
|
+
{query}
|
|
216
|
+
|
|
217
|
+
## 用户画像
|
|
218
|
+
{profile}
|
|
219
|
+
|
|
220
|
+
## 已检索的对话记录(最近 {days} 天)
|
|
221
|
+
{conversations}
|
|
222
|
+
|
|
223
|
+
## 判断标准
|
|
224
|
+
- 如果画像或对话中包含回答问题所需的信息,输出 `true`
|
|
225
|
+
- 如果问题涉及的时间、事件、数据在已有信息中找不到,输出 `false`
|
|
226
|
+
- 如果是通用问题(不依赖历史记录),输出 `true`
|
|
227
|
+
|
|
228
|
+
## 输出
|
|
229
|
+
只输出:`true`(信息足够)或 `false`(需要检索更早的记录)"""
|
|
230
|
+
|
|
231
|
+
|
|
210
232
|
# ============ 图片搜索提示词(通用) ============
|
|
211
233
|
|
|
212
234
|
IMAGE_SEARCH_PROMPT = """根据用户查询,从图片列表中找出匹配的图片。
|
mem1/storage.py
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""可插拔存储层抽象
|
|
2
|
+
|
|
3
|
+
设计目标:
|
|
4
|
+
- 将存储操作从 Mem1Memory 中解耦
|
|
5
|
+
- 支持 ES/SQLite/MySQL 等多种后端
|
|
6
|
+
- 保持接口简洁,只抽象必要操作
|
|
7
|
+
|
|
8
|
+
使用方式:
|
|
9
|
+
from mem1.storage import ESStorage
|
|
10
|
+
storage = ESStorage(config.es)
|
|
11
|
+
|
|
12
|
+
# 或未来实现
|
|
13
|
+
from mem1.storage import SQLiteStorage
|
|
14
|
+
storage = SQLiteStorage(db_path="mem1.db")
|
|
15
|
+
"""
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
from typing import List, Dict, Any, Optional
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StorageBackend(ABC):
|
|
22
|
+
"""存储后端抽象基类
|
|
23
|
+
|
|
24
|
+
所有存储实现需要实现以下方法:
|
|
25
|
+
- 对话记录:save_conversation, get_conversations, delete_conversations
|
|
26
|
+
- 用户画像:get_profile, save_profile, delete_profile
|
|
27
|
+
- 用户状态:get_user_state, save_user_state, delete_user_state
|
|
28
|
+
- 聚合查询:get_user_list, get_topic_list
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# ========== 对话记录 ==========
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def save_conversation(self, conversation: Dict[str, Any]) -> str:
|
|
35
|
+
"""保存对话记录
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
conversation: {
|
|
39
|
+
"user_id": str,
|
|
40
|
+
"topic_id": str,
|
|
41
|
+
"timestamp": str, # 格式: '%Y-%m-%d %H:%M:%S'
|
|
42
|
+
"messages": List[Dict],
|
|
43
|
+
"metadata": Dict,
|
|
44
|
+
"images": List[Dict] (可选)
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
记录ID
|
|
49
|
+
"""
|
|
50
|
+
pass
|
|
51
|
+
|
|
52
|
+
@abstractmethod
|
|
53
|
+
def get_conversations(
|
|
54
|
+
self,
|
|
55
|
+
user_id: str,
|
|
56
|
+
topic_id: Optional[str] = None,
|
|
57
|
+
start_time: Optional[datetime] = None,
|
|
58
|
+
end_time: Optional[datetime] = None,
|
|
59
|
+
metadata_filter: Optional[Dict[str, Any]] = None,
|
|
60
|
+
limit: int = 1000
|
|
61
|
+
) -> List[Dict[str, Any]]:
|
|
62
|
+
"""查询对话记录
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
user_id: 用户ID
|
|
66
|
+
topic_id: 话题ID,None 表示所有话题
|
|
67
|
+
start_time: 起始时间
|
|
68
|
+
end_time: 结束时间
|
|
69
|
+
metadata_filter: 元数据过滤
|
|
70
|
+
limit: 最大返回数量
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
对话记录列表,按时间升序
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def delete_conversations(
|
|
79
|
+
self,
|
|
80
|
+
user_id: str,
|
|
81
|
+
topic_id: Optional[str] = None
|
|
82
|
+
) -> int:
|
|
83
|
+
"""删除对话记录
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
user_id: 用户ID
|
|
87
|
+
topic_id: 话题ID,None 表示删除所有话题
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
删除的记录数
|
|
91
|
+
"""
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
# ========== 用户画像 ==========
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def get_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
98
|
+
"""获取用户画像
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
{"content": str, "updated_at": str} 或 None
|
|
102
|
+
"""
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
@abstractmethod
|
|
106
|
+
def save_profile(self, user_id: str, content: str) -> None:
|
|
107
|
+
"""保存用户画像"""
|
|
108
|
+
pass
|
|
109
|
+
|
|
110
|
+
@abstractmethod
|
|
111
|
+
def delete_profile(self, user_id: str) -> bool:
|
|
112
|
+
"""删除用户画像"""
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
# ========== 用户状态 ==========
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def get_user_state(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
119
|
+
"""获取用户状态
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
{"rounds": int, "last_update": str} 或 None
|
|
123
|
+
"""
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
@abstractmethod
|
|
127
|
+
def save_user_state(self, user_id: str, rounds: int, last_update: Optional[str] = None) -> None:
|
|
128
|
+
"""保存用户状态"""
|
|
129
|
+
pass
|
|
130
|
+
|
|
131
|
+
@abstractmethod
|
|
132
|
+
def delete_user_state(self, user_id: str) -> bool:
|
|
133
|
+
"""删除用户状态"""
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
# ========== 聚合查询 ==========
|
|
137
|
+
|
|
138
|
+
@abstractmethod
|
|
139
|
+
def get_user_list(self) -> List[str]:
|
|
140
|
+
"""获取所有用户ID列表"""
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
@abstractmethod
|
|
144
|
+
def get_topic_list(self, user_id: str) -> List[Dict[str, Any]]:
|
|
145
|
+
"""获取用户的话题列表
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
[{"topic_id": str, "conversation_count": int, "last_active": str}, ...]
|
|
149
|
+
"""
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
# ========== 初始化 ==========
|
|
153
|
+
|
|
154
|
+
@abstractmethod
|
|
155
|
+
def ensure_schema(self) -> None:
|
|
156
|
+
"""确保存储结构存在(索引/表)"""
|
|
157
|
+
pass
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class ESStorage(StorageBackend):
|
|
162
|
+
"""Elasticsearch 存储后端"""
|
|
163
|
+
|
|
164
|
+
# 索引名常量
|
|
165
|
+
USER_STATE_INDEX = "mem1_user_state"
|
|
166
|
+
USER_PROFILE_INDEX = "mem1_user_profile"
|
|
167
|
+
|
|
168
|
+
def __init__(self, hosts: List[str], index_name: str):
|
|
169
|
+
"""
|
|
170
|
+
Args:
|
|
171
|
+
hosts: ES 地址列表
|
|
172
|
+
index_name: 对话记录索引名
|
|
173
|
+
"""
|
|
174
|
+
from elasticsearch import Elasticsearch
|
|
175
|
+
self.es = Elasticsearch(hosts)
|
|
176
|
+
self.index_name = index_name
|
|
177
|
+
self.ensure_schema()
|
|
178
|
+
|
|
179
|
+
def ensure_schema(self) -> None:
|
|
180
|
+
"""确保所有索引存在"""
|
|
181
|
+
# 对话记录索引
|
|
182
|
+
if not self.es.indices.exists(index=self.index_name):
|
|
183
|
+
self.es.indices.create(
|
|
184
|
+
index=self.index_name,
|
|
185
|
+
body={
|
|
186
|
+
"mappings": {
|
|
187
|
+
"properties": {
|
|
188
|
+
"user_id": {"type": "keyword"},
|
|
189
|
+
"topic_id": {"type": "keyword"},
|
|
190
|
+
"timestamp": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"},
|
|
191
|
+
"messages": {"type": "nested"},
|
|
192
|
+
"metadata": {"type": "object"},
|
|
193
|
+
"images": {"type": "nested"}
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
# 用户状态索引
|
|
200
|
+
if not self.es.indices.exists(index=self.USER_STATE_INDEX):
|
|
201
|
+
self.es.indices.create(
|
|
202
|
+
index=self.USER_STATE_INDEX,
|
|
203
|
+
body={
|
|
204
|
+
"mappings": {
|
|
205
|
+
"properties": {
|
|
206
|
+
"user_id": {"type": "keyword"},
|
|
207
|
+
"rounds": {"type": "integer"},
|
|
208
|
+
"last_update": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"}
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# 用户画像索引
|
|
215
|
+
if not self.es.indices.exists(index=self.USER_PROFILE_INDEX):
|
|
216
|
+
self.es.indices.create(
|
|
217
|
+
index=self.USER_PROFILE_INDEX,
|
|
218
|
+
body={
|
|
219
|
+
"mappings": {
|
|
220
|
+
"properties": {
|
|
221
|
+
"user_id": {"type": "keyword"},
|
|
222
|
+
"content": {"type": "text"},
|
|
223
|
+
"updated_at": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"}
|
|
224
|
+
}
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# ========== 对话记录 ==========
|
|
230
|
+
|
|
231
|
+
def save_conversation(self, conversation: Dict[str, Any]) -> str:
|
|
232
|
+
response = self.es.index(
|
|
233
|
+
index=self.index_name,
|
|
234
|
+
document=conversation,
|
|
235
|
+
refresh=True
|
|
236
|
+
)
|
|
237
|
+
return response["_id"]
|
|
238
|
+
|
|
239
|
+
def get_conversations(
|
|
240
|
+
self,
|
|
241
|
+
user_id: str,
|
|
242
|
+
topic_id: Optional[str] = None,
|
|
243
|
+
start_time: Optional[datetime] = None,
|
|
244
|
+
end_time: Optional[datetime] = None,
|
|
245
|
+
metadata_filter: Optional[Dict[str, Any]] = None,
|
|
246
|
+
limit: int = 1000
|
|
247
|
+
) -> List[Dict[str, Any]]:
|
|
248
|
+
query = {"bool": {"must": [{"term": {"user_id": user_id}}]}}
|
|
249
|
+
|
|
250
|
+
if topic_id:
|
|
251
|
+
query["bool"]["must"].append({"term": {"topic_id": topic_id}})
|
|
252
|
+
|
|
253
|
+
if start_time or end_time:
|
|
254
|
+
range_query = {}
|
|
255
|
+
if start_time:
|
|
256
|
+
range_query["gte"] = start_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
257
|
+
if end_time:
|
|
258
|
+
range_query["lt"] = end_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
259
|
+
query["bool"]["must"].append({"range": {"timestamp": range_query}})
|
|
260
|
+
|
|
261
|
+
if metadata_filter:
|
|
262
|
+
for k, v in metadata_filter.items():
|
|
263
|
+
query["bool"]["must"].append({"term": {f"metadata.{k}": v}})
|
|
264
|
+
|
|
265
|
+
response = self.es.search(
|
|
266
|
+
index=self.index_name,
|
|
267
|
+
query=query,
|
|
268
|
+
size=limit,
|
|
269
|
+
sort=[{"timestamp": {"order": "asc"}}]
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return [hit["_source"] for hit in response["hits"]["hits"]]
|
|
273
|
+
|
|
274
|
+
def delete_conversations(self, user_id: str, topic_id: Optional[str] = None) -> int:
|
|
275
|
+
query = {"bool": {"must": [{"term": {"user_id": user_id}}]}}
|
|
276
|
+
if topic_id:
|
|
277
|
+
query["bool"]["must"].append({"term": {"topic_id": topic_id}})
|
|
278
|
+
|
|
279
|
+
try:
|
|
280
|
+
response = self.es.delete_by_query(
|
|
281
|
+
index=self.index_name,
|
|
282
|
+
query=query,
|
|
283
|
+
refresh=True
|
|
284
|
+
)
|
|
285
|
+
return response.get("deleted", 0)
|
|
286
|
+
except Exception:
|
|
287
|
+
return 0
|
|
288
|
+
|
|
289
|
+
# ========== 用户画像 ==========
|
|
290
|
+
|
|
291
|
+
def get_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
292
|
+
try:
|
|
293
|
+
response = self.es.get(index=self.USER_PROFILE_INDEX, id=user_id)
|
|
294
|
+
return response["_source"]
|
|
295
|
+
except Exception:
|
|
296
|
+
return None
|
|
297
|
+
|
|
298
|
+
def save_profile(self, user_id: str, content: str) -> None:
|
|
299
|
+
self.es.index(
|
|
300
|
+
index=self.USER_PROFILE_INDEX,
|
|
301
|
+
id=user_id,
|
|
302
|
+
document={
|
|
303
|
+
"user_id": user_id,
|
|
304
|
+
"content": content,
|
|
305
|
+
"updated_at": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
306
|
+
},
|
|
307
|
+
refresh=True
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def delete_profile(self, user_id: str) -> bool:
|
|
311
|
+
try:
|
|
312
|
+
self.es.delete(index=self.USER_PROFILE_INDEX, id=user_id, refresh=True)
|
|
313
|
+
return True
|
|
314
|
+
except Exception:
|
|
315
|
+
return False
|
|
316
|
+
|
|
317
|
+
# ========== 用户状态 ==========
|
|
318
|
+
|
|
319
|
+
def get_user_state(self, user_id: str) -> Optional[Dict[str, Any]]:
|
|
320
|
+
try:
|
|
321
|
+
response = self.es.get(index=self.USER_STATE_INDEX, id=user_id)
|
|
322
|
+
return response["_source"]
|
|
323
|
+
except Exception:
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
def save_user_state(self, user_id: str, rounds: int, last_update: Optional[str] = None) -> None:
|
|
327
|
+
doc = {"user_id": user_id, "rounds": rounds}
|
|
328
|
+
if last_update:
|
|
329
|
+
doc["last_update"] = last_update
|
|
330
|
+
|
|
331
|
+
self.es.index(
|
|
332
|
+
index=self.USER_STATE_INDEX,
|
|
333
|
+
id=user_id,
|
|
334
|
+
document=doc,
|
|
335
|
+
refresh=True
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
def delete_user_state(self, user_id: str) -> bool:
|
|
339
|
+
try:
|
|
340
|
+
self.es.delete(index=self.USER_STATE_INDEX, id=user_id, refresh=True)
|
|
341
|
+
return True
|
|
342
|
+
except Exception:
|
|
343
|
+
return False
|
|
344
|
+
|
|
345
|
+
# ========== 聚合查询 ==========
|
|
346
|
+
|
|
347
|
+
def get_user_list(self) -> List[str]:
|
|
348
|
+
response = self.es.search(
|
|
349
|
+
index=self.index_name,
|
|
350
|
+
body={
|
|
351
|
+
"size": 0,
|
|
352
|
+
"aggs": {"users": {"terms": {"field": "user_id", "size": 10000}}}
|
|
353
|
+
}
|
|
354
|
+
)
|
|
355
|
+
return [bucket["key"] for bucket in response["aggregations"]["users"]["buckets"]]
|
|
356
|
+
|
|
357
|
+
def get_topic_list(self, user_id: str) -> List[Dict[str, Any]]:
|
|
358
|
+
response = self.es.search(
|
|
359
|
+
index=self.index_name,
|
|
360
|
+
body={
|
|
361
|
+
"size": 0,
|
|
362
|
+
"query": {"term": {"user_id": user_id}},
|
|
363
|
+
"aggs": {
|
|
364
|
+
"topics": {
|
|
365
|
+
"terms": {"field": "topic_id", "size": 1000},
|
|
366
|
+
"aggs": {
|
|
367
|
+
"latest": {"max": {"field": "timestamp"}},
|
|
368
|
+
"count": {"value_count": {"field": "timestamp"}}
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
topics = []
|
|
376
|
+
for bucket in response["aggregations"]["topics"]["buckets"]:
|
|
377
|
+
topics.append({
|
|
378
|
+
"topic_id": bucket["key"],
|
|
379
|
+
"conversation_count": bucket["doc_count"],
|
|
380
|
+
"last_active": bucket["latest"]["value_as_string"] if bucket["latest"]["value"] else None
|
|
381
|
+
})
|
|
382
|
+
return topics
|
|
383
|
+
|
|
384
|
+
def get_conversations_with_images(self, user_id: str) -> List[Dict[str, Any]]:
|
|
385
|
+
"""获取用户所有带图片的对话(用于图片索引)"""
|
|
386
|
+
response = self.es.search(
|
|
387
|
+
index=self.index_name,
|
|
388
|
+
query={
|
|
389
|
+
"bool": {
|
|
390
|
+
"must": [
|
|
391
|
+
{"term": {"user_id": user_id}},
|
|
392
|
+
{"exists": {"field": "images"}}
|
|
393
|
+
]
|
|
394
|
+
}
|
|
395
|
+
},
|
|
396
|
+
size=1000,
|
|
397
|
+
sort=[{"timestamp": {"order": "asc"}}]
|
|
398
|
+
)
|
|
399
|
+
return [hit["_source"] for hit in response["hits"]["hits"]]
|