pycityagent 2.0.0a52__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a53__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. pycityagent/agent/agent.py +48 -62
  2. pycityagent/agent/agent_base.py +66 -53
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +44 -56
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/nbsagent.py +6 -29
  17. pycityagent/cityagent/societyagent.py +204 -119
  18. pycityagent/environment/sim/client.py +10 -1
  19. pycityagent/environment/sim/clock_service.py +2 -2
  20. pycityagent/environment/sim/pause_service.py +61 -0
  21. pycityagent/environment/simulator.py +17 -12
  22. pycityagent/llm/embeddings.py +0 -24
  23. pycityagent/memory/faiss_query.py +29 -26
  24. pycityagent/memory/memory.py +720 -272
  25. pycityagent/pycityagent-sim +0 -0
  26. pycityagent/simulation/agentgroup.py +92 -99
  27. pycityagent/simulation/simulation.py +115 -40
  28. pycityagent/tools/tool.py +7 -9
  29. pycityagent/workflow/block.py +11 -4
  30. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +1 -1
  31. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
  32. pycityagent/cityagent/blocks/time_block.py +0 -116
  33. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
  34. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +0 -0
  35. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
  36. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,10 @@ import logging
3
3
  from collections import defaultdict
4
4
  from collections.abc import Callable, Sequence
5
5
  from copy import deepcopy
6
- from datetime import datetime
7
- from typing import Any, Literal, Optional, Union
6
+ from typing import Any, Literal, Optional, Union, Dict
7
+ from dataclasses import dataclass
8
+ from enum import Enum
8
9
 
9
- import numpy as np
10
10
  from langchain_core.embeddings import Embeddings
11
11
  from pyparsing import deque
12
12
 
@@ -19,197 +19,479 @@ from .state import StateMemory
19
19
 
20
20
  logger = logging.getLogger("pycityagent")
21
21
 
22
+ class MemoryTag(str, Enum):
23
+ """记忆标签枚举类"""
24
+ MOBILITY = "mobility"
25
+ SOCIAL = "social"
26
+ ECONOMY = "economy"
27
+ COGNITION = "cognition"
28
+ OTHER = "other"
29
+ EVENT = "event"
30
+
31
+ @dataclass
32
+ class MemoryNode:
33
+ """记忆节点"""
34
+ tag: MemoryTag
35
+ day: int
36
+ t: int
37
+ location: str
38
+ description: str
39
+ cognition_id: Optional[int] = None # 关联的认知记忆ID
40
+ id: Optional[int] = None # 记忆ID
41
+
42
+ class StreamMemory:
43
+ """用于存储时序性的流式信息"""
44
+ def __init__(self, max_len: int = 1000):
45
+ self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
46
+ self._memory_id_counter: int = 0 # 用于生成唯一ID
47
+ self._faiss_query = None
48
+ self._embedding_model = None
49
+ self._agent_id = -1
50
+ self._status_memory = None
51
+ self._simulator = None
52
+
53
+ def set_simulator(self, simulator):
54
+ self._simulator = simulator
55
+
56
+ def set_status_memory(self, status_memory):
57
+ self._status_memory = status_memory
58
+
59
+ def set_search_components(self, faiss_query, embedding_model):
60
+ """设置搜索所需的组件"""
61
+ self._faiss_query = faiss_query
62
+ self._embedding_model = embedding_model
22
63
 
23
- class Memory:
24
- """
25
- A class to manage different types of memory (state, profile, dynamic).
64
+ def set_agent_id(self, agent_id: int):
65
+ """设置agent_id"""
66
+ self._agent_id = agent_id
26
67
 
27
- Attributes:
28
- _state (StateMemory): Stores state-related data.
29
- _profile (ProfileMemory): Stores profile-related data.
30
- _dynamic (DynamicMemory): Stores dynamically configured data.
31
- """
68
+ async def _add_memory(self, tag: MemoryTag, description: str) -> int:
69
+ """添加记忆节点的通用方法,返回记忆节点ID"""
70
+ if self._simulator is not None:
71
+ day = int(await self._simulator.get_simulator_day())
72
+ t = int(await self._simulator.get_time())
73
+ else:
74
+ day = 1
75
+ t = 1
76
+ position = await self._status_memory.get("position")
77
+ if 'aoi_position' in position:
78
+ location = position['aoi_position']['aoi_id']
79
+ elif 'lane_position' in position:
80
+ location = position['lane_position']['lane_id']
81
+ else:
82
+ location = "unknown"
83
+
84
+ current_id = self._memory_id_counter
85
+ self._memory_id_counter += 1
86
+ memory_node = MemoryNode(
87
+ tag=tag,
88
+ day=day,
89
+ t=t,
90
+ location=location,
91
+ description=description,
92
+ id=current_id,
93
+ )
94
+ self._memories.append(memory_node)
95
+
96
+
97
+ # 为新记忆创建 embedding
98
+ if self._embedding_model and self._faiss_query:
99
+ await self._faiss_query.add_documents(
100
+ agent_id=self._agent_id,
101
+ documents=description,
102
+ extra_tags={
103
+ "type": "stream",
104
+ "tag": tag,
105
+ "day": day,
106
+ "time": t,
107
+ },
108
+ )
109
+
110
+ return current_id
111
+ async def add_cognition(self, description: str) -> None:
112
+ """添加认知记忆 Add cognition memory"""
113
+ return await self._add_memory(MemoryTag.COGNITION, description)
114
+
115
+ async def add_social(self, description: str) -> None:
116
+ """添加社交记忆 Add social memory"""
117
+ return await self._add_memory(MemoryTag.SOCIAL, description)
118
+
119
+ async def add_economy(self, description: str) -> None:
120
+ """添加经济记忆 Add economy memory"""
121
+ return await self._add_memory(MemoryTag.ECONOMY, description)
122
+
123
+ async def add_mobility(self, description: str) -> None:
124
+ """添加移动记忆 Add mobility memory"""
125
+ return await self._add_memory(MemoryTag.MOBILITY, description)
126
+
127
+ async def add_event(self, description: str) -> None:
128
+ """添加事件记忆 Add event memory"""
129
+ return await self._add_memory(MemoryTag.EVENT, description)
130
+
131
+ async def add_other(self, description: str) -> None:
132
+ """添加其他记忆 Add other memory"""
133
+ return await self._add_memory(MemoryTag.OTHER, description)
134
+
135
+ async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
136
+ """获取关联的认知记忆 Get related cognition memory"""
137
+ for memory in self._memories:
138
+ if memory.cognition_id == memory_id:
139
+ for cognition_memory in self._memories:
140
+ if (cognition_memory.tag == MemoryTag.COGNITION and
141
+ memory.cognition_id is not None):
142
+ return cognition_memory
143
+ return None
144
+
145
+ async def format_memory(self, memories: list[MemoryNode]) -> str:
146
+ """格式化记忆"""
147
+ formatted_results = []
148
+ for memory in memories:
149
+ memory_tag = memory.tag
150
+ memory_day = memory.day
151
+ memory_time_seconds = memory.t
152
+ cognition_id = memory.cognition_id
153
+
154
+ # 格式化时间
155
+ if memory_time_seconds != 'unknown':
156
+ hours = memory_time_seconds // 3600
157
+ minutes = (memory_time_seconds % 3600) // 60
158
+ seconds = memory_time_seconds % 60
159
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
160
+ else:
161
+ memory_time = 'unknown'
162
+
163
+ memory_location = memory.location
164
+
165
+ # 添加认知信息(如果存在)
166
+ cognition_info = ""
167
+ if cognition_id is not None:
168
+ cognition_memory = await self.get_related_cognition(cognition_id)
169
+ if cognition_memory:
170
+ cognition_info = f"\n Related cognition: {cognition_memory.description}"
171
+
172
+ formatted_results.append(
173
+ f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
174
+ f"location: {memory_location}]{cognition_info}"
175
+ )
176
+ return "\n".join(formatted_results)
32
177
 
33
- def __init__(
34
- self,
35
- config: Optional[dict[Any, Any]] = None,
36
- profile: Optional[dict[Any, Any]] = None,
37
- base: Optional[dict[Any, Any]] = None,
38
- motion: Optional[dict[Any, Any]] = None,
39
- activate_timestamp: bool = False,
40
- embedding_model: Optional[Embeddings] = None,
41
- faiss_query: Optional[FaissQuery] = None,
42
- ) -> None:
43
- """
44
- Initializes the Memory with optional configuration.
178
+ async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
179
+ """获取指定ID的记忆"""
180
+ memories = [memory for memory in self._memories if memory.id in memory_ids]
181
+ sorted_results = sorted(
182
+ memories,
183
+ key=lambda x: (x.day, x.t),
184
+ reverse=True
185
+ )
186
+ return self.format_memory(sorted_results)
45
187
 
188
+ async def search(
189
+ self,
190
+ query: str,
191
+ tag: Optional[MemoryTag] = None,
192
+ top_k: int = 3,
193
+ day_range: Optional[tuple[int, int]] = None, # 新增参数
194
+ time_range: Optional[tuple[int, int]] = None # 新增参数
195
+ ) -> str:
196
+ """Search stream memory
197
+
46
198
  Args:
47
- config (Optional[dict[Any, Any]], optional):
48
- A configuration dictionary for dynamic memory. The dictionary format is:
49
- - Key: The name of the dynamic memory field.
50
- - Value: Can be one of two formats:
51
- 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
52
- 2. A callable that returns the default value when invoked (useful for complex default values).
53
- Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
54
- Defaults to None.
55
- profile (Optional[dict[Any, Any]], optional): profile attribute dict.
56
- base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
57
- motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
58
- activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
59
- embedding_model (Embeddings): The embedding model for memory search.
60
- faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
199
+ query: Query text
200
+ tag: Optional memory tag for filtering specific types of memories
201
+ top_k: Number of most relevant memories to return
202
+ day_range: Optional tuple of start and end days (start_day, end_day)
203
+ time_range: Optional tuple of start and end times (start_time, end_time)
61
204
  """
62
- self.watchers: dict[str, list[Callable]] = {}
63
- self._lock = asyncio.Lock()
64
- self._agent_id: int = -1
65
- self._embedding_model = embedding_model
66
-
67
- _dynamic_config: dict[Any, Any] = {}
68
- _state_config: dict[Any, Any] = {}
69
- _profile_config: dict[Any, Any] = {}
70
- # 记录哪些字段需要embedding
71
- self._embedding_fields: dict[str, bool] = {}
72
- self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
73
- self._faiss_query = faiss_query
74
-
75
- if config is not None:
76
- for k, v in config.items():
77
- try:
78
- # 处理新的三元组格式
79
- if isinstance(v, tuple) and len(v) == 3:
80
- _type, _value, enable_embedding = v
81
- self._embedding_fields[k] = enable_embedding
82
- else:
83
- _type, _value = v
84
- self._embedding_fields[k] = False
85
-
86
- try:
87
- if isinstance(_type, type):
88
- _value = _type(_value)
89
- else:
90
- if isinstance(_type, deque):
91
- _type.extend(_value)
92
- _value = deepcopy(_type)
93
- else:
94
- logger.warning(f"type `{_type}` is not supported!")
95
- pass
96
- except TypeError as e:
97
- pass
98
- except TypeError as e:
99
- if isinstance(v, type):
100
- _value = v()
101
- else:
102
- _value = v
103
- self._embedding_fields[k] = False
104
-
105
- if (
106
- k in PROFILE_ATTRIBUTES
107
- or k in STATE_ATTRIBUTES
108
- or k == TIME_STAMP_KEY
109
- ):
110
- logger.warning(f"key `{k}` already declared in memory!")
111
- continue
112
-
113
- _dynamic_config[k] = deepcopy(_value)
114
-
115
- # 初始化各类记忆
116
- self._dynamic = DynamicMemory(
117
- required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
205
+ if not self._embedding_model or not self._faiss_query:
206
+ return "Search components not initialized"
207
+
208
+ filter_dict = {"type": "stream"}
209
+
210
+ if tag:
211
+ filter_dict["tag"] = tag
212
+
213
+ # 添加时间范围过滤
214
+ if day_range:
215
+ start_day, end_day = day_range
216
+ filter_dict["day"] = lambda x: start_day <= x <= end_day
217
+
218
+ if time_range:
219
+ start_time, end_time = time_range
220
+ filter_dict["time"] = lambda x: start_time <= x <= end_time
221
+
222
+ top_results = await self._faiss_query.similarity_search(
223
+ query=query,
224
+ agent_id=self._agent_id,
225
+ k=top_k,
226
+ return_score_type="similarity_score",
227
+ filter=filter_dict
118
228
  )
119
229
 
120
- if profile is not None:
121
- for k, v in profile.items():
122
- if k not in PROFILE_ATTRIBUTES:
123
- logger.warning(f"key `{k}` is not a correct `profile` field!")
124
- continue
125
- _profile_config[k] = v
126
- if motion is not None:
127
- for k, v in motion.items():
128
- if k not in STATE_ATTRIBUTES:
129
- logger.warning(f"key `{k}` is not a correct `motion` field!")
130
- continue
131
- _state_config[k] = v
132
- if base is not None:
133
- for k, v in base.items():
134
- if k not in STATE_ATTRIBUTES:
135
- logger.warning(f"key `{k}` is not a correct `base` field!")
136
- continue
137
- _state_config[k] = v
138
- self._state = StateMemory(
139
- msg=_state_config, activate_timestamp=activate_timestamp
230
+ # 将结果按时间排序(先按天数,再按时间)
231
+ sorted_results = sorted(
232
+ top_results,
233
+ key=lambda x: (x[2].get('day', 0), x[2].get('time', 0)),
234
+ reverse=True
140
235
  )
141
- self._profile = ProfileMemory(
142
- msg=_profile_config, activate_timestamp=activate_timestamp
143
- )
144
- # self.memories = [] # 存储记忆内容
145
- # self.embeddings = [] # 存储记忆的向量表示
146
-
147
- def set_embedding_model(
148
- self,
149
- embedding_model: Embeddings,
150
- ):
151
- self._embedding_model = embedding_model
152
-
153
- @property
154
- def embedding_model(
155
- self,
156
- ):
157
- if self._embedding_model is None:
158
- raise RuntimeError(
159
- f"embedding_model before assignment, please `set_embedding_model` first!"
236
+
237
+ formatted_results = []
238
+ for content, score, metadata in sorted_results:
239
+ memory_tag = metadata.get('tag', 'unknown')
240
+ memory_day = metadata.get('day', 'unknown')
241
+ memory_time_seconds = metadata.get('time', 'unknown')
242
+ cognition_id = metadata.get('cognition_id', None)
243
+
244
+ # 格式化时间
245
+ if memory_time_seconds != 'unknown':
246
+ hours = memory_time_seconds // 3600
247
+ minutes = (memory_time_seconds % 3600) // 60
248
+ seconds = memory_time_seconds % 60
249
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
250
+ else:
251
+ memory_time = 'unknown'
252
+
253
+ memory_location = metadata.get('location', 'unknown')
254
+
255
+ # 添加认知信息(如果存在)
256
+ cognition_info = ""
257
+ if cognition_id is not None:
258
+ cognition_memory = await self.get_related_cognition(cognition_id)
259
+ if cognition_memory:
260
+ cognition_info = f"\n Related cognition: {cognition_memory.description}"
261
+
262
+ formatted_results.append(
263
+ f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
264
+ f"location: {memory_location}]{cognition_info}"
160
265
  )
161
- return self._embedding_model
266
+ return "\n".join(formatted_results)
162
267
 
163
- def set_faiss_query(self, faiss_query: FaissQuery):
268
+ async def search_today(
269
+ self,
270
+ query: str = "", # 可选的查询文本
271
+ tag: Optional[MemoryTag] = None,
272
+ top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
273
+ ) -> str:
274
+ """Search all memory events from today
275
+
276
+ Args:
277
+ query: Optional query text, returns all memories of the day if empty
278
+ tag: Optional memory tag for filtering specific types of memories
279
+ top_k: Number of most relevant memories to return, defaults to 100
280
+
281
+ Returns:
282
+ str: Formatted text of today's memories
164
283
  """
165
- Set the FaissQuery of the agent.
284
+ if self._simulator is None:
285
+ return "Simulator not initialized"
286
+
287
+ current_day = int(await self._simulator.get_simulator_day())
288
+
289
+ # 使用 search 方法,设置 day_range 为当天
290
+ return await self.search(
291
+ query=query,
292
+ tag=tag,
293
+ top_k=top_k,
294
+ day_range=(current_day, current_day)
295
+ )
296
+
297
+ async def add_cognition_to_memory(self, memory_id: Union[int, list[int]], cognition: str) -> None:
298
+ """为已存在的记忆添加认知
299
+
300
+ Args:
301
+ memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
302
+ cognition: 认知描述
166
303
  """
167
- self._faiss_query = faiss_query
304
+ # 将单个ID转换为列表以统一处理
305
+ memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
306
+
307
+ # 找到所有对应的记忆
308
+ target_memories = []
309
+ for memory in self._memories:
310
+ if id(memory) in memory_ids:
311
+ target_memories.append(memory)
312
+
313
+ if not target_memories:
314
+ raise ValueError(f"No memories found with ids {memory_ids}")
315
+
316
+ # 添加认知记忆
317
+ cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
318
+
319
+ # 更新所有原记忆的认知ID
320
+ for target_memory in target_memories:
321
+ target_memory.cognition_id = cognition_id
322
+
323
+ async def get_all(self) -> list[MemoryNode]:
324
+ """获取所有流式信息"""
325
+ return list(self._memories)
326
+
327
+ class StatusMemory:
328
+ """组合现有的三种记忆类型"""
329
+ def __init__(self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory):
330
+ self.profile = profile
331
+ self.state = state
332
+ self.dynamic = dynamic
333
+ self._faiss_query = None
334
+ self._embedding_model = None
335
+ self._simulator = None
336
+ self._agent_id = -1
337
+ self._semantic_templates = {} # 用户可配置的模板
338
+ self._embedding_fields = {} # 需要 embedding 的字段
339
+ self._embedding_field_to_doc_id = defaultdict(str) # 新增
340
+ self.watchers = {} # 新增
341
+ self._lock = asyncio.Lock() # 新增
342
+
343
+ def set_simulator(self, simulator):
344
+ self._simulator = simulator
345
+
346
+ async def initialize_embeddings(self) -> None:
347
+ """初始化所有需要 embedding 的字段"""
348
+ if not self._embedding_model or not self._faiss_query:
349
+ logger.warning("Search components not initialized, skipping embeddings initialization")
350
+ return
351
+
352
+ # 获取所有状态信息
353
+ profile, state, dynamic = await self.export()
354
+
355
+ # 为每个需要 embedding 的字段创建 embedding
356
+ for key, value in profile[0].items():
357
+ if self.should_embed(key):
358
+ semantic_text = self._generate_semantic_text(key, value)
359
+ doc_ids = await self._faiss_query.add_documents(
360
+ agent_id=self._agent_id,
361
+ documents=semantic_text,
362
+ extra_tags={
363
+ "type": "profile_state",
364
+ "key": key,
365
+ },
366
+ )
367
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
368
+
369
+ for key, value in state[0].items():
370
+ if self.should_embed(key):
371
+ semantic_text = self._generate_semantic_text(key, value)
372
+ doc_ids = await self._faiss_query.add_documents(
373
+ agent_id=self._agent_id,
374
+ documents=semantic_text,
375
+ extra_tags={
376
+ "type": "profile_state",
377
+ "key": key,
378
+ },
379
+ )
380
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
381
+
382
+ for key, value in dynamic[0].items():
383
+ if self.should_embed(key):
384
+ semantic_text = self._generate_semantic_text(key, value)
385
+ doc_ids = await self._faiss_query.add_documents(
386
+ agent_id=self._agent_id,
387
+ documents=semantic_text,
388
+ extra_tags={
389
+ "type": "profile_state",
390
+ "key": key,
391
+ },
392
+ )
393
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
394
+
395
+ def _get_memory_type_by_key(self, key: str) -> str:
396
+ """根据键名确定记忆类型"""
397
+ try:
398
+ if key in self.profile.__dict__:
399
+ return "profile"
400
+ elif key in self.state.__dict__:
401
+ return "state"
402
+ else:
403
+ return "dynamic"
404
+ except:
405
+ return "dynamic"
168
406
 
169
- @property
170
- def agent_id(
171
- self,
172
- ):
173
- if self._agent_id < 0:
174
- raise RuntimeError(
175
- f"agent_id before assignment, please `set_agent_id` first!"
176
- )
177
- return self._agent_id
407
+ def set_search_components(self, faiss_query, embedding_model):
408
+ """设置搜索所需的组件"""
409
+ self._faiss_query = faiss_query
410
+ self._embedding_model = embedding_model
178
411
 
179
412
  def set_agent_id(self, agent_id: int):
413
+ """设置agent_id"""
414
+ self._agent_id = agent_id
415
+
416
+ def set_semantic_templates(self, templates: Dict[str, str]):
417
+ """设置语义模板
418
+
419
+ Args:
420
+ templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
180
421
  """
181
- Set the FaissQuery of the agent.
422
+ self._semantic_templates = templates
423
+
424
+ def _generate_semantic_text(self, key: str, value: Any) -> str:
425
+ """生成语义文本
426
+
427
+ 如果key存在于模板中,使用自定义模板
428
+ 否则使用默认模板 "my {key} is {value}"
182
429
  """
183
- self._agent_id = agent_id
430
+ if key in self._semantic_templates:
431
+ return self._semantic_templates[key].format(value)
432
+ return f"Your {key} is {value}"
433
+
434
+ @lock_decorator
435
+ async def search(
436
+ self, query: str, top_k: int = 3, filter: Optional[dict] = None
437
+ ) -> str:
438
+ """搜索相关记忆
184
439
 
185
- @property
186
- def faiss_query(self) -> FaissQuery:
187
- """FaissQuery"""
188
- if self._faiss_query is None:
189
- raise RuntimeError(
190
- f"FaissQuery access before assignment, please `set_faiss_query` first!"
440
+ Args:
441
+ query: 查询文本
442
+ top_k: 返回最相关的记忆数量
443
+ filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
444
+
445
+ Returns:
446
+ str: 格式化的相关记忆文本
447
+ """
448
+ if not self._embedding_model:
449
+ return "Embedding model not initialized"
450
+
451
+ filter_dict = {"type": "profile_state"}
452
+ if filter is not None:
453
+ filter_dict.update(filter)
454
+ top_results: list[tuple[str, float, dict]] = (
455
+ await self._faiss_query.similarity_search( # type:ignore
456
+ query=query,
457
+ agent_id=self._agent_id,
458
+ k=top_k,
459
+ return_score_type="similarity_score",
460
+ filter=filter_dict,
191
461
  )
192
- return self._faiss_query
462
+ )
463
+ # 格式化输出
464
+ formatted_results = []
465
+ for content, score, metadata in top_results:
466
+ formatted_results.append(
467
+ f"- {content} "
468
+ )
469
+
470
+ return "\n".join(formatted_results)
471
+
472
+ def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
473
+ """设置需要 embedding 的字段"""
474
+ self._embedding_fields = embedding_fields
475
+
476
+ def should_embed(self, key: str) -> bool:
477
+ """判断字段是否需要 embedding"""
478
+ return self._embedding_fields.get(key, False)
193
479
 
194
480
  @lock_decorator
195
- async def get(
196
- self,
197
- key: Any,
198
- mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
199
- ) -> Any:
200
- """
201
- Retrieves a value from memory based on the given key and access mode.
481
+ async def get(self, key: Any,
482
+ mode: Union[Literal["read only"], Literal["read and write"]] = "read only") -> Any:
483
+ """从记忆中获取值
202
484
 
203
485
  Args:
204
- key (Any): The key of the item to retrieve.
205
- mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
486
+ key: 要获取的键
487
+ mode: 访问模式,"read only" "read and write"
206
488
 
207
489
  Returns:
208
- Any: The value associated with the key.
490
+ 获取到的值
209
491
 
210
492
  Raises:
211
- ValueError: If an invalid mode is provided.
212
- KeyError: If the key is not found in any of the memory sections.
493
+ ValueError: 如果提供了无效的模式
494
+ KeyError: 如果在任何记忆部分都找不到该键
213
495
  """
214
496
  if mode == "read only":
215
497
  process_func = deepcopy
@@ -217,54 +499,56 @@ class Memory:
217
499
  process_func = lambda x: x
218
500
  else:
219
501
  raise ValueError(f"Invalid get mode `{mode}`!")
220
- for _mem in [self._state, self._profile, self._dynamic]:
502
+
503
+ for mem in [self.state, self.profile, self.dynamic]:
221
504
  try:
222
- value = await _mem.get(key)
505
+ value = await mem.get(key)
223
506
  return process_func(value)
224
- except KeyError as e:
507
+ except KeyError:
225
508
  continue
226
509
  raise KeyError(f"No attribute `{key}` in memories!")
227
510
 
228
511
  @lock_decorator
229
- async def update(
230
- self,
231
- key: Any,
232
- value: Any,
233
- mode: Union[Literal["replace"], Literal["merge"]] = "replace",
234
- store_snapshot: bool = False,
235
- protect_llm_read_only_fields: bool = True,
236
- ) -> None:
512
+ async def update(self, key: Any, value: Any,
513
+ mode: Union[Literal["replace"], Literal["merge"]] = "replace",
514
+ store_snapshot: bool = False,
515
+ protect_llm_read_only_fields: bool = True) -> None:
237
516
  """更新记忆值并在必要时更新embedding"""
238
517
  if protect_llm_read_only_fields:
239
518
  if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
240
519
  logger.warning(f"Trying to write protected key `{key}`!")
241
520
  return
242
- for _mem in [self._state, self._profile, self._dynamic]:
521
+
522
+ for mem in [self.state, self.profile, self.dynamic]:
243
523
  try:
244
- original_value = await _mem.get(key)
524
+ original_value = await mem.get(key)
245
525
  if mode == "replace":
246
- await _mem.update(key, value, store_snapshot)
247
- # 如果字段需要embedding,则更新embedding
248
- if self._embedding_fields.get(key, False) and self.embedding_model:
249
- memory_type = self._get_memory_type(_mem)
250
- # 覆盖更新删除原vector
526
+ await mem.update(key, value, store_snapshot)
527
+ if self.should_embed(key) and self._embedding_model:
528
+ semantic_text = self._generate_semantic_text(key, value)
529
+
530
+ # 删除旧的 embedding
251
531
  orig_doc_id = self._embedding_field_to_doc_id[key]
252
532
  if orig_doc_id:
253
- await self.faiss_query.delete_documents(
533
+ await self._faiss_query.delete_documents(
254
534
  to_delete_ids=[orig_doc_id],
255
535
  )
256
- doc_ids: list[str] = await self.faiss_query.add_documents(
257
- agent_id=self.agent_id,
258
- documents=f"{key}: {str(value)}",
536
+
537
+ # 添加新的 embedding
538
+ doc_ids = await self._faiss_query.add_documents(
539
+ agent_id=self._agent_id,
540
+ documents=semantic_text,
259
541
  extra_tags={
260
- "type": memory_type,
542
+ "type": self._get_memory_type(mem),
261
543
  "key": key,
262
544
  },
263
545
  )
264
546
  self._embedding_field_to_doc_id[key] = doc_ids[0]
547
+
265
548
  if key in self.watchers:
266
549
  for callback in self.watchers[key]:
267
550
  asyncio.create_task(callback())
551
+
268
552
  elif mode == "merge":
269
553
  if isinstance(original_value, set):
270
554
  original_value.update(set(value))
@@ -278,14 +562,14 @@ class Memory:
278
562
  logger.debug(
279
563
  f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
280
564
  )
281
- await _mem.update(key, value, store_snapshot)
282
- if self._embedding_fields.get(key, False) and self.embedding_model:
283
- memory_type = self._get_memory_type(_mem)
284
- doc_ids = await self.faiss_query.add_documents(
285
- agent_id=self.agent_id,
565
+ await mem.update(key, value, store_snapshot)
566
+ if self.should_embed(key) and self._embedding_model:
567
+ semantic_text = self._generate_semantic_text(key, value)
568
+ doc_ids = await self._faiss_query.add_documents(
569
+ agent_id=self._agent_id,
286
570
  documents=f"{key}: {str(original_value)}",
287
571
  extra_tags={
288
- "type": memory_type,
572
+ "type": self._get_memory_type(mem),
289
573
  "key": key,
290
574
  },
291
575
  )
@@ -302,63 +586,20 @@ class Memory:
302
586
 
303
587
  def _get_memory_type(self, mem: Any) -> str:
304
588
  """获取记忆类型"""
305
- if mem is self._state:
589
+ if mem is self.state:
306
590
  return "state"
307
- elif mem is self._profile:
591
+ elif mem is self.profile:
308
592
  return "profile"
309
593
  else:
310
594
  return "dynamic"
311
595
 
312
- async def update_batch(
313
- self,
314
- content: Union[dict, Sequence[tuple[Any, Any]]],
315
- mode: Union[Literal["replace"], Literal["merge"]] = "replace",
316
- store_snapshot: bool = False,
317
- protect_llm_read_only_fields: bool = True,
318
- ) -> None:
319
- """
320
- Updates multiple values in the memory at once.
321
-
322
- Args:
323
- content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
324
- mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
325
- store_snapshot (bool): Whether to store a snapshot of the memory after the update.
326
- protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
327
-
328
- Raises:
329
- TypeError: If the content type is neither a dictionary nor a sequence of tuples.
330
- """
331
- if isinstance(content, dict):
332
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
333
- elif isinstance(content, Sequence):
334
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
335
- else:
336
- raise TypeError(f"Invalid content type `{type(content)}`!")
337
- for k, v in _list_content[:1]:
338
- await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
339
- for k, v in _list_content[1:]:
340
- await self.update(k, v, mode, False, protect_llm_read_only_fields)
341
-
342
596
  @lock_decorator
343
597
  async def add_watcher(self, key: str, callback: Callable) -> None:
344
- """
345
- Adds a callback function to be invoked when the value
346
- associated with the specified key in memory is updated.
347
-
348
- Args:
349
- key (str): The key for which the watcher is being registered.
350
- callback (Callable): A callable function that will be executed
351
- whenever the value associated with the specified key is updated.
352
-
353
- Notes:
354
- If the key does not already have any watchers, it will be
355
- initialized with an empty list before appending the callback.
356
- """
598
+ """添加值变更的监听器"""
357
599
  if key not in self.watchers:
358
600
  self.watchers[key] = []
359
601
  self.watchers[key].append(callback)
360
602
 
361
- @lock_decorator
362
603
  async def export(
363
604
  self,
364
605
  ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
@@ -369,12 +610,11 @@ class Memory:
369
610
  tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
370
611
  """
371
612
  return (
372
- await self._profile.export(),
373
- await self._state.export(),
374
- await self._dynamic.export(),
613
+ await self.profile.export(),
614
+ await self.state.export(),
615
+ await self.dynamic.export(),
375
616
  )
376
617
 
377
- @lock_decorator
378
618
  async def load(
379
619
  self,
380
620
  snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
@@ -390,41 +630,249 @@ class Memory:
390
630
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
391
631
  for _snapshot, _mem in zip(
392
632
  [_profile_snapshot, _state_snapshot, _dynamic_snapshot],
393
- [self._state, self._profile, self._dynamic],
633
+ [self.state, self.profile, self.dynamic],
394
634
  ):
395
635
  if _snapshot:
396
636
  await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
397
637
 
398
- @lock_decorator
399
- async def search(
400
- self, query: str, top_k: int = 3, filter: Optional[dict] = None
401
- ) -> str:
402
- """搜索相关记忆
638
+ class Memory:
639
+ """
640
+ A class to manage different types of memory (state, profile, dynamic).
641
+
642
+ Attributes:
643
+ _state (StateMemory): Stores state-related data.
644
+ _profile (ProfileMemory): Stores profile-related data.
645
+ _dynamic (DynamicMemory): Stores dynamically configured data.
646
+ """
647
+
648
+ def __init__(
649
+ self,
650
+ config: Optional[dict[Any, Any]] = None,
651
+ profile: Optional[dict[Any, Any]] = None,
652
+ base: Optional[dict[Any, Any]] = None,
653
+ activate_timestamp: bool = False,
654
+ embedding_model: Optional[Embeddings] = None,
655
+ faiss_query: Optional[FaissQuery] = None,
656
+ ) -> None:
657
+ """
658
+ Initializes the Memory with optional configuration.
403
659
 
404
660
  Args:
405
- query: 查询文本
406
- top_k: 返回最相关的记忆数量
407
- filter (dict, optional): 记忆的筛选条件,如 {"type":"dynamic", "key":"self_define_1",},默认为空
661
+ config (Optional[dict[Any, Any]], optional):
662
+ A configuration dictionary for dynamic memory. The dictionary format is:
663
+ - Key: The name of the dynamic memory field.
664
+ - Value: Can be one of two formats:
665
+ 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
666
+ 2. A callable that returns the default value when invoked (useful for complex default values).
667
+ Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
668
+ Defaults to None.
669
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
670
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
671
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
672
+ activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
673
+ embedding_model (Embeddings): The embedding model for memory search.
674
+ faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
675
+ """
676
+ self.watchers: dict[str, list[Callable]] = {}
677
+ self._lock = asyncio.Lock()
678
+ self._agent_id: int = -1
679
+ self._simulator = None
680
+ self._embedding_model = embedding_model
681
+ self._faiss_query = faiss_query
682
+ self._semantic_templates: dict[str, str] = {}
683
+ _dynamic_config: dict[Any, Any] = {}
684
+ _state_config: dict[Any, Any] = {}
685
+ _profile_config: dict[Any, Any] = {}
686
+ self._embedding_fields: dict[str, bool] = {}
687
+ self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
408
688
 
409
- Returns:
410
- str: 格式化的相关记忆文本
689
+ if config is not None:
690
+ for k, v in config.items():
691
+ try:
692
+ # 处理不同长度的配置元组
693
+ if isinstance(v, tuple):
694
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
695
+ _type, _value, enable_embedding, template = v
696
+ self._embedding_fields[k] = enable_embedding
697
+ self._semantic_templates[k] = template
698
+ elif len(v) == 3: # (_type, _value, enable_embedding)
699
+ _type, _value, enable_embedding = v
700
+ self._embedding_fields[k] = enable_embedding
701
+ else: # (_type, _value)
702
+ _type, _value = v
703
+ self._embedding_fields[k] = False
704
+ else:
705
+ _type = type(v)
706
+ _value = v
707
+ self._embedding_fields[k] = False
708
+
709
+ # 处理类型转换
710
+ try:
711
+ if isinstance(_type, type):
712
+ _value = _type(_value)
713
+ else:
714
+ if isinstance(_type, deque):
715
+ _type.extend(_value)
716
+ _value = deepcopy(_type)
717
+ else:
718
+ logger.warning(f"type `{_type}` is not supported!")
719
+ except TypeError as e:
720
+ logger.warning(f"Type conversion failed for key {k}: {e}")
721
+ except TypeError as e:
722
+ if isinstance(v, type):
723
+ _value = v()
724
+ else:
725
+ _value = v
726
+ self._embedding_fields[k] = False
727
+
728
+ if (
729
+ k in PROFILE_ATTRIBUTES
730
+ or k in STATE_ATTRIBUTES
731
+ or k == TIME_STAMP_KEY
732
+ ):
733
+ logger.warning(f"key `{k}` already declared in memory!")
734
+ continue
735
+
736
+ _dynamic_config[k] = deepcopy(_value)
737
+
738
+ # 初始化各类记忆
739
+ self._dynamic = DynamicMemory(
740
+ required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
741
+ )
742
+
743
+ if profile is not None:
744
+ for k, v in profile.items():
745
+ if k not in PROFILE_ATTRIBUTES:
746
+ logger.warning(f"key `{k}` is not a correct `profile` field!")
747
+ continue
748
+
749
+ try:
750
+ # 处理配置元组格式
751
+ if isinstance(v, tuple):
752
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
753
+ _type, _value, enable_embedding, template = v
754
+ self._embedding_fields[k] = enable_embedding
755
+ self._semantic_templates[k] = template
756
+ elif len(v) == 3: # (_type, _value, enable_embedding)
757
+ _type, _value, enable_embedding = v
758
+ self._embedding_fields[k] = enable_embedding
759
+ else: # (_type, _value)
760
+ _type, _value = v
761
+ self._embedding_fields[k] = False
762
+
763
+ # 处理类型转换
764
+ try:
765
+ if isinstance(_type, type):
766
+ _value = _type(_value)
767
+ else:
768
+ if isinstance(_type, deque):
769
+ _type.extend(_value)
770
+ _value = deepcopy(_type)
771
+ else:
772
+ logger.warning(f"type `{_type}` is not supported!")
773
+ except TypeError as e:
774
+ logger.warning(f"Type conversion failed for key {k}: {e}")
775
+ else:
776
+ # 保持对简单键值对的兼容
777
+ _value = v
778
+ self._embedding_fields[k] = False
779
+ except TypeError as e:
780
+ if isinstance(v, type):
781
+ _value = v()
782
+ else:
783
+ _value = v
784
+ self._embedding_fields[k] = False
785
+
786
+ _profile_config[k] = deepcopy(_value)
787
+ self._profile = ProfileMemory(
788
+ msg=_profile_config, activate_timestamp=activate_timestamp
789
+ )
790
+
791
+ if base is not None:
792
+ for k, v in base.items():
793
+ if k not in STATE_ATTRIBUTES:
794
+ logger.warning(f"key `{k}` is not a correct `base` field!")
795
+ continue
796
+ _state_config[k] = v
797
+
798
+ self._state = StateMemory(
799
+ msg=_state_config, activate_timestamp=activate_timestamp
800
+ )
801
+
802
+ # 组合 StatusMemory,并传递 embedding_fields 信息
803
+ self._status = StatusMemory(
804
+ profile=self._profile,
805
+ state=self._state,
806
+ dynamic=self._dynamic
807
+ )
808
+ self._status.set_embedding_fields(self._embedding_fields)
809
+ self._status.set_search_components(self._faiss_query, self._embedding_model)
810
+
811
+ # 新增 StreamMemory
812
+ self._stream = StreamMemory()
813
+ self._stream.set_status_memory(self._status)
814
+ self._stream.set_search_components(self._faiss_query, self._embedding_model)
815
+
816
+ def set_search_components(
817
+ self,
818
+ faiss_query: FaissQuery,
819
+ embedding_model: Embeddings,
820
+ ):
821
+ self._embedding_model = embedding_model
822
+ self._faiss_query = faiss_query
823
+ self._stream.set_search_components(faiss_query, embedding_model)
824
+ self._status.set_search_components(faiss_query, embedding_model)
825
+
826
+ def set_agent_id(self, agent_id: int):
411
827
  """
412
- if not self._embedding_model:
413
- return "Embedding model not initialized"
414
- top_results: list[tuple[str, float, dict]] = (
415
- await self.faiss_query.similarity_search( # type:ignore
416
- query=query,
417
- agent_id=self.agent_id,
418
- k=top_k,
419
- return_score_type="similarity_score",
420
- filter=filter,
828
+ Set the FaissQuery of the agent.
829
+ """
830
+ self._agent_id = agent_id
831
+ self._stream.set_agent_id(agent_id)
832
+ self._status.set_agent_id(agent_id)
833
+
834
+ def set_simulator(self, simulator):
835
+ self._simulator = simulator
836
+ self._stream.set_simulator(simulator)
837
+ self._status.set_simulator(simulator)
838
+
839
+ @property
840
+ def status(self) -> StatusMemory:
841
+ return self._status
842
+
843
+ @property
844
+ def stream(self) -> StreamMemory:
845
+ return self._stream
846
+
847
+ @property
848
+ def embedding_model(
849
+ self,
850
+ ):
851
+ if self._embedding_model is None:
852
+ raise RuntimeError(
853
+ f"embedding_model before assignment, please `set_embedding_model` first!"
421
854
  )
422
- )
423
- # 格式化输出
424
- formatted_results = []
425
- for content, score, metadata in top_results:
426
- formatted_results.append(
427
- f"- [{metadata['type']}] {content} " f"(相关度: {score:.2f})"
855
+ return self._embedding_model
856
+
857
+ @property
858
+ def agent_id(
859
+ self,
860
+ ):
861
+ if self._agent_id < 0:
862
+ raise RuntimeError(
863
+ f"agent_id before assignment, please `set_agent_id` first!"
428
864
  )
865
+ return self._agent_id
429
866
 
430
- return "\n".join(formatted_results)
867
+ @property
868
+ def faiss_query(self) -> FaissQuery:
869
+ """FaissQuery"""
870
+ if self._faiss_query is None:
871
+ raise RuntimeError(
872
+ f"FaissQuery access before assignment, please `set_faiss_query` first!"
873
+ )
874
+ return self._faiss_query
875
+
876
+ async def initialize_embeddings(self):
877
+ """初始化embedding"""
878
+ await self._status.initialize_embeddings()