pycityagent 2.0.0a51__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a53__cp39-cp39-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +48 -62
- pycityagent/agent/agent_base.py +66 -53
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +44 -56
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +204 -119
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/simulator.py +17 -12
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +720 -272
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +92 -99
- pycityagent/simulation/simulation.py +115 -40
- pycityagent/tools/tool.py +7 -10
- pycityagent/workflow/block.py +11 -4
- {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
- {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +1 -1
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
pycityagent/memory/memory.py
CHANGED
@@ -3,10 +3,10 @@ import logging
|
|
3
3
|
from collections import defaultdict
|
4
4
|
from collections.abc import Callable, Sequence
|
5
5
|
from copy import deepcopy
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from typing import Any, Literal, Optional, Union, Dict
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from enum import Enum
|
8
9
|
|
9
|
-
import numpy as np
|
10
10
|
from langchain_core.embeddings import Embeddings
|
11
11
|
from pyparsing import deque
|
12
12
|
|
@@ -19,197 +19,479 @@ from .state import StateMemory
|
|
19
19
|
|
20
20
|
logger = logging.getLogger("pycityagent")
|
21
21
|
|
22
|
+
class MemoryTag(str, Enum):
|
23
|
+
"""记忆标签枚举类"""
|
24
|
+
MOBILITY = "mobility"
|
25
|
+
SOCIAL = "social"
|
26
|
+
ECONOMY = "economy"
|
27
|
+
COGNITION = "cognition"
|
28
|
+
OTHER = "other"
|
29
|
+
EVENT = "event"
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class MemoryNode:
|
33
|
+
"""记忆节点"""
|
34
|
+
tag: MemoryTag
|
35
|
+
day: int
|
36
|
+
t: int
|
37
|
+
location: str
|
38
|
+
description: str
|
39
|
+
cognition_id: Optional[int] = None # 关联的认知记忆ID
|
40
|
+
id: Optional[int] = None # 记忆ID
|
41
|
+
|
42
|
+
class StreamMemory:
|
43
|
+
"""用于存储时序性的流式信息"""
|
44
|
+
def __init__(self, max_len: int = 1000):
|
45
|
+
self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
|
46
|
+
self._memory_id_counter: int = 0 # 用于生成唯一ID
|
47
|
+
self._faiss_query = None
|
48
|
+
self._embedding_model = None
|
49
|
+
self._agent_id = -1
|
50
|
+
self._status_memory = None
|
51
|
+
self._simulator = None
|
52
|
+
|
53
|
+
def set_simulator(self, simulator):
|
54
|
+
self._simulator = simulator
|
55
|
+
|
56
|
+
def set_status_memory(self, status_memory):
|
57
|
+
self._status_memory = status_memory
|
58
|
+
|
59
|
+
def set_search_components(self, faiss_query, embedding_model):
|
60
|
+
"""设置搜索所需的组件"""
|
61
|
+
self._faiss_query = faiss_query
|
62
|
+
self._embedding_model = embedding_model
|
22
63
|
|
23
|
-
|
24
|
-
|
25
|
-
|
64
|
+
def set_agent_id(self, agent_id: int):
|
65
|
+
"""设置agent_id"""
|
66
|
+
self._agent_id = agent_id
|
26
67
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
68
|
+
async def _add_memory(self, tag: MemoryTag, description: str) -> int:
|
69
|
+
"""添加记忆节点的通用方法,返回记忆节点ID"""
|
70
|
+
if self._simulator is not None:
|
71
|
+
day = int(await self._simulator.get_simulator_day())
|
72
|
+
t = int(await self._simulator.get_time())
|
73
|
+
else:
|
74
|
+
day = 1
|
75
|
+
t = 1
|
76
|
+
position = await self._status_memory.get("position")
|
77
|
+
if 'aoi_position' in position:
|
78
|
+
location = position['aoi_position']['aoi_id']
|
79
|
+
elif 'lane_position' in position:
|
80
|
+
location = position['lane_position']['lane_id']
|
81
|
+
else:
|
82
|
+
location = "unknown"
|
83
|
+
|
84
|
+
current_id = self._memory_id_counter
|
85
|
+
self._memory_id_counter += 1
|
86
|
+
memory_node = MemoryNode(
|
87
|
+
tag=tag,
|
88
|
+
day=day,
|
89
|
+
t=t,
|
90
|
+
location=location,
|
91
|
+
description=description,
|
92
|
+
id=current_id,
|
93
|
+
)
|
94
|
+
self._memories.append(memory_node)
|
95
|
+
|
96
|
+
|
97
|
+
# 为新记忆创建 embedding
|
98
|
+
if self._embedding_model and self._faiss_query:
|
99
|
+
await self._faiss_query.add_documents(
|
100
|
+
agent_id=self._agent_id,
|
101
|
+
documents=description,
|
102
|
+
extra_tags={
|
103
|
+
"type": "stream",
|
104
|
+
"tag": tag,
|
105
|
+
"day": day,
|
106
|
+
"time": t,
|
107
|
+
},
|
108
|
+
)
|
109
|
+
|
110
|
+
return current_id
|
111
|
+
async def add_cognition(self, description: str) -> None:
|
112
|
+
"""添加认知记忆 Add cognition memory"""
|
113
|
+
return await self._add_memory(MemoryTag.COGNITION, description)
|
114
|
+
|
115
|
+
async def add_social(self, description: str) -> None:
|
116
|
+
"""添加社交记忆 Add social memory"""
|
117
|
+
return await self._add_memory(MemoryTag.SOCIAL, description)
|
118
|
+
|
119
|
+
async def add_economy(self, description: str) -> None:
|
120
|
+
"""添加经济记忆 Add economy memory"""
|
121
|
+
return await self._add_memory(MemoryTag.ECONOMY, description)
|
122
|
+
|
123
|
+
async def add_mobility(self, description: str) -> None:
|
124
|
+
"""添加移动记忆 Add mobility memory"""
|
125
|
+
return await self._add_memory(MemoryTag.MOBILITY, description)
|
126
|
+
|
127
|
+
async def add_event(self, description: str) -> None:
|
128
|
+
"""添加事件记忆 Add event memory"""
|
129
|
+
return await self._add_memory(MemoryTag.EVENT, description)
|
130
|
+
|
131
|
+
async def add_other(self, description: str) -> None:
|
132
|
+
"""添加其他记忆 Add other memory"""
|
133
|
+
return await self._add_memory(MemoryTag.OTHER, description)
|
134
|
+
|
135
|
+
async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
|
136
|
+
"""获取关联的认知记忆 Get related cognition memory"""
|
137
|
+
for memory in self._memories:
|
138
|
+
if memory.cognition_id == memory_id:
|
139
|
+
for cognition_memory in self._memories:
|
140
|
+
if (cognition_memory.tag == MemoryTag.COGNITION and
|
141
|
+
memory.cognition_id is not None):
|
142
|
+
return cognition_memory
|
143
|
+
return None
|
144
|
+
|
145
|
+
async def format_memory(self, memories: list[MemoryNode]) -> str:
|
146
|
+
"""格式化记忆"""
|
147
|
+
formatted_results = []
|
148
|
+
for memory in memories:
|
149
|
+
memory_tag = memory.tag
|
150
|
+
memory_day = memory.day
|
151
|
+
memory_time_seconds = memory.t
|
152
|
+
cognition_id = memory.cognition_id
|
153
|
+
|
154
|
+
# 格式化时间
|
155
|
+
if memory_time_seconds != 'unknown':
|
156
|
+
hours = memory_time_seconds // 3600
|
157
|
+
minutes = (memory_time_seconds % 3600) // 60
|
158
|
+
seconds = memory_time_seconds % 60
|
159
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
160
|
+
else:
|
161
|
+
memory_time = 'unknown'
|
162
|
+
|
163
|
+
memory_location = memory.location
|
164
|
+
|
165
|
+
# 添加认知信息(如果存在)
|
166
|
+
cognition_info = ""
|
167
|
+
if cognition_id is not None:
|
168
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
169
|
+
if cognition_memory:
|
170
|
+
cognition_info = f"\n Related cognition: {cognition_memory.description}"
|
171
|
+
|
172
|
+
formatted_results.append(
|
173
|
+
f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
|
174
|
+
f"location: {memory_location}]{cognition_info}"
|
175
|
+
)
|
176
|
+
return "\n".join(formatted_results)
|
32
177
|
|
33
|
-
def
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
) -> None:
|
43
|
-
"""
|
44
|
-
Initializes the Memory with optional configuration.
|
178
|
+
async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
|
179
|
+
"""获取指定ID的记忆"""
|
180
|
+
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
181
|
+
sorted_results = sorted(
|
182
|
+
memories,
|
183
|
+
key=lambda x: (x.day, x.t),
|
184
|
+
reverse=True
|
185
|
+
)
|
186
|
+
return self.format_memory(sorted_results)
|
45
187
|
|
188
|
+
async def search(
|
189
|
+
self,
|
190
|
+
query: str,
|
191
|
+
tag: Optional[MemoryTag] = None,
|
192
|
+
top_k: int = 3,
|
193
|
+
day_range: Optional[tuple[int, int]] = None, # 新增参数
|
194
|
+
time_range: Optional[tuple[int, int]] = None # 新增参数
|
195
|
+
) -> str:
|
196
|
+
"""Search stream memory
|
197
|
+
|
46
198
|
Args:
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
2. A callable that returns the default value when invoked (useful for complex default values).
|
53
|
-
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
54
|
-
Defaults to None.
|
55
|
-
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
56
|
-
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
57
|
-
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
58
|
-
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
59
|
-
embedding_model (Embeddings): The embedding model for memory search.
|
60
|
-
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
199
|
+
query: Query text
|
200
|
+
tag: Optional memory tag for filtering specific types of memories
|
201
|
+
top_k: Number of most relevant memories to return
|
202
|
+
day_range: Optional tuple of start and end days (start_day, end_day)
|
203
|
+
time_range: Optional tuple of start and end times (start_time, end_time)
|
61
204
|
"""
|
62
|
-
self.
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
#
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
if
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
try:
|
87
|
-
if isinstance(_type, type):
|
88
|
-
_value = _type(_value)
|
89
|
-
else:
|
90
|
-
if isinstance(_type, deque):
|
91
|
-
_type.extend(_value)
|
92
|
-
_value = deepcopy(_type)
|
93
|
-
else:
|
94
|
-
logger.warning(f"type `{_type}` is not supported!")
|
95
|
-
pass
|
96
|
-
except TypeError as e:
|
97
|
-
pass
|
98
|
-
except TypeError as e:
|
99
|
-
if isinstance(v, type):
|
100
|
-
_value = v()
|
101
|
-
else:
|
102
|
-
_value = v
|
103
|
-
self._embedding_fields[k] = False
|
104
|
-
|
105
|
-
if (
|
106
|
-
k in PROFILE_ATTRIBUTES
|
107
|
-
or k in STATE_ATTRIBUTES
|
108
|
-
or k == TIME_STAMP_KEY
|
109
|
-
):
|
110
|
-
logger.warning(f"key `{k}` already declared in memory!")
|
111
|
-
continue
|
112
|
-
|
113
|
-
_dynamic_config[k] = deepcopy(_value)
|
114
|
-
|
115
|
-
# 初始化各类记忆
|
116
|
-
self._dynamic = DynamicMemory(
|
117
|
-
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
205
|
+
if not self._embedding_model or not self._faiss_query:
|
206
|
+
return "Search components not initialized"
|
207
|
+
|
208
|
+
filter_dict = {"type": "stream"}
|
209
|
+
|
210
|
+
if tag:
|
211
|
+
filter_dict["tag"] = tag
|
212
|
+
|
213
|
+
# 添加时间范围过滤
|
214
|
+
if day_range:
|
215
|
+
start_day, end_day = day_range
|
216
|
+
filter_dict["day"] = lambda x: start_day <= x <= end_day
|
217
|
+
|
218
|
+
if time_range:
|
219
|
+
start_time, end_time = time_range
|
220
|
+
filter_dict["time"] = lambda x: start_time <= x <= end_time
|
221
|
+
|
222
|
+
top_results = await self._faiss_query.similarity_search(
|
223
|
+
query=query,
|
224
|
+
agent_id=self._agent_id,
|
225
|
+
k=top_k,
|
226
|
+
return_score_type="similarity_score",
|
227
|
+
filter=filter_dict
|
118
228
|
)
|
119
229
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
_profile_config[k] = v
|
126
|
-
if motion is not None:
|
127
|
-
for k, v in motion.items():
|
128
|
-
if k not in STATE_ATTRIBUTES:
|
129
|
-
logger.warning(f"key `{k}` is not a correct `motion` field!")
|
130
|
-
continue
|
131
|
-
_state_config[k] = v
|
132
|
-
if base is not None:
|
133
|
-
for k, v in base.items():
|
134
|
-
if k not in STATE_ATTRIBUTES:
|
135
|
-
logger.warning(f"key `{k}` is not a correct `base` field!")
|
136
|
-
continue
|
137
|
-
_state_config[k] = v
|
138
|
-
self._state = StateMemory(
|
139
|
-
msg=_state_config, activate_timestamp=activate_timestamp
|
230
|
+
# 将结果按时间排序(先按天数,再按时间)
|
231
|
+
sorted_results = sorted(
|
232
|
+
top_results,
|
233
|
+
key=lambda x: (x[2].get('day', 0), x[2].get('time', 0)),
|
234
|
+
reverse=True
|
140
235
|
)
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
236
|
+
|
237
|
+
formatted_results = []
|
238
|
+
for content, score, metadata in sorted_results:
|
239
|
+
memory_tag = metadata.get('tag', 'unknown')
|
240
|
+
memory_day = metadata.get('day', 'unknown')
|
241
|
+
memory_time_seconds = metadata.get('time', 'unknown')
|
242
|
+
cognition_id = metadata.get('cognition_id', None)
|
243
|
+
|
244
|
+
# 格式化时间
|
245
|
+
if memory_time_seconds != 'unknown':
|
246
|
+
hours = memory_time_seconds // 3600
|
247
|
+
minutes = (memory_time_seconds % 3600) // 60
|
248
|
+
seconds = memory_time_seconds % 60
|
249
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
250
|
+
else:
|
251
|
+
memory_time = 'unknown'
|
252
|
+
|
253
|
+
memory_location = metadata.get('location', 'unknown')
|
254
|
+
|
255
|
+
# 添加认知信息(如果存在)
|
256
|
+
cognition_info = ""
|
257
|
+
if cognition_id is not None:
|
258
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
259
|
+
if cognition_memory:
|
260
|
+
cognition_info = f"\n Related cognition: {cognition_memory.description}"
|
261
|
+
|
262
|
+
formatted_results.append(
|
263
|
+
f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
|
264
|
+
f"location: {memory_location}]{cognition_info}"
|
160
265
|
)
|
161
|
-
return
|
266
|
+
return "\n".join(formatted_results)
|
162
267
|
|
163
|
-
def
|
268
|
+
async def search_today(
|
269
|
+
self,
|
270
|
+
query: str = "", # 可选的查询文本
|
271
|
+
tag: Optional[MemoryTag] = None,
|
272
|
+
top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
|
273
|
+
) -> str:
|
274
|
+
"""Search all memory events from today
|
275
|
+
|
276
|
+
Args:
|
277
|
+
query: Optional query text, returns all memories of the day if empty
|
278
|
+
tag: Optional memory tag for filtering specific types of memories
|
279
|
+
top_k: Number of most relevant memories to return, defaults to 100
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
str: Formatted text of today's memories
|
164
283
|
"""
|
165
|
-
|
284
|
+
if self._simulator is None:
|
285
|
+
return "Simulator not initialized"
|
286
|
+
|
287
|
+
current_day = int(await self._simulator.get_simulator_day())
|
288
|
+
|
289
|
+
# 使用 search 方法,设置 day_range 为当天
|
290
|
+
return await self.search(
|
291
|
+
query=query,
|
292
|
+
tag=tag,
|
293
|
+
top_k=top_k,
|
294
|
+
day_range=(current_day, current_day)
|
295
|
+
)
|
296
|
+
|
297
|
+
async def add_cognition_to_memory(self, memory_id: Union[int, list[int]], cognition: str) -> None:
|
298
|
+
"""为已存在的记忆添加认知
|
299
|
+
|
300
|
+
Args:
|
301
|
+
memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
|
302
|
+
cognition: 认知描述
|
166
303
|
"""
|
167
|
-
|
304
|
+
# 将单个ID转换为列表以统一处理
|
305
|
+
memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
|
306
|
+
|
307
|
+
# 找到所有对应的记忆
|
308
|
+
target_memories = []
|
309
|
+
for memory in self._memories:
|
310
|
+
if id(memory) in memory_ids:
|
311
|
+
target_memories.append(memory)
|
312
|
+
|
313
|
+
if not target_memories:
|
314
|
+
raise ValueError(f"No memories found with ids {memory_ids}")
|
315
|
+
|
316
|
+
# 添加认知记忆
|
317
|
+
cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
|
318
|
+
|
319
|
+
# 更新所有原记忆的认知ID
|
320
|
+
for target_memory in target_memories:
|
321
|
+
target_memory.cognition_id = cognition_id
|
322
|
+
|
323
|
+
async def get_all(self) -> list[MemoryNode]:
|
324
|
+
"""获取所有流式信息"""
|
325
|
+
return list(self._memories)
|
326
|
+
|
327
|
+
class StatusMemory:
|
328
|
+
"""组合现有的三种记忆类型"""
|
329
|
+
def __init__(self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory):
|
330
|
+
self.profile = profile
|
331
|
+
self.state = state
|
332
|
+
self.dynamic = dynamic
|
333
|
+
self._faiss_query = None
|
334
|
+
self._embedding_model = None
|
335
|
+
self._simulator = None
|
336
|
+
self._agent_id = -1
|
337
|
+
self._semantic_templates = {} # 用户可配置的模板
|
338
|
+
self._embedding_fields = {} # 需要 embedding 的字段
|
339
|
+
self._embedding_field_to_doc_id = defaultdict(str) # 新增
|
340
|
+
self.watchers = {} # 新增
|
341
|
+
self._lock = asyncio.Lock() # 新增
|
342
|
+
|
343
|
+
def set_simulator(self, simulator):
|
344
|
+
self._simulator = simulator
|
345
|
+
|
346
|
+
async def initialize_embeddings(self) -> None:
|
347
|
+
"""初始化所有需要 embedding 的字段"""
|
348
|
+
if not self._embedding_model or not self._faiss_query:
|
349
|
+
logger.warning("Search components not initialized, skipping embeddings initialization")
|
350
|
+
return
|
351
|
+
|
352
|
+
# 获取所有状态信息
|
353
|
+
profile, state, dynamic = await self.export()
|
354
|
+
|
355
|
+
# 为每个需要 embedding 的字段创建 embedding
|
356
|
+
for key, value in profile[0].items():
|
357
|
+
if self.should_embed(key):
|
358
|
+
semantic_text = self._generate_semantic_text(key, value)
|
359
|
+
doc_ids = await self._faiss_query.add_documents(
|
360
|
+
agent_id=self._agent_id,
|
361
|
+
documents=semantic_text,
|
362
|
+
extra_tags={
|
363
|
+
"type": "profile_state",
|
364
|
+
"key": key,
|
365
|
+
},
|
366
|
+
)
|
367
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
368
|
+
|
369
|
+
for key, value in state[0].items():
|
370
|
+
if self.should_embed(key):
|
371
|
+
semantic_text = self._generate_semantic_text(key, value)
|
372
|
+
doc_ids = await self._faiss_query.add_documents(
|
373
|
+
agent_id=self._agent_id,
|
374
|
+
documents=semantic_text,
|
375
|
+
extra_tags={
|
376
|
+
"type": "profile_state",
|
377
|
+
"key": key,
|
378
|
+
},
|
379
|
+
)
|
380
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
381
|
+
|
382
|
+
for key, value in dynamic[0].items():
|
383
|
+
if self.should_embed(key):
|
384
|
+
semantic_text = self._generate_semantic_text(key, value)
|
385
|
+
doc_ids = await self._faiss_query.add_documents(
|
386
|
+
agent_id=self._agent_id,
|
387
|
+
documents=semantic_text,
|
388
|
+
extra_tags={
|
389
|
+
"type": "profile_state",
|
390
|
+
"key": key,
|
391
|
+
},
|
392
|
+
)
|
393
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
394
|
+
|
395
|
+
def _get_memory_type_by_key(self, key: str) -> str:
|
396
|
+
"""根据键名确定记忆类型"""
|
397
|
+
try:
|
398
|
+
if key in self.profile.__dict__:
|
399
|
+
return "profile"
|
400
|
+
elif key in self.state.__dict__:
|
401
|
+
return "state"
|
402
|
+
else:
|
403
|
+
return "dynamic"
|
404
|
+
except:
|
405
|
+
return "dynamic"
|
168
406
|
|
169
|
-
|
170
|
-
|
171
|
-
self
|
172
|
-
|
173
|
-
if self._agent_id < 0:
|
174
|
-
raise RuntimeError(
|
175
|
-
f"agent_id before assignment, please `set_agent_id` first!"
|
176
|
-
)
|
177
|
-
return self._agent_id
|
407
|
+
def set_search_components(self, faiss_query, embedding_model):
|
408
|
+
"""设置搜索所需的组件"""
|
409
|
+
self._faiss_query = faiss_query
|
410
|
+
self._embedding_model = embedding_model
|
178
411
|
|
179
412
|
def set_agent_id(self, agent_id: int):
|
413
|
+
"""设置agent_id"""
|
414
|
+
self._agent_id = agent_id
|
415
|
+
|
416
|
+
def set_semantic_templates(self, templates: Dict[str, str]):
|
417
|
+
"""设置语义模板
|
418
|
+
|
419
|
+
Args:
|
420
|
+
templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
|
180
421
|
"""
|
181
|
-
|
422
|
+
self._semantic_templates = templates
|
423
|
+
|
424
|
+
def _generate_semantic_text(self, key: str, value: Any) -> str:
|
425
|
+
"""生成语义文本
|
426
|
+
|
427
|
+
如果key存在于模板中,使用自定义模板
|
428
|
+
否则使用默认模板 "my {key} is {value}"
|
182
429
|
"""
|
183
|
-
self.
|
430
|
+
if key in self._semantic_templates:
|
431
|
+
return self._semantic_templates[key].format(value)
|
432
|
+
return f"Your {key} is {value}"
|
433
|
+
|
434
|
+
@lock_decorator
|
435
|
+
async def search(
|
436
|
+
self, query: str, top_k: int = 3, filter: Optional[dict] = None
|
437
|
+
) -> str:
|
438
|
+
"""搜索相关记忆
|
184
439
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
440
|
+
Args:
|
441
|
+
query: 查询文本
|
442
|
+
top_k: 返回最相关的记忆数量
|
443
|
+
filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
|
444
|
+
|
445
|
+
Returns:
|
446
|
+
str: 格式化的相关记忆文本
|
447
|
+
"""
|
448
|
+
if not self._embedding_model:
|
449
|
+
return "Embedding model not initialized"
|
450
|
+
|
451
|
+
filter_dict = {"type": "profile_state"}
|
452
|
+
if filter is not None:
|
453
|
+
filter_dict.update(filter)
|
454
|
+
top_results: list[tuple[str, float, dict]] = (
|
455
|
+
await self._faiss_query.similarity_search( # type:ignore
|
456
|
+
query=query,
|
457
|
+
agent_id=self._agent_id,
|
458
|
+
k=top_k,
|
459
|
+
return_score_type="similarity_score",
|
460
|
+
filter=filter_dict,
|
191
461
|
)
|
192
|
-
|
462
|
+
)
|
463
|
+
# 格式化输出
|
464
|
+
formatted_results = []
|
465
|
+
for content, score, metadata in top_results:
|
466
|
+
formatted_results.append(
|
467
|
+
f"- {content} "
|
468
|
+
)
|
469
|
+
|
470
|
+
return "\n".join(formatted_results)
|
471
|
+
|
472
|
+
def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
|
473
|
+
"""设置需要 embedding 的字段"""
|
474
|
+
self._embedding_fields = embedding_fields
|
475
|
+
|
476
|
+
def should_embed(self, key: str) -> bool:
|
477
|
+
"""判断字段是否需要 embedding"""
|
478
|
+
return self._embedding_fields.get(key, False)
|
193
479
|
|
194
480
|
@lock_decorator
|
195
|
-
async def get(
|
196
|
-
|
197
|
-
|
198
|
-
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
199
|
-
) -> Any:
|
200
|
-
"""
|
201
|
-
Retrieves a value from memory based on the given key and access mode.
|
481
|
+
async def get(self, key: Any,
|
482
|
+
mode: Union[Literal["read only"], Literal["read and write"]] = "read only") -> Any:
|
483
|
+
"""从记忆中获取值
|
202
484
|
|
203
485
|
Args:
|
204
|
-
key
|
205
|
-
mode
|
486
|
+
key: 要获取的键
|
487
|
+
mode: 访问模式,"read only" 或 "read and write"
|
206
488
|
|
207
489
|
Returns:
|
208
|
-
|
490
|
+
获取到的值
|
209
491
|
|
210
492
|
Raises:
|
211
|
-
ValueError:
|
212
|
-
KeyError:
|
493
|
+
ValueError: 如果提供了无效的模式
|
494
|
+
KeyError: 如果在任何记忆部分都找不到该键
|
213
495
|
"""
|
214
496
|
if mode == "read only":
|
215
497
|
process_func = deepcopy
|
@@ -217,54 +499,56 @@ class Memory:
|
|
217
499
|
process_func = lambda x: x
|
218
500
|
else:
|
219
501
|
raise ValueError(f"Invalid get mode `{mode}`!")
|
220
|
-
|
502
|
+
|
503
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
221
504
|
try:
|
222
|
-
value = await
|
505
|
+
value = await mem.get(key)
|
223
506
|
return process_func(value)
|
224
|
-
except KeyError
|
507
|
+
except KeyError:
|
225
508
|
continue
|
226
509
|
raise KeyError(f"No attribute `{key}` in memories!")
|
227
510
|
|
228
511
|
@lock_decorator
|
229
|
-
async def update(
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
234
|
-
store_snapshot: bool = False,
|
235
|
-
protect_llm_read_only_fields: bool = True,
|
236
|
-
) -> None:
|
512
|
+
async def update(self, key: Any, value: Any,
|
513
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
514
|
+
store_snapshot: bool = False,
|
515
|
+
protect_llm_read_only_fields: bool = True) -> None:
|
237
516
|
"""更新记忆值并在必要时更新embedding"""
|
238
517
|
if protect_llm_read_only_fields:
|
239
518
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
240
519
|
logger.warning(f"Trying to write protected key `{key}`!")
|
241
520
|
return
|
242
|
-
|
521
|
+
|
522
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
243
523
|
try:
|
244
|
-
original_value = await
|
524
|
+
original_value = await mem.get(key)
|
245
525
|
if mode == "replace":
|
246
|
-
await
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
#
|
526
|
+
await mem.update(key, value, store_snapshot)
|
527
|
+
if self.should_embed(key) and self._embedding_model:
|
528
|
+
semantic_text = self._generate_semantic_text(key, value)
|
529
|
+
|
530
|
+
# 删除旧的 embedding
|
251
531
|
orig_doc_id = self._embedding_field_to_doc_id[key]
|
252
532
|
if orig_doc_id:
|
253
|
-
await self.
|
533
|
+
await self._faiss_query.delete_documents(
|
254
534
|
to_delete_ids=[orig_doc_id],
|
255
535
|
)
|
256
|
-
|
257
|
-
|
258
|
-
|
536
|
+
|
537
|
+
# 添加新的 embedding
|
538
|
+
doc_ids = await self._faiss_query.add_documents(
|
539
|
+
agent_id=self._agent_id,
|
540
|
+
documents=semantic_text,
|
259
541
|
extra_tags={
|
260
|
-
"type":
|
542
|
+
"type": self._get_memory_type(mem),
|
261
543
|
"key": key,
|
262
544
|
},
|
263
545
|
)
|
264
546
|
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
547
|
+
|
265
548
|
if key in self.watchers:
|
266
549
|
for callback in self.watchers[key]:
|
267
550
|
asyncio.create_task(callback())
|
551
|
+
|
268
552
|
elif mode == "merge":
|
269
553
|
if isinstance(original_value, set):
|
270
554
|
original_value.update(set(value))
|
@@ -278,14 +562,14 @@ class Memory:
|
|
278
562
|
logger.debug(
|
279
563
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
280
564
|
)
|
281
|
-
await
|
282
|
-
if self.
|
283
|
-
|
284
|
-
doc_ids = await self.
|
285
|
-
agent_id=self.
|
565
|
+
await mem.update(key, value, store_snapshot)
|
566
|
+
if self.should_embed(key) and self._embedding_model:
|
567
|
+
semantic_text = self._generate_semantic_text(key, value)
|
568
|
+
doc_ids = await self._faiss_query.add_documents(
|
569
|
+
agent_id=self._agent_id,
|
286
570
|
documents=f"{key}: {str(original_value)}",
|
287
571
|
extra_tags={
|
288
|
-
"type":
|
572
|
+
"type": self._get_memory_type(mem),
|
289
573
|
"key": key,
|
290
574
|
},
|
291
575
|
)
|
@@ -302,63 +586,20 @@ class Memory:
|
|
302
586
|
|
303
587
|
def _get_memory_type(self, mem: Any) -> str:
|
304
588
|
"""获取记忆类型"""
|
305
|
-
if mem is self.
|
589
|
+
if mem is self.state:
|
306
590
|
return "state"
|
307
|
-
elif mem is self.
|
591
|
+
elif mem is self.profile:
|
308
592
|
return "profile"
|
309
593
|
else:
|
310
594
|
return "dynamic"
|
311
595
|
|
312
|
-
async def update_batch(
|
313
|
-
self,
|
314
|
-
content: Union[dict, Sequence[tuple[Any, Any]]],
|
315
|
-
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
316
|
-
store_snapshot: bool = False,
|
317
|
-
protect_llm_read_only_fields: bool = True,
|
318
|
-
) -> None:
|
319
|
-
"""
|
320
|
-
Updates multiple values in the memory at once.
|
321
|
-
|
322
|
-
Args:
|
323
|
-
content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
|
324
|
-
mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
325
|
-
store_snapshot (bool): Whether to store a snapshot of the memory after the update.
|
326
|
-
protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
|
327
|
-
|
328
|
-
Raises:
|
329
|
-
TypeError: If the content type is neither a dictionary nor a sequence of tuples.
|
330
|
-
"""
|
331
|
-
if isinstance(content, dict):
|
332
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
|
333
|
-
elif isinstance(content, Sequence):
|
334
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
|
335
|
-
else:
|
336
|
-
raise TypeError(f"Invalid content type `{type(content)}`!")
|
337
|
-
for k, v in _list_content[:1]:
|
338
|
-
await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
|
339
|
-
for k, v in _list_content[1:]:
|
340
|
-
await self.update(k, v, mode, False, protect_llm_read_only_fields)
|
341
|
-
|
342
596
|
@lock_decorator
|
343
597
|
async def add_watcher(self, key: str, callback: Callable) -> None:
|
344
|
-
"""
|
345
|
-
Adds a callback function to be invoked when the value
|
346
|
-
associated with the specified key in memory is updated.
|
347
|
-
|
348
|
-
Args:
|
349
|
-
key (str): The key for which the watcher is being registered.
|
350
|
-
callback (Callable): A callable function that will be executed
|
351
|
-
whenever the value associated with the specified key is updated.
|
352
|
-
|
353
|
-
Notes:
|
354
|
-
If the key does not already have any watchers, it will be
|
355
|
-
initialized with an empty list before appending the callback.
|
356
|
-
"""
|
598
|
+
"""添加值变更的监听器"""
|
357
599
|
if key not in self.watchers:
|
358
600
|
self.watchers[key] = []
|
359
601
|
self.watchers[key].append(callback)
|
360
602
|
|
361
|
-
@lock_decorator
|
362
603
|
async def export(
|
363
604
|
self,
|
364
605
|
) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
|
@@ -369,12 +610,11 @@ class Memory:
|
|
369
610
|
tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
|
370
611
|
"""
|
371
612
|
return (
|
372
|
-
await self.
|
373
|
-
await self.
|
374
|
-
await self.
|
613
|
+
await self.profile.export(),
|
614
|
+
await self.state.export(),
|
615
|
+
await self.dynamic.export(),
|
375
616
|
)
|
376
617
|
|
377
|
-
@lock_decorator
|
378
618
|
async def load(
|
379
619
|
self,
|
380
620
|
snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
|
@@ -390,41 +630,249 @@ class Memory:
|
|
390
630
|
_profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
|
391
631
|
for _snapshot, _mem in zip(
|
392
632
|
[_profile_snapshot, _state_snapshot, _dynamic_snapshot],
|
393
|
-
[self.
|
633
|
+
[self.state, self.profile, self.dynamic],
|
394
634
|
):
|
395
635
|
if _snapshot:
|
396
636
|
await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
|
397
637
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
638
|
+
class Memory:
|
639
|
+
"""
|
640
|
+
A class to manage different types of memory (state, profile, dynamic).
|
641
|
+
|
642
|
+
Attributes:
|
643
|
+
_state (StateMemory): Stores state-related data.
|
644
|
+
_profile (ProfileMemory): Stores profile-related data.
|
645
|
+
_dynamic (DynamicMemory): Stores dynamically configured data.
|
646
|
+
"""
|
647
|
+
|
648
|
+
def __init__(
|
649
|
+
self,
|
650
|
+
config: Optional[dict[Any, Any]] = None,
|
651
|
+
profile: Optional[dict[Any, Any]] = None,
|
652
|
+
base: Optional[dict[Any, Any]] = None,
|
653
|
+
activate_timestamp: bool = False,
|
654
|
+
embedding_model: Optional[Embeddings] = None,
|
655
|
+
faiss_query: Optional[FaissQuery] = None,
|
656
|
+
) -> None:
|
657
|
+
"""
|
658
|
+
Initializes the Memory with optional configuration.
|
403
659
|
|
404
660
|
Args:
|
405
|
-
|
406
|
-
|
407
|
-
|
661
|
+
config (Optional[dict[Any, Any]], optional):
|
662
|
+
A configuration dictionary for dynamic memory. The dictionary format is:
|
663
|
+
- Key: The name of the dynamic memory field.
|
664
|
+
- Value: Can be one of two formats:
|
665
|
+
1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
|
666
|
+
2. A callable that returns the default value when invoked (useful for complex default values).
|
667
|
+
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
668
|
+
Defaults to None.
|
669
|
+
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
670
|
+
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
671
|
+
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
672
|
+
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
673
|
+
embedding_model (Embeddings): The embedding model for memory search.
|
674
|
+
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
675
|
+
"""
|
676
|
+
self.watchers: dict[str, list[Callable]] = {}
|
677
|
+
self._lock = asyncio.Lock()
|
678
|
+
self._agent_id: int = -1
|
679
|
+
self._simulator = None
|
680
|
+
self._embedding_model = embedding_model
|
681
|
+
self._faiss_query = faiss_query
|
682
|
+
self._semantic_templates: dict[str, str] = {}
|
683
|
+
_dynamic_config: dict[Any, Any] = {}
|
684
|
+
_state_config: dict[Any, Any] = {}
|
685
|
+
_profile_config: dict[Any, Any] = {}
|
686
|
+
self._embedding_fields: dict[str, bool] = {}
|
687
|
+
self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
|
408
688
|
|
409
|
-
|
410
|
-
|
689
|
+
if config is not None:
|
690
|
+
for k, v in config.items():
|
691
|
+
try:
|
692
|
+
# 处理不同长度的配置元组
|
693
|
+
if isinstance(v, tuple):
|
694
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
695
|
+
_type, _value, enable_embedding, template = v
|
696
|
+
self._embedding_fields[k] = enable_embedding
|
697
|
+
self._semantic_templates[k] = template
|
698
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
699
|
+
_type, _value, enable_embedding = v
|
700
|
+
self._embedding_fields[k] = enable_embedding
|
701
|
+
else: # (_type, _value)
|
702
|
+
_type, _value = v
|
703
|
+
self._embedding_fields[k] = False
|
704
|
+
else:
|
705
|
+
_type = type(v)
|
706
|
+
_value = v
|
707
|
+
self._embedding_fields[k] = False
|
708
|
+
|
709
|
+
# 处理类型转换
|
710
|
+
try:
|
711
|
+
if isinstance(_type, type):
|
712
|
+
_value = _type(_value)
|
713
|
+
else:
|
714
|
+
if isinstance(_type, deque):
|
715
|
+
_type.extend(_value)
|
716
|
+
_value = deepcopy(_type)
|
717
|
+
else:
|
718
|
+
logger.warning(f"type `{_type}` is not supported!")
|
719
|
+
except TypeError as e:
|
720
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
721
|
+
except TypeError as e:
|
722
|
+
if isinstance(v, type):
|
723
|
+
_value = v()
|
724
|
+
else:
|
725
|
+
_value = v
|
726
|
+
self._embedding_fields[k] = False
|
727
|
+
|
728
|
+
if (
|
729
|
+
k in PROFILE_ATTRIBUTES
|
730
|
+
or k in STATE_ATTRIBUTES
|
731
|
+
or k == TIME_STAMP_KEY
|
732
|
+
):
|
733
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
734
|
+
continue
|
735
|
+
|
736
|
+
_dynamic_config[k] = deepcopy(_value)
|
737
|
+
|
738
|
+
# 初始化各类记忆
|
739
|
+
self._dynamic = DynamicMemory(
|
740
|
+
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
741
|
+
)
|
742
|
+
|
743
|
+
if profile is not None:
|
744
|
+
for k, v in profile.items():
|
745
|
+
if k not in PROFILE_ATTRIBUTES:
|
746
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
747
|
+
continue
|
748
|
+
|
749
|
+
try:
|
750
|
+
# 处理配置元组格式
|
751
|
+
if isinstance(v, tuple):
|
752
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
753
|
+
_type, _value, enable_embedding, template = v
|
754
|
+
self._embedding_fields[k] = enable_embedding
|
755
|
+
self._semantic_templates[k] = template
|
756
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
757
|
+
_type, _value, enable_embedding = v
|
758
|
+
self._embedding_fields[k] = enable_embedding
|
759
|
+
else: # (_type, _value)
|
760
|
+
_type, _value = v
|
761
|
+
self._embedding_fields[k] = False
|
762
|
+
|
763
|
+
# 处理类型转换
|
764
|
+
try:
|
765
|
+
if isinstance(_type, type):
|
766
|
+
_value = _type(_value)
|
767
|
+
else:
|
768
|
+
if isinstance(_type, deque):
|
769
|
+
_type.extend(_value)
|
770
|
+
_value = deepcopy(_type)
|
771
|
+
else:
|
772
|
+
logger.warning(f"type `{_type}` is not supported!")
|
773
|
+
except TypeError as e:
|
774
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
775
|
+
else:
|
776
|
+
# 保持对简单键值对的兼容
|
777
|
+
_value = v
|
778
|
+
self._embedding_fields[k] = False
|
779
|
+
except TypeError as e:
|
780
|
+
if isinstance(v, type):
|
781
|
+
_value = v()
|
782
|
+
else:
|
783
|
+
_value = v
|
784
|
+
self._embedding_fields[k] = False
|
785
|
+
|
786
|
+
_profile_config[k] = deepcopy(_value)
|
787
|
+
self._profile = ProfileMemory(
|
788
|
+
msg=_profile_config, activate_timestamp=activate_timestamp
|
789
|
+
)
|
790
|
+
|
791
|
+
if base is not None:
|
792
|
+
for k, v in base.items():
|
793
|
+
if k not in STATE_ATTRIBUTES:
|
794
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
795
|
+
continue
|
796
|
+
_state_config[k] = v
|
797
|
+
|
798
|
+
self._state = StateMemory(
|
799
|
+
msg=_state_config, activate_timestamp=activate_timestamp
|
800
|
+
)
|
801
|
+
|
802
|
+
# 组合 StatusMemory,并传递 embedding_fields 信息
|
803
|
+
self._status = StatusMemory(
|
804
|
+
profile=self._profile,
|
805
|
+
state=self._state,
|
806
|
+
dynamic=self._dynamic
|
807
|
+
)
|
808
|
+
self._status.set_embedding_fields(self._embedding_fields)
|
809
|
+
self._status.set_search_components(self._faiss_query, self._embedding_model)
|
810
|
+
|
811
|
+
# 新增 StreamMemory
|
812
|
+
self._stream = StreamMemory()
|
813
|
+
self._stream.set_status_memory(self._status)
|
814
|
+
self._stream.set_search_components(self._faiss_query, self._embedding_model)
|
815
|
+
|
816
|
+
def set_search_components(
|
817
|
+
self,
|
818
|
+
faiss_query: FaissQuery,
|
819
|
+
embedding_model: Embeddings,
|
820
|
+
):
|
821
|
+
self._embedding_model = embedding_model
|
822
|
+
self._faiss_query = faiss_query
|
823
|
+
self._stream.set_search_components(faiss_query, embedding_model)
|
824
|
+
self._status.set_search_components(faiss_query, embedding_model)
|
825
|
+
|
826
|
+
def set_agent_id(self, agent_id: int):
|
411
827
|
"""
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
828
|
+
Set the FaissQuery of the agent.
|
829
|
+
"""
|
830
|
+
self._agent_id = agent_id
|
831
|
+
self._stream.set_agent_id(agent_id)
|
832
|
+
self._status.set_agent_id(agent_id)
|
833
|
+
|
834
|
+
def set_simulator(self, simulator):
|
835
|
+
self._simulator = simulator
|
836
|
+
self._stream.set_simulator(simulator)
|
837
|
+
self._status.set_simulator(simulator)
|
838
|
+
|
839
|
+
@property
|
840
|
+
def status(self) -> StatusMemory:
|
841
|
+
return self._status
|
842
|
+
|
843
|
+
@property
|
844
|
+
def stream(self) -> StreamMemory:
|
845
|
+
return self._stream
|
846
|
+
|
847
|
+
@property
|
848
|
+
def embedding_model(
|
849
|
+
self,
|
850
|
+
):
|
851
|
+
if self._embedding_model is None:
|
852
|
+
raise RuntimeError(
|
853
|
+
f"embedding_model before assignment, please `set_embedding_model` first!"
|
421
854
|
)
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
855
|
+
return self._embedding_model
|
856
|
+
|
857
|
+
@property
|
858
|
+
def agent_id(
|
859
|
+
self,
|
860
|
+
):
|
861
|
+
if self._agent_id < 0:
|
862
|
+
raise RuntimeError(
|
863
|
+
f"agent_id before assignment, please `set_agent_id` first!"
|
428
864
|
)
|
865
|
+
return self._agent_id
|
429
866
|
|
430
|
-
|
867
|
+
@property
|
868
|
+
def faiss_query(self) -> FaissQuery:
|
869
|
+
"""FaissQuery"""
|
870
|
+
if self._faiss_query is None:
|
871
|
+
raise RuntimeError(
|
872
|
+
f"FaissQuery access before assignment, please `set_faiss_query` first!"
|
873
|
+
)
|
874
|
+
return self._faiss_query
|
875
|
+
|
876
|
+
async def initialize_embeddings(self):
|
877
|
+
"""初始化embedding"""
|
878
|
+
await self._status.initialize_embeddings()
|