pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
pycityagent/memory/memory.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import asyncio
|
2
2
|
import logging
|
3
3
|
from collections import defaultdict
|
4
|
-
from collections.abc import Callable, Sequence
|
4
|
+
from collections.abc import Callable, Coroutine, Sequence
|
5
5
|
from copy import deepcopy
|
6
|
-
from
|
7
|
-
from
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Any, Dict, Literal, Optional, Union
|
8
9
|
|
9
|
-
import numpy as np
|
10
10
|
from langchain_core.embeddings import Embeddings
|
11
11
|
from pyparsing import deque
|
12
12
|
|
@@ -20,176 +20,496 @@ from .state import StateMemory
|
|
20
20
|
logger = logging.getLogger("pycityagent")
|
21
21
|
|
22
22
|
|
23
|
-
class
|
24
|
-
"""
|
25
|
-
A class to manage different types of memory (state, profile, dynamic).
|
23
|
+
class MemoryTag(str, Enum):
|
24
|
+
"""记忆标签枚举类"""
|
26
25
|
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
""
|
26
|
+
MOBILITY = "mobility"
|
27
|
+
SOCIAL = "social"
|
28
|
+
ECONOMY = "economy"
|
29
|
+
COGNITION = "cognition"
|
30
|
+
OTHER = "other"
|
31
|
+
EVENT = "event"
|
32
32
|
|
33
|
-
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class MemoryNode:
|
36
|
+
"""记忆节点"""
|
37
|
+
|
38
|
+
tag: MemoryTag
|
39
|
+
day: int
|
40
|
+
t: int
|
41
|
+
location: str
|
42
|
+
description: str
|
43
|
+
cognition_id: Optional[int] = None # 关联的认知记忆ID
|
44
|
+
id: Optional[int] = None # 记忆ID
|
45
|
+
|
46
|
+
|
47
|
+
class StreamMemory:
|
48
|
+
"""用于存储时序性的流式信息"""
|
49
|
+
|
50
|
+
def __init__(self, max_len: int = 1000):
|
51
|
+
self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
|
52
|
+
self._memory_id_counter: int = 0 # 用于生成唯一ID
|
53
|
+
self._faiss_query = None
|
54
|
+
self._embedding_model = None
|
55
|
+
self._agent_id = -1
|
56
|
+
self._status_memory = None
|
57
|
+
self._simulator = None
|
58
|
+
|
59
|
+
@property
|
60
|
+
def faiss_query(
|
34
61
|
self,
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
motion: Optional[dict[Any, Any]] = None,
|
39
|
-
activate_timestamp: bool = False,
|
40
|
-
embedding_model: Optional[Embeddings] = None,
|
41
|
-
faiss_query: Optional[FaissQuery] = None,
|
42
|
-
) -> None:
|
43
|
-
"""
|
44
|
-
Initializes the Memory with optional configuration.
|
62
|
+
) -> FaissQuery:
|
63
|
+
assert self._faiss_query is not None
|
64
|
+
return self._faiss_query
|
45
65
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
2. A callable that returns the default value when invoked (useful for complex default values).
|
53
|
-
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
54
|
-
Defaults to None.
|
55
|
-
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
56
|
-
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
57
|
-
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
58
|
-
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
59
|
-
embedding_model (Embeddings): The embedding model for memory search.
|
60
|
-
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
61
|
-
"""
|
62
|
-
self.watchers: dict[str, list[Callable]] = {}
|
63
|
-
self._lock = asyncio.Lock()
|
64
|
-
self._agent_id: int = -1
|
65
|
-
self._embedding_model = embedding_model
|
66
|
+
@property
|
67
|
+
def status_memory(
|
68
|
+
self,
|
69
|
+
):
|
70
|
+
assert self._status_memory is not None
|
71
|
+
return self._status_memory
|
66
72
|
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
self.
|
72
|
-
|
73
|
+
def set_simulator(self, simulator):
|
74
|
+
self._simulator = simulator
|
75
|
+
|
76
|
+
def set_status_memory(self, status_memory):
|
77
|
+
self._status_memory = status_memory
|
78
|
+
|
79
|
+
def set_search_components(self, faiss_query, embedding_model):
|
80
|
+
"""设置搜索所需的组件"""
|
73
81
|
self._faiss_query = faiss_query
|
82
|
+
self._embedding_model = embedding_model
|
74
83
|
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
# 处理新的三元组格式
|
79
|
-
if isinstance(v, tuple) and len(v) == 3:
|
80
|
-
_type, _value, enable_embedding = v
|
81
|
-
self._embedding_fields[k] = enable_embedding
|
82
|
-
else:
|
83
|
-
_type, _value = v
|
84
|
-
self._embedding_fields[k] = False
|
84
|
+
def set_agent_id(self, agent_id: int):
|
85
|
+
"""设置agent_id"""
|
86
|
+
self._agent_id = agent_id
|
85
87
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
88
|
+
async def _add_memory(self, tag: MemoryTag, description: str) -> int:
|
89
|
+
"""添加记忆节点的通用方法,返回记忆节点ID"""
|
90
|
+
if self._simulator is not None:
|
91
|
+
day = int(await self._simulator.get_simulator_day())
|
92
|
+
t = int(await self._simulator.get_time())
|
93
|
+
else:
|
94
|
+
day = 1
|
95
|
+
t = 1
|
96
|
+
position = await self.status_memory.get("position")
|
97
|
+
if "aoi_position" in position:
|
98
|
+
location = position["aoi_position"]["aoi_id"]
|
99
|
+
elif "lane_position" in position:
|
100
|
+
location = position["lane_position"]["lane_id"]
|
101
|
+
else:
|
102
|
+
location = "unknown"
|
103
|
+
|
104
|
+
current_id = self._memory_id_counter
|
105
|
+
self._memory_id_counter += 1
|
106
|
+
memory_node = MemoryNode(
|
107
|
+
tag=tag,
|
108
|
+
day=day,
|
109
|
+
t=t,
|
110
|
+
location=location,
|
111
|
+
description=description,
|
112
|
+
id=current_id,
|
113
|
+
)
|
114
|
+
self._memories.append(memory_node)
|
115
|
+
|
116
|
+
# 为新记忆创建 embedding
|
117
|
+
if self._embedding_model and self._faiss_query:
|
118
|
+
await self.faiss_query.add_documents(
|
119
|
+
agent_id=self._agent_id,
|
120
|
+
documents=description,
|
121
|
+
extra_tags={
|
122
|
+
"type": "stream",
|
123
|
+
"tag": tag,
|
124
|
+
"day": day,
|
125
|
+
"time": t,
|
126
|
+
},
|
127
|
+
)
|
104
128
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
129
|
+
return current_id
|
130
|
+
|
131
|
+
async def add_cognition(self, description: str) -> int:
|
132
|
+
"""添加认知记忆 Add cognition memory"""
|
133
|
+
return await self._add_memory(MemoryTag.COGNITION, description)
|
134
|
+
|
135
|
+
async def add_social(self, description: str) -> int:
|
136
|
+
"""添加社交记忆 Add social memory"""
|
137
|
+
return await self._add_memory(MemoryTag.SOCIAL, description)
|
138
|
+
|
139
|
+
async def add_economy(self, description: str) -> int:
|
140
|
+
"""添加经济记忆 Add economy memory"""
|
141
|
+
return await self._add_memory(MemoryTag.ECONOMY, description)
|
142
|
+
|
143
|
+
async def add_mobility(self, description: str) -> int:
|
144
|
+
"""添加移动记忆 Add mobility memory"""
|
145
|
+
return await self._add_memory(MemoryTag.MOBILITY, description)
|
146
|
+
|
147
|
+
async def add_event(self, description: str) -> int:
|
148
|
+
"""添加事件记忆 Add event memory"""
|
149
|
+
return await self._add_memory(MemoryTag.EVENT, description)
|
150
|
+
|
151
|
+
async def add_other(self, description: str) -> int:
|
152
|
+
"""添加其他记忆 Add other memory"""
|
153
|
+
return await self._add_memory(MemoryTag.OTHER, description)
|
154
|
+
|
155
|
+
async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
|
156
|
+
"""获取关联的认知记忆 Get related cognition memory"""
|
157
|
+
for memory in self._memories:
|
158
|
+
if memory.cognition_id == memory_id:
|
159
|
+
for cognition_memory in self._memories:
|
160
|
+
if (
|
161
|
+
cognition_memory.tag == MemoryTag.COGNITION
|
162
|
+
and memory.cognition_id is not None
|
163
|
+
):
|
164
|
+
return cognition_memory
|
165
|
+
return None
|
166
|
+
|
167
|
+
async def format_memory(self, memories: list[MemoryNode]) -> str:
|
168
|
+
"""格式化记忆"""
|
169
|
+
formatted_results = []
|
170
|
+
for memory in memories:
|
171
|
+
memory_tag = memory.tag
|
172
|
+
memory_day = memory.day
|
173
|
+
memory_time_seconds = memory.t
|
174
|
+
cognition_id = memory.cognition_id
|
175
|
+
|
176
|
+
# 格式化时间
|
177
|
+
if memory_time_seconds != "unknown":
|
178
|
+
hours = memory_time_seconds // 3600
|
179
|
+
minutes = (memory_time_seconds % 3600) // 60
|
180
|
+
seconds = memory_time_seconds % 60
|
181
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
182
|
+
else:
|
183
|
+
memory_time = "unknown"
|
184
|
+
|
185
|
+
memory_location = memory.location
|
186
|
+
|
187
|
+
# 添加认知信息(如果存在)
|
188
|
+
cognition_info = ""
|
189
|
+
if cognition_id is not None:
|
190
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
191
|
+
if cognition_memory:
|
192
|
+
cognition_info = (
|
193
|
+
f"\n Related cognition: {cognition_memory.description}"
|
194
|
+
)
|
112
195
|
|
113
|
-
|
196
|
+
formatted_results.append(
|
197
|
+
f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
|
198
|
+
f"location: {memory_location}]{cognition_info}"
|
199
|
+
)
|
200
|
+
return "\n".join(formatted_results)
|
114
201
|
|
115
|
-
|
116
|
-
self
|
117
|
-
|
118
|
-
|
202
|
+
async def get_by_ids(
|
203
|
+
self, memory_ids: Union[int, list[int]]
|
204
|
+
) -> Coroutine[Any, Any, str]:
|
205
|
+
"""获取指定ID的记忆"""
|
206
|
+
memories = [memory for memory in self._memories if memory.id in memory_ids]
|
207
|
+
sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
|
208
|
+
return self.format_memory(sorted_results)
|
119
209
|
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
210
|
+
async def search(
|
211
|
+
self,
|
212
|
+
query: str,
|
213
|
+
tag: Optional[MemoryTag] = None,
|
214
|
+
top_k: int = 3,
|
215
|
+
day_range: Optional[tuple[int, int]] = None, # 新增参数
|
216
|
+
time_range: Optional[tuple[int, int]] = None, # 新增参数
|
217
|
+
) -> str:
|
218
|
+
"""Search stream memory
|
219
|
+
|
220
|
+
Args:
|
221
|
+
query: Query text
|
222
|
+
tag: Optional memory tag for filtering specific types of memories
|
223
|
+
top_k: Number of most relevant memories to return
|
224
|
+
day_range: Optional tuple of start and end days (start_day, end_day)
|
225
|
+
time_range: Optional tuple of start and end times (start_time, end_time)
|
226
|
+
"""
|
227
|
+
if not self._embedding_model or not self._faiss_query:
|
228
|
+
return "Search components not initialized"
|
229
|
+
|
230
|
+
filter_dict: dict[str, Any] = {"type": "stream"}
|
231
|
+
|
232
|
+
if tag:
|
233
|
+
filter_dict["tag"] = tag
|
234
|
+
|
235
|
+
# 添加时间范围过滤
|
236
|
+
if day_range:
|
237
|
+
start_day, end_day = day_range
|
238
|
+
filter_dict["day"] = lambda x: start_day <= x <= end_day
|
239
|
+
|
240
|
+
if time_range:
|
241
|
+
start_time, end_time = time_range
|
242
|
+
filter_dict["time"] = lambda x: start_time <= x <= end_time
|
243
|
+
|
244
|
+
top_results = await self.faiss_query.similarity_search(
|
245
|
+
query=query,
|
246
|
+
agent_id=self._agent_id,
|
247
|
+
k=top_k,
|
248
|
+
return_score_type="similarity_score",
|
249
|
+
filter=filter_dict,
|
140
250
|
)
|
141
|
-
|
142
|
-
|
251
|
+
|
252
|
+
# 将结果按时间排序(先按天数,再按时间)
|
253
|
+
sorted_results = sorted(
|
254
|
+
top_results,
|
255
|
+
key=lambda x: (x[2].get("day", 0), x[2].get("time", 0)), # type:ignore
|
256
|
+
reverse=True,
|
143
257
|
)
|
144
|
-
# self.memories = [] # 存储记忆内容
|
145
|
-
# self.embeddings = [] # 存储记忆的向量表示
|
146
258
|
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
259
|
+
formatted_results = []
|
260
|
+
for content, score, metadata in sorted_results: # type:ignore
|
261
|
+
memory_tag = metadata.get("tag", "unknown")
|
262
|
+
memory_day = metadata.get("day", "unknown")
|
263
|
+
memory_time_seconds = metadata.get("time", "unknown")
|
264
|
+
cognition_id = metadata.get("cognition_id", None)
|
265
|
+
|
266
|
+
# 格式化时间
|
267
|
+
if memory_time_seconds != "unknown":
|
268
|
+
hours = memory_time_seconds // 3600
|
269
|
+
minutes = (memory_time_seconds % 3600) // 60
|
270
|
+
seconds = memory_time_seconds % 60
|
271
|
+
memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
272
|
+
else:
|
273
|
+
memory_time = "unknown"
|
274
|
+
|
275
|
+
memory_location = metadata.get("location", "unknown")
|
276
|
+
|
277
|
+
# 添加认知信息(如果存在)
|
278
|
+
cognition_info = ""
|
279
|
+
if cognition_id is not None:
|
280
|
+
cognition_memory = await self.get_related_cognition(cognition_id)
|
281
|
+
if cognition_memory:
|
282
|
+
cognition_info = (
|
283
|
+
f"\n Related cognition: {cognition_memory.description}"
|
284
|
+
)
|
152
285
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
):
|
157
|
-
if self._embedding_model is None:
|
158
|
-
raise RuntimeError(
|
159
|
-
f"embedding_model before assignment, please `set_embedding_model` first!"
|
286
|
+
formatted_results.append(
|
287
|
+
f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
|
288
|
+
f"location: {memory_location}]{cognition_info}"
|
160
289
|
)
|
161
|
-
return
|
290
|
+
return "\n".join(formatted_results)
|
162
291
|
|
163
|
-
def
|
292
|
+
async def search_today(
|
293
|
+
self,
|
294
|
+
query: str = "", # 可选的查询文本
|
295
|
+
tag: Optional[MemoryTag] = None,
|
296
|
+
top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
|
297
|
+
) -> str:
|
298
|
+
"""Search all memory events from today
|
299
|
+
|
300
|
+
Args:
|
301
|
+
query: Optional query text, returns all memories of the day if empty
|
302
|
+
tag: Optional memory tag for filtering specific types of memories
|
303
|
+
top_k: Number of most relevant memories to return, defaults to 100
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
str: Formatted text of today's memories
|
164
307
|
"""
|
165
|
-
|
308
|
+
if self._simulator is None:
|
309
|
+
return "Simulator not initialized"
|
310
|
+
|
311
|
+
current_day = int(await self._simulator.get_simulator_day())
|
312
|
+
|
313
|
+
# 使用 search 方法,设置 day_range 为当天
|
314
|
+
return await self.search(
|
315
|
+
query=query, tag=tag, top_k=top_k, day_range=(current_day, current_day)
|
316
|
+
)
|
317
|
+
|
318
|
+
async def add_cognition_to_memory(
|
319
|
+
self, memory_id: Union[int, list[int]], cognition: str
|
320
|
+
) -> None:
|
321
|
+
"""为已存在的记忆添加认知
|
322
|
+
|
323
|
+
Args:
|
324
|
+
memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
|
325
|
+
cognition: 认知描述
|
166
326
|
"""
|
167
|
-
|
327
|
+
# 将单个ID转换为列表以统一处理
|
328
|
+
memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
|
329
|
+
|
330
|
+
# 找到所有对应的记忆
|
331
|
+
target_memories = []
|
332
|
+
for memory in self._memories:
|
333
|
+
if id(memory) in memory_ids:
|
334
|
+
target_memories.append(memory)
|
335
|
+
|
336
|
+
if not target_memories:
|
337
|
+
raise ValueError(f"No memories found with ids {memory_ids}")
|
338
|
+
|
339
|
+
# 添加认知记忆
|
340
|
+
cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
|
341
|
+
|
342
|
+
# 更新所有原记忆的认知ID
|
343
|
+
for target_memory in target_memories:
|
344
|
+
target_memory.cognition_id = cognition_id
|
345
|
+
|
346
|
+
async def get_all(self) -> list[MemoryNode]:
|
347
|
+
"""获取所有流式信息"""
|
348
|
+
return list(self._memories)
|
349
|
+
|
350
|
+
|
351
|
+
class StatusMemory:
|
352
|
+
"""组合现有的三种记忆类型"""
|
353
|
+
|
354
|
+
def __init__(
|
355
|
+
self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory
|
356
|
+
):
|
357
|
+
self.profile = profile
|
358
|
+
self.state = state
|
359
|
+
self.dynamic = dynamic
|
360
|
+
self._faiss_query = None
|
361
|
+
self._embedding_model = None
|
362
|
+
self._simulator = None
|
363
|
+
self._agent_id = -1
|
364
|
+
self._semantic_templates = {} # 用户可配置的模板
|
365
|
+
self._embedding_fields = {} # 需要 embedding 的字段
|
366
|
+
self._embedding_field_to_doc_id = defaultdict(str) # 新增
|
367
|
+
self.watchers = {} # 新增
|
368
|
+
self._lock = asyncio.Lock() # 新增
|
168
369
|
|
169
370
|
@property
|
170
|
-
def
|
371
|
+
def faiss_query(
|
171
372
|
self,
|
172
|
-
):
|
173
|
-
|
174
|
-
|
175
|
-
|
373
|
+
) -> FaissQuery:
|
374
|
+
assert self._faiss_query is not None
|
375
|
+
return self._faiss_query
|
376
|
+
|
377
|
+
def set_simulator(self, simulator):
|
378
|
+
self._simulator = simulator
|
379
|
+
|
380
|
+
async def initialize_embeddings(self) -> None:
|
381
|
+
"""初始化所有需要 embedding 的字段"""
|
382
|
+
if not self._embedding_model or not self._faiss_query:
|
383
|
+
logger.warning(
|
384
|
+
"Search components not initialized, skipping embeddings initialization"
|
176
385
|
)
|
177
|
-
|
386
|
+
return
|
387
|
+
|
388
|
+
# 获取所有状态信息
|
389
|
+
profile, state, dynamic = await self.export()
|
390
|
+
|
391
|
+
# 为每个需要 embedding 的字段创建 embedding
|
392
|
+
for key, value in profile[0].items():
|
393
|
+
if self.should_embed(key):
|
394
|
+
semantic_text = self._generate_semantic_text(key, value)
|
395
|
+
doc_ids = await self.faiss_query.add_documents(
|
396
|
+
agent_id=self._agent_id,
|
397
|
+
documents=semantic_text,
|
398
|
+
extra_tags={
|
399
|
+
"type": "profile_state",
|
400
|
+
"key": key,
|
401
|
+
},
|
402
|
+
)
|
403
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
404
|
+
|
405
|
+
for key, value in state[0].items():
|
406
|
+
if self.should_embed(key):
|
407
|
+
semantic_text = self._generate_semantic_text(key, value)
|
408
|
+
doc_ids = await self.faiss_query.add_documents(
|
409
|
+
agent_id=self._agent_id,
|
410
|
+
documents=semantic_text,
|
411
|
+
extra_tags={
|
412
|
+
"type": "profile_state",
|
413
|
+
"key": key,
|
414
|
+
},
|
415
|
+
)
|
416
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
417
|
+
|
418
|
+
for key, value in dynamic[0].items():
|
419
|
+
if self.should_embed(key):
|
420
|
+
semantic_text = self._generate_semantic_text(key, value)
|
421
|
+
doc_ids = await self.faiss_query.add_documents(
|
422
|
+
agent_id=self._agent_id,
|
423
|
+
documents=semantic_text,
|
424
|
+
extra_tags={
|
425
|
+
"type": "profile_state",
|
426
|
+
"key": key,
|
427
|
+
},
|
428
|
+
)
|
429
|
+
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
430
|
+
|
431
|
+
def _get_memory_type_by_key(self, key: str) -> str:
|
432
|
+
"""根据键名确定记忆类型"""
|
433
|
+
try:
|
434
|
+
if key in self.profile.__dict__:
|
435
|
+
return "profile"
|
436
|
+
elif key in self.state.__dict__:
|
437
|
+
return "state"
|
438
|
+
else:
|
439
|
+
return "dynamic"
|
440
|
+
except:
|
441
|
+
return "dynamic"
|
442
|
+
|
443
|
+
def set_search_components(self, faiss_query, embedding_model):
|
444
|
+
"""设置搜索所需的组件"""
|
445
|
+
self._faiss_query = faiss_query
|
446
|
+
self._embedding_model = embedding_model
|
178
447
|
|
179
448
|
def set_agent_id(self, agent_id: int):
|
449
|
+
"""设置agent_id"""
|
450
|
+
self._agent_id = agent_id
|
451
|
+
|
452
|
+
def set_semantic_templates(self, templates: Dict[str, str]):
|
453
|
+
"""设置语义模板
|
454
|
+
|
455
|
+
Args:
|
456
|
+
templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
|
180
457
|
"""
|
181
|
-
|
458
|
+
self._semantic_templates = templates
|
459
|
+
|
460
|
+
def _generate_semantic_text(self, key: str, value: Any) -> str:
|
461
|
+
"""生成语义文本
|
462
|
+
|
463
|
+
如果key存在于模板中,使用自定义模板
|
464
|
+
否则使用默认模板 "my {key} is {value}"
|
182
465
|
"""
|
183
|
-
self.
|
466
|
+
if key in self._semantic_templates:
|
467
|
+
return self._semantic_templates[key].format(value)
|
468
|
+
return f"Your {key} is {value}"
|
184
469
|
|
185
|
-
@
|
186
|
-
def
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
470
|
+
@lock_decorator
|
471
|
+
async def search(
|
472
|
+
self, query: str, top_k: int = 3, filter: Optional[dict] = None
|
473
|
+
) -> str:
|
474
|
+
"""搜索相关记忆
|
475
|
+
|
476
|
+
Args:
|
477
|
+
query: 查询文本
|
478
|
+
top_k: 返回最相关的记忆数量
|
479
|
+
filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
|
480
|
+
|
481
|
+
Returns:
|
482
|
+
str: 格式化的相关记忆文本
|
483
|
+
"""
|
484
|
+
if not self._embedding_model:
|
485
|
+
return "Embedding model not initialized"
|
486
|
+
|
487
|
+
filter_dict = {"type": "profile_state"}
|
488
|
+
if filter is not None:
|
489
|
+
filter_dict.update(filter)
|
490
|
+
top_results: list[tuple[str, float, dict]] = (
|
491
|
+
await self.faiss_query.similarity_search( # type:ignore
|
492
|
+
query=query,
|
493
|
+
agent_id=self._agent_id,
|
494
|
+
k=top_k,
|
495
|
+
return_score_type="similarity_score",
|
496
|
+
filter=filter_dict,
|
191
497
|
)
|
192
|
-
|
498
|
+
)
|
499
|
+
# 格式化输出
|
500
|
+
formatted_results = []
|
501
|
+
for content, score, metadata in top_results:
|
502
|
+
formatted_results.append(f"- {content} ")
|
503
|
+
|
504
|
+
return "\n".join(formatted_results)
|
505
|
+
|
506
|
+
def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
|
507
|
+
"""设置需要 embedding 的字段"""
|
508
|
+
self._embedding_fields = embedding_fields
|
509
|
+
|
510
|
+
def should_embed(self, key: str) -> bool:
|
511
|
+
"""判断字段是否需要 embedding"""
|
512
|
+
return self._embedding_fields.get(key, False)
|
193
513
|
|
194
514
|
@lock_decorator
|
195
515
|
async def get(
|
@@ -197,19 +517,18 @@ class Memory:
|
|
197
517
|
key: Any,
|
198
518
|
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
199
519
|
) -> Any:
|
200
|
-
"""
|
201
|
-
Retrieves a value from memory based on the given key and access mode.
|
520
|
+
"""从记忆中获取值
|
202
521
|
|
203
522
|
Args:
|
204
|
-
key
|
205
|
-
mode
|
523
|
+
key: 要获取的键
|
524
|
+
mode: 访问模式,"read only" 或 "read and write"
|
206
525
|
|
207
526
|
Returns:
|
208
|
-
|
527
|
+
获取到的值
|
209
528
|
|
210
529
|
Raises:
|
211
|
-
ValueError:
|
212
|
-
KeyError:
|
530
|
+
ValueError: 如果提供了无效的模式
|
531
|
+
KeyError: 如果在任何记忆部分都找不到该键
|
213
532
|
"""
|
214
533
|
if mode == "read only":
|
215
534
|
process_func = deepcopy
|
@@ -217,11 +536,12 @@ class Memory:
|
|
217
536
|
process_func = lambda x: x
|
218
537
|
else:
|
219
538
|
raise ValueError(f"Invalid get mode `{mode}`!")
|
220
|
-
|
539
|
+
|
540
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
221
541
|
try:
|
222
|
-
value = await
|
542
|
+
value = await mem.get(key)
|
223
543
|
return process_func(value)
|
224
|
-
except KeyError
|
544
|
+
except KeyError:
|
225
545
|
continue
|
226
546
|
raise KeyError(f"No attribute `{key}` in memories!")
|
227
547
|
|
@@ -239,32 +559,37 @@ class Memory:
|
|
239
559
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
240
560
|
logger.warning(f"Trying to write protected key `{key}`!")
|
241
561
|
return
|
242
|
-
|
562
|
+
|
563
|
+
for mem in [self.state, self.profile, self.dynamic]:
|
243
564
|
try:
|
244
|
-
original_value = await
|
565
|
+
original_value = await mem.get(key)
|
245
566
|
if mode == "replace":
|
246
|
-
await
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
#
|
567
|
+
await mem.update(key, value, store_snapshot)
|
568
|
+
if self.should_embed(key) and self._embedding_model:
|
569
|
+
semantic_text = self._generate_semantic_text(key, value)
|
570
|
+
|
571
|
+
# 删除旧的 embedding
|
251
572
|
orig_doc_id = self._embedding_field_to_doc_id[key]
|
252
573
|
if orig_doc_id:
|
253
574
|
await self.faiss_query.delete_documents(
|
254
575
|
to_delete_ids=[orig_doc_id],
|
255
576
|
)
|
256
|
-
|
257
|
-
|
258
|
-
|
577
|
+
|
578
|
+
# 添加新的 embedding
|
579
|
+
doc_ids = await self.faiss_query.add_documents(
|
580
|
+
agent_id=self._agent_id,
|
581
|
+
documents=semantic_text,
|
259
582
|
extra_tags={
|
260
|
-
"type":
|
583
|
+
"type": self._get_memory_type(mem),
|
261
584
|
"key": key,
|
262
585
|
},
|
263
586
|
)
|
264
587
|
self._embedding_field_to_doc_id[key] = doc_ids[0]
|
588
|
+
|
265
589
|
if key in self.watchers:
|
266
590
|
for callback in self.watchers[key]:
|
267
591
|
asyncio.create_task(callback())
|
592
|
+
|
268
593
|
elif mode == "merge":
|
269
594
|
if isinstance(original_value, set):
|
270
595
|
original_value.update(set(value))
|
@@ -278,14 +603,14 @@ class Memory:
|
|
278
603
|
logger.debug(
|
279
604
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
280
605
|
)
|
281
|
-
await
|
282
|
-
if self.
|
283
|
-
|
606
|
+
await mem.update(key, value, store_snapshot)
|
607
|
+
if self.should_embed(key) and self._embedding_model:
|
608
|
+
semantic_text = self._generate_semantic_text(key, value)
|
284
609
|
doc_ids = await self.faiss_query.add_documents(
|
285
|
-
agent_id=self.
|
610
|
+
agent_id=self._agent_id,
|
286
611
|
documents=f"{key}: {str(original_value)}",
|
287
612
|
extra_tags={
|
288
|
-
"type":
|
613
|
+
"type": self._get_memory_type(mem),
|
289
614
|
"key": key,
|
290
615
|
},
|
291
616
|
)
|
@@ -302,63 +627,20 @@ class Memory:
|
|
302
627
|
|
303
628
|
def _get_memory_type(self, mem: Any) -> str:
|
304
629
|
"""获取记忆类型"""
|
305
|
-
if mem is self.
|
630
|
+
if mem is self.state:
|
306
631
|
return "state"
|
307
|
-
elif mem is self.
|
632
|
+
elif mem is self.profile:
|
308
633
|
return "profile"
|
309
634
|
else:
|
310
635
|
return "dynamic"
|
311
636
|
|
312
|
-
async def update_batch(
|
313
|
-
self,
|
314
|
-
content: Union[dict, Sequence[tuple[Any, Any]]],
|
315
|
-
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
316
|
-
store_snapshot: bool = False,
|
317
|
-
protect_llm_read_only_fields: bool = True,
|
318
|
-
) -> None:
|
319
|
-
"""
|
320
|
-
Updates multiple values in the memory at once.
|
321
|
-
|
322
|
-
Args:
|
323
|
-
content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
|
324
|
-
mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
325
|
-
store_snapshot (bool): Whether to store a snapshot of the memory after the update.
|
326
|
-
protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
|
327
|
-
|
328
|
-
Raises:
|
329
|
-
TypeError: If the content type is neither a dictionary nor a sequence of tuples.
|
330
|
-
"""
|
331
|
-
if isinstance(content, dict):
|
332
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
|
333
|
-
elif isinstance(content, Sequence):
|
334
|
-
_list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
|
335
|
-
else:
|
336
|
-
raise TypeError(f"Invalid content type `{type(content)}`!")
|
337
|
-
for k, v in _list_content[:1]:
|
338
|
-
await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
|
339
|
-
for k, v in _list_content[1:]:
|
340
|
-
await self.update(k, v, mode, False, protect_llm_read_only_fields)
|
341
|
-
|
342
637
|
@lock_decorator
|
343
638
|
async def add_watcher(self, key: str, callback: Callable) -> None:
|
344
|
-
"""
|
345
|
-
Adds a callback function to be invoked when the value
|
346
|
-
associated with the specified key in memory is updated.
|
347
|
-
|
348
|
-
Args:
|
349
|
-
key (str): The key for which the watcher is being registered.
|
350
|
-
callback (Callable): A callable function that will be executed
|
351
|
-
whenever the value associated with the specified key is updated.
|
352
|
-
|
353
|
-
Notes:
|
354
|
-
If the key does not already have any watchers, it will be
|
355
|
-
initialized with an empty list before appending the callback.
|
356
|
-
"""
|
639
|
+
"""添加值变更的监听器"""
|
357
640
|
if key not in self.watchers:
|
358
641
|
self.watchers[key] = []
|
359
642
|
self.watchers[key].append(callback)
|
360
643
|
|
361
|
-
@lock_decorator
|
362
644
|
async def export(
|
363
645
|
self,
|
364
646
|
) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
|
@@ -369,12 +651,11 @@ class Memory:
|
|
369
651
|
tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
|
370
652
|
"""
|
371
653
|
return (
|
372
|
-
await self.
|
373
|
-
await self.
|
374
|
-
await self.
|
654
|
+
await self.profile.export(),
|
655
|
+
await self.state.export(),
|
656
|
+
await self.dynamic.export(),
|
375
657
|
)
|
376
658
|
|
377
|
-
@lock_decorator
|
378
659
|
async def load(
|
379
660
|
self,
|
380
661
|
snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
|
@@ -390,41 +671,246 @@ class Memory:
|
|
390
671
|
_profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
|
391
672
|
for _snapshot, _mem in zip(
|
392
673
|
[_profile_snapshot, _state_snapshot, _dynamic_snapshot],
|
393
|
-
[self.
|
674
|
+
[self.state, self.profile, self.dynamic],
|
394
675
|
):
|
395
676
|
if _snapshot:
|
396
677
|
await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
|
397
678
|
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
679
|
+
|
680
|
+
class Memory:
|
681
|
+
"""
|
682
|
+
A class to manage different types of memory (state, profile, dynamic).
|
683
|
+
|
684
|
+
Attributes:
|
685
|
+
_state (StateMemory): Stores state-related data.
|
686
|
+
_profile (ProfileMemory): Stores profile-related data.
|
687
|
+
_dynamic (DynamicMemory): Stores dynamically configured data.
|
688
|
+
"""
|
689
|
+
|
690
|
+
def __init__(
|
691
|
+
self,
|
692
|
+
config: Optional[dict[Any, Any]] = None,
|
693
|
+
profile: Optional[dict[Any, Any]] = None,
|
694
|
+
base: Optional[dict[Any, Any]] = None,
|
695
|
+
activate_timestamp: bool = False,
|
696
|
+
embedding_model: Optional[Embeddings] = None,
|
697
|
+
faiss_query: Optional[FaissQuery] = None,
|
698
|
+
) -> None:
|
699
|
+
"""
|
700
|
+
Initializes the Memory with optional configuration.
|
403
701
|
|
404
702
|
Args:
|
405
|
-
|
406
|
-
|
407
|
-
|
703
|
+
config (Optional[dict[Any, Any]], optional):
|
704
|
+
A configuration dictionary for dynamic memory. The dictionary format is:
|
705
|
+
- Key: The name of the dynamic memory field.
|
706
|
+
- Value: Can be one of two formats:
|
707
|
+
1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
|
708
|
+
2. A callable that returns the default value when invoked (useful for complex default values).
|
709
|
+
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
710
|
+
Defaults to None.
|
711
|
+
profile (Optional[dict[Any, Any]], optional): profile attribute dict.
|
712
|
+
base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
|
713
|
+
motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
714
|
+
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
715
|
+
embedding_model (Embeddings): The embedding model for memory search.
|
716
|
+
faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
|
717
|
+
"""
|
718
|
+
self.watchers: dict[str, list[Callable]] = {}
|
719
|
+
self._lock = asyncio.Lock()
|
720
|
+
self._agent_id: int = -1
|
721
|
+
self._simulator = None
|
722
|
+
self._embedding_model = embedding_model
|
723
|
+
self._faiss_query = faiss_query
|
724
|
+
self._semantic_templates: dict[str, str] = {}
|
725
|
+
_dynamic_config: dict[Any, Any] = {}
|
726
|
+
_state_config: dict[Any, Any] = {}
|
727
|
+
_profile_config: dict[Any, Any] = {}
|
728
|
+
self._embedding_fields: dict[str, bool] = {}
|
729
|
+
self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
|
408
730
|
|
409
|
-
|
410
|
-
|
731
|
+
if config is not None:
|
732
|
+
for k, v in config.items():
|
733
|
+
try:
|
734
|
+
# 处理不同长度的配置元组
|
735
|
+
if isinstance(v, tuple):
|
736
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
737
|
+
_type, _value, enable_embedding, template = v
|
738
|
+
self._embedding_fields[k] = enable_embedding
|
739
|
+
self._semantic_templates[k] = template
|
740
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
741
|
+
_type, _value, enable_embedding = v
|
742
|
+
self._embedding_fields[k] = enable_embedding
|
743
|
+
else: # (_type, _value)
|
744
|
+
_type, _value = v
|
745
|
+
self._embedding_fields[k] = False
|
746
|
+
else:
|
747
|
+
_type = type(v)
|
748
|
+
_value = v
|
749
|
+
self._embedding_fields[k] = False
|
750
|
+
|
751
|
+
# 处理类型转换
|
752
|
+
try:
|
753
|
+
if isinstance(_type, type):
|
754
|
+
_value = _type(_value)
|
755
|
+
else:
|
756
|
+
if isinstance(_type, deque):
|
757
|
+
_type.extend(_value)
|
758
|
+
_value = deepcopy(_type)
|
759
|
+
else:
|
760
|
+
logger.warning(f"type `{_type}` is not supported!")
|
761
|
+
except TypeError as e:
|
762
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
763
|
+
except TypeError as e:
|
764
|
+
if isinstance(v, type):
|
765
|
+
_value = v()
|
766
|
+
else:
|
767
|
+
_value = v
|
768
|
+
self._embedding_fields[k] = False
|
769
|
+
|
770
|
+
if (
|
771
|
+
k in PROFILE_ATTRIBUTES
|
772
|
+
or k in STATE_ATTRIBUTES
|
773
|
+
or k == TIME_STAMP_KEY
|
774
|
+
):
|
775
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
776
|
+
continue
|
777
|
+
|
778
|
+
_dynamic_config[k] = deepcopy(_value)
|
779
|
+
|
780
|
+
# 初始化各类记忆
|
781
|
+
self._dynamic = DynamicMemory(
|
782
|
+
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
783
|
+
)
|
784
|
+
|
785
|
+
if profile is not None:
|
786
|
+
for k, v in profile.items():
|
787
|
+
if k not in PROFILE_ATTRIBUTES:
|
788
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
789
|
+
continue
|
790
|
+
try:
|
791
|
+
# 处理配置元组格式
|
792
|
+
if isinstance(v, tuple):
|
793
|
+
if len(v) == 4: # (_type, _value, enable_embedding, template)
|
794
|
+
_type, _value, enable_embedding, template = v
|
795
|
+
self._embedding_fields[k] = enable_embedding
|
796
|
+
self._semantic_templates[k] = template
|
797
|
+
elif len(v) == 3: # (_type, _value, enable_embedding)
|
798
|
+
_type, _value, enable_embedding = v
|
799
|
+
self._embedding_fields[k] = enable_embedding
|
800
|
+
else: # (_type, _value)
|
801
|
+
_type, _value = v
|
802
|
+
self._embedding_fields[k] = False
|
803
|
+
|
804
|
+
# 处理类型转换
|
805
|
+
try:
|
806
|
+
if isinstance(_type, type):
|
807
|
+
_value = _type(_value)
|
808
|
+
else:
|
809
|
+
if isinstance(_type, deque):
|
810
|
+
_type.extend(_value)
|
811
|
+
_value = deepcopy(_type)
|
812
|
+
else:
|
813
|
+
logger.warning(f"type `{_type}` is not supported!")
|
814
|
+
except TypeError as e:
|
815
|
+
logger.warning(f"Type conversion failed for key {k}: {e}")
|
816
|
+
else:
|
817
|
+
# 保持对简单键值对的兼容
|
818
|
+
_value = v
|
819
|
+
self._embedding_fields[k] = False
|
820
|
+
except TypeError as e:
|
821
|
+
if isinstance(v, type):
|
822
|
+
_value = v()
|
823
|
+
else:
|
824
|
+
_value = v
|
825
|
+
self._embedding_fields[k] = False
|
826
|
+
|
827
|
+
_profile_config[k] = deepcopy(_value)
|
828
|
+
self._profile = ProfileMemory(
|
829
|
+
msg=_profile_config, activate_timestamp=activate_timestamp
|
830
|
+
)
|
831
|
+
if base is not None:
|
832
|
+
for k, v in base.items():
|
833
|
+
if k not in STATE_ATTRIBUTES:
|
834
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
835
|
+
continue
|
836
|
+
_state_config[k] = v
|
837
|
+
|
838
|
+
self._state = StateMemory(
|
839
|
+
msg=_state_config, activate_timestamp=activate_timestamp
|
840
|
+
)
|
841
|
+
|
842
|
+
# 组合 StatusMemory,并传递 embedding_fields 信息
|
843
|
+
self._status = StatusMemory(
|
844
|
+
profile=self._profile, state=self._state, dynamic=self._dynamic
|
845
|
+
)
|
846
|
+
self._status.set_embedding_fields(self._embedding_fields)
|
847
|
+
self._status.set_search_components(self._faiss_query, self._embedding_model)
|
848
|
+
|
849
|
+
# 新增 StreamMemory
|
850
|
+
self._stream = StreamMemory()
|
851
|
+
self._stream.set_status_memory(self._status)
|
852
|
+
self._stream.set_search_components(self._faiss_query, self._embedding_model)
|
853
|
+
|
854
|
+
def set_search_components(
|
855
|
+
self,
|
856
|
+
faiss_query: FaissQuery,
|
857
|
+
embedding_model: Embeddings,
|
858
|
+
):
|
859
|
+
self._embedding_model = embedding_model
|
860
|
+
self._faiss_query = faiss_query
|
861
|
+
self._stream.set_search_components(faiss_query, embedding_model)
|
862
|
+
self._status.set_search_components(faiss_query, embedding_model)
|
863
|
+
|
864
|
+
def set_agent_id(self, agent_id: int):
|
411
865
|
"""
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
866
|
+
Set the FaissQuery of the agent.
|
867
|
+
"""
|
868
|
+
self._agent_id = agent_id
|
869
|
+
self._stream.set_agent_id(agent_id)
|
870
|
+
self._status.set_agent_id(agent_id)
|
871
|
+
|
872
|
+
def set_simulator(self, simulator):
|
873
|
+
self._simulator = simulator
|
874
|
+
self._stream.set_simulator(simulator)
|
875
|
+
self._status.set_simulator(simulator)
|
876
|
+
|
877
|
+
@property
|
878
|
+
def status(self) -> StatusMemory:
|
879
|
+
return self._status
|
880
|
+
|
881
|
+
@property
|
882
|
+
def stream(self) -> StreamMemory:
|
883
|
+
return self._stream
|
884
|
+
|
885
|
+
@property
|
886
|
+
def embedding_model(
|
887
|
+
self,
|
888
|
+
):
|
889
|
+
if self._embedding_model is None:
|
890
|
+
raise RuntimeError(
|
891
|
+
f"embedding_model before assignment, please `set_embedding_model` first!"
|
421
892
|
)
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
893
|
+
return self._embedding_model
|
894
|
+
|
895
|
+
@property
|
896
|
+
def agent_id(
|
897
|
+
self,
|
898
|
+
):
|
899
|
+
if self._agent_id < 0:
|
900
|
+
raise RuntimeError(
|
901
|
+
f"agent_id before assignment, please `set_agent_id` first!"
|
428
902
|
)
|
903
|
+
return self._agent_id
|
429
904
|
|
430
|
-
|
905
|
+
@property
|
906
|
+
def faiss_query(self) -> FaissQuery:
|
907
|
+
"""FaissQuery"""
|
908
|
+
if self._faiss_query is None:
|
909
|
+
raise RuntimeError(
|
910
|
+
f"FaissQuery access before assignment, please `set_faiss_query` first!"
|
911
|
+
)
|
912
|
+
return self._faiss_query
|
913
|
+
|
914
|
+
async def initialize_embeddings(self):
|
915
|
+
"""初始化embedding"""
|
916
|
+
await self._status.initialize_embeddings()
|