pycityagent 2.0.0a52__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a53__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +48 -62
- pycityagent/agent/agent_base.py +66 -53
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +44 -56
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +204 -119
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/simulator.py +17 -12
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +720 -272
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +92 -99
- pycityagent/simulation/simulation.py +115 -40
- pycityagent/tools/tool.py +7 -9
- pycityagent/workflow/block.py +11 -4
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
pycityagent/memory/memory.py
CHANGED
@@ -3,10 +3,10 @@ import logging
|
|
3
3
|
from collections import defaultdict
|
4
4
|
from collections.abc import Callable, Sequence
|
5
5
|
from copy import deepcopy
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from typing import Any, Literal, Optional, Union, Dict
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from enum import Enum
|
8
9
|
|
9
|
-
import numpy as np
|
10
10
|
from langchain_core.embeddings import Embeddings
|
11
11
|
from pyparsing import deque
|
12
12
|
|
@@ -19,197 +19,479 @@ from .state import StateMemory
|
|
19
19
|
|
20
20
|
logger = logging.getLogger("pycityagent")
|
21
21
|
|
22
|
+
class MemoryTag(str, Enum):
|
23
|
+
"""记忆标签枚举类"""
|
24
|
+
MOBILITY = "mobility"
|
25
|
+
SOCIAL = "social"
|
26
|
+
ECONOMY = "economy"
|
27
|
+
COGNITION = "cognition"
|
28
|
+
OTHER = "other"
|
29
|
+
EVENT = "event"
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class MemoryNode:
|
33
|
+
"""记忆节点"""
|
34
|
+
tag: MemoryTag
|
35
|
+
day: int
|
36
|
+
t: int
|
37
|
+
location: str
|
38
|
+
description: str
|
39
|
+
cognition_id: Optional[int] = None # 关联的认知记忆ID
|
40
|
+
id: Optional[int] = None # 记忆ID
|
41
|
+
|
42
|
+
class StreamMemory:
|
43
|
+
"""用于存储时序性的流式信息"""
|
44
|
+
def __init__(self, max_len: int = 1000):
|
45
|
+
self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
|
46
|
+
self._memory_id_counter: int = 0 # 用于生成唯一ID
|
47
|
+
self._faiss_query = None
|
48
|
+
self._embedding_model = None
|
49
|
+
self._agent_id = -1
|
50
|
+
self._status_memory = None
|
51
|
+
self._simulator = None
|
52
|
+
|
53
|
+
def set_simulator(self, simulator):
|
54
|
+
self._simulator = simulator
|
55
|
+
|
56
|
+
def set_status_memory(self, status_memory):
|
57
|
+
self._status_memory = status_memory
|
58
|
+
|
59
|
+
def set_search_components(self, faiss_query, embedding_model):
|
60
|
+
"""设置搜索所需的组件"""
|
61
|
+
self._faiss_query = faiss_query
|
62
|
+
self._embedding_model = embedding_model
|
22
63
|
|
23
|
-
|
24
|
-
|
25
|
-
|
64
|
+
def set_agent_id(self, agent_id: int):
|
65
|
+
"""设置agent_id"""
|
66
|
+
self._agent_id = agent_id
|
26
67
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
68
|
+
async def _add_memory(self, tag: MemoryTag, description: str) -> int:
|
69
|
+
"""添加记忆节点的通用方法,返回记忆节点ID"""
|
70
|
+
if self._simulator is not None:
|
71
|
+
day = int(await self._simulator.get_simulator_day())
|
72
|
+
t = int(await self._simulator.get_time())
|
73
|
+
else:
|
74
|
+
day = 1
|
75
|
+
t = 1
|
76
|
+
position = await self._status_memory.get("position")
|
77
|
+
if 'aoi_position' in position:
|
78
|
+
location = position['aoi_position']['aoi_id']
|
79
|
+
elif 'lane_position' in position:
|
80
|
+
location = position['lane_position']['lane_id']
|
81
|
+
else:
|
82
|
+
location = "unknown"
|
83
|
+
|
84
|
+
current_id = self._memory_id_counter
|
85
|
+
self._memory_id_counter += 1
|
86
|
+
memory_node = MemoryNode(
|
87
|
+
tag=tag,
|
88
|
+
day=day,
|
89
|
+
t=t,
|
90
|
+
location=location,
|
91
|
+
description=description,
|
92
|
+
id=current_id,
|
93
|
+
)
|
94
|
+
self._memories.append(memory_node)
|
95
|
+
|
96
|
+
|
97
|
+
# 为新记忆创建 embedding
|
98
|
+
if self._embedding_model and self._faiss_query:
|
99
|
+
await self._faiss_query.add_documents(
|
100
|
+
agent_id=self._agent_id,
|
101
|
+
documents=description,
|
102
|
+
extra_tags={
|
103
|
+
"type": "stream",
|
104
|
+
"tag": tag,
|
105
|
+
"day": day,
|
106
|
+
"time": t,
|
107
|
+
},
|
108
|
+
)
|
109
|
+
|
110
|
+
return current_id
|
111
|
+
async def add_cognition(self, description: str) -> None:
|
112
|
+
"""添加认知记忆 Add cognition memory"""
|
113
|
+
return await self._add_memory(MemoryTag.COGNITION, description)
|
114
|
+
|
115
|
+
async def add_social(self, description: str) -> None:
|
116
|
+
"""添加社交记忆 Add social memory"""
|
117
|
+
return await self._add_memory(MemoryTag.SOCIAL, description)
|
118
|
+
|
119
|
+
async def add_economy(self, description: str) -> None:
|
120
|
+
"""添加经济记忆 Add economy memory"""
|
121
|
+
return await self._add_memory(MemoryTag.ECONOMY, description)
|
122
|
+
|
123
|
+
async def add_mobility(self, description: str) -> None:
|
124
|
+
"""添加移动记忆 Add mobility memory"""
|
125
|
+
return await self._add_memory(MemoryTag.MOBILITY, description)
|
126
|
+
|
127
|
+
async def add_event(self, description: str) -> None:
|
128
|
+
"""添加事件记忆 Add event memory"""
|
129
|
+
return await self._add_memory(MemoryTag.EVENT, description)
|
130
|
+
|
131
|
+
async def add_other(self, description: str) -> None:
|
132
|
+
"""添加其他记忆 Add other memory"""
|
133
|
+
return await self._add_memory(MemoryTag.OTHER, description)
|
134
|
+
|
135
|
+
async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
|
136
|
+
"""获取关联的认知记忆 Get related cognition memory"""
|
137
|
+
for memory in self._memories:
|
138
|
+
if memory.cognition_id == memory_id:
|
139
|
+
for cognition_memory in self._memories:
|
140
|
+
if (cognition_memory.tag == MemoryTag.COGNITION and
|
141
|
+
memory.cognition_id is not None):
|
142
|
+
return cognition_memory
|
143
|
+
return None
|
144
|
+
|
145
|
+
async def format_memory(self, memories: list[MemoryNode]) -> str:
|
146
|
+
"""格式化记忆"""
|
147
|
+
formatted_results = []
|
148
|
+
for memory in memories:
|
149
|
+
memory_tag = memory.tag
|
150
|
+
memory_day = memory.day
|
151
|
+
memory_time_seconds = memory.t
|
152
|
+
cognition_id = memory.cognition_id
|
153
|
+
|
154
|
+
# 格式化时间
|
155
|
+
if memory_time_seconds != 'unknown':
|
156
|
+
hours = memory_time_seconds // 3600
|
157
|
+
minutes = (memory_time_seconds % 3600) // 60
|
158
|
+
seconds = memory_time_seconds % 60
|
159
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
160
|
+
else:
|
161
|
+
memory_time = 'unknown'
|
162
|
+
|
163
|
+
memory_location = memory.location
|
164
|
+
|
165
|
+
# 添加认知信息(如果存在)
|
166
|
+
cognition_info = ""
|
167
|
+
if cognition_id is not None:
|
168
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
169
|
+
if cognition_memory:
|
170
|
+
cognition_info = f"\n Related cognition: {cognition_memory.description}"
|
171
|
+
|
172
|
+
formatted_results.append(
|
173
|
+
f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
|
174
|
+
f"location: {memory_location}]{cognition_info}"
|
175
|
+
)
|
176
|
+
return "\n".join(formatted_results)
|
32
177
|
|
33
|
-
def
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
) -> None:
|
43
|
-
"""
|
44
|
-
Initializes the Memory with optional configuration.
|
178
|
+
async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
|
179
|
+
"""获取指定ID的记忆"""
|
180
|
+
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
181
|
+
sorted_results = sorted(
|
182
|
+
memories,
|
183
|
+
key=lambda x: (x.day, x.t),
|
184
|
+
reverse=True
|
185
|
+
)
|
186
|
+
return self.format_memory(sorted_results)
|
45
187
|
|
188
|
+
async def search(
|
189
|
+
self,
|
190
|
+
query: str,
|
191
|
+
tag: Optional[MemoryTag] = None,
|
192
|
+
top_k: int = 3,
|
193
|
+
day_range: Optional[tuple[int, int]] = None, # 新增参数
|
194
|
+
time_range: Optional[tuple[int, int]] = None # 新增参数
|
195
|
+
) -> str:
|
196
|
+
"""Search stream memory
|
197
|
+
|
46
198
|
Args:
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
2. A callable that returns the default value when invoked (useful for complex default values).
|
53
|
-
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
54
|
-
Defaults to None.
|
55
|
-
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
56
|
-
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
57
|
-
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
58
|
-
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
59
|
-
embedding_model (Embeddings): The embedding model for memory search.
|
60
|
-
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
199
|
+
query: Query text
|
200
|
+
tag: Optional memory tag for filtering specific types of memories
|
201
|
+
top_k: Number of most relevant memories to return
|
202
|
+
day_range: Optional tuple of start and end days (start_day, end_day)
|
203
|
+
time_range: Optional tuple of start and end times (start_time, end_time)
|
61
204
|
"""
|
62
|
-
self.
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
#
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
if
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
try:
|
87
|
-
if isinstance(_type, type):
|
88
|
-
_value = _type(_value)
|
89
|
-
else:
|
90
|
-
if isinstance(_type, deque):
|
91
|
-
_type.extend(_value)
|
92
|
-
_value = deepcopy(_type)
|
93
|
-
else:
|
94
|
-
logger.warning(f"type `{_type}` is not supported!")
|
95
|
-
pass
|
96
|
-
except TypeError as e:
|
97
|
-
pass
|
98
|
-
except TypeError as e:
|
99
|
-
if isinstance(v, type):
|
100
|
-
_value = v()
|
101
|
-
else:
|
102
|
-
_value = v
|
103
|
-
self._embedding_fields[k] = False
|
104
|
-
|
105
|
-
if (
|
106
|
-
k in PROFILE_ATTRIBUTES
|
107
|
-
or k in STATE_ATTRIBUTES
|
108
|
-
or k == TIME_STAMP_KEY
|
109
|
-
):
|
110
|
-
logger.warning(f"key `{k}` already declared in memory!")
|
111
|
-
continue
|
112
|
-
|
113
|
-
_dynamic_config[k] = deepcopy(_value)
|
114
|
-
|
115
|
-
# 初始化各类记忆
|
116
|
-
self._dynamic = DynamicMemory(
|
117
|
-
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
205
|
+
if not self._embedding_model or not self._faiss_query:
|
206
|
+
return "Search components not initialized"
|
207
|
+
|
208
|
+
filter_dict = {"type": "stream"}
|
209
|
+
|
210
|
+
if tag:
|
211
|
+
filter_dict["tag"] = tag
|
212
|
+
|
213
|
+
# 添加时间范围过滤
|
214
|
+
if day_range:
|
215
|
+
start_day, end_day = day_range
|
216
|
+
filter_dict["day"] = lambda x: start_day <= x <= end_day
|
217
|
+
|
218
|
+
if time_range:
|
219
|
+
start_time, end_time = time_range
|
220
|
+
filter_dict["time"] = lambda x: start_time <= x <= end_time
|
221
|
+
|
222
|
+
top_results = await self._faiss_query.similarity_search(
|
223
|
+
query=query,
|
224
|
+
agent_id=self._agent_id,
|
225
|
+
k=top_k,
|
226
|
+
return_score_type="similarity_score",
|
227
|
+
filter=filter_dict
|
118
228
|
)
|
119
229
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
_profile_config[k] = v
|
126
|
-
if motion is not None:
|
127
|
-
for k, v in motion.items():
|
128
|
-
if k not in STATE_ATTRIBUTES:
|
129
|
-
logger.warning(f"key `{k}` is not a correct `motion` field!")
|
130
|
-
continue
|
131
|
-
_state_config[k] = v
|
132
|
-
if base is not None:
|
133
|
-
for k, v in base.items():
|
134
|
-
if k not in STATE_ATTRIBUTES:
|
135
|
-
logger.warning(f"key `{k}` is not a correct `base` field!")
|
136
|
-
continue
|
137
|
-
_state_config[k] = v
|
138
|
-
self._state = StateMemory(
|
139
|
-
msg=_state_config, activate_timestamp=activate_timestamp
|
230
|
+
# 将结果按时间排序(先按天数,再按时间)
|
231
|
+
sorted_results = sorted(
|
232
|
+
top_results,
|
233
|
+
key=lambda x: (x[2].get('day', 0), x[2].get('time', 0)),
|
234
|
+
reverse=True
|
140
235
|
)
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
236
|
+
|
237
|
+
formatted_results = []
|
238
|
+
for content, score, metadata in sorted_results:
|
239
|
+
memory_tag = metadata.get('tag', 'unknown')
|
240
|
+
memory_day = metadata.get('day', 'unknown')
|
241
|
+
memory_time_seconds = metadata.get('time', 'unknown')
|
242
|
+
cognition_id = metadata.get('cognition_id', None)
|
243
|
+
|
244
|
+
# 格式化时间
|
245
|
+
if memory_time_seconds != 'unknown':
|
246
|
+
hours = memory_time_seconds // 3600
|
247
|
+
minutes = (memory_time_seconds % 3600) // 60
|
248
|
+
seconds = memory_time_seconds % 60
|
249
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
250
|
+
else:
|
251
|
+
memory_time = 'unknown'
|
252
|
+
|
253
|
+
memory_location = metadata.get('location', 'unknown')
|
254
|
+
|
255
|
+
# 添加认知信息(如果存在)
|
256
|
+
cognition_info = ""
|
257
|
+
if cognition_id is not None:
|
258
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
259
|
+
if cognition_memory:
|
260
|
+
cognition_info = f"\n Related cognition: {cognition_memory.description}"
|
261
|
+
|
262
|
+
formatted_results.append(
|
263
|
+
f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
|
264
|
+
f"location: {memory_location}]{cognition_info}"
|
160
265
|
)
|
161
|
-
return
|
266
|
+
return "\n".join(formatted_results)
|
162
267
|
|
163
|
-
def
|
268
|
+
async def search_today(
|
269
|
+
self,
|
270
|
+
query: str = "", # 可选的查询文本
|
271
|
+
tag: Optional[MemoryTag] = None,
|
272
|
+
top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
|
273
|
+
) -> str:
|
274
|
+
"""Search all memory events from today
|
275
|
+
|
276
|
+
Args:
|
277
|
+
query: Optional query text, returns all memories of the day if empty
|
278
|
+
tag: Optional memory tag for filtering specific types of memories
|
279
|
+
top_k: Number of most relevant memories to return, defaults to 100
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
str: Formatted text of today's memories
|
164
283
|
"""
|
165
|
-
|
284
|
+
if self._simulator is None:
|
285
|
+
return "Simulator not initialized"
|
286
|
+
|
287
|
+
current_day = int(await self._simulator.get_simulator_day())
|
288
|
+
|
289
|
+
# 使用 search 方法,设置 day_range 为当天
|
290
|
+
return await self.search(
|
291
|
+
query=query,
|
292
|
+
tag=tag,
|
293
|
+
top_k=top_k,
|
294
|
+
day_range=(current_day, current_day)
|
295
|
+
)
|
296
|
+
|
297
|
+
async def add_cognition_to_memory(self, memory_id: Union[int, list[int]], cognition: str) -> None:
|
298
|
+
"""为已存在的记忆添加认知
|
299
|
+
|
300
|
+
Args:
|
301
|
+
memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
|
302
|
+
cognition: 认知描述
|
166
303
|
"""
|
167
|
-
|
304
|
+
# 将单个ID转换为列表以统一处理
|
305
|
+
memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
|
306
|
+
|
307
|
+
# 找到所有对应的记忆
|
308
|
+
target_memories = []
|
309
|
+
for memory in self._memories:
|
310
|
+
if id(memory) in memory_ids:
|
311
|
+
target_memories.append(memory)
|
312
|
+
|
313
|
+
if not target_memories:
|
314
|
+
raise ValueError(f"No memories found with ids {memory_ids}")
|
315
|
+
|
316
|
+
# 添加认知记忆
|
317
|
+
cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
|
318
|
+
|
319
|
+
# 更新所有原记忆的认知ID
|
320
|
+
for target_memory in target_memories:
|
321
|
+
target_memory.cognition_id = cognition_id
|
322
|
+
|
323
|
+
async def get_all(self) -> list[MemoryNode]:
|
324
|
+
"""获取所有流式信息"""
|
325
|
+
return list(self._memories)
|
326
|
+
|
327
|
+
class StatusMemory:
|
328
|
+
"""组合现有的三种记忆类型"""
|
329
|
+
def __init__(self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory):
|
330
|
+
self.profile = profile
|
331
|
+
self.state = state
|
332
|
+
self.dynamic = dynamic
|
333
|
+
self._faiss_query = None
|
334
|
+
self._embedding_model = None
|
335
|
+
self._simulator = None
|
336
|
+
self._agent_id = -1
|
337
|
+
self._semantic_templates = {} # 用户可配置的模板
|
338
|
+
self._embedding_fields = {} # 需要 embedding 的字段
|
339
|
+
self._embedding_field_to_doc_id = defaultdict(str) # 新增
|
340
|
+
self.watchers = {} # 新增
|
341
|
+
self._lock = asyncio.Lock() # 新增
|
342
|
+
|
343
|
+
def set_simulator(self, simulator):
|
344
|
+
self._simulator = simulator
|
345
|
+
|
346
|
+
async def initialize_embeddings(self) -> None:
|
347
|
+
"""初始化所有需要 embedding 的字段"""
|
348
|
+
if not self._embedding_model or not self._faiss_query:
|
349
|
+
logger.warning("Search components not initialized, skipping embeddings initialization")
|
350
|
+
return
|
351
|
+
|
352
|
+
# 获取所有状态信息
|
353
|
+
profile, state, dynamic = await self.export()
|
354
|
+
|
355
|
+
# 为每个需要 embedding 的字段创建 embedding
|
356
|
+
for key, value in profile[0].items():
|
357
|
+
if self.should_embed(key):
|
358
|
+
semantic_text = self._generate_semantic_text(key, value)
|
359
|
+
doc_ids = await self._faiss_query.add_documents(
|
360
|
+
agent_id=self._agent_id,
|
361
|
+
documents=semantic_text,
|
362
|
+
extra_tags={
|
363
|
+
"type": "profile_state",
|
364
|
+
"key": key,
|
365
|
+
},
|
366
|
+
)
|
367
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
368
|
+
|
369
|
+
for key, value in state[0].items():
|
370
|
+
if self.should_embed(key):
|
371
|
+
semantic_text = self._generate_semantic_text(key, value)
|
372
|
+
doc_ids = await self._faiss_query.add_documents(
|
373
|
+
agent_id=self._agent_id,
|
374
|
+
documents=semantic_text,
|
375
|
+
extra_tags={
|
376
|
+
"type": "profile_state",
|
377
|
+
"key": key,
|
378
|
+
},
|
379
|
+
)
|
380
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
381
|
+
|
382
|
+
for key, value in dynamic[0].items():
|
383
|
+
if self.should_embed(key):
|
384
|
+
semantic_text = self._generate_semantic_text(key, value)
|
385
|
+
doc_ids = await self._faiss_query.add_documents(
|
386
|
+
agent_id=self._agent_id,
|
387
|
+
documents=semantic_text,
|
388
|
+
extra_tags={
|
389
|
+
"type": "profile_state",
|
390
|
+
"key": key,
|
391
|
+
},
|
392
|
+
)
|
393
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
394
|
+
|
395
|
+
def _get_memory_type_by_key(self, key: str) -> str:
|
396
|
+
"""根据键名确定记忆类型"""
|
397
|
+
try:
|
398
|
+
if key in self.profile.__dict__:
|
399
|
+
return "profile"
|
400
|
+
elif key in self.state.__dict__:
|
401
|
+
return "state"
|
402
|
+
else:
|
403
|
+
return "dynamic"
|
404
|
+
except:
|
405
|
+
return "dynamic"
|
168
406
|
|
169
|
-
|
170
|
-
|
171
|
-
self
|
172
|
-
|
173
|
-
if self._agent_id < 0:
|
174
|
-
raise RuntimeError(
|
175
|
-
f"agent_id before assignment, please `set_agent_id` first!"
|
176
|
-
)
|
177
|
-
return self._agent_id
|
407
|
+
def set_search_components(self, faiss_query, embedding_model):
|
408
|
+
"""设置搜索所需的组件"""
|
409
|
+
self._faiss_query = faiss_query
|
410
|
+
self._embedding_model = embedding_model
|
178
411
|
|
179
412
|
def set_agent_id(self, agent_id: int):
|
413
|
+
"""设置agent_id"""
|
414
|
+
self._agent_id = agent_id
|
415
|
+
|
416
|
+
def set_semantic_templates(self, templates: Dict[str, str]):
|
417
|
+
"""设置语义模板
|
418
|
+
|
419
|
+
Args:
|
420
|
+
templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
|
180
421
|
"""
|
181
|
-
|
422
|
+
self._semantic_templates = templates
|
423
|
+
|
424
|
+
def _generate_semantic_text(self, key: str, value: Any) -> str:
|
425
|
+
"""生成语义文本
|
426
|
+
|
427
|
+
如果key存在于模板中,使用自定义模板
|
428
|
+
否则使用默认模板 "my {key} is {value}"
|
182
429
|
"""
|
183
|
-
self.
|
430
|
+
if key in self._semantic_templates:
|
431
|
+
return self._semantic_templates[key].format(value)
|
432
|
+
return f"Your {key} is {value}"
|
433
|
+
|
434
|
+
@lock_decorator
|
435
|
+
async def search(
|
436
|
+
self, query: str, top_k: int = 3, filter: Optional[dict] = None
|
437
|
+
) -> str:
|
438
|
+
"""搜索相关记忆
|
184
439
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
440
|
+
Args:
|
441
|
+
query: 查询文本
|
442
|
+
top_k: 返回最相关的记忆数量
|
443
|
+
filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
|
444
|
+
|
445
|
+
Returns:
|
446
|
+
str: 格式化的相关记忆文本
|
447
|
+
"""
|
448
|
+
if not self._embedding_model:
|
449
|
+
return "Embedding model not initialized"
|
450
|
+
|
451
|
+
filter_dict = {"type": "profile_state"}
|
452
|
+
if filter is not None:
|
453
|
+
filter_dict.update(filter)
|
454
|
+
top_results: list[tuple[str, float, dict]] = (
|
455
|
+
await self._faiss_query.similarity_search( # type:ignore
|
456
|
+
query=query,
|
457
|
+
agent_id=self._agent_id,
|
458
|
+
k=top_k,
|
459
|
+
return_score_type="similarity_score",
|
460
|
+
filter=filter_dict,
|
191
461
|
)
|
192
|
-
|
462
|
+
)
|
463
|
+
# 格式化输出
|
464
|
+
formatted_results = []
|
465
|
+
for content, score, metadata in top_results:
|
466
|
+
formatted_results.append(
|
467
|
+
f"- {content} "
|
468
|
+
)
|
469
|
+
|
470
|
+
return "\n".join(formatted_results)
|
471
|
+
|
472
|
+
def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
|
473
|
+
"""设置需要 embedding 的字段"""
|
474
|
+
self._embedding_fields = embedding_fields
|
475
|
+
|
476
|
+
def should_embed(self, key: str) -> bool:
|
477
|
+
"""判断字段是否需要 embedding"""
|
478
|
+
return self._embedding_fields.get(key, False)
|
193
479
|
|
194
480
|
@lock_decorator
|
195
|
-
async def get(
|
196
|
-
|
197
|
-
|
198
|
-
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
199
|
-
) -> Any:
|
200
|
-
"""
|
201
|
-
Retrieves a value from memory based on the given key and access mode.
|
481
|
+
async def get(self, key: Any,
|
482
|
+
mode: Union[Literal["read only"], Literal["read and write"]] = "read only") -> Any:
|
483
|
+
"""从记忆中获取值
|
202
484
|
|
203
485
|
Args:
|
204
|
-
key
|
205
|
-
mode
|
486
|
+
key: 要获取的键
|
487
|
+
mode: 访问模式,"read only" 或 "read and write"
|
206
488
|
|
207
489
|
Returns:
|
208
|
-
|
490
|
+
获取到的值
|
209
491
|
|
210
492
|
Raises:
|
211
|
-
ValueError:
|
212
|
-
KeyError:
|
493
|
+
ValueError: 如果提供了无效的模式
|
494
|
+
KeyError: 如果在任何记忆部分都找不到该键
|
213
495
|
"""
|
214
496
|
if mode == "read only":
|
215
497
|
process_func = deepcopy
|
@@ -217,54 +499,56 @@ class Memory:
|
|
217
499
|
process_func = lambda x: x
|
218
500
|
else:
|
219
501
|
raise ValueError(f"Invalid get mode `{mode}`!")
|
220
|
-
|
502
|
+
|
503
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
221
504
|
try:
|
222
|
-
value = await
|
505
|
+
value = await mem.get(key)
|
223
506
|
return process_func(value)
|
224
|
-
except KeyError
|
507
|
+
except KeyError:
|
225
508
|
continue
|
226
509
|
raise KeyError(f"No attribute `{key}` in memories!")
|
227
510
|
|
228
511
|
@lock_decorator
|
229
|
-
async def update(
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
234
|
-
store_snapshot: bool = False,
|
235
|
-
protect_llm_read_only_fields: bool = True,
|
236
|
-
) -> None:
|
512
|
+
async def update(self, key: Any, value: Any,
|
513
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
514
|
+
store_snapshot: bool = False,
|
515
|
+
protect_llm_read_only_fields: bool = True) -> None:
|
237
516
|
"""更新记忆值并在必要时更新embedding"""
|
238
517
|
if protect_llm_read_only_fields:
|
239
518
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
240
519
|
logger.warning(f"Trying to write protected key `{key}`!")
|
241
520
|
return
|
242
|
-
|
521
|
+
|
522
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
243
523
|
try:
|
244
|
-
original_value = await
|
524
|
+
original_value = await mem.get(key)
|
245
525
|
if mode == "replace":
|
246
|
-
await
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
#
|
526
|
+
await mem.update(key, value, store_snapshot)
|
527
|
+
if self.should_embed(key) and self._embedding_model:
|
528
|
+
semantic_text = self._generate_semantic_text(key, value)
|
529
|
+
|
530
|
+
# 删除旧的 embedding
|
251
531
|
orig_doc_id = self._embedding_field_to_doc_id[key]
|
252
532
|
if orig_doc_id:
|
253
|
-
await self.
|
533
|
+
await self._faiss_query.delete_documents(
|
254
534
|
to_delete_ids=[orig_doc_id],
|
255
535
|
)
|
256
|
-
|
257
|
-
|
258
|
-
|
536
|
+
|
537
|
+
# 添加新的 embedding
|
538
|
+
doc_ids = await self._faiss_query.add_documents(
|
539
|
+
agent_id=self._agent_id,
|
540
|
+
documents=semantic_text,
|
259
541
|
extra_tags={
|
260
|
-
"type":
|
542
|
+
"type": self._get_memory_type(mem),
|
261
543
|
"key": key,
|
262
544
|
},
|
263
545
|
)
|
264
546
|
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
547
|
+
|
265
548
|
if key in self.watchers:
|
266
549
|
for callback in self.watchers[key]:
|
267
550
|
asyncio.create_task(callback())
|
551
|
+
|
268
552
|
elif mode == "merge":
|
269
553
|
if isinstance(original_value, set):
|
270
554
|
original_value.update(set(value))
|
@@ -278,14 +562,14 @@ class Memory:
|
|
278
562
|
logger.debug(
|
279
563
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
280
564
|
)
|
281
|
-
await
|
282
|
-
if self.
|
283
|
-
|
284
|
-
doc_ids = await self.
|
285
|
-
agent_id=self.
|
565
|
+
await mem.update(key, value, store_snapshot)
|
566
|
+
if self.should_embed(key) and self._embedding_model:
|
567
|
+
semantic_text = self._generate_semantic_text(key, value)
|
568
|
+
doc_ids = await self._faiss_query.add_documents(
|
569
|
+
agent_id=self._agent_id,
|
286
570
|
documents=f"{key}: {str(original_value)}",
|
287
571
|
extra_tags={
|
288
|
-
"type":
|
572
|
+
"type": self._get_memory_type(mem),
|
289
573
|
"key": key,
|
290
574
|
},
|
291
575
|
)
|
@@ -302,63 +586,20 @@ class Memory:
|
|
302
586
|
|
303
587
|
def _get_memory_type(self, mem: Any) -> str:
|
304
588
|
"""获取记忆类型"""
|
305
|
-
if mem is self.
|
589
|
+
if mem is self.state:
|
306
590
|
return "state"
|
307
|
-
elif mem is self.
|
591
|
+
elif mem is self.profile:
|
308
592
|
return "profile"
|
309
593
|
else:
|
310
594
|
return "dynamic"
|
311
595
|
|
312
|
-
async def update_batch(
|
313
|
-
self,
|
314
|
-
content: Union[dict, Sequence[tuple[Any, Any]]],
|
315
|
-
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
316
|
-
store_snapshot: bool = False,
|
317
|
-
protect_llm_read_only_fields: bool = True,
|
318
|
-
) -> None:
|
319
|
-
"""
|
320
|
-
Updates multiple values in the memory at once.
|
321
|
-
|
322
|
-
Args:
|
323
|
-
content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
|
324
|
-
mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
325
|
-
store_snapshot (bool): Whether to store a snapshot of the memory after the update.
|
326
|
-
protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
|
327
|
-
|
328
|
-
Raises:
|
329
|
-
TypeError: If the content type is neither a dictionary nor a sequence of tuples.
|
330
|
-
"""
|
331
|
-
if isinstance(content, dict):
|
332
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
|
333
|
-
elif isinstance(content, Sequence):
|
334
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
|
335
|
-
else:
|
336
|
-
raise TypeError(f"Invalid content type `{type(content)}`!")
|
337
|
-
for k, v in _list_content[:1]:
|
338
|
-
await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
|
339
|
-
for k, v in _list_content[1:]:
|
340
|
-
await self.update(k, v, mode, False, protect_llm_read_only_fields)
|
341
|
-
|
342
596
|
@lock_decorator
|
343
597
|
async def add_watcher(self, key: str, callback: Callable) -> None:
|
344
|
-
"""
|
345
|
-
Adds a callback function to be invoked when the value
|
346
|
-
associated with the specified key in memory is updated.
|
347
|
-
|
348
|
-
Args:
|
349
|
-
key (str): The key for which the watcher is being registered.
|
350
|
-
callback (Callable): A callable function that will be executed
|
351
|
-
whenever the value associated with the specified key is updated.
|
352
|
-
|
353
|
-
Notes:
|
354
|
-
If the key does not already have any watchers, it will be
|
355
|
-
initialized with an empty list before appending the callback.
|
356
|
-
"""
|
598
|
+
"""添加值变更的监听器"""
|
357
599
|
if key not in self.watchers:
|
358
600
|
self.watchers[key] = []
|
359
601
|
self.watchers[key].append(callback)
|
360
602
|
|
361
|
-
@lock_decorator
|
362
603
|
async def export(
|
363
604
|
self,
|
364
605
|
) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
|
@@ -369,12 +610,11 @@ class Memory:
|
|
369
610
|
tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
|
370
611
|
"""
|
371
612
|
return (
|
372
|
-
await self.
|
373
|
-
await self.
|
374
|
-
await self.
|
613
|
+
await self.profile.export(),
|
614
|
+
await self.state.export(),
|
615
|
+
await self.dynamic.export(),
|
375
616
|
)
|
376
617
|
|
377
|
-
@lock_decorator
|
378
618
|
async def load(
|
379
619
|
self,
|
380
620
|
snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
|
@@ -390,41 +630,249 @@ class Memory:
|
|
390
630
|
_profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
|
391
631
|
for _snapshot, _mem in zip(
|
392
632
|
[_profile_snapshot, _state_snapshot, _dynamic_snapshot],
|
393
|
-
[self.
|
633
|
+
[self.state, self.profile, self.dynamic],
|
394
634
|
):
|
395
635
|
if _snapshot:
|
396
636
|
await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
|
397
637
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
638
|
+
class Memory:
|
639
|
+
"""
|
640
|
+
A class to manage different types of memory (state, profile, dynamic).
|
641
|
+
|
642
|
+
Attributes:
|
643
|
+
_state (StateMemory): Stores state-related data.
|
644
|
+
_profile (ProfileMemory): Stores profile-related data.
|
645
|
+
_dynamic (DynamicMemory): Stores dynamically configured data.
|
646
|
+
"""
|
647
|
+
|
648
|
+
def __init__(
|
649
|
+
self,
|
650
|
+
config: Optional[dict[Any, Any]] = None,
|
651
|
+
profile: Optional[dict[Any, Any]] = None,
|
652
|
+
base: Optional[dict[Any, Any]] = None,
|
653
|
+
activate_timestamp: bool = False,
|
654
|
+
embedding_model: Optional[Embeddings] = None,
|
655
|
+
faiss_query: Optional[FaissQuery] = None,
|
656
|
+
) -> None:
|
657
|
+
"""
|
658
|
+
Initializes the Memory with optional configuration.
|
403
659
|
|
404
660
|
Args:
|
405
|
-
|
406
|
-
|
407
|
-
|
661
|
+
config (Optional[dict[Any, Any]], optional):
|
662
|
+
A configuration dictionary for dynamic memory. The dictionary format is:
|
663
|
+
- Key: The name of the dynamic memory field.
|
664
|
+
- Value: Can be one of two formats:
|
665
|
+
1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
|
666
|
+
2. A callable that returns the default value when invoked (useful for complex default values).
|
667
|
+
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
668
|
+
Defaults to None.
|
669
|
+
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
670
|
+
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
671
|
+
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
672
|
+
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
673
|
+
embedding_model (Embeddings): The embedding model for memory search.
|
674
|
+
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
675
|
+
"""
|
676
|
+
self.watchers: dict[str, list[Callable]] = {}
|
677
|
+
self._lock = asyncio.Lock()
|
678
|
+
self._agent_id: int = -1
|
679
|
+
self._simulator = None
|
680
|
+
self._embedding_model = embedding_model
|
681
|
+
self._faiss_query = faiss_query
|
682
|
+
self._semantic_templates: dict[str, str] = {}
|
683
|
+
_dynamic_config: dict[Any, Any] = {}
|
684
|
+
_state_config: dict[Any, Any] = {}
|
685
|
+
_profile_config: dict[Any, Any] = {}
|
686
|
+
self._embedding_fields: dict[str, bool] = {}
|
687
|
+
self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
|
408
688
|
|
409
|
-
|
410
|
-
|
689
|
+
if config is not None:
|
690
|
+
for k, v in config.items():
|
691
|
+
try:
|
692
|
+
# 处理不同长度的配置元组
|
693
|
+
if isinstance(v, tuple):
|
694
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
695
|
+
_type, _value, enable_embedding, template = v
|
696
|
+
self._embedding_fields[k] = enable_embedding
|
697
|
+
self._semantic_templates[k] = template
|
698
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
699
|
+
_type, _value, enable_embedding = v
|
700
|
+
self._embedding_fields[k] = enable_embedding
|
701
|
+
else: # (_type, _value)
|
702
|
+
_type, _value = v
|
703
|
+
self._embedding_fields[k] = False
|
704
|
+
else:
|
705
|
+
_type = type(v)
|
706
|
+
_value = v
|
707
|
+
self._embedding_fields[k] = False
|
708
|
+
|
709
|
+
# 处理类型转换
|
710
|
+
try:
|
711
|
+
if isinstance(_type, type):
|
712
|
+
_value = _type(_value)
|
713
|
+
else:
|
714
|
+
if isinstance(_type, deque):
|
715
|
+
_type.extend(_value)
|
716
|
+
_value = deepcopy(_type)
|
717
|
+
else:
|
718
|
+
logger.warning(f"type `{_type}` is not supported!")
|
719
|
+
except TypeError as e:
|
720
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
721
|
+
except TypeError as e:
|
722
|
+
if isinstance(v, type):
|
723
|
+
_value = v()
|
724
|
+
else:
|
725
|
+
_value = v
|
726
|
+
self._embedding_fields[k] = False
|
727
|
+
|
728
|
+
if (
|
729
|
+
k in PROFILE_ATTRIBUTES
|
730
|
+
or k in STATE_ATTRIBUTES
|
731
|
+
or k == TIME_STAMP_KEY
|
732
|
+
):
|
733
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
734
|
+
continue
|
735
|
+
|
736
|
+
_dynamic_config[k] = deepcopy(_value)
|
737
|
+
|
738
|
+
# 初始化各类记忆
|
739
|
+
self._dynamic = DynamicMemory(
|
740
|
+
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
741
|
+
)
|
742
|
+
|
743
|
+
if profile is not None:
|
744
|
+
for k, v in profile.items():
|
745
|
+
if k not in PROFILE_ATTRIBUTES:
|
746
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
747
|
+
continue
|
748
|
+
|
749
|
+
try:
|
750
|
+
# 处理配置元组格式
|
751
|
+
if isinstance(v, tuple):
|
752
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
753
|
+
_type, _value, enable_embedding, template = v
|
754
|
+
self._embedding_fields[k] = enable_embedding
|
755
|
+
self._semantic_templates[k] = template
|
756
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
757
|
+
_type, _value, enable_embedding = v
|
758
|
+
self._embedding_fields[k] = enable_embedding
|
759
|
+
else: # (_type, _value)
|
760
|
+
_type, _value = v
|
761
|
+
self._embedding_fields[k] = False
|
762
|
+
|
763
|
+
# 处理类型转换
|
764
|
+
try:
|
765
|
+
if isinstance(_type, type):
|
766
|
+
_value = _type(_value)
|
767
|
+
else:
|
768
|
+
if isinstance(_type, deque):
|
769
|
+
_type.extend(_value)
|
770
|
+
_value = deepcopy(_type)
|
771
|
+
else:
|
772
|
+
logger.warning(f"type `{_type}` is not supported!")
|
773
|
+
except TypeError as e:
|
774
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
775
|
+
else:
|
776
|
+
# 保持对简单键值对的兼容
|
777
|
+
_value = v
|
778
|
+
self._embedding_fields[k] = False
|
779
|
+
except TypeError as e:
|
780
|
+
if isinstance(v, type):
|
781
|
+
_value = v()
|
782
|
+
else:
|
783
|
+
_value = v
|
784
|
+
self._embedding_fields[k] = False
|
785
|
+
|
786
|
+
_profile_config[k] = deepcopy(_value)
|
787
|
+
self._profile = ProfileMemory(
|
788
|
+
msg=_profile_config, activate_timestamp=activate_timestamp
|
789
|
+
)
|
790
|
+
|
791
|
+
if base is not None:
|
792
|
+
for k, v in base.items():
|
793
|
+
if k not in STATE_ATTRIBUTES:
|
794
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
795
|
+
continue
|
796
|
+
_state_config[k] = v
|
797
|
+
|
798
|
+
self._state = StateMemory(
|
799
|
+
msg=_state_config, activate_timestamp=activate_timestamp
|
800
|
+
)
|
801
|
+
|
802
|
+
# 组合 StatusMemory,并传递 embedding_fields 信息
|
803
|
+
self._status = StatusMemory(
|
804
|
+
profile=self._profile,
|
805
|
+
state=self._state,
|
806
|
+
dynamic=self._dynamic
|
807
|
+
)
|
808
|
+
self._status.set_embedding_fields(self._embedding_fields)
|
809
|
+
self._status.set_search_components(self._faiss_query, self._embedding_model)
|
810
|
+
|
811
|
+
# 新增 StreamMemory
|
812
|
+
self._stream = StreamMemory()
|
813
|
+
self._stream.set_status_memory(self._status)
|
814
|
+
self._stream.set_search_components(self._faiss_query, self._embedding_model)
|
815
|
+
|
816
|
+
def set_search_components(
|
817
|
+
self,
|
818
|
+
faiss_query: FaissQuery,
|
819
|
+
embedding_model: Embeddings,
|
820
|
+
):
|
821
|
+
self._embedding_model = embedding_model
|
822
|
+
self._faiss_query = faiss_query
|
823
|
+
self._stream.set_search_components(faiss_query, embedding_model)
|
824
|
+
self._status.set_search_components(faiss_query, embedding_model)
|
825
|
+
|
826
|
+
def set_agent_id(self, agent_id: int):
|
411
827
|
"""
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
828
|
+
Set the FaissQuery of the agent.
|
829
|
+
"""
|
830
|
+
self._agent_id = agent_id
|
831
|
+
self._stream.set_agent_id(agent_id)
|
832
|
+
self._status.set_agent_id(agent_id)
|
833
|
+
|
834
|
+
def set_simulator(self, simulator):
|
835
|
+
self._simulator = simulator
|
836
|
+
self._stream.set_simulator(simulator)
|
837
|
+
self._status.set_simulator(simulator)
|
838
|
+
|
839
|
+
@property
|
840
|
+
def status(self) -> StatusMemory:
|
841
|
+
return self._status
|
842
|
+
|
843
|
+
@property
|
844
|
+
def stream(self) -> StreamMemory:
|
845
|
+
return self._stream
|
846
|
+
|
847
|
+
@property
|
848
|
+
def embedding_model(
|
849
|
+
self,
|
850
|
+
):
|
851
|
+
if self._embedding_model is None:
|
852
|
+
raise RuntimeError(
|
853
|
+
f"embedding_model before assignment, please `set_embedding_model` first!"
|
421
854
|
)
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
855
|
+
return self._embedding_model
|
856
|
+
|
857
|
+
@property
|
858
|
+
def agent_id(
|
859
|
+
self,
|
860
|
+
):
|
861
|
+
if self._agent_id < 0:
|
862
|
+
raise RuntimeError(
|
863
|
+
f"agent_id before assignment, please `set_agent_id` first!"
|
428
864
|
)
|
865
|
+
return self._agent_id
|
429
866
|
|
430
|
-
|
867
|
+
@property
|
868
|
+
def faiss_query(self) -> FaissQuery:
|
869
|
+
"""FaissQuery"""
|
870
|
+
if self._faiss_query is None:
|
871
|
+
raise RuntimeError(
|
872
|
+
f"FaissQuery access before assignment, please `set_faiss_query` first!"
|
873
|
+
)
|
874
|
+
return self._faiss_query
|
875
|
+
|
876
|
+
async def initialize_embeddings(self):
|
877
|
+
"""初始化embedding"""
|
878
|
+
await self._status.initialize_embeddings()
|