pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
pycityagent/memory/memory.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
3
|
from collections import defaultdict
|
4
|
-
from collections.abc import Callable, Sequence
|
4
|
+
from collections.abc import Callable, Coroutine, Sequence
|
5
5
|
from copy import deepcopy
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Any, Dict, Literal, Optional, Union
|
8
9
|
|
9
|
-
import numpy as np
|
10
10
|
from langchain_core.embeddings import Embeddings
|
11
11
|
from pyparsing import deque
|
12
12
|
|
@@ -20,176 +20,496 @@ from .state import StateMemory
|
|
20
20
|
logger = logging.getLogger("pycityagent")
|
21
21
|
|
22
22
|
|
23
|
-
class
|
24
|
-
"""
|
25
|
-
A class to manage different types of memory (state, profile, dynamic).
|
23
|
+
class MemoryTag(str, Enum):
|
24
|
+
"""记忆标签枚举类"""
|
26
25
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
""
|
26
|
+
MOBILITY = "mobility"
|
27
|
+
SOCIAL = "social"
|
28
|
+
ECONOMY = "economy"
|
29
|
+
COGNITION = "cognition"
|
30
|
+
OTHER = "other"
|
31
|
+
EVENT = "event"
|
32
32
|
|
33
|
-
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class MemoryNode:
|
36
|
+
"""记忆节点"""
|
37
|
+
|
38
|
+
tag: MemoryTag
|
39
|
+
day: int
|
40
|
+
t: int
|
41
|
+
location: str
|
42
|
+
description: str
|
43
|
+
cognition_id: Optional[int] = None # 关联的认知记忆ID
|
44
|
+
id: Optional[int] = None # 记忆ID
|
45
|
+
|
46
|
+
|
47
|
+
class StreamMemory:
|
48
|
+
"""用于存储时序性的流式信息"""
|
49
|
+
|
50
|
+
def __init__(self, max_len: int = 1000):
|
51
|
+
self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
|
52
|
+
self._memory_id_counter: int = 0 # 用于生成唯一ID
|
53
|
+
self._faiss_query = None
|
54
|
+
self._embedding_model = None
|
55
|
+
self._agent_id = -1
|
56
|
+
self._status_memory = None
|
57
|
+
self._simulator = None
|
58
|
+
|
59
|
+
@property
|
60
|
+
def faiss_query(
|
34
61
|
self,
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
motion: Optional[dict[Any, Any]] = None,
|
39
|
-
activate_timestamp: bool = False,
|
40
|
-
embedding_model: Optional[Embeddings] = None,
|
41
|
-
faiss_query: Optional[FaissQuery] = None,
|
42
|
-
) -> None:
|
43
|
-
"""
|
44
|
-
Initializes the Memory with optional configuration.
|
62
|
+
) -> FaissQuery:
|
63
|
+
assert self._faiss_query is not None
|
64
|
+
return self._faiss_query
|
45
65
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
2. A callable that returns the default value when invoked (useful for complex default values).
|
53
|
-
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
54
|
-
Defaults to None.
|
55
|
-
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
56
|
-
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
57
|
-
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
58
|
-
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
59
|
-
embedding_model (Embeddings): The embedding model for memory search.
|
60
|
-
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
61
|
-
"""
|
62
|
-
self.watchers: dict[str, list[Callable]] = {}
|
63
|
-
self._lock = asyncio.Lock()
|
64
|
-
self._agent_id: int = -1
|
65
|
-
self._embedding_model = embedding_model
|
66
|
+
@property
|
67
|
+
def status_memory(
|
68
|
+
self,
|
69
|
+
):
|
70
|
+
assert self._status_memory is not None
|
71
|
+
return self._status_memory
|
66
72
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
self.
|
72
|
-
|
73
|
+
def set_simulator(self, simulator):
|
74
|
+
self._simulator = simulator
|
75
|
+
|
76
|
+
def set_status_memory(self, status_memory):
|
77
|
+
self._status_memory = status_memory
|
78
|
+
|
79
|
+
def set_search_components(self, faiss_query, embedding_model):
|
80
|
+
"""设置搜索所需的组件"""
|
73
81
|
self._faiss_query = faiss_query
|
82
|
+
self._embedding_model = embedding_model
|
74
83
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
# 处理新的三元组格式
|
79
|
-
if isinstance(v, tuple) and len(v) == 3:
|
80
|
-
_type, _value, enable_embedding = v
|
81
|
-
self._embedding_fields[k] = enable_embedding
|
82
|
-
else:
|
83
|
-
_type, _value = v
|
84
|
-
self._embedding_fields[k] = False
|
84
|
+
def set_agent_id(self, agent_id: int):
|
85
|
+
"""设置agent_id"""
|
86
|
+
self._agent_id = agent_id
|
85
87
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
88
|
+
async def _add_memory(self, tag: MemoryTag, description: str) -> int:
|
89
|
+
"""添加记忆节点的通用方法,返回记忆节点ID"""
|
90
|
+
if self._simulator is not None:
|
91
|
+
day = int(await self._simulator.get_simulator_day())
|
92
|
+
t = int(await self._simulator.get_time())
|
93
|
+
else:
|
94
|
+
day = 1
|
95
|
+
t = 1
|
96
|
+
position = await self.status_memory.get("position")
|
97
|
+
if "aoi_position" in position:
|
98
|
+
location = position["aoi_position"]["aoi_id"]
|
99
|
+
elif "lane_position" in position:
|
100
|
+
location = position["lane_position"]["lane_id"]
|
101
|
+
else:
|
102
|
+
location = "unknown"
|
103
|
+
|
104
|
+
current_id = self._memory_id_counter
|
105
|
+
self._memory_id_counter += 1
|
106
|
+
memory_node = MemoryNode(
|
107
|
+
tag=tag,
|
108
|
+
day=day,
|
109
|
+
t=t,
|
110
|
+
location=location,
|
111
|
+
description=description,
|
112
|
+
id=current_id,
|
113
|
+
)
|
114
|
+
self._memories.append(memory_node)
|
115
|
+
|
116
|
+
# 为新记忆创建 embedding
|
117
|
+
if self._embedding_model and self._faiss_query:
|
118
|
+
await self.faiss_query.add_documents(
|
119
|
+
agent_id=self._agent_id,
|
120
|
+
documents=description,
|
121
|
+
extra_tags={
|
122
|
+
"type": "stream",
|
123
|
+
"tag": tag,
|
124
|
+
"day": day,
|
125
|
+
"time": t,
|
126
|
+
},
|
127
|
+
)
|
104
128
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
129
|
+
return current_id
|
130
|
+
|
131
|
+
async def add_cognition(self, description: str) -> int:
|
132
|
+
"""添加认知记忆 Add cognition memory"""
|
133
|
+
return await self._add_memory(MemoryTag.COGNITION, description)
|
134
|
+
|
135
|
+
async def add_social(self, description: str) -> int:
|
136
|
+
"""添加社交记忆 Add social memory"""
|
137
|
+
return await self._add_memory(MemoryTag.SOCIAL, description)
|
138
|
+
|
139
|
+
async def add_economy(self, description: str) -> int:
|
140
|
+
"""添加经济记忆 Add economy memory"""
|
141
|
+
return await self._add_memory(MemoryTag.ECONOMY, description)
|
142
|
+
|
143
|
+
async def add_mobility(self, description: str) -> int:
|
144
|
+
"""添加移动记忆 Add mobility memory"""
|
145
|
+
return await self._add_memory(MemoryTag.MOBILITY, description)
|
146
|
+
|
147
|
+
async def add_event(self, description: str) -> int:
|
148
|
+
"""添加事件记忆 Add event memory"""
|
149
|
+
return await self._add_memory(MemoryTag.EVENT, description)
|
150
|
+
|
151
|
+
async def add_other(self, description: str) -> int:
|
152
|
+
"""添加其他记忆 Add other memory"""
|
153
|
+
return await self._add_memory(MemoryTag.OTHER, description)
|
154
|
+
|
155
|
+
async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
|
156
|
+
"""获取关联的认知记忆 Get related cognition memory"""
|
157
|
+
for memory in self._memories:
|
158
|
+
if memory.cognition_id == memory_id:
|
159
|
+
for cognition_memory in self._memories:
|
160
|
+
if (
|
161
|
+
cognition_memory.tag == MemoryTag.COGNITION
|
162
|
+
and memory.cognition_id is not None
|
163
|
+
):
|
164
|
+
return cognition_memory
|
165
|
+
return None
|
166
|
+
|
167
|
+
async def format_memory(self, memories: list[MemoryNode]) -> str:
|
168
|
+
"""格式化记忆"""
|
169
|
+
formatted_results = []
|
170
|
+
for memory in memories:
|
171
|
+
memory_tag = memory.tag
|
172
|
+
memory_day = memory.day
|
173
|
+
memory_time_seconds = memory.t
|
174
|
+
cognition_id = memory.cognition_id
|
175
|
+
|
176
|
+
# 格式化时间
|
177
|
+
if memory_time_seconds != "unknown":
|
178
|
+
hours = memory_time_seconds // 3600
|
179
|
+
minutes = (memory_time_seconds % 3600) // 60
|
180
|
+
seconds = memory_time_seconds % 60
|
181
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
182
|
+
else:
|
183
|
+
memory_time = "unknown"
|
184
|
+
|
185
|
+
memory_location = memory.location
|
186
|
+
|
187
|
+
# 添加认知信息(如果存在)
|
188
|
+
cognition_info = ""
|
189
|
+
if cognition_id is not None:
|
190
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
191
|
+
if cognition_memory:
|
192
|
+
cognition_info = (
|
193
|
+
f"\n Related cognition: {cognition_memory.description}"
|
194
|
+
)
|
112
195
|
|
113
|
-
|
196
|
+
formatted_results.append(
|
197
|
+
f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
|
198
|
+
f"location: {memory_location}]{cognition_info}"
|
199
|
+
)
|
200
|
+
return "\n".join(formatted_results)
|
114
201
|
|
115
|
-
|
116
|
-
self
|
117
|
-
|
118
|
-
|
202
|
+
async def get_by_ids(
|
203
|
+
self, memory_ids: Union[int, list[int]]
|
204
|
+
) -> Coroutine[Any, Any, str]:
|
205
|
+
"""获取指定ID的记忆"""
|
206
|
+
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
207
|
+
sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
|
208
|
+
return self.format_memory(sorted_results)
|
119
209
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
210
|
+
async def search(
|
211
|
+
self,
|
212
|
+
query: str,
|
213
|
+
tag: Optional[MemoryTag] = None,
|
214
|
+
top_k: int = 3,
|
215
|
+
day_range: Optional[tuple[int, int]] = None, # 新增参数
|
216
|
+
time_range: Optional[tuple[int, int]] = None, # 新增参数
|
217
|
+
) -> str:
|
218
|
+
"""Search stream memory
|
219
|
+
|
220
|
+
Args:
|
221
|
+
query: Query text
|
222
|
+
tag: Optional memory tag for filtering specific types of memories
|
223
|
+
top_k: Number of most relevant memories to return
|
224
|
+
day_range: Optional tuple of start and end days (start_day, end_day)
|
225
|
+
time_range: Optional tuple of start and end times (start_time, end_time)
|
226
|
+
"""
|
227
|
+
if not self._embedding_model or not self._faiss_query:
|
228
|
+
return "Search components not initialized"
|
229
|
+
|
230
|
+
filter_dict: dict[str, Any] = {"type": "stream"}
|
231
|
+
|
232
|
+
if tag:
|
233
|
+
filter_dict["tag"] = tag
|
234
|
+
|
235
|
+
# 添加时间范围过滤
|
236
|
+
if day_range:
|
237
|
+
start_day, end_day = day_range
|
238
|
+
filter_dict["day"] = lambda x: start_day <= x <= end_day
|
239
|
+
|
240
|
+
if time_range:
|
241
|
+
start_time, end_time = time_range
|
242
|
+
filter_dict["time"] = lambda x: start_time <= x <= end_time
|
243
|
+
|
244
|
+
top_results = await self.faiss_query.similarity_search(
|
245
|
+
query=query,
|
246
|
+
agent_id=self._agent_id,
|
247
|
+
k=top_k,
|
248
|
+
return_score_type="similarity_score",
|
249
|
+
filter=filter_dict,
|
140
250
|
)
|
141
|
-
|
142
|
-
|
251
|
+
|
252
|
+
# 将结果按时间排序(先按天数,再按时间)
|
253
|
+
sorted_results = sorted(
|
254
|
+
top_results,
|
255
|
+
key=lambda x: (x[2].get("day", 0), x[2].get("time", 0)), # type:ignore
|
256
|
+
reverse=True,
|
143
257
|
)
|
144
|
-
# self.memories = [] # 存储记忆内容
|
145
|
-
# self.embeddings = [] # 存储记忆的向量表示
|
146
258
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
259
|
+
formatted_results = []
|
260
|
+
for content, score, metadata in sorted_results: # type:ignore
|
261
|
+
memory_tag = metadata.get("tag", "unknown")
|
262
|
+
memory_day = metadata.get("day", "unknown")
|
263
|
+
memory_time_seconds = metadata.get("time", "unknown")
|
264
|
+
cognition_id = metadata.get("cognition_id", None)
|
265
|
+
|
266
|
+
# 格式化时间
|
267
|
+
if memory_time_seconds != "unknown":
|
268
|
+
hours = memory_time_seconds // 3600
|
269
|
+
minutes = (memory_time_seconds % 3600) // 60
|
270
|
+
seconds = memory_time_seconds % 60
|
271
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
272
|
+
else:
|
273
|
+
memory_time = "unknown"
|
274
|
+
|
275
|
+
memory_location = metadata.get("location", "unknown")
|
276
|
+
|
277
|
+
# 添加认知信息(如果存在)
|
278
|
+
cognition_info = ""
|
279
|
+
if cognition_id is not None:
|
280
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
281
|
+
if cognition_memory:
|
282
|
+
cognition_info = (
|
283
|
+
f"\n Related cognition: {cognition_memory.description}"
|
284
|
+
)
|
152
285
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
):
|
157
|
-
if self._embedding_model is None:
|
158
|
-
raise RuntimeError(
|
159
|
-
f"embedding_model before assignment, please `set_embedding_model` first!"
|
286
|
+
formatted_results.append(
|
287
|
+
f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
|
288
|
+
f"location: {memory_location}]{cognition_info}"
|
160
289
|
)
|
161
|
-
return
|
290
|
+
return "\n".join(formatted_results)
|
162
291
|
|
163
|
-
def
|
292
|
+
async def search_today(
|
293
|
+
self,
|
294
|
+
query: str = "", # 可选的查询文本
|
295
|
+
tag: Optional[MemoryTag] = None,
|
296
|
+
top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
|
297
|
+
) -> str:
|
298
|
+
"""Search all memory events from today
|
299
|
+
|
300
|
+
Args:
|
301
|
+
query: Optional query text, returns all memories of the day if empty
|
302
|
+
tag: Optional memory tag for filtering specific types of memories
|
303
|
+
top_k: Number of most relevant memories to return, defaults to 100
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
str: Formatted text of today's memories
|
164
307
|
"""
|
165
|
-
|
308
|
+
if self._simulator is None:
|
309
|
+
return "Simulator not initialized"
|
310
|
+
|
311
|
+
current_day = int(await self._simulator.get_simulator_day())
|
312
|
+
|
313
|
+
# 使用 search 方法,设置 day_range 为当天
|
314
|
+
return await self.search(
|
315
|
+
query=query, tag=tag, top_k=top_k, day_range=(current_day, current_day)
|
316
|
+
)
|
317
|
+
|
318
|
+
async def add_cognition_to_memory(
|
319
|
+
self, memory_id: Union[int, list[int]], cognition: str
|
320
|
+
) -> None:
|
321
|
+
"""为已存在的记忆添加认知
|
322
|
+
|
323
|
+
Args:
|
324
|
+
memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
|
325
|
+
cognition: 认知描述
|
166
326
|
"""
|
167
|
-
|
327
|
+
# 将单个ID转换为列表以统一处理
|
328
|
+
memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
|
329
|
+
|
330
|
+
# 找到所有对应的记忆
|
331
|
+
target_memories = []
|
332
|
+
for memory in self._memories:
|
333
|
+
if id(memory) in memory_ids:
|
334
|
+
target_memories.append(memory)
|
335
|
+
|
336
|
+
if not target_memories:
|
337
|
+
raise ValueError(f"No memories found with ids {memory_ids}")
|
338
|
+
|
339
|
+
# 添加认知记忆
|
340
|
+
cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
|
341
|
+
|
342
|
+
# 更新所有原记忆的认知ID
|
343
|
+
for target_memory in target_memories:
|
344
|
+
target_memory.cognition_id = cognition_id
|
345
|
+
|
346
|
+
async def get_all(self) -> list[MemoryNode]:
|
347
|
+
"""获取所有流式信息"""
|
348
|
+
return list(self._memories)
|
349
|
+
|
350
|
+
|
351
|
+
class StatusMemory:
|
352
|
+
"""组合现有的三种记忆类型"""
|
353
|
+
|
354
|
+
def __init__(
|
355
|
+
self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory
|
356
|
+
):
|
357
|
+
self.profile = profile
|
358
|
+
self.state = state
|
359
|
+
self.dynamic = dynamic
|
360
|
+
self._faiss_query = None
|
361
|
+
self._embedding_model = None
|
362
|
+
self._simulator = None
|
363
|
+
self._agent_id = -1
|
364
|
+
self._semantic_templates = {} # 用户可配置的模板
|
365
|
+
self._embedding_fields = {} # 需要 embedding 的字段
|
366
|
+
self._embedding_field_to_doc_id = defaultdict(str) # 新增
|
367
|
+
self.watchers = {} # 新增
|
368
|
+
self._lock = asyncio.Lock() # 新增
|
168
369
|
|
169
370
|
@property
|
170
|
-
def
|
371
|
+
def faiss_query(
|
171
372
|
self,
|
172
|
-
):
|
173
|
-
|
174
|
-
|
175
|
-
|
373
|
+
) -> FaissQuery:
|
374
|
+
assert self._faiss_query is not None
|
375
|
+
return self._faiss_query
|
376
|
+
|
377
|
+
def set_simulator(self, simulator):
|
378
|
+
self._simulator = simulator
|
379
|
+
|
380
|
+
async def initialize_embeddings(self) -> None:
|
381
|
+
"""初始化所有需要 embedding 的字段"""
|
382
|
+
if not self._embedding_model or not self._faiss_query:
|
383
|
+
logger.warning(
|
384
|
+
"Search components not initialized, skipping embeddings initialization"
|
176
385
|
)
|
177
|
-
|
386
|
+
return
|
387
|
+
|
388
|
+
# 获取所有状态信息
|
389
|
+
profile, state, dynamic = await self.export()
|
390
|
+
|
391
|
+
# 为每个需要 embedding 的字段创建 embedding
|
392
|
+
for key, value in profile[0].items():
|
393
|
+
if self.should_embed(key):
|
394
|
+
semantic_text = self._generate_semantic_text(key, value)
|
395
|
+
doc_ids = await self.faiss_query.add_documents(
|
396
|
+
agent_id=self._agent_id,
|
397
|
+
documents=semantic_text,
|
398
|
+
extra_tags={
|
399
|
+
"type": "profile_state",
|
400
|
+
"key": key,
|
401
|
+
},
|
402
|
+
)
|
403
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
404
|
+
|
405
|
+
for key, value in state[0].items():
|
406
|
+
if self.should_embed(key):
|
407
|
+
semantic_text = self._generate_semantic_text(key, value)
|
408
|
+
doc_ids = await self.faiss_query.add_documents(
|
409
|
+
agent_id=self._agent_id,
|
410
|
+
documents=semantic_text,
|
411
|
+
extra_tags={
|
412
|
+
"type": "profile_state",
|
413
|
+
"key": key,
|
414
|
+
},
|
415
|
+
)
|
416
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
417
|
+
|
418
|
+
for key, value in dynamic[0].items():
|
419
|
+
if self.should_embed(key):
|
420
|
+
semantic_text = self._generate_semantic_text(key, value)
|
421
|
+
doc_ids = await self.faiss_query.add_documents(
|
422
|
+
agent_id=self._agent_id,
|
423
|
+
documents=semantic_text,
|
424
|
+
extra_tags={
|
425
|
+
"type": "profile_state",
|
426
|
+
"key": key,
|
427
|
+
},
|
428
|
+
)
|
429
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
430
|
+
|
431
|
+
def _get_memory_type_by_key(self, key: str) -> str:
|
432
|
+
"""根据键名确定记忆类型"""
|
433
|
+
try:
|
434
|
+
if key in self.profile.__dict__:
|
435
|
+
return "profile"
|
436
|
+
elif key in self.state.__dict__:
|
437
|
+
return "state"
|
438
|
+
else:
|
439
|
+
return "dynamic"
|
440
|
+
except:
|
441
|
+
return "dynamic"
|
442
|
+
|
443
|
+
def set_search_components(self, faiss_query, embedding_model):
|
444
|
+
"""设置搜索所需的组件"""
|
445
|
+
self._faiss_query = faiss_query
|
446
|
+
self._embedding_model = embedding_model
|
178
447
|
|
179
448
|
def set_agent_id(self, agent_id: int):
|
449
|
+
"""设置agent_id"""
|
450
|
+
self._agent_id = agent_id
|
451
|
+
|
452
|
+
def set_semantic_templates(self, templates: Dict[str, str]):
|
453
|
+
"""设置语义模板
|
454
|
+
|
455
|
+
Args:
|
456
|
+
templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
|
180
457
|
"""
|
181
|
-
|
458
|
+
self._semantic_templates = templates
|
459
|
+
|
460
|
+
def _generate_semantic_text(self, key: str, value: Any) -> str:
|
461
|
+
"""生成语义文本
|
462
|
+
|
463
|
+
如果key存在于模板中,使用自定义模板
|
464
|
+
否则使用默认模板 "my {key} is {value}"
|
182
465
|
"""
|
183
|
-
self.
|
466
|
+
if key in self._semantic_templates:
|
467
|
+
return self._semantic_templates[key].format(value)
|
468
|
+
return f"Your {key} is {value}"
|
184
469
|
|
185
|
-
@
|
186
|
-
def
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
470
|
+
@lock_decorator
|
471
|
+
async def search(
|
472
|
+
self, query: str, top_k: int = 3, filter: Optional[dict] = None
|
473
|
+
) -> str:
|
474
|
+
"""搜索相关记忆
|
475
|
+
|
476
|
+
Args:
|
477
|
+
query: 查询文本
|
478
|
+
top_k: 返回最相关的记忆数量
|
479
|
+
filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
str: 格式化的相关记忆文本
|
483
|
+
"""
|
484
|
+
if not self._embedding_model:
|
485
|
+
return "Embedding model not initialized"
|
486
|
+
|
487
|
+
filter_dict = {"type": "profile_state"}
|
488
|
+
if filter is not None:
|
489
|
+
filter_dict.update(filter)
|
490
|
+
top_results: list[tuple[str, float, dict]] = (
|
491
|
+
await self.faiss_query.similarity_search( # type:ignore
|
492
|
+
query=query,
|
493
|
+
agent_id=self._agent_id,
|
494
|
+
k=top_k,
|
495
|
+
return_score_type="similarity_score",
|
496
|
+
filter=filter_dict,
|
191
497
|
)
|
192
|
-
|
498
|
+
)
|
499
|
+
# 格式化输出
|
500
|
+
formatted_results = []
|
501
|
+
for content, score, metadata in top_results:
|
502
|
+
formatted_results.append(f"- {content} ")
|
503
|
+
|
504
|
+
return "\n".join(formatted_results)
|
505
|
+
|
506
|
+
def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
|
507
|
+
"""设置需要 embedding 的字段"""
|
508
|
+
self._embedding_fields = embedding_fields
|
509
|
+
|
510
|
+
def should_embed(self, key: str) -> bool:
|
511
|
+
"""判断字段是否需要 embedding"""
|
512
|
+
return self._embedding_fields.get(key, False)
|
193
513
|
|
194
514
|
@lock_decorator
|
195
515
|
async def get(
|
@@ -197,19 +517,18 @@ class Memory:
|
|
197
517
|
key: Any,
|
198
518
|
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
199
519
|
) -> Any:
|
200
|
-
"""
|
201
|
-
Retrieves a value from memory based on the given key and access mode.
|
520
|
+
"""从记忆中获取值
|
202
521
|
|
203
522
|
Args:
|
204
|
-
key
|
205
|
-
mode
|
523
|
+
key: 要获取的键
|
524
|
+
mode: 访问模式,"read only" 或 "read and write"
|
206
525
|
|
207
526
|
Returns:
|
208
|
-
|
527
|
+
获取到的值
|
209
528
|
|
210
529
|
Raises:
|
211
|
-
ValueError:
|
212
|
-
KeyError:
|
530
|
+
ValueError: 如果提供了无效的模式
|
531
|
+
KeyError: 如果在任何记忆部分都找不到该键
|
213
532
|
"""
|
214
533
|
if mode == "read only":
|
215
534
|
process_func = deepcopy
|
@@ -217,11 +536,12 @@ class Memory:
|
|
217
536
|
process_func = lambda x: x
|
218
537
|
else:
|
219
538
|
raise ValueError(f"Invalid get mode `{mode}`!")
|
220
|
-
|
539
|
+
|
540
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
221
541
|
try:
|
222
|
-
value = await
|
542
|
+
value = await mem.get(key)
|
223
543
|
return process_func(value)
|
224
|
-
except KeyError
|
544
|
+
except KeyError:
|
225
545
|
continue
|
226
546
|
raise KeyError(f"No attribute `{key}` in memories!")
|
227
547
|
|
@@ -239,32 +559,37 @@ class Memory:
|
|
239
559
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
240
560
|
logger.warning(f"Trying to write protected key `{key}`!")
|
241
561
|
return
|
242
|
-
|
562
|
+
|
563
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
243
564
|
try:
|
244
|
-
original_value = await
|
565
|
+
original_value = await mem.get(key)
|
245
566
|
if mode == "replace":
|
246
|
-
await
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
#
|
567
|
+
await mem.update(key, value, store_snapshot)
|
568
|
+
if self.should_embed(key) and self._embedding_model:
|
569
|
+
semantic_text = self._generate_semantic_text(key, value)
|
570
|
+
|
571
|
+
# 删除旧的 embedding
|
251
572
|
orig_doc_id = self._embedding_field_to_doc_id[key]
|
252
573
|
if orig_doc_id:
|
253
574
|
await self.faiss_query.delete_documents(
|
254
575
|
to_delete_ids=[orig_doc_id],
|
255
576
|
)
|
256
|
-
|
257
|
-
|
258
|
-
|
577
|
+
|
578
|
+
# 添加新的 embedding
|
579
|
+
doc_ids = await self.faiss_query.add_documents(
|
580
|
+
agent_id=self._agent_id,
|
581
|
+
documents=semantic_text,
|
259
582
|
extra_tags={
|
260
|
-
"type":
|
583
|
+
"type": self._get_memory_type(mem),
|
261
584
|
"key": key,
|
262
585
|
},
|
263
586
|
)
|
264
587
|
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
588
|
+
|
265
589
|
if key in self.watchers:
|
266
590
|
for callback in self.watchers[key]:
|
267
591
|
asyncio.create_task(callback())
|
592
|
+
|
268
593
|
elif mode == "merge":
|
269
594
|
if isinstance(original_value, set):
|
270
595
|
original_value.update(set(value))
|
@@ -278,14 +603,14 @@ class Memory:
|
|
278
603
|
logger.debug(
|
279
604
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
280
605
|
)
|
281
|
-
await
|
282
|
-
if self.
|
283
|
-
|
606
|
+
await mem.update(key, value, store_snapshot)
|
607
|
+
if self.should_embed(key) and self._embedding_model:
|
608
|
+
semantic_text = self._generate_semantic_text(key, value)
|
284
609
|
doc_ids = await self.faiss_query.add_documents(
|
285
|
-
agent_id=self.
|
610
|
+
agent_id=self._agent_id,
|
286
611
|
documents=f"{key}: {str(original_value)}",
|
287
612
|
extra_tags={
|
288
|
-
"type":
|
613
|
+
"type": self._get_memory_type(mem),
|
289
614
|
"key": key,
|
290
615
|
},
|
291
616
|
)
|
@@ -302,63 +627,20 @@ class Memory:
|
|
302
627
|
|
303
628
|
def _get_memory_type(self, mem: Any) -> str:
|
304
629
|
"""获取记忆类型"""
|
305
|
-
if mem is self.
|
630
|
+
if mem is self.state:
|
306
631
|
return "state"
|
307
|
-
elif mem is self.
|
632
|
+
elif mem is self.profile:
|
308
633
|
return "profile"
|
309
634
|
else:
|
310
635
|
return "dynamic"
|
311
636
|
|
312
|
-
async def update_batch(
|
313
|
-
self,
|
314
|
-
content: Union[dict, Sequence[tuple[Any, Any]]],
|
315
|
-
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
316
|
-
store_snapshot: bool = False,
|
317
|
-
protect_llm_read_only_fields: bool = True,
|
318
|
-
) -> None:
|
319
|
-
"""
|
320
|
-
Updates multiple values in the memory at once.
|
321
|
-
|
322
|
-
Args:
|
323
|
-
content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
|
324
|
-
mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
325
|
-
store_snapshot (bool): Whether to store a snapshot of the memory after the update.
|
326
|
-
protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
|
327
|
-
|
328
|
-
Raises:
|
329
|
-
TypeError: If the content type is neither a dictionary nor a sequence of tuples.
|
330
|
-
"""
|
331
|
-
if isinstance(content, dict):
|
332
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
|
333
|
-
elif isinstance(content, Sequence):
|
334
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
|
335
|
-
else:
|
336
|
-
raise TypeError(f"Invalid content type `{type(content)}`!")
|
337
|
-
for k, v in _list_content[:1]:
|
338
|
-
await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
|
339
|
-
for k, v in _list_content[1:]:
|
340
|
-
await self.update(k, v, mode, False, protect_llm_read_only_fields)
|
341
|
-
|
342
637
|
@lock_decorator
|
343
638
|
async def add_watcher(self, key: str, callback: Callable) -> None:
|
344
|
-
"""
|
345
|
-
Adds a callback function to be invoked when the value
|
346
|
-
associated with the specified key in memory is updated.
|
347
|
-
|
348
|
-
Args:
|
349
|
-
key (str): The key for which the watcher is being registered.
|
350
|
-
callback (Callable): A callable function that will be executed
|
351
|
-
whenever the value associated with the specified key is updated.
|
352
|
-
|
353
|
-
Notes:
|
354
|
-
If the key does not already have any watchers, it will be
|
355
|
-
initialized with an empty list before appending the callback.
|
356
|
-
"""
|
639
|
+
"""添加值变更的监听器"""
|
357
640
|
if key not in self.watchers:
|
358
641
|
self.watchers[key] = []
|
359
642
|
self.watchers[key].append(callback)
|
360
643
|
|
361
|
-
@lock_decorator
|
362
644
|
async def export(
|
363
645
|
self,
|
364
646
|
) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
|
@@ -369,12 +651,11 @@ class Memory:
|
|
369
651
|
tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
|
370
652
|
"""
|
371
653
|
return (
|
372
|
-
await self.
|
373
|
-
await self.
|
374
|
-
await self.
|
654
|
+
await self.profile.export(),
|
655
|
+
await self.state.export(),
|
656
|
+
await self.dynamic.export(),
|
375
657
|
)
|
376
658
|
|
377
|
-
@lock_decorator
|
378
659
|
async def load(
|
379
660
|
self,
|
380
661
|
snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
|
@@ -390,41 +671,246 @@ class Memory:
|
|
390
671
|
_profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
|
391
672
|
for _snapshot, _mem in zip(
|
392
673
|
[_profile_snapshot, _state_snapshot, _dynamic_snapshot],
|
393
|
-
[self.
|
674
|
+
[self.state, self.profile, self.dynamic],
|
394
675
|
):
|
395
676
|
if _snapshot:
|
396
677
|
await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
|
397
678
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
679
|
+
|
680
|
+
class Memory:
|
681
|
+
"""
|
682
|
+
A class to manage different types of memory (state, profile, dynamic).
|
683
|
+
|
684
|
+
Attributes:
|
685
|
+
_state (StateMemory): Stores state-related data.
|
686
|
+
_profile (ProfileMemory): Stores profile-related data.
|
687
|
+
_dynamic (DynamicMemory): Stores dynamically configured data.
|
688
|
+
"""
|
689
|
+
|
690
|
+
def __init__(
|
691
|
+
self,
|
692
|
+
config: Optional[dict[Any, Any]] = None,
|
693
|
+
profile: Optional[dict[Any, Any]] = None,
|
694
|
+
base: Optional[dict[Any, Any]] = None,
|
695
|
+
activate_timestamp: bool = False,
|
696
|
+
embedding_model: Optional[Embeddings] = None,
|
697
|
+
faiss_query: Optional[FaissQuery] = None,
|
698
|
+
) -> None:
|
699
|
+
"""
|
700
|
+
Initializes the Memory with optional configuration.
|
403
701
|
|
404
702
|
Args:
|
405
|
-
|
406
|
-
|
407
|
-
|
703
|
+
config (Optional[dict[Any, Any]], optional):
|
704
|
+
A configuration dictionary for dynamic memory. The dictionary format is:
|
705
|
+
- Key: The name of the dynamic memory field.
|
706
|
+
- Value: Can be one of two formats:
|
707
|
+
1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
|
708
|
+
2. A callable that returns the default value when invoked (useful for complex default values).
|
709
|
+
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
710
|
+
Defaults to None.
|
711
|
+
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
712
|
+
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
713
|
+
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
714
|
+
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
715
|
+
embedding_model (Embeddings): The embedding model for memory search.
|
716
|
+
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
717
|
+
"""
|
718
|
+
self.watchers: dict[str, list[Callable]] = {}
|
719
|
+
self._lock = asyncio.Lock()
|
720
|
+
self._agent_id: int = -1
|
721
|
+
self._simulator = None
|
722
|
+
self._embedding_model = embedding_model
|
723
|
+
self._faiss_query = faiss_query
|
724
|
+
self._semantic_templates: dict[str, str] = {}
|
725
|
+
_dynamic_config: dict[Any, Any] = {}
|
726
|
+
_state_config: dict[Any, Any] = {}
|
727
|
+
_profile_config: dict[Any, Any] = {}
|
728
|
+
self._embedding_fields: dict[str, bool] = {}
|
729
|
+
self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
|
408
730
|
|
409
|
-
|
410
|
-
|
731
|
+
if config is not None:
|
732
|
+
for k, v in config.items():
|
733
|
+
try:
|
734
|
+
# 处理不同长度的配置元组
|
735
|
+
if isinstance(v, tuple):
|
736
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
737
|
+
_type, _value, enable_embedding, template = v
|
738
|
+
self._embedding_fields[k] = enable_embedding
|
739
|
+
self._semantic_templates[k] = template
|
740
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
741
|
+
_type, _value, enable_embedding = v
|
742
|
+
self._embedding_fields[k] = enable_embedding
|
743
|
+
else: # (_type, _value)
|
744
|
+
_type, _value = v
|
745
|
+
self._embedding_fields[k] = False
|
746
|
+
else:
|
747
|
+
_type = type(v)
|
748
|
+
_value = v
|
749
|
+
self._embedding_fields[k] = False
|
750
|
+
|
751
|
+
# 处理类型转换
|
752
|
+
try:
|
753
|
+
if isinstance(_type, type):
|
754
|
+
_value = _type(_value)
|
755
|
+
else:
|
756
|
+
if isinstance(_type, deque):
|
757
|
+
_type.extend(_value)
|
758
|
+
_value = deepcopy(_type)
|
759
|
+
else:
|
760
|
+
logger.warning(f"type `{_type}` is not supported!")
|
761
|
+
except TypeError as e:
|
762
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
763
|
+
except TypeError as e:
|
764
|
+
if isinstance(v, type):
|
765
|
+
_value = v()
|
766
|
+
else:
|
767
|
+
_value = v
|
768
|
+
self._embedding_fields[k] = False
|
769
|
+
|
770
|
+
if (
|
771
|
+
k in PROFILE_ATTRIBUTES
|
772
|
+
or k in STATE_ATTRIBUTES
|
773
|
+
or k == TIME_STAMP_KEY
|
774
|
+
):
|
775
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
776
|
+
continue
|
777
|
+
|
778
|
+
_dynamic_config[k] = deepcopy(_value)
|
779
|
+
|
780
|
+
# 初始化各类记忆
|
781
|
+
self._dynamic = DynamicMemory(
|
782
|
+
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
783
|
+
)
|
784
|
+
|
785
|
+
if profile is not None:
|
786
|
+
for k, v in profile.items():
|
787
|
+
if k not in PROFILE_ATTRIBUTES:
|
788
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
789
|
+
continue
|
790
|
+
try:
|
791
|
+
# 处理配置元组格式
|
792
|
+
if isinstance(v, tuple):
|
793
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
794
|
+
_type, _value, enable_embedding, template = v
|
795
|
+
self._embedding_fields[k] = enable_embedding
|
796
|
+
self._semantic_templates[k] = template
|
797
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
798
|
+
_type, _value, enable_embedding = v
|
799
|
+
self._embedding_fields[k] = enable_embedding
|
800
|
+
else: # (_type, _value)
|
801
|
+
_type, _value = v
|
802
|
+
self._embedding_fields[k] = False
|
803
|
+
|
804
|
+
# 处理类型转换
|
805
|
+
try:
|
806
|
+
if isinstance(_type, type):
|
807
|
+
_value = _type(_value)
|
808
|
+
else:
|
809
|
+
if isinstance(_type, deque):
|
810
|
+
_type.extend(_value)
|
811
|
+
_value = deepcopy(_type)
|
812
|
+
else:
|
813
|
+
logger.warning(f"type `{_type}` is not supported!")
|
814
|
+
except TypeError as e:
|
815
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
816
|
+
else:
|
817
|
+
# 保持对简单键值对的兼容
|
818
|
+
_value = v
|
819
|
+
self._embedding_fields[k] = False
|
820
|
+
except TypeError as e:
|
821
|
+
if isinstance(v, type):
|
822
|
+
_value = v()
|
823
|
+
else:
|
824
|
+
_value = v
|
825
|
+
self._embedding_fields[k] = False
|
826
|
+
|
827
|
+
_profile_config[k] = deepcopy(_value)
|
828
|
+
self._profile = ProfileMemory(
|
829
|
+
msg=_profile_config, activate_timestamp=activate_timestamp
|
830
|
+
)
|
831
|
+
if base is not None:
|
832
|
+
for k, v in base.items():
|
833
|
+
if k not in STATE_ATTRIBUTES:
|
834
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
835
|
+
continue
|
836
|
+
_state_config[k] = v
|
837
|
+
|
838
|
+
self._state = StateMemory(
|
839
|
+
msg=_state_config, activate_timestamp=activate_timestamp
|
840
|
+
)
|
841
|
+
|
842
|
+
# 组合 StatusMemory,并传递 embedding_fields 信息
|
843
|
+
self._status = StatusMemory(
|
844
|
+
profile=self._profile, state=self._state, dynamic=self._dynamic
|
845
|
+
)
|
846
|
+
self._status.set_embedding_fields(self._embedding_fields)
|
847
|
+
self._status.set_search_components(self._faiss_query, self._embedding_model)
|
848
|
+
|
849
|
+
# 新增 StreamMemory
|
850
|
+
self._stream = StreamMemory()
|
851
|
+
self._stream.set_status_memory(self._status)
|
852
|
+
self._stream.set_search_components(self._faiss_query, self._embedding_model)
|
853
|
+
|
854
|
+
def set_search_components(
|
855
|
+
self,
|
856
|
+
faiss_query: FaissQuery,
|
857
|
+
embedding_model: Embeddings,
|
858
|
+
):
|
859
|
+
self._embedding_model = embedding_model
|
860
|
+
self._faiss_query = faiss_query
|
861
|
+
self._stream.set_search_components(faiss_query, embedding_model)
|
862
|
+
self._status.set_search_components(faiss_query, embedding_model)
|
863
|
+
|
864
|
+
def set_agent_id(self, agent_id: int):
|
411
865
|
"""
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
866
|
+
Set the FaissQuery of the agent.
|
867
|
+
"""
|
868
|
+
self._agent_id = agent_id
|
869
|
+
self._stream.set_agent_id(agent_id)
|
870
|
+
self._status.set_agent_id(agent_id)
|
871
|
+
|
872
|
+
def set_simulator(self, simulator):
|
873
|
+
self._simulator = simulator
|
874
|
+
self._stream.set_simulator(simulator)
|
875
|
+
self._status.set_simulator(simulator)
|
876
|
+
|
877
|
+
@property
|
878
|
+
def status(self) -> StatusMemory:
|
879
|
+
return self._status
|
880
|
+
|
881
|
+
@property
|
882
|
+
def stream(self) -> StreamMemory:
|
883
|
+
return self._stream
|
884
|
+
|
885
|
+
@property
|
886
|
+
def embedding_model(
|
887
|
+
self,
|
888
|
+
):
|
889
|
+
if self._embedding_model is None:
|
890
|
+
raise RuntimeError(
|
891
|
+
f"embedding_model before assignment, please `set_embedding_model` first!"
|
421
892
|
)
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
893
|
+
return self._embedding_model
|
894
|
+
|
895
|
+
@property
|
896
|
+
def agent_id(
|
897
|
+
self,
|
898
|
+
):
|
899
|
+
if self._agent_id < 0:
|
900
|
+
raise RuntimeError(
|
901
|
+
f"agent_id before assignment, please `set_agent_id` first!"
|
428
902
|
)
|
903
|
+
return self._agent_id
|
429
904
|
|
430
|
-
|
905
|
+
@property
|
906
|
+
def faiss_query(self) -> FaissQuery:
|
907
|
+
"""FaissQuery"""
|
908
|
+
if self._faiss_query is None:
|
909
|
+
raise RuntimeError(
|
910
|
+
f"FaissQuery access before assignment, please `set_faiss_query` first!"
|
911
|
+
)
|
912
|
+
return self._faiss_query
|
913
|
+
|
914
|
+
async def initialize_embeddings(self):
|
915
|
+
"""初始化embedding"""
|
916
|
+
await self._status.initialize_embeddings()
|