pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. pycityagent/agent/agent.py +83 -62
  2. pycityagent/agent/agent_base.py +81 -54
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +45 -57
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/message_intercept.py +99 -0
  17. pycityagent/cityagent/nbsagent.py +6 -29
  18. pycityagent/cityagent/societyagent.py +325 -127
  19. pycityagent/cli/wrapper.py +4 -0
  20. pycityagent/economy/econ_client.py +0 -2
  21. pycityagent/environment/__init__.py +7 -1
  22. pycityagent/environment/sim/client.py +10 -1
  23. pycityagent/environment/sim/clock_service.py +2 -2
  24. pycityagent/environment/sim/pause_service.py +61 -0
  25. pycityagent/environment/sim/sim_env.py +34 -46
  26. pycityagent/environment/simulator.py +18 -14
  27. pycityagent/llm/embeddings.py +0 -24
  28. pycityagent/llm/llm.py +18 -10
  29. pycityagent/memory/faiss_query.py +29 -26
  30. pycityagent/memory/memory.py +733 -247
  31. pycityagent/message/__init__.py +8 -1
  32. pycityagent/message/message_interceptor.py +322 -0
  33. pycityagent/message/messager.py +42 -11
  34. pycityagent/pycityagent-sim +0 -0
  35. pycityagent/simulation/agentgroup.py +137 -96
  36. pycityagent/simulation/simulation.py +184 -38
  37. pycityagent/simulation/storage/pg.py +2 -2
  38. pycityagent/tools/tool.py +7 -9
  39. pycityagent/utils/__init__.py +7 -2
  40. pycityagent/utils/pg_query.py +1 -0
  41. pycityagent/utils/survey_util.py +26 -23
  42. pycityagent/workflow/block.py +14 -7
  43. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
  44. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
  45. pycityagent/cityagent/blocks/time_block.py +0 -116
  46. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
  47. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
  48. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
  49. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from collections import defaultdict
4
- from collections.abc import Callable, Sequence
4
+ from collections.abc import Callable, Coroutine, Sequence
5
5
  from copy import deepcopy
6
- from datetime import datetime
7
- from typing import Any, Literal, Optional, Union
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Any, Dict, Literal, Optional, Union
8
9
 
9
- import numpy as np
10
10
  from langchain_core.embeddings import Embeddings
11
11
  from pyparsing import deque
12
12
 
@@ -20,176 +20,496 @@ from .state import StateMemory
20
20
  logger = logging.getLogger("pycityagent")
21
21
 
22
22
 
23
- class Memory:
24
- """
25
- A class to manage different types of memory (state, profile, dynamic).
23
+ class MemoryTag(str, Enum):
24
+ """记忆标签枚举类"""
26
25
 
27
- Attributes:
28
- _state (StateMemory): Stores state-related data.
29
- _profile (ProfileMemory): Stores profile-related data.
30
- _dynamic (DynamicMemory): Stores dynamically configured data.
31
- """
26
+ MOBILITY = "mobility"
27
+ SOCIAL = "social"
28
+ ECONOMY = "economy"
29
+ COGNITION = "cognition"
30
+ OTHER = "other"
31
+ EVENT = "event"
32
32
 
33
- def __init__(
33
+
34
+ @dataclass
35
+ class MemoryNode:
36
+ """记忆节点"""
37
+
38
+ tag: MemoryTag
39
+ day: int
40
+ t: int
41
+ location: str
42
+ description: str
43
+ cognition_id: Optional[int] = None # 关联的认知记忆ID
44
+ id: Optional[int] = None # 记忆ID
45
+
46
+
47
+ class StreamMemory:
48
+ """用于存储时序性的流式信息"""
49
+
50
+ def __init__(self, max_len: int = 1000):
51
+ self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
52
+ self._memory_id_counter: int = 0 # 用于生成唯一ID
53
+ self._faiss_query = None
54
+ self._embedding_model = None
55
+ self._agent_id = -1
56
+ self._status_memory = None
57
+ self._simulator = None
58
+
59
+ @property
60
+ def faiss_query(
34
61
  self,
35
- config: Optional[dict[Any, Any]] = None,
36
- profile: Optional[dict[Any, Any]] = None,
37
- base: Optional[dict[Any, Any]] = None,
38
- motion: Optional[dict[Any, Any]] = None,
39
- activate_timestamp: bool = False,
40
- embedding_model: Optional[Embeddings] = None,
41
- faiss_query: Optional[FaissQuery] = None,
42
- ) -> None:
43
- """
44
- Initializes the Memory with optional configuration.
62
+ ) -> FaissQuery:
63
+ assert self._faiss_query is not None
64
+ return self._faiss_query
45
65
 
46
- Args:
47
- config (Optional[dict[Any, Any]], optional):
48
- A configuration dictionary for dynamic memory. The dictionary format is:
49
- - Key: The name of the dynamic memory field.
50
- - Value: Can be one of two formats:
51
- 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
52
- 2. A callable that returns the default value when invoked (useful for complex default values).
53
- Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
54
- Defaults to None.
55
- profile (Optional[dict[Any, Any]], optional): profile attribute dict.
56
- base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
57
- motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
58
- activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
59
- embedding_model (Embeddings): The embedding model for memory search.
60
- faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
61
- """
62
- self.watchers: dict[str, list[Callable]] = {}
63
- self._lock = asyncio.Lock()
64
- self._agent_id: int = -1
65
- self._embedding_model = embedding_model
66
+ @property
67
+ def status_memory(
68
+ self,
69
+ ):
70
+ assert self._status_memory is not None
71
+ return self._status_memory
66
72
 
67
- _dynamic_config: dict[Any, Any] = {}
68
- _state_config: dict[Any, Any] = {}
69
- _profile_config: dict[Any, Any] = {}
70
- # 记录哪些字段需要embedding
71
- self._embedding_fields: dict[str, bool] = {}
72
- self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
73
+ def set_simulator(self, simulator):
74
+ self._simulator = simulator
75
+
76
+ def set_status_memory(self, status_memory):
77
+ self._status_memory = status_memory
78
+
79
+ def set_search_components(self, faiss_query, embedding_model):
80
+ """设置搜索所需的组件"""
73
81
  self._faiss_query = faiss_query
82
+ self._embedding_model = embedding_model
74
83
 
75
- if config is not None:
76
- for k, v in config.items():
77
- try:
78
- # 处理新的三元组格式
79
- if isinstance(v, tuple) and len(v) == 3:
80
- _type, _value, enable_embedding = v
81
- self._embedding_fields[k] = enable_embedding
82
- else:
83
- _type, _value = v
84
- self._embedding_fields[k] = False
84
+ def set_agent_id(self, agent_id: int):
85
+ """设置agent_id"""
86
+ self._agent_id = agent_id
85
87
 
86
- try:
87
- if isinstance(_type, type):
88
- _value = _type(_value)
89
- else:
90
- if isinstance(_type, deque):
91
- _type.extend(_value)
92
- _value = deepcopy(_type)
93
- else:
94
- logger.warning(f"type `{_type}` is not supported!")
95
- pass
96
- except TypeError as e:
97
- pass
98
- except TypeError as e:
99
- if isinstance(v, type):
100
- _value = v()
101
- else:
102
- _value = v
103
- self._embedding_fields[k] = False
88
+ async def _add_memory(self, tag: MemoryTag, description: str) -> int:
89
+ """添加记忆节点的通用方法,返回记忆节点ID"""
90
+ if self._simulator is not None:
91
+ day = int(await self._simulator.get_simulator_day())
92
+ t = int(await self._simulator.get_time())
93
+ else:
94
+ day = 1
95
+ t = 1
96
+ position = await self.status_memory.get("position")
97
+ if "aoi_position" in position:
98
+ location = position["aoi_position"]["aoi_id"]
99
+ elif "lane_position" in position:
100
+ location = position["lane_position"]["lane_id"]
101
+ else:
102
+ location = "unknown"
103
+
104
+ current_id = self._memory_id_counter
105
+ self._memory_id_counter += 1
106
+ memory_node = MemoryNode(
107
+ tag=tag,
108
+ day=day,
109
+ t=t,
110
+ location=location,
111
+ description=description,
112
+ id=current_id,
113
+ )
114
+ self._memories.append(memory_node)
115
+
116
+ # 为新记忆创建 embedding
117
+ if self._embedding_model and self._faiss_query:
118
+ await self.faiss_query.add_documents(
119
+ agent_id=self._agent_id,
120
+ documents=description,
121
+ extra_tags={
122
+ "type": "stream",
123
+ "tag": tag,
124
+ "day": day,
125
+ "time": t,
126
+ },
127
+ )
104
128
 
105
- if (
106
- k in PROFILE_ATTRIBUTES
107
- or k in STATE_ATTRIBUTES
108
- or k == TIME_STAMP_KEY
109
- ):
110
- logger.warning(f"key `{k}` already declared in memory!")
111
- continue
129
+ return current_id
130
+
131
+ async def add_cognition(self, description: str) -> int:
132
+ """添加认知记忆 Add cognition memory"""
133
+ return await self._add_memory(MemoryTag.COGNITION, description)
134
+
135
+ async def add_social(self, description: str) -> int:
136
+ """添加社交记忆 Add social memory"""
137
+ return await self._add_memory(MemoryTag.SOCIAL, description)
138
+
139
+ async def add_economy(self, description: str) -> int:
140
+ """添加经济记忆 Add economy memory"""
141
+ return await self._add_memory(MemoryTag.ECONOMY, description)
142
+
143
+ async def add_mobility(self, description: str) -> int:
144
+ """添加移动记忆 Add mobility memory"""
145
+ return await self._add_memory(MemoryTag.MOBILITY, description)
146
+
147
+ async def add_event(self, description: str) -> int:
148
+ """添加事件记忆 Add event memory"""
149
+ return await self._add_memory(MemoryTag.EVENT, description)
150
+
151
+ async def add_other(self, description: str) -> int:
152
+ """添加其他记忆 Add other memory"""
153
+ return await self._add_memory(MemoryTag.OTHER, description)
154
+
155
+ async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
156
+ """获取关联的认知记忆 Get related cognition memory"""
157
+ for memory in self._memories:
158
+ if memory.cognition_id == memory_id:
159
+ for cognition_memory in self._memories:
160
+ if (
161
+ cognition_memory.tag == MemoryTag.COGNITION
162
+ and memory.cognition_id is not None
163
+ ):
164
+ return cognition_memory
165
+ return None
166
+
167
+ async def format_memory(self, memories: list[MemoryNode]) -> str:
168
+ """格式化记忆"""
169
+ formatted_results = []
170
+ for memory in memories:
171
+ memory_tag = memory.tag
172
+ memory_day = memory.day
173
+ memory_time_seconds = memory.t
174
+ cognition_id = memory.cognition_id
175
+
176
+ # 格式化时间
177
+ if memory_time_seconds != "unknown":
178
+ hours = memory_time_seconds // 3600
179
+ minutes = (memory_time_seconds % 3600) // 60
180
+ seconds = memory_time_seconds % 60
181
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
182
+ else:
183
+ memory_time = "unknown"
184
+
185
+ memory_location = memory.location
186
+
187
+ # 添加认知信息(如果存在)
188
+ cognition_info = ""
189
+ if cognition_id is not None:
190
+ cognition_memory = await self.get_related_cognition(cognition_id)
191
+ if cognition_memory:
192
+ cognition_info = (
193
+ f"\n Related cognition: {cognition_memory.description}"
194
+ )
112
195
 
113
- _dynamic_config[k] = deepcopy(_value)
196
+ formatted_results.append(
197
+ f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
198
+ f"location: {memory_location}]{cognition_info}"
199
+ )
200
+ return "\n".join(formatted_results)
114
201
 
115
- # 初始化各类记忆
116
- self._dynamic = DynamicMemory(
117
- required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
118
- )
202
+ async def get_by_ids(
203
+ self, memory_ids: Union[int, list[int]]
204
+ ) -> Coroutine[Any, Any, str]:
205
+ """获取指定ID的记忆"""
206
+ memories = [memory for memory in self._memories if memory.id in memory_ids]
207
+ sorted_results = sorted(memories, key=lambda x: (x.day, x.t), reverse=True)
208
+ return self.format_memory(sorted_results)
119
209
 
120
- if profile is not None:
121
- for k, v in profile.items():
122
- if k not in PROFILE_ATTRIBUTES:
123
- logger.warning(f"key `{k}` is not a correct `profile` field!")
124
- continue
125
- _profile_config[k] = v
126
- if motion is not None:
127
- for k, v in motion.items():
128
- if k not in STATE_ATTRIBUTES:
129
- logger.warning(f"key `{k}` is not a correct `motion` field!")
130
- continue
131
- _state_config[k] = v
132
- if base is not None:
133
- for k, v in base.items():
134
- if k not in STATE_ATTRIBUTES:
135
- logger.warning(f"key `{k}` is not a correct `base` field!")
136
- continue
137
- _state_config[k] = v
138
- self._state = StateMemory(
139
- msg=_state_config, activate_timestamp=activate_timestamp
210
+ async def search(
211
+ self,
212
+ query: str,
213
+ tag: Optional[MemoryTag] = None,
214
+ top_k: int = 3,
215
+ day_range: Optional[tuple[int, int]] = None, # 新增参数
216
+ time_range: Optional[tuple[int, int]] = None, # 新增参数
217
+ ) -> str:
218
+ """Search stream memory
219
+
220
+ Args:
221
+ query: Query text
222
+ tag: Optional memory tag for filtering specific types of memories
223
+ top_k: Number of most relevant memories to return
224
+ day_range: Optional tuple of start and end days (start_day, end_day)
225
+ time_range: Optional tuple of start and end times (start_time, end_time)
226
+ """
227
+ if not self._embedding_model or not self._faiss_query:
228
+ return "Search components not initialized"
229
+
230
+ filter_dict: dict[str, Any] = {"type": "stream"}
231
+
232
+ if tag:
233
+ filter_dict["tag"] = tag
234
+
235
+ # 添加时间范围过滤
236
+ if day_range:
237
+ start_day, end_day = day_range
238
+ filter_dict["day"] = lambda x: start_day <= x <= end_day
239
+
240
+ if time_range:
241
+ start_time, end_time = time_range
242
+ filter_dict["time"] = lambda x: start_time <= x <= end_time
243
+
244
+ top_results = await self.faiss_query.similarity_search(
245
+ query=query,
246
+ agent_id=self._agent_id,
247
+ k=top_k,
248
+ return_score_type="similarity_score",
249
+ filter=filter_dict,
140
250
  )
141
- self._profile = ProfileMemory(
142
- msg=_profile_config, activate_timestamp=activate_timestamp
251
+
252
+ # 将结果按时间排序(先按天数,再按时间)
253
+ sorted_results = sorted(
254
+ top_results,
255
+ key=lambda x: (x[2].get("day", 0), x[2].get("time", 0)), # type:ignore
256
+ reverse=True,
143
257
  )
144
- # self.memories = [] # 存储记忆内容
145
- # self.embeddings = [] # 存储记忆的向量表示
146
258
 
147
- def set_embedding_model(
148
- self,
149
- embedding_model: Embeddings,
150
- ):
151
- self._embedding_model = embedding_model
259
+ formatted_results = []
260
+ for content, score, metadata in sorted_results: # type:ignore
261
+ memory_tag = metadata.get("tag", "unknown")
262
+ memory_day = metadata.get("day", "unknown")
263
+ memory_time_seconds = metadata.get("time", "unknown")
264
+ cognition_id = metadata.get("cognition_id", None)
265
+
266
+ # 格式化时间
267
+ if memory_time_seconds != "unknown":
268
+ hours = memory_time_seconds // 3600
269
+ minutes = (memory_time_seconds % 3600) // 60
270
+ seconds = memory_time_seconds % 60
271
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
272
+ else:
273
+ memory_time = "unknown"
274
+
275
+ memory_location = metadata.get("location", "unknown")
276
+
277
+ # 添加认知信息(如果存在)
278
+ cognition_info = ""
279
+ if cognition_id is not None:
280
+ cognition_memory = await self.get_related_cognition(cognition_id)
281
+ if cognition_memory:
282
+ cognition_info = (
283
+ f"\n Related cognition: {cognition_memory.description}"
284
+ )
152
285
 
153
- @property
154
- def embedding_model(
155
- self,
156
- ):
157
- if self._embedding_model is None:
158
- raise RuntimeError(
159
- f"embedding_model before assignment, please `set_embedding_model` first!"
286
+ formatted_results.append(
287
+ f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
288
+ f"location: {memory_location}]{cognition_info}"
160
289
  )
161
- return self._embedding_model
290
+ return "\n".join(formatted_results)
162
291
 
163
- def set_faiss_query(self, faiss_query: FaissQuery):
292
+ async def search_today(
293
+ self,
294
+ query: str = "", # 可选的查询文本
295
+ tag: Optional[MemoryTag] = None,
296
+ top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
297
+ ) -> str:
298
+ """Search all memory events from today
299
+
300
+ Args:
301
+ query: Optional query text, returns all memories of the day if empty
302
+ tag: Optional memory tag for filtering specific types of memories
303
+ top_k: Number of most relevant memories to return, defaults to 100
304
+
305
+ Returns:
306
+ str: Formatted text of today's memories
164
307
  """
165
- Set the FaissQuery of the agent.
308
+ if self._simulator is None:
309
+ return "Simulator not initialized"
310
+
311
+ current_day = int(await self._simulator.get_simulator_day())
312
+
313
+ # 使用 search 方法,设置 day_range 为当天
314
+ return await self.search(
315
+ query=query, tag=tag, top_k=top_k, day_range=(current_day, current_day)
316
+ )
317
+
318
+ async def add_cognition_to_memory(
319
+ self, memory_id: Union[int, list[int]], cognition: str
320
+ ) -> None:
321
+ """为已存在的记忆添加认知
322
+
323
+ Args:
324
+ memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
325
+ cognition: 认知描述
166
326
  """
167
- self._faiss_query = faiss_query
327
+ # 将单个ID转换为列表以统一处理
328
+ memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
329
+
330
+ # 找到所有对应的记忆
331
+ target_memories = []
332
+ for memory in self._memories:
333
+ if id(memory) in memory_ids:
334
+ target_memories.append(memory)
335
+
336
+ if not target_memories:
337
+ raise ValueError(f"No memories found with ids {memory_ids}")
338
+
339
+ # 添加认知记忆
340
+ cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
341
+
342
+ # 更新所有原记忆的认知ID
343
+ for target_memory in target_memories:
344
+ target_memory.cognition_id = cognition_id
345
+
346
+ async def get_all(self) -> list[MemoryNode]:
347
+ """获取所有流式信息"""
348
+ return list(self._memories)
349
+
350
+
351
+ class StatusMemory:
352
+ """组合现有的三种记忆类型"""
353
+
354
+ def __init__(
355
+ self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory
356
+ ):
357
+ self.profile = profile
358
+ self.state = state
359
+ self.dynamic = dynamic
360
+ self._faiss_query = None
361
+ self._embedding_model = None
362
+ self._simulator = None
363
+ self._agent_id = -1
364
+ self._semantic_templates = {} # 用户可配置的模板
365
+ self._embedding_fields = {} # 需要 embedding 的字段
366
+ self._embedding_field_to_doc_id = defaultdict(str) # 新增
367
+ self.watchers = {} # 新增
368
+ self._lock = asyncio.Lock() # 新增
168
369
 
169
370
  @property
170
- def agent_id(
371
+ def faiss_query(
171
372
  self,
172
- ):
173
- if self._agent_id < 0:
174
- raise RuntimeError(
175
- f"agent_id before assignment, please `set_agent_id` first!"
373
+ ) -> FaissQuery:
374
+ assert self._faiss_query is not None
375
+ return self._faiss_query
376
+
377
+ def set_simulator(self, simulator):
378
+ self._simulator = simulator
379
+
380
+ async def initialize_embeddings(self) -> None:
381
+ """初始化所有需要 embedding 的字段"""
382
+ if not self._embedding_model or not self._faiss_query:
383
+ logger.warning(
384
+ "Search components not initialized, skipping embeddings initialization"
176
385
  )
177
- return self._agent_id
386
+ return
387
+
388
+ # 获取所有状态信息
389
+ profile, state, dynamic = await self.export()
390
+
391
+ # 为每个需要 embedding 的字段创建 embedding
392
+ for key, value in profile[0].items():
393
+ if self.should_embed(key):
394
+ semantic_text = self._generate_semantic_text(key, value)
395
+ doc_ids = await self.faiss_query.add_documents(
396
+ agent_id=self._agent_id,
397
+ documents=semantic_text,
398
+ extra_tags={
399
+ "type": "profile_state",
400
+ "key": key,
401
+ },
402
+ )
403
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
404
+
405
+ for key, value in state[0].items():
406
+ if self.should_embed(key):
407
+ semantic_text = self._generate_semantic_text(key, value)
408
+ doc_ids = await self.faiss_query.add_documents(
409
+ agent_id=self._agent_id,
410
+ documents=semantic_text,
411
+ extra_tags={
412
+ "type": "profile_state",
413
+ "key": key,
414
+ },
415
+ )
416
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
417
+
418
+ for key, value in dynamic[0].items():
419
+ if self.should_embed(key):
420
+ semantic_text = self._generate_semantic_text(key, value)
421
+ doc_ids = await self.faiss_query.add_documents(
422
+ agent_id=self._agent_id,
423
+ documents=semantic_text,
424
+ extra_tags={
425
+ "type": "profile_state",
426
+ "key": key,
427
+ },
428
+ )
429
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
430
+
431
+ def _get_memory_type_by_key(self, key: str) -> str:
432
+ """根据键名确定记忆类型"""
433
+ try:
434
+ if key in self.profile.__dict__:
435
+ return "profile"
436
+ elif key in self.state.__dict__:
437
+ return "state"
438
+ else:
439
+ return "dynamic"
440
+ except:
441
+ return "dynamic"
442
+
443
+ def set_search_components(self, faiss_query, embedding_model):
444
+ """设置搜索所需的组件"""
445
+ self._faiss_query = faiss_query
446
+ self._embedding_model = embedding_model
178
447
 
179
448
  def set_agent_id(self, agent_id: int):
449
+ """设置agent_id"""
450
+ self._agent_id = agent_id
451
+
452
+ def set_semantic_templates(self, templates: Dict[str, str]):
453
+ """设置语义模板
454
+
455
+ Args:
456
+ templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
180
457
  """
181
- Set the FaissQuery of the agent.
458
+ self._semantic_templates = templates
459
+
460
+ def _generate_semantic_text(self, key: str, value: Any) -> str:
461
+ """生成语义文本
462
+
463
+ 如果key存在于模板中,使用自定义模板
464
+ 否则使用默认模板 "my {key} is {value}"
182
465
  """
183
- self._agent_id = agent_id
466
+ if key in self._semantic_templates:
467
+ return self._semantic_templates[key].format(value)
468
+ return f"Your {key} is {value}"
184
469
 
185
- @property
186
- def faiss_query(self) -> FaissQuery:
187
- """FaissQuery"""
188
- if self._faiss_query is None:
189
- raise RuntimeError(
190
- f"FaissQuery access before assignment, please `set_faiss_query` first!"
470
+ @lock_decorator
471
+ async def search(
472
+ self, query: str, top_k: int = 3, filter: Optional[dict] = None
473
+ ) -> str:
474
+ """搜索相关记忆
475
+
476
+ Args:
477
+ query: 查询文本
478
+ top_k: 返回最相关的记忆数量
479
+ filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
480
+
481
+ Returns:
482
+ str: 格式化的相关记忆文本
483
+ """
484
+ if not self._embedding_model:
485
+ return "Embedding model not initialized"
486
+
487
+ filter_dict = {"type": "profile_state"}
488
+ if filter is not None:
489
+ filter_dict.update(filter)
490
+ top_results: list[tuple[str, float, dict]] = (
491
+ await self.faiss_query.similarity_search( # type:ignore
492
+ query=query,
493
+ agent_id=self._agent_id,
494
+ k=top_k,
495
+ return_score_type="similarity_score",
496
+ filter=filter_dict,
191
497
  )
192
- return self._faiss_query
498
+ )
499
+ # 格式化输出
500
+ formatted_results = []
501
+ for content, score, metadata in top_results:
502
+ formatted_results.append(f"- {content} ")
503
+
504
+ return "\n".join(formatted_results)
505
+
506
+ def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
507
+ """设置需要 embedding 的字段"""
508
+ self._embedding_fields = embedding_fields
509
+
510
+ def should_embed(self, key: str) -> bool:
511
+ """判断字段是否需要 embedding"""
512
+ return self._embedding_fields.get(key, False)
193
513
 
194
514
  @lock_decorator
195
515
  async def get(
@@ -197,19 +517,18 @@ class Memory:
197
517
  key: Any,
198
518
  mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
199
519
  ) -> Any:
200
- """
201
- Retrieves a value from memory based on the given key and access mode.
520
+ """从记忆中获取值
202
521
 
203
522
  Args:
204
- key (Any): The key of the item to retrieve.
205
- mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
523
+ key: 要获取的键
524
+ mode: 访问模式,"read only" "read and write"
206
525
 
207
526
  Returns:
208
- Any: The value associated with the key.
527
+ 获取到的值
209
528
 
210
529
  Raises:
211
- ValueError: If an invalid mode is provided.
212
- KeyError: If the key is not found in any of the memory sections.
530
+ ValueError: 如果提供了无效的模式
531
+ KeyError: 如果在任何记忆部分都找不到该键
213
532
  """
214
533
  if mode == "read only":
215
534
  process_func = deepcopy
@@ -217,11 +536,12 @@ class Memory:
217
536
  process_func = lambda x: x
218
537
  else:
219
538
  raise ValueError(f"Invalid get mode `{mode}`!")
220
- for _mem in [self._state, self._profile, self._dynamic]:
539
+
540
+ for mem in [self.state, self.profile, self.dynamic]:
221
541
  try:
222
- value = await _mem.get(key)
542
+ value = await mem.get(key)
223
543
  return process_func(value)
224
- except KeyError as e:
544
+ except KeyError:
225
545
  continue
226
546
  raise KeyError(f"No attribute `{key}` in memories!")
227
547
 
@@ -239,32 +559,37 @@ class Memory:
239
559
  if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
240
560
  logger.warning(f"Trying to write protected key `{key}`!")
241
561
  return
242
- for _mem in [self._state, self._profile, self._dynamic]:
562
+
563
+ for mem in [self.state, self.profile, self.dynamic]:
243
564
  try:
244
- original_value = await _mem.get(key)
565
+ original_value = await mem.get(key)
245
566
  if mode == "replace":
246
- await _mem.update(key, value, store_snapshot)
247
- # 如果字段需要embedding,则更新embedding
248
- if self._embedding_fields.get(key, False) and self.embedding_model:
249
- memory_type = self._get_memory_type(_mem)
250
- # 覆盖更新删除原vector
567
+ await mem.update(key, value, store_snapshot)
568
+ if self.should_embed(key) and self._embedding_model:
569
+ semantic_text = self._generate_semantic_text(key, value)
570
+
571
+ # 删除旧的 embedding
251
572
  orig_doc_id = self._embedding_field_to_doc_id[key]
252
573
  if orig_doc_id:
253
574
  await self.faiss_query.delete_documents(
254
575
  to_delete_ids=[orig_doc_id],
255
576
  )
256
- doc_ids: list[str] = await self.faiss_query.add_documents(
257
- agent_id=self.agent_id,
258
- documents=f"{key}: {str(value)}",
577
+
578
+ # 添加新的 embedding
579
+ doc_ids = await self.faiss_query.add_documents(
580
+ agent_id=self._agent_id,
581
+ documents=semantic_text,
259
582
  extra_tags={
260
- "type": memory_type,
583
+ "type": self._get_memory_type(mem),
261
584
  "key": key,
262
585
  },
263
586
  )
264
587
  self._embedding_field_to_doc_id[key] = doc_ids[0]
588
+
265
589
  if key in self.watchers:
266
590
  for callback in self.watchers[key]:
267
591
  asyncio.create_task(callback())
592
+
268
593
  elif mode == "merge":
269
594
  if isinstance(original_value, set):
270
595
  original_value.update(set(value))
@@ -278,14 +603,14 @@ class Memory:
278
603
  logger.debug(
279
604
  f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
280
605
  )
281
- await _mem.update(key, value, store_snapshot)
282
- if self._embedding_fields.get(key, False) and self.embedding_model:
283
- memory_type = self._get_memory_type(_mem)
606
+ await mem.update(key, value, store_snapshot)
607
+ if self.should_embed(key) and self._embedding_model:
608
+ semantic_text = self._generate_semantic_text(key, value)
284
609
  doc_ids = await self.faiss_query.add_documents(
285
- agent_id=self.agent_id,
610
+ agent_id=self._agent_id,
286
611
  documents=f"{key}: {str(original_value)}",
287
612
  extra_tags={
288
- "type": memory_type,
613
+ "type": self._get_memory_type(mem),
289
614
  "key": key,
290
615
  },
291
616
  )
@@ -302,63 +627,20 @@ class Memory:
302
627
 
303
628
  def _get_memory_type(self, mem: Any) -> str:
304
629
  """获取记忆类型"""
305
- if mem is self._state:
630
+ if mem is self.state:
306
631
  return "state"
307
- elif mem is self._profile:
632
+ elif mem is self.profile:
308
633
  return "profile"
309
634
  else:
310
635
  return "dynamic"
311
636
 
312
- async def update_batch(
313
- self,
314
- content: Union[dict, Sequence[tuple[Any, Any]]],
315
- mode: Union[Literal["replace"], Literal["merge"]] = "replace",
316
- store_snapshot: bool = False,
317
- protect_llm_read_only_fields: bool = True,
318
- ) -> None:
319
- """
320
- Updates multiple values in the memory at once.
321
-
322
- Args:
323
- content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
324
- mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
325
- store_snapshot (bool): Whether to store a snapshot of the memory after the update.
326
- protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
327
-
328
- Raises:
329
- TypeError: If the content type is neither a dictionary nor a sequence of tuples.
330
- """
331
- if isinstance(content, dict):
332
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
333
- elif isinstance(content, Sequence):
334
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
335
- else:
336
- raise TypeError(f"Invalid content type `{type(content)}`!")
337
- for k, v in _list_content[:1]:
338
- await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
339
- for k, v in _list_content[1:]:
340
- await self.update(k, v, mode, False, protect_llm_read_only_fields)
341
-
342
637
  @lock_decorator
343
638
  async def add_watcher(self, key: str, callback: Callable) -> None:
344
- """
345
- Adds a callback function to be invoked when the value
346
- associated with the specified key in memory is updated.
347
-
348
- Args:
349
- key (str): The key for which the watcher is being registered.
350
- callback (Callable): A callable function that will be executed
351
- whenever the value associated with the specified key is updated.
352
-
353
- Notes:
354
- If the key does not already have any watchers, it will be
355
- initialized with an empty list before appending the callback.
356
- """
639
+ """添加值变更的监听器"""
357
640
  if key not in self.watchers:
358
641
  self.watchers[key] = []
359
642
  self.watchers[key].append(callback)
360
643
 
361
- @lock_decorator
362
644
  async def export(
363
645
  self,
364
646
  ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
@@ -369,12 +651,11 @@ class Memory:
369
651
  tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
370
652
  """
371
653
  return (
372
- await self._profile.export(),
373
- await self._state.export(),
374
- await self._dynamic.export(),
654
+ await self.profile.export(),
655
+ await self.state.export(),
656
+ await self.dynamic.export(),
375
657
  )
376
658
 
377
- @lock_decorator
378
659
  async def load(
379
660
  self,
380
661
  snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
@@ -390,41 +671,246 @@ class Memory:
390
671
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
391
672
  for _snapshot, _mem in zip(
392
673
  [_profile_snapshot, _state_snapshot, _dynamic_snapshot],
393
- [self._state, self._profile, self._dynamic],
674
+ [self.state, self.profile, self.dynamic],
394
675
  ):
395
676
  if _snapshot:
396
677
  await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
397
678
 
398
- @lock_decorator
399
- async def search(
400
- self, query: str, top_k: int = 3, filter: Optional[dict] = None
401
- ) -> str:
402
- """搜索相关记忆
679
+
680
+ class Memory:
681
+ """
682
+ A class to manage different types of memory (state, profile, dynamic).
683
+
684
+ Attributes:
685
+ _state (StateMemory): Stores state-related data.
686
+ _profile (ProfileMemory): Stores profile-related data.
687
+ _dynamic (DynamicMemory): Stores dynamically configured data.
688
+ """
689
+
690
+ def __init__(
691
+ self,
692
+ config: Optional[dict[Any, Any]] = None,
693
+ profile: Optional[dict[Any, Any]] = None,
694
+ base: Optional[dict[Any, Any]] = None,
695
+ activate_timestamp: bool = False,
696
+ embedding_model: Optional[Embeddings] = None,
697
+ faiss_query: Optional[FaissQuery] = None,
698
+ ) -> None:
699
+ """
700
+ Initializes the Memory with optional configuration.
403
701
 
404
702
  Args:
405
- query: 查询文本
406
- top_k: 返回最相关的记忆数量
407
- filter (dict, optional): 记忆的筛选条件,如 {"type":"dynamic", "key":"self_define_1",},默认为空
703
+ config (Optional[dict[Any, Any]], optional):
704
+ A configuration dictionary for dynamic memory. The dictionary format is:
705
+ - Key: The name of the dynamic memory field.
706
+ - Value: Can be one of two formats:
707
+ 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
708
+ 2. A callable that returns the default value when invoked (useful for complex default values).
709
+ Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
710
+ Defaults to None.
711
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
712
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
713
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
714
+ activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
715
+ embedding_model (Embeddings): The embedding model for memory search.
716
+ faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
717
+ """
718
+ self.watchers: dict[str, list[Callable]] = {}
719
+ self._lock = asyncio.Lock()
720
+ self._agent_id: int = -1
721
+ self._simulator = None
722
+ self._embedding_model = embedding_model
723
+ self._faiss_query = faiss_query
724
+ self._semantic_templates: dict[str, str] = {}
725
+ _dynamic_config: dict[Any, Any] = {}
726
+ _state_config: dict[Any, Any] = {}
727
+ _profile_config: dict[Any, Any] = {}
728
+ self._embedding_fields: dict[str, bool] = {}
729
+ self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
408
730
 
409
- Returns:
410
- str: 格式化的相关记忆文本
731
+ if config is not None:
732
+ for k, v in config.items():
733
+ try:
734
+ # 处理不同长度的配置元组
735
+ if isinstance(v, tuple):
736
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
737
+ _type, _value, enable_embedding, template = v
738
+ self._embedding_fields[k] = enable_embedding
739
+ self._semantic_templates[k] = template
740
+ elif len(v) == 3: # (_type, _value, enable_embedding)
741
+ _type, _value, enable_embedding = v
742
+ self._embedding_fields[k] = enable_embedding
743
+ else: # (_type, _value)
744
+ _type, _value = v
745
+ self._embedding_fields[k] = False
746
+ else:
747
+ _type = type(v)
748
+ _value = v
749
+ self._embedding_fields[k] = False
750
+
751
+ # 处理类型转换
752
+ try:
753
+ if isinstance(_type, type):
754
+ _value = _type(_value)
755
+ else:
756
+ if isinstance(_type, deque):
757
+ _type.extend(_value)
758
+ _value = deepcopy(_type)
759
+ else:
760
+ logger.warning(f"type `{_type}` is not supported!")
761
+ except TypeError as e:
762
+ logger.warning(f"Type conversion failed for key {k}: {e}")
763
+ except TypeError as e:
764
+ if isinstance(v, type):
765
+ _value = v()
766
+ else:
767
+ _value = v
768
+ self._embedding_fields[k] = False
769
+
770
+ if (
771
+ k in PROFILE_ATTRIBUTES
772
+ or k in STATE_ATTRIBUTES
773
+ or k == TIME_STAMP_KEY
774
+ ):
775
+ logger.warning(f"key `{k}` already declared in memory!")
776
+ continue
777
+
778
+ _dynamic_config[k] = deepcopy(_value)
779
+
780
+ # 初始化各类记忆
781
+ self._dynamic = DynamicMemory(
782
+ required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
783
+ )
784
+
785
+ if profile is not None:
786
+ for k, v in profile.items():
787
+ if k not in PROFILE_ATTRIBUTES:
788
+ logger.warning(f"key `{k}` is not a correct `profile` field!")
789
+ continue
790
+ try:
791
+ # 处理配置元组格式
792
+ if isinstance(v, tuple):
793
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
794
+ _type, _value, enable_embedding, template = v
795
+ self._embedding_fields[k] = enable_embedding
796
+ self._semantic_templates[k] = template
797
+ elif len(v) == 3: # (_type, _value, enable_embedding)
798
+ _type, _value, enable_embedding = v
799
+ self._embedding_fields[k] = enable_embedding
800
+ else: # (_type, _value)
801
+ _type, _value = v
802
+ self._embedding_fields[k] = False
803
+
804
+ # 处理类型转换
805
+ try:
806
+ if isinstance(_type, type):
807
+ _value = _type(_value)
808
+ else:
809
+ if isinstance(_type, deque):
810
+ _type.extend(_value)
811
+ _value = deepcopy(_type)
812
+ else:
813
+ logger.warning(f"type `{_type}` is not supported!")
814
+ except TypeError as e:
815
+ logger.warning(f"Type conversion failed for key {k}: {e}")
816
+ else:
817
+ # 保持对简单键值对的兼容
818
+ _value = v
819
+ self._embedding_fields[k] = False
820
+ except TypeError as e:
821
+ if isinstance(v, type):
822
+ _value = v()
823
+ else:
824
+ _value = v
825
+ self._embedding_fields[k] = False
826
+
827
+ _profile_config[k] = deepcopy(_value)
828
+ self._profile = ProfileMemory(
829
+ msg=_profile_config, activate_timestamp=activate_timestamp
830
+ )
831
+ if base is not None:
832
+ for k, v in base.items():
833
+ if k not in STATE_ATTRIBUTES:
834
+ logger.warning(f"key `{k}` is not a correct `base` field!")
835
+ continue
836
+ _state_config[k] = v
837
+
838
+ self._state = StateMemory(
839
+ msg=_state_config, activate_timestamp=activate_timestamp
840
+ )
841
+
842
+ # 组合 StatusMemory,并传递 embedding_fields 信息
843
+ self._status = StatusMemory(
844
+ profile=self._profile, state=self._state, dynamic=self._dynamic
845
+ )
846
+ self._status.set_embedding_fields(self._embedding_fields)
847
+ self._status.set_search_components(self._faiss_query, self._embedding_model)
848
+
849
+ # 新增 StreamMemory
850
+ self._stream = StreamMemory()
851
+ self._stream.set_status_memory(self._status)
852
+ self._stream.set_search_components(self._faiss_query, self._embedding_model)
853
+
854
+ def set_search_components(
855
+ self,
856
+ faiss_query: FaissQuery,
857
+ embedding_model: Embeddings,
858
+ ):
859
+ self._embedding_model = embedding_model
860
+ self._faiss_query = faiss_query
861
+ self._stream.set_search_components(faiss_query, embedding_model)
862
+ self._status.set_search_components(faiss_query, embedding_model)
863
+
864
+ def set_agent_id(self, agent_id: int):
411
865
  """
412
- if not self._embedding_model:
413
- return "Embedding model not initialized"
414
- top_results: list[tuple[str, float, dict]] = (
415
- await self.faiss_query.similarity_search( # type:ignore
416
- query=query,
417
- agent_id=self.agent_id,
418
- k=top_k,
419
- return_score_type="similarity_score",
420
- filter=filter,
866
+ Set the FaissQuery of the agent.
867
+ """
868
+ self._agent_id = agent_id
869
+ self._stream.set_agent_id(agent_id)
870
+ self._status.set_agent_id(agent_id)
871
+
872
+ def set_simulator(self, simulator):
873
+ self._simulator = simulator
874
+ self._stream.set_simulator(simulator)
875
+ self._status.set_simulator(simulator)
876
+
877
+ @property
878
+ def status(self) -> StatusMemory:
879
+ return self._status
880
+
881
+ @property
882
+ def stream(self) -> StreamMemory:
883
+ return self._stream
884
+
885
+ @property
886
+ def embedding_model(
887
+ self,
888
+ ):
889
+ if self._embedding_model is None:
890
+ raise RuntimeError(
891
+ f"embedding_model before assignment, please `set_embedding_model` first!"
421
892
  )
422
- )
423
- # 格式化输出
424
- formatted_results = []
425
- for content, score, metadata in top_results:
426
- formatted_results.append(
427
- f"- [{metadata['type']}] {content} " f"(相关度: {score:.2f})"
893
+ return self._embedding_model
894
+
895
+ @property
896
+ def agent_id(
897
+ self,
898
+ ):
899
+ if self._agent_id < 0:
900
+ raise RuntimeError(
901
+ f"agent_id before assignment, please `set_agent_id` first!"
428
902
  )
903
+ return self._agent_id
429
904
 
430
- return "\n".join(formatted_results)
905
+ @property
906
+ def faiss_query(self) -> FaissQuery:
907
+ """FaissQuery"""
908
+ if self._faiss_query is None:
909
+ raise RuntimeError(
910
+ f"FaissQuery access before assignment, please `set_faiss_query` first!"
911
+ )
912
+ return self._faiss_query
913
+
914
+ async def initialize_embeddings(self):
915
+ """初始化embedding"""
916
+ await self._status.initialize_embeddings()