pycityagent 2.0.0a52__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a54__cp39-cp39-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (49) hide show
  1. pycityagent/agent/agent.py +83 -62
  2. pycityagent/agent/agent_base.py +81 -54
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +45 -57
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/message_intercept.py +99 -0
  17. pycityagent/cityagent/nbsagent.py +6 -29
  18. pycityagent/cityagent/societyagent.py +325 -127
  19. pycityagent/cli/wrapper.py +4 -0
  20. pycityagent/economy/econ_client.py +0 -2
  21. pycityagent/environment/__init__.py +7 -1
  22. pycityagent/environment/sim/client.py +10 -1
  23. pycityagent/environment/sim/clock_service.py +2 -2
  24. pycityagent/environment/sim/pause_service.py +61 -0
  25. pycityagent/environment/sim/sim_env.py +34 -46
  26. pycityagent/environment/simulator.py +18 -14
  27. pycityagent/llm/embeddings.py +0 -24
  28. pycityagent/llm/llm.py +18 -10
  29. pycityagent/memory/faiss_query.py +29 -26
  30. pycityagent/memory/memory.py +733 -247
  31. pycityagent/message/__init__.py +8 -1
  32. pycityagent/message/message_interceptor.py +322 -0
  33. pycityagent/message/messager.py +42 -11
  34. pycityagent/pycityagent-sim +0 -0
  35. pycityagent/simulation/agentgroup.py +137 -96
  36. pycityagent/simulation/simulation.py +184 -38
  37. pycityagent/simulation/storage/pg.py +2 -2
  38. pycityagent/tools/tool.py +7 -9
  39. pycityagent/utils/__init__.py +7 -2
  40. pycityagent/utils/pg_query.py +1 -0
  41. pycityagent/utils/survey_util.py +26 -23
  42. pycityagent/workflow/block.py +14 -7
  43. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
  44. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
  45. pycityagent/cityagent/blocks/time_block.py +0 -116
  46. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
  47. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
  48. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
  49. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from collections import defaultdict
4
- from collections.abc import Callable, Sequence
4
+ from collections.abc import Callable, Coroutine, Sequence
5
5
  from copy import deepcopy
6
- from datetime import datetime
7
- from typing import Any, Literal, Optional, Union
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Any, Dict, Literal, Optional, Union
8
9
 
9
- import numpy as np
10
10
  from langchain_core.embeddings import Embeddings
11
11
  from pyparsing import deque
12
12
 
@@ -20,176 +20,496 @@ from .state import StateMemory
20
20
  logger = logging.getLogger("pycityagent")
21
21
 
22
22
 
23
- class Memory:
24
- """
25
- A class to manage different types of memory (state, profile, dynamic).
23
+ class MemoryTag(str, Enum):
24
+ """记忆标签枚举类"""
26
25
 
27
- Attributes:
28
- _state (StateMemory): Stores state-related data.
29
- _profile (ProfileMemory): Stores profile-related data.
30
- _dynamic (DynamicMemory): Stores dynamically configured data.
31
- """
26
+ MOBILITY = "mobility"
27
+ SOCIAL = "social"
28
+ ECONOMY = "economy"
29
+ COGNITION = "cognition"
30
+ OTHER = "other"
31
+ EVENT = "event"
32
32
 
33
- def __init__(
33
+
34
+ @dataclass
35
+ class MemoryNode:
36
+ """记忆节点"""
37
+
38
+ tag: MemoryTag
39
+ day: int
40
+ t: int
41
+ location: str
42
+ description: str
43
+ cognition_id: Optional[int] = None # 关联的认知记忆ID
44
+ id: Optional[int] = None # 记忆ID
45
+
46
+
47
+ class StreamMemory:
48
+ """用于存储时序性的流式信息"""
49
+
50
+ def __init__(self, max_len: int = 1000):
51
+ self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
52
+ self._memory_id_counter: int = 0 # 用于生成唯一ID
53
+ self._faiss_query = None
54
+ self._embedding_model = None
55
+ self._agent_id = -1
56
+ self._status_memory = None
57
+ self._simulator = None
58
+
59
+ @property
60
+ def faiss_query(
34
61
  self,
35
- config: Optional[dict[Any, Any]] = None,
36
- profile: Optional[dict[Any, Any]] = None,
37
- base: Optional[dict[Any, Any]] = None,
38
- motion: Optional[dict[Any, Any]] = None,
39
- activate_timestamp: bool = False,
40
- embedding_model: Optional[Embeddings] = None,
41
- faiss_query: Optional[FaissQuery] = None,
42
- ) -> None:
43
- """
44
- Initializes the Memory with optional configuration.
62
+ ) -> FaissQuery:
63
+ assert self._faiss_query is not None
64
+ return self._faiss_query
45
65
 
46
- Args:
47
- config (Optional[dict[Any, Any]], optional):
48
- A configuration dictionary for dynamic memory. The dictionary format is:
49
- - Key: The name of the dynamic memory field.
50
- - Value: Can be one of two formats:
51
- 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
52
- 2. A callable that returns the default value when invoked (useful for complex default values).
53
- Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
54
- Defaults to None.
55
- profile (Optional[dict[Any, Any]], optional): profile attribute dict.
56
- base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
57
- motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
58
- activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
59
- embedding_model (Embeddings): The embedding model for memory search.
60
- faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
61
- """
62
- self.watchers: dict[str, list[Callable]] = {}
63
- self._lock = asyncio.Lock()
64
- self._agent_id: int = -1
65
- self._embedding_model = embedding_model
66
+ @property
67
+ def status_memory(
68
+ self,
69
+ ):
70
+ assert self._status_memory is not None
71
+ return self._status_memory
66
72
 
67
- _dynamic_config: dict[Any, Any] = {}
68
- _state_config: dict[Any, Any] = {}
69
- _profile_config: dict[Any, Any] = {}
70
- # 记录哪些字段需要embedding
71
- self._embedding_fields: dict[str, bool] = {}
72
- self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
73
+ def set_simulator(self, simulator):
74
+ self._simulator = simulator
75
+
76
+ def set_status_memory(self, status_memory):
77
+ self._status_memory = status_memory
78
+
79
+ def set_search_components(self, faiss_query, embedding_model):
80
+ """设置搜索所需的组件"""
73
81
  self._faiss_query = faiss_query
82
+ self._embedding_model = embedding_model
74
83
 
75
- if config is not None:
76
- for k, v in config.items():
77
- try:
78
- # 处理新的三元组格式
79
- if isinstance(v, tuple) and len(v) == 3:
80
- _type, _value, enable_embedding = v
81
- self._embedding_fields[k] = enable_embedding
82
- else:
83
- _type, _value = v
84
- self._embedding_fields[k] = False
84
+ def set_agent_id(self, agent_id: int):
85
+ """设置agent_id"""
86
+ self._agent_id = agent_id
85
87
 
86
- try:
87
- if isinstance(_type, type):
88
- _value = _type(_value)
89
- else:
90
- if isinstance(_type, deque):
91
- _type.extend(_value)
92
- _value = deepcopy(_type)
93
- else:
94
- logger.warning(f"type `{_type}` is not supported!")
95
- pass
96
- except TypeError as e:
97
- pass
98
- except TypeError as e:
99
- if isinstance(v, type):
100
- _value = v()
101
- else:
102
- _value = v
103
- self._embedding_fields[k] = False
88
+ async def _add_memory(self, tag: MemoryTag, description: str) -> int:
89
+ """添加记忆节点的通用方法,返回记忆节点ID"""
90
+ if self._simulator is not None:
91
+ day = int(await self._simulator.get_simulator_day())
92
+ t = int(await self._simulator.get_time())
93
+ else:
94
+ day = 1
95
+ t = 1
96
+ position = await self.status_memory.get("position")
97
+ if "aoi_position" in position:
98
+ location = position["aoi_position"]["aoi_id"]
99
+ elif "lane_position" in position:
100
+ location = position["lane_position"]["lane_id"]
101
+ else:
102
+ location = "unknown"
103
+
104
+ current_id = self._memory_id_counter
105
+ self._memory_id_counter += 1
106
+ memory_node = MemoryNode(
107
+ tag=tag,
108
+ day=day,
109
+ t=t,
110
+ location=location,
111
+ description=description,
112
+ id=current_id,
113
+ )
114
+ self._memories.append(memory_node)
115
+
116
+ # 为新记忆创建 embedding
117
+ if self._embedding_model and self._faiss_query:
118
+ await self.faiss_query.add_documents(
119
+ agent_id=self._agent_id,
120
+ documents=description,
121
+ extra_tags={
122
+ "type": "stream",
123
+ "tag": tag,
124
+ "day": day,
125
+ "time": t,
126
+ },
127
+ )
104
128
 
105
- if (
106
- k in PROFILE_ATTRIBUTES
107
- or k in STATE_ATTRIBUTES
108
- or k == TIME_STAMP_KEY
109
- ):
110
- logger.warning(f"key `{k}` already declared in memory!")
111
- continue
129
+ return current_id
130
+
131
+ async def add_cognition(self, description: str) -> int:
132
+ """添加认知记忆 Add cognition memory"""
133
+ return await self._add_memory(MemoryTag.COGNITION, description)
134
+
135
+ async def add_social(self, description: str) -> int:
136
+ """添加社交记忆 Add social memory"""
137
+ return await self._add_memory(MemoryTag.SOCIAL, description)
138
+
139
+ async def add_economy(self, description: str) -> int:
140
+ """添加经济记忆 Add economy memory"""
141
+ return await self._add_memory(MemoryTag.ECONOMY, description)
142
+
143
+ async def add_mobility(self, description: str) -> int:
144
+ """添加移动记忆 Add mobility memory"""
145
+ return await self._add_memory(MemoryTag.MOBILITY, description)
146
+
147
+ async def add_event(self, description: str) -> int:
148
+ """添加事件记忆 Add event memory"""
149
+ return await self._add_memory(MemoryTag.EVENT, description)
150
+
151
+ async def add_other(self, description: str) -> int:
152
+ """添加其他记忆 Add other memory"""
153
+ return await self._add_memory(MemoryTag.OTHER, description)
154
+
155
+ async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
156
+ """获取关联的认知记忆 Get related cognition memory"""
157
+ for memory in self._memories:
158
+ if memory.cognition_id == memory_id:
159
+ for cognition_memory in self._memories:
160
+ if (
161
+ cognition_memory.tag == MemoryTag.COGNITION
162
+ and memory.cognition_id is not None
163
+ ):
164
+ return cognition_memory
165
+ return None
166
+
167
+ async def format_memory(self, memories: list[MemoryNode]) -> str:
168
+ """格式化记忆"""
169
+ formatted_results = []
170
+ for memory in memories:
171
+ memory_tag = memory.tag
172
+ memory_day = memory.day
173
+ memory_time_seconds = memory.t
174
+ cognition_id = memory.cognition_id
175
+
176
+ # 格式化时间
177
+ if memory_time_seconds != "unknown":
178
+ hours = memory_time_seconds // 3600
179
+ minutes = (memory_time_seconds % 3600) // 60
180
+ seconds = memory_time_seconds % 60
181
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
182
+ else:
183
+ memory_time = "unknown"
184
+
185
+ memory_location = memory.location
186
+
187
+ # 添加认知信息(如果存在)
188
+ cognition_info = ""
189
+ if cognition_id is not None:
190
+ cognition_memory = await self.get_related_cognition(cognition_id)
191
+ if cognition_memory:
192
+ cognition_info = (
193
+ f"\n Related cognition: {cognition_memory.description}"
194
+ )
112
195
 
113
- _dynamic_config[k] = deepcopy(_value)
196
+ formatted_results.append(
197
+ f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
198
+ f"location: {memory_location}]{cognition_info}"
199
+ )
200
+ return "\n".join(formatted_results)
114
201
 
115
- # 初始化各类记忆
116
- self._dynamic = DynamicMemory(
117
- required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
118
- )
202
+ async def get_by_ids(
203
+ self, memory_ids: Union[int, list[int]]
204
+ ) -> Coroutine[Any, Any, str]:
205
+ """获取指定ID的记忆"""
206
+ memories = [memory for memory in self._memories if memory.id in memory_ids]
207
+ sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
208
+ return self.format_memory(sorted_results)
119
209
 
120
- if profile is not None:
121
- for k, v in profile.items():
122
- if k not in PROFILE_ATTRIBUTES:
123
- logger.warning(f"key `{k}` is not a correct `profile` field!")
124
- continue
125
- _profile_config[k] = v
126
- if motion is not None:
127
- for k, v in motion.items():
128
- if k not in STATE_ATTRIBUTES:
129
- logger.warning(f"key `{k}` is not a correct `motion` field!")
130
- continue
131
- _state_config[k] = v
132
- if base is not None:
133
- for k, v in base.items():
134
- if k not in STATE_ATTRIBUTES:
135
- logger.warning(f"key `{k}` is not a correct `base` field!")
136
- continue
137
- _state_config[k] = v
138
- self._state = StateMemory(
139
- msg=_state_config, activate_timestamp=activate_timestamp
210
+ async def search(
211
+ self,
212
+ query: str,
213
+ tag: Optional[MemoryTag] = None,
214
+ top_k: int = 3,
215
+ day_range: Optional[tuple[int, int]] = None, # 新增参数
216
+ time_range: Optional[tuple[int, int]] = None, # 新增参数
217
+ ) -> str:
218
+ """Search stream memory
219
+
220
+ Args:
221
+ query: Query text
222
+ tag: Optional memory tag for filtering specific types of memories
223
+ top_k: Number of most relevant memories to return
224
+ day_range: Optional tuple of start and end days (start_day, end_day)
225
+ time_range: Optional tuple of start and end times (start_time, end_time)
226
+ """
227
+ if not self._embedding_model or not self._faiss_query:
228
+ return "Search components not initialized"
229
+
230
+ filter_dict: dict[str, Any] = {"type": "stream"}
231
+
232
+ if tag:
233
+ filter_dict["tag"] = tag
234
+
235
+ # 添加时间范围过滤
236
+ if day_range:
237
+ start_day, end_day = day_range
238
+ filter_dict["day"] = lambda x: start_day <= x <= end_day
239
+
240
+ if time_range:
241
+ start_time, end_time = time_range
242
+ filter_dict["time"] = lambda x: start_time <= x <= end_time
243
+
244
+ top_results = await self.faiss_query.similarity_search(
245
+ query=query,
246
+ agent_id=self._agent_id,
247
+ k=top_k,
248
+ return_score_type="similarity_score",
249
+ filter=filter_dict,
140
250
  )
141
- self._profile = ProfileMemory(
142
- msg=_profile_config, activate_timestamp=activate_timestamp
251
+
252
+ # 将结果按时间排序(先按天数,再按时间)
253
+ sorted_results = sorted(
254
+ top_results,
255
+ key=lambda x: (x[2].get("day", 0), x[2].get("time", 0)), # type:ignore
256
+ reverse=True,
143
257
  )
144
- # self.memories = [] # 存储记忆内容
145
- # self.embeddings = [] # 存储记忆的向量表示
146
258
 
147
- def set_embedding_model(
148
- self,
149
- embedding_model: Embeddings,
150
- ):
151
- self._embedding_model = embedding_model
259
+ formatted_results = []
260
+ for content, score, metadata in sorted_results: # type:ignore
261
+ memory_tag = metadata.get("tag", "unknown")
262
+ memory_day = metadata.get("day", "unknown")
263
+ memory_time_seconds = metadata.get("time", "unknown")
264
+ cognition_id = metadata.get("cognition_id", None)
265
+
266
+ # 格式化时间
267
+ if memory_time_seconds != "unknown":
268
+ hours = memory_time_seconds // 3600
269
+ minutes = (memory_time_seconds % 3600) // 60
270
+ seconds = memory_time_seconds % 60
271
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
272
+ else:
273
+ memory_time = "unknown"
274
+
275
+ memory_location = metadata.get("location", "unknown")
276
+
277
+ # 添加认知信息(如果存在)
278
+ cognition_info = ""
279
+ if cognition_id is not None:
280
+ cognition_memory = await self.get_related_cognition(cognition_id)
281
+ if cognition_memory:
282
+ cognition_info = (
283
+ f"\n Related cognition: {cognition_memory.description}"
284
+ )
152
285
 
153
- @property
154
- def embedding_model(
155
- self,
156
- ):
157
- if self._embedding_model is None:
158
- raise RuntimeError(
159
- f"embedding_model before assignment, please `set_embedding_model` first!"
286
+ formatted_results.append(
287
+ f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
288
+ f"location: {memory_location}]{cognition_info}"
160
289
  )
161
- return self._embedding_model
290
+ return "\n".join(formatted_results)
162
291
 
163
- def set_faiss_query(self, faiss_query: FaissQuery):
292
+ async def search_today(
293
+ self,
294
+ query: str = "", # 可选的查询文本
295
+ tag: Optional[MemoryTag] = None,
296
+ top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
297
+ ) -> str:
298
+ """Search all memory events from today
299
+
300
+ Args:
301
+ query: Optional query text, returns all memories of the day if empty
302
+ tag: Optional memory tag for filtering specific types of memories
303
+ top_k: Number of most relevant memories to return, defaults to 100
304
+
305
+ Returns:
306
+ str: Formatted text of today's memories
164
307
  """
165
- Set the FaissQuery of the agent.
308
+ if self._simulator is None:
309
+ return "Simulator not initialized"
310
+
311
+ current_day = int(await self._simulator.get_simulator_day())
312
+
313
+ # 使用 search 方法,设置 day_range 为当天
314
+ return await self.search(
315
+ query=query, tag=tag, top_k=top_k, day_range=(current_day, current_day)
316
+ )
317
+
318
+ async def add_cognition_to_memory(
319
+ self, memory_id: Union[int, list[int]], cognition: str
320
+ ) -> None:
321
+ """为已存在的记忆添加认知
322
+
323
+ Args:
324
+ memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
325
+ cognition: 认知描述
166
326
  """
167
- self._faiss_query = faiss_query
327
+ # 将单个ID转换为列表以统一处理
328
+ memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
329
+
330
+ # 找到所有对应的记忆
331
+ target_memories = []
332
+ for memory in self._memories:
333
+ if id(memory) in memory_ids:
334
+ target_memories.append(memory)
335
+
336
+ if not target_memories:
337
+ raise ValueError(f"No memories found with ids {memory_ids}")
338
+
339
+ # 添加认知记忆
340
+ cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
341
+
342
+ # 更新所有原记忆的认知ID
343
+ for target_memory in target_memories:
344
+ target_memory.cognition_id = cognition_id
345
+
346
+ async def get_all(self) -> list[MemoryNode]:
347
+ """获取所有流式信息"""
348
+ return list(self._memories)
349
+
350
+
351
+ class StatusMemory:
352
+ """组合现有的三种记忆类型"""
353
+
354
+ def __init__(
355
+ self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory
356
+ ):
357
+ self.profile = profile
358
+ self.state = state
359
+ self.dynamic = dynamic
360
+ self._faiss_query = None
361
+ self._embedding_model = None
362
+ self._simulator = None
363
+ self._agent_id = -1
364
+ self._semantic_templates = {} # 用户可配置的模板
365
+ self._embedding_fields = {} # 需要 embedding 的字段
366
+ self._embedding_field_to_doc_id = defaultdict(str) # 新增
367
+ self.watchers = {} # 新增
368
+ self._lock = asyncio.Lock() # 新增
168
369
 
169
370
  @property
170
- def agent_id(
371
+ def faiss_query(
171
372
  self,
172
- ):
173
- if self._agent_id < 0:
174
- raise RuntimeError(
175
- f"agent_id before assignment, please `set_agent_id` first!"
373
+ ) -> FaissQuery:
374
+ assert self._faiss_query is not None
375
+ return self._faiss_query
376
+
377
+ def set_simulator(self, simulator):
378
+ self._simulator = simulator
379
+
380
+ async def initialize_embeddings(self) -> None:
381
+ """初始化所有需要 embedding 的字段"""
382
+ if not self._embedding_model or not self._faiss_query:
383
+ logger.warning(
384
+ "Search components not initialized, skipping embeddings initialization"
176
385
  )
177
- return self._agent_id
386
+ return
387
+
388
+ # 获取所有状态信息
389
+ profile, state, dynamic = await self.export()
390
+
391
+ # 为每个需要 embedding 的字段创建 embedding
392
+ for key, value in profile[0].items():
393
+ if self.should_embed(key):
394
+ semantic_text = self._generate_semantic_text(key, value)
395
+ doc_ids = await self.faiss_query.add_documents(
396
+ agent_id=self._agent_id,
397
+ documents=semantic_text,
398
+ extra_tags={
399
+ "type": "profile_state",
400
+ "key": key,
401
+ },
402
+ )
403
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
404
+
405
+ for key, value in state[0].items():
406
+ if self.should_embed(key):
407
+ semantic_text = self._generate_semantic_text(key, value)
408
+ doc_ids = await self.faiss_query.add_documents(
409
+ agent_id=self._agent_id,
410
+ documents=semantic_text,
411
+ extra_tags={
412
+ "type": "profile_state",
413
+ "key": key,
414
+ },
415
+ )
416
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
417
+
418
+ for key, value in dynamic[0].items():
419
+ if self.should_embed(key):
420
+ semantic_text = self._generate_semantic_text(key, value)
421
+ doc_ids = await self.faiss_query.add_documents(
422
+ agent_id=self._agent_id,
423
+ documents=semantic_text,
424
+ extra_tags={
425
+ "type": "profile_state",
426
+ "key": key,
427
+ },
428
+ )
429
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
430
+
431
+ def _get_memory_type_by_key(self, key: str) -> str:
432
+ """根据键名确定记忆类型"""
433
+ try:
434
+ if key in self.profile.__dict__:
435
+ return "profile"
436
+ elif key in self.state.__dict__:
437
+ return "state"
438
+ else:
439
+ return "dynamic"
440
+ except:
441
+ return "dynamic"
442
+
443
+ def set_search_components(self, faiss_query, embedding_model):
444
+ """设置搜索所需的组件"""
445
+ self._faiss_query = faiss_query
446
+ self._embedding_model = embedding_model
178
447
 
179
448
  def set_agent_id(self, agent_id: int):
449
+ """设置agent_id"""
450
+ self._agent_id = agent_id
451
+
452
+ def set_semantic_templates(self, templates: Dict[str, str]):
453
+ """设置语义模板
454
+
455
+ Args:
456
+ templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
180
457
  """
181
- Set the FaissQuery of the agent.
458
+ self._semantic_templates = templates
459
+
460
+ def _generate_semantic_text(self, key: str, value: Any) -> str:
461
+ """生成语义文本
462
+
463
+ 如果key存在于模板中,使用自定义模板
464
+ 否则使用默认模板 "my {key} is {value}"
182
465
  """
183
- self._agent_id = agent_id
466
+ if key in self._semantic_templates:
467
+ return self._semantic_templates[key].format(value)
468
+ return f"Your {key} is {value}"
184
469
 
185
- @property
186
- def faiss_query(self) -> FaissQuery:
187
- """FaissQuery"""
188
- if self._faiss_query is None:
189
- raise RuntimeError(
190
- f"FaissQuery access before assignment, please `set_faiss_query` first!"
470
+ @lock_decorator
471
+ async def search(
472
+ self, query: str, top_k: int = 3, filter: Optional[dict] = None
473
+ ) -> str:
474
+ """搜索相关记忆
475
+
476
+ Args:
477
+ query: 查询文本
478
+ top_k: 返回最相关的记忆数量
479
+ filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
480
+
481
+ Returns:
482
+ str: 格式化的相关记忆文本
483
+ """
484
+ if not self._embedding_model:
485
+ return "Embedding model not initialized"
486
+
487
+ filter_dict = {"type": "profile_state"}
488
+ if filter is not None:
489
+ filter_dict.update(filter)
490
+ top_results: list[tuple[str, float, dict]] = (
491
+ await self.faiss_query.similarity_search( # type:ignore
492
+ query=query,
493
+ agent_id=self._agent_id,
494
+ k=top_k,
495
+ return_score_type="similarity_score",
496
+ filter=filter_dict,
191
497
  )
192
- return self._faiss_query
498
+ )
499
+ # 格式化输出
500
+ formatted_results = []
501
+ for content, score, metadata in top_results:
502
+ formatted_results.append(f"- {content} ")
503
+
504
+ return "\n".join(formatted_results)
505
+
506
+ def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
507
+ """设置需要 embedding 的字段"""
508
+ self._embedding_fields = embedding_fields
509
+
510
+ def should_embed(self, key: str) -> bool:
511
+ """判断字段是否需要 embedding"""
512
+ return self._embedding_fields.get(key, False)
193
513
 
194
514
  @lock_decorator
195
515
  async def get(
@@ -197,19 +517,18 @@ class Memory:
197
517
  key: Any,
198
518
  mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
199
519
  ) -> Any:
200
- """
201
- Retrieves a value from memory based on the given key and access mode.
520
+ """从记忆中获取值
202
521
 
203
522
  Args:
204
- key (Any): The key of the item to retrieve.
205
- mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
523
+ key: 要获取的键
524
+ mode: 访问模式,"read only" "read and write"
206
525
 
207
526
  Returns:
208
- Any: The value associated with the key.
527
+ 获取到的值
209
528
 
210
529
  Raises:
211
- ValueError: If an invalid mode is provided.
212
- KeyError: If the key is not found in any of the memory sections.
530
+ ValueError: 如果提供了无效的模式
531
+ KeyError: 如果在任何记忆部分都找不到该键
213
532
  """
214
533
  if mode == "read only":
215
534
  process_func = deepcopy
@@ -217,11 +536,12 @@ class Memory:
217
536
  process_func = lambda x: x
218
537
  else:
219
538
  raise ValueError(f"Invalid get mode `{mode}`!")
220
- for _mem in [self._state, self._profile, self._dynamic]:
539
+
540
+ for mem in [self.state, self.profile, self.dynamic]:
221
541
  try:
222
- value = await _mem.get(key)
542
+ value = await mem.get(key)
223
543
  return process_func(value)
224
- except KeyError as e:
544
+ except KeyError:
225
545
  continue
226
546
  raise KeyError(f"No attribute `{key}` in memories!")
227
547
 
@@ -239,32 +559,37 @@ class Memory:
239
559
  if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
240
560
  logger.warning(f"Trying to write protected key `{key}`!")
241
561
  return
242
- for _mem in [self._state, self._profile, self._dynamic]:
562
+
563
+ for mem in [self.state, self.profile, self.dynamic]:
243
564
  try:
244
- original_value = await _mem.get(key)
565
+ original_value = await mem.get(key)
245
566
  if mode == "replace":
246
- await _mem.update(key, value, store_snapshot)
247
- # 如果字段需要embedding,则更新embedding
248
- if self._embedding_fields.get(key, False) and self.embedding_model:
249
- memory_type = self._get_memory_type(_mem)
250
- # 覆盖更新删除原vector
567
+ await mem.update(key, value, store_snapshot)
568
+ if self.should_embed(key) and self._embedding_model:
569
+ semantic_text = self._generate_semantic_text(key, value)
570
+
571
+ # 删除旧的 embedding
251
572
  orig_doc_id = self._embedding_field_to_doc_id[key]
252
573
  if orig_doc_id:
253
574
  await self.faiss_query.delete_documents(
254
575
  to_delete_ids=[orig_doc_id],
255
576
  )
256
- doc_ids: list[str] = await self.faiss_query.add_documents(
257
- agent_id=self.agent_id,
258
- documents=f"{key}: {str(value)}",
577
+
578
+ # 添加新的 embedding
579
+ doc_ids = await self.faiss_query.add_documents(
580
+ agent_id=self._agent_id,
581
+ documents=semantic_text,
259
582
  extra_tags={
260
- "type": memory_type,
583
+ "type": self._get_memory_type(mem),
261
584
  "key": key,
262
585
  },
263
586
  )
264
587
  self._embedding_field_to_doc_id[key] = doc_ids[0]
588
+
265
589
  if key in self.watchers:
266
590
  for callback in self.watchers[key]:
267
591
  asyncio.create_task(callback())
592
+
268
593
  elif mode == "merge":
269
594
  if isinstance(original_value, set):
270
595
  original_value.update(set(value))
@@ -278,14 +603,14 @@ class Memory:
278
603
  logger.debug(
279
604
  f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
280
605
  )
281
- await _mem.update(key, value, store_snapshot)
282
- if self._embedding_fields.get(key, False) and self.embedding_model:
283
- memory_type = self._get_memory_type(_mem)
606
+ await mem.update(key, value, store_snapshot)
607
+ if self.should_embed(key) and self._embedding_model:
608
+ semantic_text = self._generate_semantic_text(key, value)
284
609
  doc_ids = await self.faiss_query.add_documents(
285
- agent_id=self.agent_id,
610
+ agent_id=self._agent_id,
286
611
  documents=f"{key}: {str(original_value)}",
287
612
  extra_tags={
288
- "type": memory_type,
613
+ "type": self._get_memory_type(mem),
289
614
  "key": key,
290
615
  },
291
616
  )
@@ -302,63 +627,20 @@ class Memory:
302
627
 
303
628
  def _get_memory_type(self, mem: Any) -> str:
304
629
  """获取记忆类型"""
305
- if mem is self._state:
630
+ if mem is self.state:
306
631
  return "state"
307
- elif mem is self._profile:
632
+ elif mem is self.profile:
308
633
  return "profile"
309
634
  else:
310
635
  return "dynamic"
311
636
 
312
- async def update_batch(
313
- self,
314
- content: Union[dict, Sequence[tuple[Any, Any]]],
315
- mode: Union[Literal["replace"], Literal["merge"]] = "replace",
316
- store_snapshot: bool = False,
317
- protect_llm_read_only_fields: bool = True,
318
- ) -> None:
319
- """
320
- Updates multiple values in the memory at once.
321
-
322
- Args:
323
- content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
324
- mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
325
- store_snapshot (bool): Whether to store a snapshot of the memory after the update.
326
- protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
327
-
328
- Raises:
329
- TypeError: If the content type is neither a dictionary nor a sequence of tuples.
330
- """
331
- if isinstance(content, dict):
332
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
333
- elif isinstance(content, Sequence):
334
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
335
- else:
336
- raise TypeError(f"Invalid content type `{type(content)}`!")
337
- for k, v in _list_content[:1]:
338
- await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
339
- for k, v in _list_content[1:]:
340
- await self.update(k, v, mode, False, protect_llm_read_only_fields)
341
-
342
637
  @lock_decorator
343
638
  async def add_watcher(self, key: str, callback: Callable) -> None:
344
- """
345
- Adds a callback function to be invoked when the value
346
- associated with the specified key in memory is updated.
347
-
348
- Args:
349
- key (str): The key for which the watcher is being registered.
350
- callback (Callable): A callable function that will be executed
351
- whenever the value associated with the specified key is updated.
352
-
353
- Notes:
354
- If the key does not already have any watchers, it will be
355
- initialized with an empty list before appending the callback.
356
- """
639
+ """添加值变更的监听器"""
357
640
  if key not in self.watchers:
358
641
  self.watchers[key] = []
359
642
  self.watchers[key].append(callback)
360
643
 
361
- @lock_decorator
362
644
  async def export(
363
645
  self,
364
646
  ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
@@ -369,12 +651,11 @@ class Memory:
369
651
  tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
370
652
  """
371
653
  return (
372
- await self._profile.export(),
373
- await self._state.export(),
374
- await self._dynamic.export(),
654
+ await self.profile.export(),
655
+ await self.state.export(),
656
+ await self.dynamic.export(),
375
657
  )
376
658
 
377
- @lock_decorator
378
659
  async def load(
379
660
  self,
380
661
  snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
@@ -390,41 +671,246 @@ class Memory:
390
671
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
391
672
  for _snapshot, _mem in zip(
392
673
  [_profile_snapshot, _state_snapshot, _dynamic_snapshot],
393
- [self._state, self._profile, self._dynamic],
674
+ [self.state, self.profile, self.dynamic],
394
675
  ):
395
676
  if _snapshot:
396
677
  await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
397
678
 
398
- @lock_decorator
399
- async def search(
400
- self, query: str, top_k: int = 3, filter: Optional[dict] = None
401
- ) -> str:
402
- """搜索相关记忆
679
+
680
+ class Memory:
681
+ """
682
+ A class to manage different types of memory (state, profile, dynamic).
683
+
684
+ Attributes:
685
+ _state (StateMemory): Stores state-related data.
686
+ _profile (ProfileMemory): Stores profile-related data.
687
+ _dynamic (DynamicMemory): Stores dynamically configured data.
688
+ """
689
+
690
+ def __init__(
691
+ self,
692
+ config: Optional[dict[Any, Any]] = None,
693
+ profile: Optional[dict[Any, Any]] = None,
694
+ base: Optional[dict[Any, Any]] = None,
695
+ activate_timestamp: bool = False,
696
+ embedding_model: Optional[Embeddings] = None,
697
+ faiss_query: Optional[FaissQuery] = None,
698
+ ) -> None:
699
+ """
700
+ Initializes the Memory with optional configuration.
403
701
 
404
702
  Args:
405
- query: 查询文本
406
- top_k: 返回最相关的记忆数量
407
- filter (dict, optional): 记忆的筛选条件,如 {"type":"dynamic", "key":"self_define_1",},默认为空
703
+ config (Optional[dict[Any, Any]], optional):
704
+ A configuration dictionary for dynamic memory. The dictionary format is:
705
+ - Key: The name of the dynamic memory field.
706
+ - Value: Can be one of two formats:
707
+ 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
708
+ 2. A callable that returns the default value when invoked (useful for complex default values).
709
+ Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
710
+ Defaults to None.
711
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
712
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
713
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
714
+ activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
715
+ embedding_model (Embeddings): The embedding model for memory search.
716
+ faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
717
+ """
718
+ self.watchers: dict[str, list[Callable]] = {}
719
+ self._lock = asyncio.Lock()
720
+ self._agent_id: int = -1
721
+ self._simulator = None
722
+ self._embedding_model = embedding_model
723
+ self._faiss_query = faiss_query
724
+ self._semantic_templates: dict[str, str] = {}
725
+ _dynamic_config: dict[Any, Any] = {}
726
+ _state_config: dict[Any, Any] = {}
727
+ _profile_config: dict[Any, Any] = {}
728
+ self._embedding_fields: dict[str, bool] = {}
729
+ self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
408
730
 
409
- Returns:
410
- str: 格式化的相关记忆文本
731
+ if config is not None:
732
+ for k, v in config.items():
733
+ try:
734
+ # 处理不同长度的配置元组
735
+ if isinstance(v, tuple):
736
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
737
+ _type, _value, enable_embedding, template = v
738
+ self._embedding_fields[k] = enable_embedding
739
+ self._semantic_templates[k] = template
740
+ elif len(v) == 3: # (_type, _value, enable_embedding)
741
+ _type, _value, enable_embedding = v
742
+ self._embedding_fields[k] = enable_embedding
743
+ else: # (_type, _value)
744
+ _type, _value = v
745
+ self._embedding_fields[k] = False
746
+ else:
747
+ _type = type(v)
748
+ _value = v
749
+ self._embedding_fields[k] = False
750
+
751
+ # 处理类型转换
752
+ try:
753
+ if isinstance(_type, type):
754
+ _value = _type(_value)
755
+ else:
756
+ if isinstance(_type, deque):
757
+ _type.extend(_value)
758
+ _value = deepcopy(_type)
759
+ else:
760
+ logger.warning(f"type `{_type}` is not supported!")
761
+ except TypeError as e:
762
+ logger.warning(f"Type conversion failed for key {k}: {e}")
763
+ except TypeError as e:
764
+ if isinstance(v, type):
765
+ _value = v()
766
+ else:
767
+ _value = v
768
+ self._embedding_fields[k] = False
769
+
770
+ if (
771
+ k in PROFILE_ATTRIBUTES
772
+ or k in STATE_ATTRIBUTES
773
+ or k == TIME_STAMP_KEY
774
+ ):
775
+ logger.warning(f"key `{k}` already declared in memory!")
776
+ continue
777
+
778
+ _dynamic_config[k] = deepcopy(_value)
779
+
780
+ # 初始化各类记忆
781
+ self._dynamic = DynamicMemory(
782
+ required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
783
+ )
784
+
785
+ if profile is not None:
786
+ for k, v in profile.items():
787
+ if k not in PROFILE_ATTRIBUTES:
788
+ logger.warning(f"key `{k}` is not a correct `profile` field!")
789
+ continue
790
+ try:
791
+ # 处理配置元组格式
792
+ if isinstance(v, tuple):
793
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
794
+ _type, _value, enable_embedding, template = v
795
+ self._embedding_fields[k] = enable_embedding
796
+ self._semantic_templates[k] = template
797
+ elif len(v) == 3: # (_type, _value, enable_embedding)
798
+ _type, _value, enable_embedding = v
799
+ self._embedding_fields[k] = enable_embedding
800
+ else: # (_type, _value)
801
+ _type, _value = v
802
+ self._embedding_fields[k] = False
803
+
804
+ # 处理类型转换
805
+ try:
806
+ if isinstance(_type, type):
807
+ _value = _type(_value)
808
+ else:
809
+ if isinstance(_type, deque):
810
+ _type.extend(_value)
811
+ _value = deepcopy(_type)
812
+ else:
813
+ logger.warning(f"type `{_type}` is not supported!")
814
+ except TypeError as e:
815
+ logger.warning(f"Type conversion failed for key {k}: {e}")
816
+ else:
817
+ # 保持对简单键值对的兼容
818
+ _value = v
819
+ self._embedding_fields[k] = False
820
+ except TypeError as e:
821
+ if isinstance(v, type):
822
+ _value = v()
823
+ else:
824
+ _value = v
825
+ self._embedding_fields[k] = False
826
+
827
+ _profile_config[k] = deepcopy(_value)
828
+ self._profile = ProfileMemory(
829
+ msg=_profile_config, activate_timestamp=activate_timestamp
830
+ )
831
+ if base is not None:
832
+ for k, v in base.items():
833
+ if k not in STATE_ATTRIBUTES:
834
+ logger.warning(f"key `{k}` is not a correct `base` field!")
835
+ continue
836
+ _state_config[k] = v
837
+
838
+ self._state = StateMemory(
839
+ msg=_state_config, activate_timestamp=activate_timestamp
840
+ )
841
+
842
+ # 组合 StatusMemory,并传递 embedding_fields 信息
843
+ self._status = StatusMemory(
844
+ profile=self._profile, state=self._state, dynamic=self._dynamic
845
+ )
846
+ self._status.set_embedding_fields(self._embedding_fields)
847
+ self._status.set_search_components(self._faiss_query, self._embedding_model)
848
+
849
+ # 新增 StreamMemory
850
+ self._stream = StreamMemory()
851
+ self._stream.set_status_memory(self._status)
852
+ self._stream.set_search_components(self._faiss_query, self._embedding_model)
853
+
854
+ def set_search_components(
855
+ self,
856
+ faiss_query: FaissQuery,
857
+ embedding_model: Embeddings,
858
+ ):
859
+ self._embedding_model = embedding_model
860
+ self._faiss_query = faiss_query
861
+ self._stream.set_search_components(faiss_query, embedding_model)
862
+ self._status.set_search_components(faiss_query, embedding_model)
863
+
864
+ def set_agent_id(self, agent_id: int):
411
865
  """
412
- if not self._embedding_model:
413
- return "Embedding model not initialized"
414
- top_results: list[tuple[str, float, dict]] = (
415
- await self.faiss_query.similarity_search( # type:ignore
416
- query=query,
417
- agent_id=self.agent_id,
418
- k=top_k,
419
- return_score_type="similarity_score",
420
- filter=filter,
866
+ Set the FaissQuery of the agent.
867
+ """
868
+ self._agent_id = agent_id
869
+ self._stream.set_agent_id(agent_id)
870
+ self._status.set_agent_id(agent_id)
871
+
872
+ def set_simulator(self, simulator):
873
+ self._simulator = simulator
874
+ self._stream.set_simulator(simulator)
875
+ self._status.set_simulator(simulator)
876
+
877
+ @property
878
+ def status(self) -> StatusMemory:
879
+ return self._status
880
+
881
+ @property
882
+ def stream(self) -> StreamMemory:
883
+ return self._stream
884
+
885
+ @property
886
+ def embedding_model(
887
+ self,
888
+ ):
889
+ if self._embedding_model is None:
890
+ raise RuntimeError(
891
+ f"embedding_model before assignment, please `set_embedding_model` first!"
421
892
  )
422
- )
423
- # 格式化输出
424
- formatted_results = []
425
- for content, score, metadata in top_results:
426
- formatted_results.append(
427
- f"- [{metadata['type']}] {content} " f"(相关度: {score:.2f})"
893
+ return self._embedding_model
894
+
895
+ @property
896
+ def agent_id(
897
+ self,
898
+ ):
899
+ if self._agent_id < 0:
900
+ raise RuntimeError(
901
+ f"agent_id before assignment, please `set_agent_id` first!"
428
902
  )
903
+ return self._agent_id
429
904
 
430
- return "\n".join(formatted_results)
905
+ @property
906
+ def faiss_query(self) -> FaissQuery:
907
+ """FaissQuery"""
908
+ if self._faiss_query is None:
909
+ raise RuntimeError(
910
+ f"FaissQuery access before assignment, please `set_faiss_query` first!"
911
+ )
912
+ return self._faiss_query
913
+
914
+ async def initialize_embeddings(self):
915
+ """初始化embedding"""
916
+ await self._status.initialize_embeddings()