pycityagent 2.0.0a53__cp310-cp310-macosx_11_0_arm64.whl → 2.0.0a55__cp310-cp310-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +39 -4
- pycityagent/agent/agent_base.py +39 -25
- pycityagent/cityagent/blocks/plan_block.py +1 -1
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/societyagent.py +145 -32
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +2 -3
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/memory.py +151 -113
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/metrics/mlflow_client.py +4 -0
- pycityagent/simulation/agentgroup.py +62 -14
- pycityagent/simulation/simulation.py +103 -29
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +3 -3
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a55.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a55.dist-info}/RECORD +29 -27
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a55.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a55.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a55.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a55.dist-info}/top_level.txt +0 -0
pycityagent/memory/memory.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
3
|
from collections import defaultdict
|
4
|
-
from collections.abc import Callable, Sequence
|
4
|
+
from collections.abc import Callable, Coroutine, Sequence
|
5
5
|
from copy import deepcopy
|
6
|
-
from typing import Any, Literal, Optional, Union, Dict
|
7
6
|
from dataclasses import dataclass
|
8
7
|
from enum import Enum
|
8
|
+
from typing import Any, Dict, Literal, Optional, Union
|
9
9
|
|
10
10
|
from langchain_core.embeddings import Embeddings
|
11
11
|
from pyparsing import deque
|
@@ -19,8 +19,10 @@ from .state import StateMemory
|
|
19
19
|
|
20
20
|
logger = logging.getLogger("pycityagent")
|
21
21
|
|
22
|
+
|
22
23
|
class MemoryTag(str, Enum):
|
23
24
|
"""记忆标签枚举类"""
|
25
|
+
|
24
26
|
MOBILITY = "mobility"
|
25
27
|
SOCIAL = "social"
|
26
28
|
ECONOMY = "economy"
|
@@ -28,9 +30,11 @@ class MemoryTag(str, Enum):
|
|
28
30
|
OTHER = "other"
|
29
31
|
EVENT = "event"
|
30
32
|
|
33
|
+
|
31
34
|
@dataclass
|
32
35
|
class MemoryNode:
|
33
36
|
"""记忆节点"""
|
37
|
+
|
34
38
|
tag: MemoryTag
|
35
39
|
day: int
|
36
40
|
t: int
|
@@ -39,8 +43,10 @@ class MemoryNode:
|
|
39
43
|
cognition_id: Optional[int] = None # 关联的认知记忆ID
|
40
44
|
id: Optional[int] = None # 记忆ID
|
41
45
|
|
46
|
+
|
42
47
|
class StreamMemory:
|
43
48
|
"""用于存储时序性的流式信息"""
|
49
|
+
|
44
50
|
def __init__(self, max_len: int = 1000):
|
45
51
|
self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
|
46
52
|
self._memory_id_counter: int = 0 # 用于生成唯一ID
|
@@ -50,6 +56,20 @@ class StreamMemory:
|
|
50
56
|
self._status_memory = None
|
51
57
|
self._simulator = None
|
52
58
|
|
59
|
+
@property
|
60
|
+
def faiss_query(
|
61
|
+
self,
|
62
|
+
) -> FaissQuery:
|
63
|
+
assert self._faiss_query is not None
|
64
|
+
return self._faiss_query
|
65
|
+
|
66
|
+
@property
|
67
|
+
def status_memory(
|
68
|
+
self,
|
69
|
+
):
|
70
|
+
assert self._status_memory is not None
|
71
|
+
return self._status_memory
|
72
|
+
|
53
73
|
def set_simulator(self, simulator):
|
54
74
|
self._simulator = simulator
|
55
75
|
|
@@ -73,14 +93,14 @@ class StreamMemory:
|
|
73
93
|
else:
|
74
94
|
day = 1
|
75
95
|
t = 1
|
76
|
-
position = await self.
|
77
|
-
if
|
78
|
-
location = position[
|
79
|
-
elif
|
80
|
-
location = position[
|
96
|
+
position = await self.status_memory.get("position")
|
97
|
+
if "aoi_position" in position:
|
98
|
+
location = position["aoi_position"]["aoi_id"]
|
99
|
+
elif "lane_position" in position:
|
100
|
+
location = position["lane_position"]["lane_id"]
|
81
101
|
else:
|
82
102
|
location = "unknown"
|
83
|
-
|
103
|
+
|
84
104
|
current_id = self._memory_id_counter
|
85
105
|
self._memory_id_counter += 1
|
86
106
|
memory_node = MemoryNode(
|
@@ -92,11 +112,10 @@ class StreamMemory:
|
|
92
112
|
id=current_id,
|
93
113
|
)
|
94
114
|
self._memories.append(memory_node)
|
95
|
-
|
96
115
|
|
97
116
|
# 为新记忆创建 embedding
|
98
117
|
if self._embedding_model and self._faiss_query:
|
99
|
-
await self.
|
118
|
+
await self.faiss_query.add_documents(
|
100
119
|
agent_id=self._agent_id,
|
101
120
|
documents=description,
|
102
121
|
extra_tags={
|
@@ -106,29 +125,30 @@ class StreamMemory:
|
|
106
125
|
"time": t,
|
107
126
|
},
|
108
127
|
)
|
109
|
-
|
128
|
+
|
110
129
|
return current_id
|
111
|
-
|
130
|
+
|
131
|
+
async def add_cognition(self, description: str) -> int:
|
112
132
|
"""添加认知记忆 Add cognition memory"""
|
113
133
|
return await self._add_memory(MemoryTag.COGNITION, description)
|
114
134
|
|
115
|
-
async def add_social(self, description: str) ->
|
135
|
+
async def add_social(self, description: str) -> int:
|
116
136
|
"""添加社交记忆 Add social memory"""
|
117
137
|
return await self._add_memory(MemoryTag.SOCIAL, description)
|
118
138
|
|
119
|
-
async def add_economy(self, description: str) ->
|
139
|
+
async def add_economy(self, description: str) -> int:
|
120
140
|
"""添加经济记忆 Add economy memory"""
|
121
141
|
return await self._add_memory(MemoryTag.ECONOMY, description)
|
122
142
|
|
123
|
-
async def add_mobility(self, description: str) ->
|
143
|
+
async def add_mobility(self, description: str) -> int:
|
124
144
|
"""添加移动记忆 Add mobility memory"""
|
125
145
|
return await self._add_memory(MemoryTag.MOBILITY, description)
|
126
146
|
|
127
|
-
async def add_event(self, description: str) ->
|
147
|
+
async def add_event(self, description: str) -> int:
|
128
148
|
"""添加事件记忆 Add event memory"""
|
129
149
|
return await self._add_memory(MemoryTag.EVENT, description)
|
130
150
|
|
131
|
-
async def add_other(self, description: str) ->
|
151
|
+
async def add_other(self, description: str) -> int:
|
132
152
|
"""添加其他记忆 Add other memory"""
|
133
153
|
return await self._add_memory(MemoryTag.OTHER, description)
|
134
154
|
|
@@ -137,11 +157,13 @@ class StreamMemory:
|
|
137
157
|
for memory in self._memories:
|
138
158
|
if memory.cognition_id == memory_id:
|
139
159
|
for cognition_memory in self._memories:
|
140
|
-
if (
|
141
|
-
|
160
|
+
if (
|
161
|
+
cognition_memory.tag == MemoryTag.COGNITION
|
162
|
+
and memory.cognition_id is not None
|
163
|
+
):
|
142
164
|
return cognition_memory
|
143
165
|
return None
|
144
|
-
|
166
|
+
|
145
167
|
async def format_memory(self, memories: list[MemoryNode]) -> str:
|
146
168
|
"""格式化记忆"""
|
147
169
|
formatted_results = []
|
@@ -150,51 +172,51 @@ class StreamMemory:
|
|
150
172
|
memory_day = memory.day
|
151
173
|
memory_time_seconds = memory.t
|
152
174
|
cognition_id = memory.cognition_id
|
153
|
-
|
175
|
+
|
154
176
|
# 格式化时间
|
155
|
-
if memory_time_seconds !=
|
177
|
+
if memory_time_seconds != "unknown":
|
156
178
|
hours = memory_time_seconds // 3600
|
157
179
|
minutes = (memory_time_seconds % 3600) // 60
|
158
180
|
seconds = memory_time_seconds % 60
|
159
181
|
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
160
182
|
else:
|
161
|
-
memory_time =
|
162
|
-
|
183
|
+
memory_time = "unknown"
|
184
|
+
|
163
185
|
memory_location = memory.location
|
164
|
-
|
186
|
+
|
165
187
|
# 添加认知信息(如果存在)
|
166
188
|
cognition_info = ""
|
167
189
|
if cognition_id is not None:
|
168
190
|
cognition_memory = await self.get_related_cognition(cognition_id)
|
169
191
|
if cognition_memory:
|
170
|
-
cognition_info =
|
171
|
-
|
192
|
+
cognition_info = (
|
193
|
+
f"\n Related cognition: {cognition_memory.description}"
|
194
|
+
)
|
195
|
+
|
172
196
|
formatted_results.append(
|
173
197
|
f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
|
174
198
|
f"location: {memory_location}]{cognition_info}"
|
175
199
|
)
|
176
200
|
return "\n".join(formatted_results)
|
177
201
|
|
178
|
-
async def get_by_ids(
|
202
|
+
async def get_by_ids(
|
203
|
+
self, memory_ids: Union[int, list[int]]
|
204
|
+
) -> Coroutine[Any, Any, str]:
|
179
205
|
"""获取指定ID的记忆"""
|
180
|
-
memories =
|
181
|
-
sorted_results = sorted(
|
182
|
-
memories,
|
183
|
-
key=lambda x: (x.day, x.t),
|
184
|
-
reverse=True
|
185
|
-
)
|
206
|
+
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
207
|
+
sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
|
186
208
|
return self.format_memory(sorted_results)
|
187
209
|
|
188
210
|
async def search(
|
189
|
-
self,
|
190
|
-
query: str,
|
211
|
+
self,
|
212
|
+
query: str,
|
191
213
|
tag: Optional[MemoryTag] = None,
|
192
214
|
top_k: int = 3,
|
193
215
|
day_range: Optional[tuple[int, int]] = None, # 新增参数
|
194
|
-
time_range: Optional[tuple[int, int]] = None # 新增参数
|
216
|
+
time_range: Optional[tuple[int, int]] = None, # 新增参数
|
195
217
|
) -> str:
|
196
218
|
"""Search stream memory
|
197
|
-
|
219
|
+
|
198
220
|
Args:
|
199
221
|
query: Query text
|
200
222
|
tag: Optional memory tag for filtering specific types of memories
|
@@ -205,60 +227,62 @@ class StreamMemory:
|
|
205
227
|
if not self._embedding_model or not self._faiss_query:
|
206
228
|
return "Search components not initialized"
|
207
229
|
|
208
|
-
filter_dict = {"type": "stream"}
|
209
|
-
|
230
|
+
filter_dict: dict[str, Any] = {"type": "stream"}
|
231
|
+
|
210
232
|
if tag:
|
211
233
|
filter_dict["tag"] = tag
|
212
|
-
|
234
|
+
|
213
235
|
# 添加时间范围过滤
|
214
236
|
if day_range:
|
215
237
|
start_day, end_day = day_range
|
216
238
|
filter_dict["day"] = lambda x: start_day <= x <= end_day
|
217
|
-
|
239
|
+
|
218
240
|
if time_range:
|
219
241
|
start_time, end_time = time_range
|
220
242
|
filter_dict["time"] = lambda x: start_time <= x <= end_time
|
221
243
|
|
222
|
-
top_results = await self.
|
244
|
+
top_results = await self.faiss_query.similarity_search(
|
223
245
|
query=query,
|
224
246
|
agent_id=self._agent_id,
|
225
247
|
k=top_k,
|
226
248
|
return_score_type="similarity_score",
|
227
|
-
filter=filter_dict
|
249
|
+
filter=filter_dict,
|
228
250
|
)
|
229
251
|
|
230
252
|
# 将结果按时间排序(先按天数,再按时间)
|
231
253
|
sorted_results = sorted(
|
232
|
-
top_results,
|
233
|
-
key=lambda x: (x[2].get(
|
234
|
-
reverse=True
|
254
|
+
top_results,
|
255
|
+
key=lambda x: (x[2].get("day", 0), x[2].get("time", 0)), # type:ignore
|
256
|
+
reverse=True,
|
235
257
|
)
|
236
|
-
|
258
|
+
|
237
259
|
formatted_results = []
|
238
|
-
for content, score, metadata in sorted_results:
|
239
|
-
memory_tag = metadata.get(
|
240
|
-
memory_day = metadata.get(
|
241
|
-
memory_time_seconds = metadata.get(
|
242
|
-
cognition_id = metadata.get(
|
243
|
-
|
260
|
+
for content, score, metadata in sorted_results: # type:ignore
|
261
|
+
memory_tag = metadata.get("tag", "unknown")
|
262
|
+
memory_day = metadata.get("day", "unknown")
|
263
|
+
memory_time_seconds = metadata.get("time", "unknown")
|
264
|
+
cognition_id = metadata.get("cognition_id", None)
|
265
|
+
|
244
266
|
# 格式化时间
|
245
|
-
if memory_time_seconds !=
|
267
|
+
if memory_time_seconds != "unknown":
|
246
268
|
hours = memory_time_seconds // 3600
|
247
269
|
minutes = (memory_time_seconds % 3600) // 60
|
248
270
|
seconds = memory_time_seconds % 60
|
249
271
|
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
250
272
|
else:
|
251
|
-
memory_time =
|
252
|
-
|
253
|
-
memory_location = metadata.get(
|
254
|
-
|
273
|
+
memory_time = "unknown"
|
274
|
+
|
275
|
+
memory_location = metadata.get("location", "unknown")
|
276
|
+
|
255
277
|
# 添加认知信息(如果存在)
|
256
278
|
cognition_info = ""
|
257
279
|
if cognition_id is not None:
|
258
280
|
cognition_memory = await self.get_related_cognition(cognition_id)
|
259
281
|
if cognition_memory:
|
260
|
-
cognition_info =
|
261
|
-
|
282
|
+
cognition_info = (
|
283
|
+
f"\n Related cognition: {cognition_memory.description}"
|
284
|
+
)
|
285
|
+
|
262
286
|
formatted_results.append(
|
263
287
|
f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
|
264
288
|
f"location: {memory_location}]{cognition_info}"
|
@@ -272,50 +296,49 @@ class StreamMemory:
|
|
272
296
|
top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
|
273
297
|
) -> str:
|
274
298
|
"""Search all memory events from today
|
275
|
-
|
299
|
+
|
276
300
|
Args:
|
277
301
|
query: Optional query text, returns all memories of the day if empty
|
278
302
|
tag: Optional memory tag for filtering specific types of memories
|
279
303
|
top_k: Number of most relevant memories to return, defaults to 100
|
280
|
-
|
304
|
+
|
281
305
|
Returns:
|
282
306
|
str: Formatted text of today's memories
|
283
307
|
"""
|
284
308
|
if self._simulator is None:
|
285
309
|
return "Simulator not initialized"
|
286
|
-
|
310
|
+
|
287
311
|
current_day = int(await self._simulator.get_simulator_day())
|
288
|
-
|
312
|
+
|
289
313
|
# 使用 search 方法,设置 day_range 为当天
|
290
314
|
return await self.search(
|
291
|
-
query=query,
|
292
|
-
tag=tag,
|
293
|
-
top_k=top_k,
|
294
|
-
day_range=(current_day, current_day)
|
315
|
+
query=query, tag=tag, top_k=top_k, day_range=(current_day, current_day)
|
295
316
|
)
|
296
317
|
|
297
|
-
async def add_cognition_to_memory(
|
318
|
+
async def add_cognition_to_memory(
|
319
|
+
self, memory_id: Union[int, list[int]], cognition: str
|
320
|
+
) -> None:
|
298
321
|
"""为已存在的记忆添加认知
|
299
|
-
|
322
|
+
|
300
323
|
Args:
|
301
324
|
memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
|
302
325
|
cognition: 认知描述
|
303
326
|
"""
|
304
327
|
# 将单个ID转换为列表以统一处理
|
305
328
|
memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
|
306
|
-
|
329
|
+
|
307
330
|
# 找到所有对应的记忆
|
308
331
|
target_memories = []
|
309
332
|
for memory in self._memories:
|
310
333
|
if id(memory) in memory_ids:
|
311
334
|
target_memories.append(memory)
|
312
|
-
|
335
|
+
|
313
336
|
if not target_memories:
|
314
337
|
raise ValueError(f"No memories found with ids {memory_ids}")
|
315
|
-
|
338
|
+
|
316
339
|
# 添加认知记忆
|
317
340
|
cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
|
318
|
-
|
341
|
+
|
319
342
|
# 更新所有原记忆的认知ID
|
320
343
|
for target_memory in target_memories:
|
321
344
|
target_memory.cognition_id = cognition_id
|
@@ -324,9 +347,13 @@ class StreamMemory:
|
|
324
347
|
"""获取所有流式信息"""
|
325
348
|
return list(self._memories)
|
326
349
|
|
350
|
+
|
327
351
|
class StatusMemory:
|
328
352
|
"""组合现有的三种记忆类型"""
|
329
|
-
|
353
|
+
|
354
|
+
def __init__(
|
355
|
+
self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory
|
356
|
+
):
|
330
357
|
self.profile = profile
|
331
358
|
self.state = state
|
332
359
|
self.dynamic = dynamic
|
@@ -340,23 +367,32 @@ class StatusMemory:
|
|
340
367
|
self.watchers = {} # 新增
|
341
368
|
self._lock = asyncio.Lock() # 新增
|
342
369
|
|
370
|
+
@property
|
371
|
+
def faiss_query(
|
372
|
+
self,
|
373
|
+
) -> FaissQuery:
|
374
|
+
assert self._faiss_query is not None
|
375
|
+
return self._faiss_query
|
376
|
+
|
343
377
|
def set_simulator(self, simulator):
|
344
378
|
self._simulator = simulator
|
345
379
|
|
346
380
|
async def initialize_embeddings(self) -> None:
|
347
381
|
"""初始化所有需要 embedding 的字段"""
|
348
382
|
if not self._embedding_model or not self._faiss_query:
|
349
|
-
logger.warning(
|
383
|
+
logger.warning(
|
384
|
+
"Search components not initialized, skipping embeddings initialization"
|
385
|
+
)
|
350
386
|
return
|
351
387
|
|
352
388
|
# 获取所有状态信息
|
353
389
|
profile, state, dynamic = await self.export()
|
354
|
-
|
390
|
+
|
355
391
|
# 为每个需要 embedding 的字段创建 embedding
|
356
392
|
for key, value in profile[0].items():
|
357
393
|
if self.should_embed(key):
|
358
394
|
semantic_text = self._generate_semantic_text(key, value)
|
359
|
-
doc_ids = await self.
|
395
|
+
doc_ids = await self.faiss_query.add_documents(
|
360
396
|
agent_id=self._agent_id,
|
361
397
|
documents=semantic_text,
|
362
398
|
extra_tags={
|
@@ -369,7 +405,7 @@ class StatusMemory:
|
|
369
405
|
for key, value in state[0].items():
|
370
406
|
if self.should_embed(key):
|
371
407
|
semantic_text = self._generate_semantic_text(key, value)
|
372
|
-
doc_ids = await self.
|
408
|
+
doc_ids = await self.faiss_query.add_documents(
|
373
409
|
agent_id=self._agent_id,
|
374
410
|
documents=semantic_text,
|
375
411
|
extra_tags={
|
@@ -378,11 +414,11 @@ class StatusMemory:
|
|
378
414
|
},
|
379
415
|
)
|
380
416
|
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
381
|
-
|
417
|
+
|
382
418
|
for key, value in dynamic[0].items():
|
383
419
|
if self.should_embed(key):
|
384
420
|
semantic_text = self._generate_semantic_text(key, value)
|
385
|
-
doc_ids = await self.
|
421
|
+
doc_ids = await self.faiss_query.add_documents(
|
386
422
|
agent_id=self._agent_id,
|
387
423
|
documents=semantic_text,
|
388
424
|
extra_tags={
|
@@ -415,7 +451,7 @@ class StatusMemory:
|
|
415
451
|
|
416
452
|
def set_semantic_templates(self, templates: Dict[str, str]):
|
417
453
|
"""设置语义模板
|
418
|
-
|
454
|
+
|
419
455
|
Args:
|
420
456
|
templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
|
421
457
|
"""
|
@@ -423,14 +459,14 @@ class StatusMemory:
|
|
423
459
|
|
424
460
|
def _generate_semantic_text(self, key: str, value: Any) -> str:
|
425
461
|
"""生成语义文本
|
426
|
-
|
462
|
+
|
427
463
|
如果key存在于模板中,使用自定义模板
|
428
464
|
否则使用默认模板 "my {key} is {value}"
|
429
465
|
"""
|
430
466
|
if key in self._semantic_templates:
|
431
467
|
return self._semantic_templates[key].format(value)
|
432
468
|
return f"Your {key} is {value}"
|
433
|
-
|
469
|
+
|
434
470
|
@lock_decorator
|
435
471
|
async def search(
|
436
472
|
self, query: str, top_k: int = 3, filter: Optional[dict] = None
|
@@ -447,12 +483,12 @@ class StatusMemory:
|
|
447
483
|
"""
|
448
484
|
if not self._embedding_model:
|
449
485
|
return "Embedding model not initialized"
|
450
|
-
|
486
|
+
|
451
487
|
filter_dict = {"type": "profile_state"}
|
452
488
|
if filter is not None:
|
453
489
|
filter_dict.update(filter)
|
454
490
|
top_results: list[tuple[str, float, dict]] = (
|
455
|
-
await self.
|
491
|
+
await self.faiss_query.similarity_search( # type:ignore
|
456
492
|
query=query,
|
457
493
|
agent_id=self._agent_id,
|
458
494
|
k=top_k,
|
@@ -463,9 +499,7 @@ class StatusMemory:
|
|
463
499
|
# 格式化输出
|
464
500
|
formatted_results = []
|
465
501
|
for content, score, metadata in top_results:
|
466
|
-
formatted_results.append(
|
467
|
-
f"- {content} "
|
468
|
-
)
|
502
|
+
formatted_results.append(f"- {content} ")
|
469
503
|
|
470
504
|
return "\n".join(formatted_results)
|
471
505
|
|
@@ -478,8 +512,11 @@ class StatusMemory:
|
|
478
512
|
return self._embedding_fields.get(key, False)
|
479
513
|
|
480
514
|
@lock_decorator
|
481
|
-
async def get(
|
482
|
-
|
515
|
+
async def get(
|
516
|
+
self,
|
517
|
+
key: Any,
|
518
|
+
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
519
|
+
) -> Any:
|
483
520
|
"""从记忆中获取值
|
484
521
|
|
485
522
|
Args:
|
@@ -499,7 +536,7 @@ class StatusMemory:
|
|
499
536
|
process_func = lambda x: x
|
500
537
|
else:
|
501
538
|
raise ValueError(f"Invalid get mode `{mode}`!")
|
502
|
-
|
539
|
+
|
503
540
|
for mem in [self.state, self.profile, self.dynamic]:
|
504
541
|
try:
|
505
542
|
value = await mem.get(key)
|
@@ -509,16 +546,20 @@ class StatusMemory:
|
|
509
546
|
raise KeyError(f"No attribute `{key}` in memories!")
|
510
547
|
|
511
548
|
@lock_decorator
|
512
|
-
async def update(
|
513
|
-
|
514
|
-
|
515
|
-
|
549
|
+
async def update(
|
550
|
+
self,
|
551
|
+
key: Any,
|
552
|
+
value: Any,
|
553
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
554
|
+
store_snapshot: bool = False,
|
555
|
+
protect_llm_read_only_fields: bool = True,
|
556
|
+
) -> None:
|
516
557
|
"""更新记忆值并在必要时更新embedding"""
|
517
558
|
if protect_llm_read_only_fields:
|
518
559
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
519
560
|
logger.warning(f"Trying to write protected key `{key}`!")
|
520
561
|
return
|
521
|
-
|
562
|
+
|
522
563
|
for mem in [self.state, self.profile, self.dynamic]:
|
523
564
|
try:
|
524
565
|
original_value = await mem.get(key)
|
@@ -526,16 +567,16 @@ class StatusMemory:
|
|
526
567
|
await mem.update(key, value, store_snapshot)
|
527
568
|
if self.should_embed(key) and self._embedding_model:
|
528
569
|
semantic_text = self._generate_semantic_text(key, value)
|
529
|
-
|
570
|
+
|
530
571
|
# 删除旧的 embedding
|
531
572
|
orig_doc_id = self._embedding_field_to_doc_id[key]
|
532
573
|
if orig_doc_id:
|
533
|
-
await self.
|
574
|
+
await self.faiss_query.delete_documents(
|
534
575
|
to_delete_ids=[orig_doc_id],
|
535
576
|
)
|
536
|
-
|
577
|
+
|
537
578
|
# 添加新的 embedding
|
538
|
-
doc_ids = await self.
|
579
|
+
doc_ids = await self.faiss_query.add_documents(
|
539
580
|
agent_id=self._agent_id,
|
540
581
|
documents=semantic_text,
|
541
582
|
extra_tags={
|
@@ -544,11 +585,11 @@ class StatusMemory:
|
|
544
585
|
},
|
545
586
|
)
|
546
587
|
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
547
|
-
|
588
|
+
|
548
589
|
if key in self.watchers:
|
549
590
|
for callback in self.watchers[key]:
|
550
591
|
asyncio.create_task(callback())
|
551
|
-
|
592
|
+
|
552
593
|
elif mode == "merge":
|
553
594
|
if isinstance(original_value, set):
|
554
595
|
original_value.update(set(value))
|
@@ -565,7 +606,7 @@ class StatusMemory:
|
|
565
606
|
await mem.update(key, value, store_snapshot)
|
566
607
|
if self.should_embed(key) and self._embedding_model:
|
567
608
|
semantic_text = self._generate_semantic_text(key, value)
|
568
|
-
doc_ids = await self.
|
609
|
+
doc_ids = await self.faiss_query.add_documents(
|
569
610
|
agent_id=self._agent_id,
|
570
611
|
documents=f"{key}: {str(original_value)}",
|
571
612
|
extra_tags={
|
@@ -635,6 +676,7 @@ class StatusMemory:
|
|
635
676
|
if _snapshot:
|
636
677
|
await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
|
637
678
|
|
679
|
+
|
638
680
|
class Memory:
|
639
681
|
"""
|
640
682
|
A class to manage different types of memory (state, profile, dynamic).
|
@@ -745,7 +787,6 @@ class Memory:
|
|
745
787
|
if k not in PROFILE_ATTRIBUTES:
|
746
788
|
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
747
789
|
continue
|
748
|
-
|
749
790
|
try:
|
750
791
|
# 处理配置元组格式
|
751
792
|
if isinstance(v, tuple):
|
@@ -787,7 +828,6 @@ class Memory:
|
|
787
828
|
self._profile = ProfileMemory(
|
788
829
|
msg=_profile_config, activate_timestamp=activate_timestamp
|
789
830
|
)
|
790
|
-
|
791
831
|
if base is not None:
|
792
832
|
for k, v in base.items():
|
793
833
|
if k not in STATE_ATTRIBUTES:
|
@@ -798,12 +838,10 @@ class Memory:
|
|
798
838
|
self._state = StateMemory(
|
799
839
|
msg=_state_config, activate_timestamp=activate_timestamp
|
800
840
|
)
|
801
|
-
|
841
|
+
|
802
842
|
# 组合 StatusMemory,并传递 embedding_fields 信息
|
803
843
|
self._status = StatusMemory(
|
804
|
-
profile=self._profile,
|
805
|
-
state=self._state,
|
806
|
-
dynamic=self._dynamic
|
844
|
+
profile=self._profile, state=self._state, dynamic=self._dynamic
|
807
845
|
)
|
808
846
|
self._status.set_embedding_fields(self._embedding_fields)
|
809
847
|
self._status.set_search_components(self._faiss_query, self._embedding_model)
|
@@ -839,7 +877,7 @@ class Memory:
|
|
839
877
|
@property
|
840
878
|
def status(self) -> StatusMemory:
|
841
879
|
return self._status
|
842
|
-
|
880
|
+
|
843
881
|
@property
|
844
882
|
def stream(self) -> StreamMemory:
|
845
883
|
return self._stream
|
@@ -872,7 +910,7 @@ class Memory:
|
|
872
910
|
f"FaissQuery access before assignment, please `set_faiss_query` first!"
|
873
911
|
)
|
874
912
|
return self._faiss_query
|
875
|
-
|
913
|
+
|
876
914
|
async def initialize_embeddings(self):
|
877
915
|
"""初始化embedding"""
|
878
916
|
await self._status.initialize_embeddings()
|
pycityagent/message/__init__.py
CHANGED
@@ -1,3 +1,10 @@
|
|
1
|
+
from .message_interceptor import (MessageBlockBase, MessageBlockListenerBase,
|
2
|
+
MessageInterceptor)
|
1
3
|
from .messager import Messager
|
2
4
|
|
3
|
-
__all__ = [
|
5
|
+
__all__ = [
|
6
|
+
"Messager",
|
7
|
+
"MessageBlockBase",
|
8
|
+
"MessageBlockListenerBase",
|
9
|
+
"MessageInterceptor",
|
10
|
+
]
|