pycityagent 2.0.0a51__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a53__cp39-cp39-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (36) hide show
  1. pycityagent/agent/agent.py +48 -62
  2. pycityagent/agent/agent_base.py +66 -53
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +44 -56
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/nbsagent.py +6 -29
  17. pycityagent/cityagent/societyagent.py +204 -119
  18. pycityagent/environment/sim/client.py +10 -1
  19. pycityagent/environment/sim/clock_service.py +2 -2
  20. pycityagent/environment/sim/pause_service.py +61 -0
  21. pycityagent/environment/simulator.py +17 -12
  22. pycityagent/llm/embeddings.py +0 -24
  23. pycityagent/memory/faiss_query.py +29 -26
  24. pycityagent/memory/memory.py +720 -272
  25. pycityagent/pycityagent-sim +0 -0
  26. pycityagent/simulation/agentgroup.py +92 -99
  27. pycityagent/simulation/simulation.py +115 -40
  28. pycityagent/tools/tool.py +7 -10
  29. pycityagent/workflow/block.py +11 -4
  30. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +2 -2
  31. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
  32. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +1 -1
  33. pycityagent/cityagent/blocks/time_block.py +0 -116
  34. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
  35. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
  36. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,10 @@ import logging
3
3
  from collections import defaultdict
4
4
  from collections.abc import Callable, Sequence
5
5
  from copy import deepcopy
6
- from datetime import datetime
7
- from typing import Any, Literal, Optional, Union
6
+ from typing import Any, Literal, Optional, Union, Dict
7
+ from dataclasses import dataclass
8
+ from enum import Enum
8
9
 
9
- import numpy as np
10
10
  from langchain_core.embeddings import Embeddings
11
11
  from pyparsing import deque
12
12
 
@@ -19,197 +19,479 @@ from .state import StateMemory
19
19
 
20
20
  logger = logging.getLogger("pycityagent")
21
21
 
22
+ class MemoryTag(str, Enum):
23
+ """记忆标签枚举类"""
24
+ MOBILITY = "mobility"
25
+ SOCIAL = "social"
26
+ ECONOMY = "economy"
27
+ COGNITION = "cognition"
28
+ OTHER = "other"
29
+ EVENT = "event"
30
+
31
+ @dataclass
32
+ class MemoryNode:
33
+ """记忆节点"""
34
+ tag: MemoryTag
35
+ day: int
36
+ t: int
37
+ location: str
38
+ description: str
39
+ cognition_id: Optional[int] = None # 关联的认知记忆ID
40
+ id: Optional[int] = None # 记忆ID
41
+
42
+ class StreamMemory:
43
+ """用于存储时序性的流式信息"""
44
+ def __init__(self, max_len: int = 1000):
45
+ self._memories: deque = deque(maxlen=max_len) # 限制最大存储量
46
+ self._memory_id_counter: int = 0 # 用于生成唯一ID
47
+ self._faiss_query = None
48
+ self._embedding_model = None
49
+ self._agent_id = -1
50
+ self._status_memory = None
51
+ self._simulator = None
52
+
53
+ def set_simulator(self, simulator):
54
+ self._simulator = simulator
55
+
56
+ def set_status_memory(self, status_memory):
57
+ self._status_memory = status_memory
58
+
59
+ def set_search_components(self, faiss_query, embedding_model):
60
+ """设置搜索所需的组件"""
61
+ self._faiss_query = faiss_query
62
+ self._embedding_model = embedding_model
22
63
 
23
- class Memory:
24
- """
25
- A class to manage different types of memory (state, profile, dynamic).
64
+ def set_agent_id(self, agent_id: int):
65
+ """设置agent_id"""
66
+ self._agent_id = agent_id
26
67
 
27
- Attributes:
28
- _state (StateMemory): Stores state-related data.
29
- _profile (ProfileMemory): Stores profile-related data.
30
- _dynamic (DynamicMemory): Stores dynamically configured data.
31
- """
68
+ async def _add_memory(self, tag: MemoryTag, description: str) -> int:
69
+ """添加记忆节点的通用方法,返回记忆节点ID"""
70
+ if self._simulator is not None:
71
+ day = int(await self._simulator.get_simulator_day())
72
+ t = int(await self._simulator.get_time())
73
+ else:
74
+ day = 1
75
+ t = 1
76
+ position = await self._status_memory.get("position")
77
+ if 'aoi_position' in position:
78
+ location = position['aoi_position']['aoi_id']
79
+ elif 'lane_position' in position:
80
+ location = position['lane_position']['lane_id']
81
+ else:
82
+ location = "unknown"
83
+
84
+ current_id = self._memory_id_counter
85
+ self._memory_id_counter += 1
86
+ memory_node = MemoryNode(
87
+ tag=tag,
88
+ day=day,
89
+ t=t,
90
+ location=location,
91
+ description=description,
92
+ id=current_id,
93
+ )
94
+ self._memories.append(memory_node)
95
+
96
+
97
+ # 为新记忆创建 embedding
98
+ if self._embedding_model and self._faiss_query:
99
+ await self._faiss_query.add_documents(
100
+ agent_id=self._agent_id,
101
+ documents=description,
102
+ extra_tags={
103
+ "type": "stream",
104
+ "tag": tag,
105
+ "day": day,
106
+ "time": t,
107
+ },
108
+ )
109
+
110
+ return current_id
111
+ async def add_cognition(self, description: str) -> None:
112
+ """添加认知记忆 Add cognition memory"""
113
+ return await self._add_memory(MemoryTag.COGNITION, description)
114
+
115
+ async def add_social(self, description: str) -> None:
116
+ """添加社交记忆 Add social memory"""
117
+ return await self._add_memory(MemoryTag.SOCIAL, description)
118
+
119
+ async def add_economy(self, description: str) -> None:
120
+ """添加经济记忆 Add economy memory"""
121
+ return await self._add_memory(MemoryTag.ECONOMY, description)
122
+
123
+ async def add_mobility(self, description: str) -> None:
124
+ """添加移动记忆 Add mobility memory"""
125
+ return await self._add_memory(MemoryTag.MOBILITY, description)
126
+
127
+ async def add_event(self, description: str) -> None:
128
+ """添加事件记忆 Add event memory"""
129
+ return await self._add_memory(MemoryTag.EVENT, description)
130
+
131
+ async def add_other(self, description: str) -> None:
132
+ """添加其他记忆 Add other memory"""
133
+ return await self._add_memory(MemoryTag.OTHER, description)
134
+
135
+ async def get_related_cognition(self, memory_id: int) -> Optional[MemoryNode]:
136
+ """获取关联的认知记忆 Get related cognition memory"""
137
+ for memory in self._memories:
138
+ if memory.cognition_id == memory_id:
139
+ for cognition_memory in self._memories:
140
+ if (cognition_memory.tag == MemoryTag.COGNITION and
141
+ memory.cognition_id is not None):
142
+ return cognition_memory
143
+ return None
144
+
145
+ async def format_memory(self, memories: list[MemoryNode]) -> str:
146
+ """格式化记忆"""
147
+ formatted_results = []
148
+ for memory in memories:
149
+ memory_tag = memory.tag
150
+ memory_day = memory.day
151
+ memory_time_seconds = memory.t
152
+ cognition_id = memory.cognition_id
153
+
154
+ # 格式化时间
155
+ if memory_time_seconds != 'unknown':
156
+ hours = memory_time_seconds // 3600
157
+ minutes = (memory_time_seconds % 3600) // 60
158
+ seconds = memory_time_seconds % 60
159
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
160
+ else:
161
+ memory_time = 'unknown'
162
+
163
+ memory_location = memory.location
164
+
165
+ # 添加认知信息(如果存在)
166
+ cognition_info = ""
167
+ if cognition_id is not None:
168
+ cognition_memory = await self.get_related_cognition(cognition_id)
169
+ if cognition_memory:
170
+ cognition_info = f"\n Related cognition: {cognition_memory.description}"
171
+
172
+ formatted_results.append(
173
+ f"- [{memory_tag}]: {memory.description} [day: {memory_day}, time: {memory_time}, "
174
+ f"location: {memory_location}]{cognition_info}"
175
+ )
176
+ return "\n".join(formatted_results)
32
177
 
33
- def __init__(
34
- self,
35
- config: Optional[dict[Any, Any]] = None,
36
- profile: Optional[dict[Any, Any]] = None,
37
- base: Optional[dict[Any, Any]] = None,
38
- motion: Optional[dict[Any, Any]] = None,
39
- activate_timestamp: bool = False,
40
- embedding_model: Optional[Embeddings] = None,
41
- faiss_query: Optional[FaissQuery] = None,
42
- ) -> None:
43
- """
44
- Initializes the Memory with optional configuration.
178
+ async def get_by_ids(self, memory_ids: Union[int, list[int]]) -> str:
179
+ """获取指定ID的记忆"""
180
+ memories = [memory for memory in self._memories if memory.id in memory_ids]
181
+ sorted_results = sorted(
182
+ memories,
183
+ key=lambda x: (x.day, x.t),
184
+ reverse=True
185
+ )
186
+ return self.format_memory(sorted_results)
45
187
 
188
+ async def search(
189
+ self,
190
+ query: str,
191
+ tag: Optional[MemoryTag] = None,
192
+ top_k: int = 3,
193
+ day_range: Optional[tuple[int, int]] = None, # 新增参数
194
+ time_range: Optional[tuple[int, int]] = None # 新增参数
195
+ ) -> str:
196
+ """Search stream memory
197
+
46
198
  Args:
47
- config (Optional[dict[Any, Any]], optional):
48
- A configuration dictionary for dynamic memory. The dictionary format is:
49
- - Key: The name of the dynamic memory field.
50
- - Value: Can be one of two formats:
51
- 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
52
- 2. A callable that returns the default value when invoked (useful for complex default values).
53
- Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
54
- Defaults to None.
55
- profile (Optional[dict[Any, Any]], optional): profile attribute dict.
56
- base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
57
- motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
58
- activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
59
- embedding_model (Embeddings): The embedding model for memory search.
60
- faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
199
+ query: Query text
200
+ tag: Optional memory tag for filtering specific types of memories
201
+ top_k: Number of most relevant memories to return
202
+ day_range: Optional tuple of start and end days (start_day, end_day)
203
+ time_range: Optional tuple of start and end times (start_time, end_time)
61
204
  """
62
- self.watchers: dict[str, list[Callable]] = {}
63
- self._lock = asyncio.Lock()
64
- self._agent_id: int = -1
65
- self._embedding_model = embedding_model
66
-
67
- _dynamic_config: dict[Any, Any] = {}
68
- _state_config: dict[Any, Any] = {}
69
- _profile_config: dict[Any, Any] = {}
70
- # 记录哪些字段需要embedding
71
- self._embedding_fields: dict[str, bool] = {}
72
- self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
73
- self._faiss_query = faiss_query
74
-
75
- if config is not None:
76
- for k, v in config.items():
77
- try:
78
- # 处理新的三元组格式
79
- if isinstance(v, tuple) and len(v) == 3:
80
- _type, _value, enable_embedding = v
81
- self._embedding_fields[k] = enable_embedding
82
- else:
83
- _type, _value = v
84
- self._embedding_fields[k] = False
85
-
86
- try:
87
- if isinstance(_type, type):
88
- _value = _type(_value)
89
- else:
90
- if isinstance(_type, deque):
91
- _type.extend(_value)
92
- _value = deepcopy(_type)
93
- else:
94
- logger.warning(f"type `{_type}` is not supported!")
95
- pass
96
- except TypeError as e:
97
- pass
98
- except TypeError as e:
99
- if isinstance(v, type):
100
- _value = v()
101
- else:
102
- _value = v
103
- self._embedding_fields[k] = False
104
-
105
- if (
106
- k in PROFILE_ATTRIBUTES
107
- or k in STATE_ATTRIBUTES
108
- or k == TIME_STAMP_KEY
109
- ):
110
- logger.warning(f"key `{k}` already declared in memory!")
111
- continue
112
-
113
- _dynamic_config[k] = deepcopy(_value)
114
-
115
- # 初始化各类记忆
116
- self._dynamic = DynamicMemory(
117
- required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
205
+ if not self._embedding_model or not self._faiss_query:
206
+ return "Search components not initialized"
207
+
208
+ filter_dict = {"type": "stream"}
209
+
210
+ if tag:
211
+ filter_dict["tag"] = tag
212
+
213
+ # 添加时间范围过滤
214
+ if day_range:
215
+ start_day, end_day = day_range
216
+ filter_dict["day"] = lambda x: start_day <= x <= end_day
217
+
218
+ if time_range:
219
+ start_time, end_time = time_range
220
+ filter_dict["time"] = lambda x: start_time <= x <= end_time
221
+
222
+ top_results = await self._faiss_query.similarity_search(
223
+ query=query,
224
+ agent_id=self._agent_id,
225
+ k=top_k,
226
+ return_score_type="similarity_score",
227
+ filter=filter_dict
118
228
  )
119
229
 
120
- if profile is not None:
121
- for k, v in profile.items():
122
- if k not in PROFILE_ATTRIBUTES:
123
- logger.warning(f"key `{k}` is not a correct `profile` field!")
124
- continue
125
- _profile_config[k] = v
126
- if motion is not None:
127
- for k, v in motion.items():
128
- if k not in STATE_ATTRIBUTES:
129
- logger.warning(f"key `{k}` is not a correct `motion` field!")
130
- continue
131
- _state_config[k] = v
132
- if base is not None:
133
- for k, v in base.items():
134
- if k not in STATE_ATTRIBUTES:
135
- logger.warning(f"key `{k}` is not a correct `base` field!")
136
- continue
137
- _state_config[k] = v
138
- self._state = StateMemory(
139
- msg=_state_config, activate_timestamp=activate_timestamp
230
+ # 将结果按时间排序(先按天数,再按时间)
231
+ sorted_results = sorted(
232
+ top_results,
233
+ key=lambda x: (x[2].get('day', 0), x[2].get('time', 0)),
234
+ reverse=True
140
235
  )
141
- self._profile = ProfileMemory(
142
- msg=_profile_config, activate_timestamp=activate_timestamp
143
- )
144
- # self.memories = [] # 存储记忆内容
145
- # self.embeddings = [] # 存储记忆的向量表示
146
-
147
- def set_embedding_model(
148
- self,
149
- embedding_model: Embeddings,
150
- ):
151
- self._embedding_model = embedding_model
152
-
153
- @property
154
- def embedding_model(
155
- self,
156
- ):
157
- if self._embedding_model is None:
158
- raise RuntimeError(
159
- f"embedding_model before assignment, please `set_embedding_model` first!"
236
+
237
+ formatted_results = []
238
+ for content, score, metadata in sorted_results:
239
+ memory_tag = metadata.get('tag', 'unknown')
240
+ memory_day = metadata.get('day', 'unknown')
241
+ memory_time_seconds = metadata.get('time', 'unknown')
242
+ cognition_id = metadata.get('cognition_id', None)
243
+
244
+ # 格式化时间
245
+ if memory_time_seconds != 'unknown':
246
+ hours = memory_time_seconds // 3600
247
+ minutes = (memory_time_seconds % 3600) // 60
248
+ seconds = memory_time_seconds % 60
249
+ memory_time = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
250
+ else:
251
+ memory_time = 'unknown'
252
+
253
+ memory_location = metadata.get('location', 'unknown')
254
+
255
+ # 添加认知信息(如果存在)
256
+ cognition_info = ""
257
+ if cognition_id is not None:
258
+ cognition_memory = await self.get_related_cognition(cognition_id)
259
+ if cognition_memory:
260
+ cognition_info = f"\n Related cognition: {cognition_memory.description}"
261
+
262
+ formatted_results.append(
263
+ f"- [{memory_tag}]: {content} [day: {memory_day}, time: {memory_time}, "
264
+ f"location: {memory_location}]{cognition_info}"
160
265
  )
161
- return self._embedding_model
266
+ return "\n".join(formatted_results)
162
267
 
163
- def set_faiss_query(self, faiss_query: FaissQuery):
268
+ async def search_today(
269
+ self,
270
+ query: str = "", # 可选的查询文本
271
+ tag: Optional[MemoryTag] = None,
272
+ top_k: int = 100, # 默认返回较大数量以确保获取当天所有记忆
273
+ ) -> str:
274
+ """Search all memory events from today
275
+
276
+ Args:
277
+ query: Optional query text, returns all memories of the day if empty
278
+ tag: Optional memory tag for filtering specific types of memories
279
+ top_k: Number of most relevant memories to return, defaults to 100
280
+
281
+ Returns:
282
+ str: Formatted text of today's memories
164
283
  """
165
- Set the FaissQuery of the agent.
284
+ if self._simulator is None:
285
+ return "Simulator not initialized"
286
+
287
+ current_day = int(await self._simulator.get_simulator_day())
288
+
289
+ # 使用 search 方法,设置 day_range 为当天
290
+ return await self.search(
291
+ query=query,
292
+ tag=tag,
293
+ top_k=top_k,
294
+ day_range=(current_day, current_day)
295
+ )
296
+
297
+ async def add_cognition_to_memory(self, memory_id: Union[int, list[int]], cognition: str) -> None:
298
+ """为已存在的记忆添加认知
299
+
300
+ Args:
301
+ memory_id: 要添加认知的记忆ID,可以是单个ID或ID列表
302
+ cognition: 认知描述
166
303
  """
167
- self._faiss_query = faiss_query
304
+ # 将单个ID转换为列表以统一处理
305
+ memory_ids = [memory_id] if isinstance(memory_id, int) else memory_id
306
+
307
+ # 找到所有对应的记忆
308
+ target_memories = []
309
+ for memory in self._memories:
310
+ if id(memory) in memory_ids:
311
+ target_memories.append(memory)
312
+
313
+ if not target_memories:
314
+ raise ValueError(f"No memories found with ids {memory_ids}")
315
+
316
+ # 添加认知记忆
317
+ cognition_id = await self._add_memory(MemoryTag.COGNITION, cognition)
318
+
319
+ # 更新所有原记忆的认知ID
320
+ for target_memory in target_memories:
321
+ target_memory.cognition_id = cognition_id
322
+
323
+ async def get_all(self) -> list[MemoryNode]:
324
+ """获取所有流式信息"""
325
+ return list(self._memories)
326
+
327
+ class StatusMemory:
328
+ """组合现有的三种记忆类型"""
329
+ def __init__(self, profile: ProfileMemory, state: StateMemory, dynamic: DynamicMemory):
330
+ self.profile = profile
331
+ self.state = state
332
+ self.dynamic = dynamic
333
+ self._faiss_query = None
334
+ self._embedding_model = None
335
+ self._simulator = None
336
+ self._agent_id = -1
337
+ self._semantic_templates = {} # 用户可配置的模板
338
+ self._embedding_fields = {} # 需要 embedding 的字段
339
+ self._embedding_field_to_doc_id = defaultdict(str) # 新增
340
+ self.watchers = {} # 新增
341
+ self._lock = asyncio.Lock() # 新增
342
+
343
+ def set_simulator(self, simulator):
344
+ self._simulator = simulator
345
+
346
+ async def initialize_embeddings(self) -> None:
347
+ """初始化所有需要 embedding 的字段"""
348
+ if not self._embedding_model or not self._faiss_query:
349
+ logger.warning("Search components not initialized, skipping embeddings initialization")
350
+ return
351
+
352
+ # 获取所有状态信息
353
+ profile, state, dynamic = await self.export()
354
+
355
+ # 为每个需要 embedding 的字段创建 embedding
356
+ for key, value in profile[0].items():
357
+ if self.should_embed(key):
358
+ semantic_text = self._generate_semantic_text(key, value)
359
+ doc_ids = await self._faiss_query.add_documents(
360
+ agent_id=self._agent_id,
361
+ documents=semantic_text,
362
+ extra_tags={
363
+ "type": "profile_state",
364
+ "key": key,
365
+ },
366
+ )
367
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
368
+
369
+ for key, value in state[0].items():
370
+ if self.should_embed(key):
371
+ semantic_text = self._generate_semantic_text(key, value)
372
+ doc_ids = await self._faiss_query.add_documents(
373
+ agent_id=self._agent_id,
374
+ documents=semantic_text,
375
+ extra_tags={
376
+ "type": "profile_state",
377
+ "key": key,
378
+ },
379
+ )
380
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
381
+
382
+ for key, value in dynamic[0].items():
383
+ if self.should_embed(key):
384
+ semantic_text = self._generate_semantic_text(key, value)
385
+ doc_ids = await self._faiss_query.add_documents(
386
+ agent_id=self._agent_id,
387
+ documents=semantic_text,
388
+ extra_tags={
389
+ "type": "profile_state",
390
+ "key": key,
391
+ },
392
+ )
393
+ self._embedding_field_to_doc_id[key] = doc_ids[0]
394
+
395
+ def _get_memory_type_by_key(self, key: str) -> str:
396
+ """根据键名确定记忆类型"""
397
+ try:
398
+ if key in self.profile.__dict__:
399
+ return "profile"
400
+ elif key in self.state.__dict__:
401
+ return "state"
402
+ else:
403
+ return "dynamic"
404
+ except:
405
+ return "dynamic"
168
406
 
169
- @property
170
- def agent_id(
171
- self,
172
- ):
173
- if self._agent_id < 0:
174
- raise RuntimeError(
175
- f"agent_id before assignment, please `set_agent_id` first!"
176
- )
177
- return self._agent_id
407
+ def set_search_components(self, faiss_query, embedding_model):
408
+ """设置搜索所需的组件"""
409
+ self._faiss_query = faiss_query
410
+ self._embedding_model = embedding_model
178
411
 
179
412
  def set_agent_id(self, agent_id: int):
413
+ """设置agent_id"""
414
+ self._agent_id = agent_id
415
+
416
+ def set_semantic_templates(self, templates: Dict[str, str]):
417
+ """设置语义模板
418
+
419
+ Args:
420
+ templates: 键值对形式的模板字典,如 {"name": "my name is {}", "age": "I am {} years old"}
180
421
  """
181
- Set the FaissQuery of the agent.
422
+ self._semantic_templates = templates
423
+
424
+ def _generate_semantic_text(self, key: str, value: Any) -> str:
425
+ """生成语义文本
426
+
427
+ 如果key存在于模板中,使用自定义模板
428
+ 否则使用默认模板 "my {key} is {value}"
182
429
  """
183
- self._agent_id = agent_id
430
+ if key in self._semantic_templates:
431
+ return self._semantic_templates[key].format(value)
432
+ return f"Your {key} is {value}"
433
+
434
+ @lock_decorator
435
+ async def search(
436
+ self, query: str, top_k: int = 3, filter: Optional[dict] = None
437
+ ) -> str:
438
+ """搜索相关记忆
184
439
 
185
- @property
186
- def faiss_query(self) -> FaissQuery:
187
- """FaissQuery"""
188
- if self._faiss_query is None:
189
- raise RuntimeError(
190
- f"FaissQuery access before assignment, please `set_faiss_query` first!"
440
+ Args:
441
+ query: 查询文本
442
+ top_k: 返回最相关的记忆数量
443
+ filter (dict, optional): 记忆的筛选条件,如 {"key":"self_define_1",},默认为空
444
+
445
+ Returns:
446
+ str: 格式化的相关记忆文本
447
+ """
448
+ if not self._embedding_model:
449
+ return "Embedding model not initialized"
450
+
451
+ filter_dict = {"type": "profile_state"}
452
+ if filter is not None:
453
+ filter_dict.update(filter)
454
+ top_results: list[tuple[str, float, dict]] = (
455
+ await self._faiss_query.similarity_search( # type:ignore
456
+ query=query,
457
+ agent_id=self._agent_id,
458
+ k=top_k,
459
+ return_score_type="similarity_score",
460
+ filter=filter_dict,
191
461
  )
192
- return self._faiss_query
462
+ )
463
+ # 格式化输出
464
+ formatted_results = []
465
+ for content, score, metadata in top_results:
466
+ formatted_results.append(
467
+ f"- {content} "
468
+ )
469
+
470
+ return "\n".join(formatted_results)
471
+
472
+ def set_embedding_fields(self, embedding_fields: Dict[str, bool]):
473
+ """设置需要 embedding 的字段"""
474
+ self._embedding_fields = embedding_fields
475
+
476
+ def should_embed(self, key: str) -> bool:
477
+ """判断字段是否需要 embedding"""
478
+ return self._embedding_fields.get(key, False)
193
479
 
194
480
  @lock_decorator
195
- async def get(
196
- self,
197
- key: Any,
198
- mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
199
- ) -> Any:
200
- """
201
- Retrieves a value from memory based on the given key and access mode.
481
+ async def get(self, key: Any,
482
+ mode: Union[Literal["read only"], Literal["read and write"]] = "read only") -> Any:
483
+ """从记忆中获取值
202
484
 
203
485
  Args:
204
- key (Any): The key of the item to retrieve.
205
- mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
486
+ key: 要获取的键
487
+ mode: 访问模式,"read only" "read and write"
206
488
 
207
489
  Returns:
208
- Any: The value associated with the key.
490
+ 获取到的值
209
491
 
210
492
  Raises:
211
- ValueError: If an invalid mode is provided.
212
- KeyError: If the key is not found in any of the memory sections.
493
+ ValueError: 如果提供了无效的模式
494
+ KeyError: 如果在任何记忆部分都找不到该键
213
495
  """
214
496
  if mode == "read only":
215
497
  process_func = deepcopy
@@ -217,54 +499,56 @@ class Memory:
217
499
  process_func = lambda x: x
218
500
  else:
219
501
  raise ValueError(f"Invalid get mode `{mode}`!")
220
- for _mem in [self._state, self._profile, self._dynamic]:
502
+
503
+ for mem in [self.state, self.profile, self.dynamic]:
221
504
  try:
222
- value = await _mem.get(key)
505
+ value = await mem.get(key)
223
506
  return process_func(value)
224
- except KeyError as e:
507
+ except KeyError:
225
508
  continue
226
509
  raise KeyError(f"No attribute `{key}` in memories!")
227
510
 
228
511
  @lock_decorator
229
- async def update(
230
- self,
231
- key: Any,
232
- value: Any,
233
- mode: Union[Literal["replace"], Literal["merge"]] = "replace",
234
- store_snapshot: bool = False,
235
- protect_llm_read_only_fields: bool = True,
236
- ) -> None:
512
+ async def update(self, key: Any, value: Any,
513
+ mode: Union[Literal["replace"], Literal["merge"]] = "replace",
514
+ store_snapshot: bool = False,
515
+ protect_llm_read_only_fields: bool = True) -> None:
237
516
  """更新记忆值并在必要时更新embedding"""
238
517
  if protect_llm_read_only_fields:
239
518
  if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
240
519
  logger.warning(f"Trying to write protected key `{key}`!")
241
520
  return
242
- for _mem in [self._state, self._profile, self._dynamic]:
521
+
522
+ for mem in [self.state, self.profile, self.dynamic]:
243
523
  try:
244
- original_value = await _mem.get(key)
524
+ original_value = await mem.get(key)
245
525
  if mode == "replace":
246
- await _mem.update(key, value, store_snapshot)
247
- # 如果字段需要embedding,则更新embedding
248
- if self._embedding_fields.get(key, False) and self.embedding_model:
249
- memory_type = self._get_memory_type(_mem)
250
- # 覆盖更新删除原vector
526
+ await mem.update(key, value, store_snapshot)
527
+ if self.should_embed(key) and self._embedding_model:
528
+ semantic_text = self._generate_semantic_text(key, value)
529
+
530
+ # 删除旧的 embedding
251
531
  orig_doc_id = self._embedding_field_to_doc_id[key]
252
532
  if orig_doc_id:
253
- await self.faiss_query.delete_documents(
533
+ await self._faiss_query.delete_documents(
254
534
  to_delete_ids=[orig_doc_id],
255
535
  )
256
- doc_ids: list[str] = await self.faiss_query.add_documents(
257
- agent_id=self.agent_id,
258
- documents=f"{key}: {str(value)}",
536
+
537
+ # 添加新的 embedding
538
+ doc_ids = await self._faiss_query.add_documents(
539
+ agent_id=self._agent_id,
540
+ documents=semantic_text,
259
541
  extra_tags={
260
- "type": memory_type,
542
+ "type": self._get_memory_type(mem),
261
543
  "key": key,
262
544
  },
263
545
  )
264
546
  self._embedding_field_to_doc_id[key] = doc_ids[0]
547
+
265
548
  if key in self.watchers:
266
549
  for callback in self.watchers[key]:
267
550
  asyncio.create_task(callback())
551
+
268
552
  elif mode == "merge":
269
553
  if isinstance(original_value, set):
270
554
  original_value.update(set(value))
@@ -278,14 +562,14 @@ class Memory:
278
562
  logger.debug(
279
563
  f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
280
564
  )
281
- await _mem.update(key, value, store_snapshot)
282
- if self._embedding_fields.get(key, False) and self.embedding_model:
283
- memory_type = self._get_memory_type(_mem)
284
- doc_ids = await self.faiss_query.add_documents(
285
- agent_id=self.agent_id,
565
+ await mem.update(key, value, store_snapshot)
566
+ if self.should_embed(key) and self._embedding_model:
567
+ semantic_text = self._generate_semantic_text(key, value)
568
+ doc_ids = await self._faiss_query.add_documents(
569
+ agent_id=self._agent_id,
286
570
  documents=f"{key}: {str(original_value)}",
287
571
  extra_tags={
288
- "type": memory_type,
572
+ "type": self._get_memory_type(mem),
289
573
  "key": key,
290
574
  },
291
575
  )
@@ -302,63 +586,20 @@ class Memory:
302
586
 
303
587
  def _get_memory_type(self, mem: Any) -> str:
304
588
  """获取记忆类型"""
305
- if mem is self._state:
589
+ if mem is self.state:
306
590
  return "state"
307
- elif mem is self._profile:
591
+ elif mem is self.profile:
308
592
  return "profile"
309
593
  else:
310
594
  return "dynamic"
311
595
 
312
- async def update_batch(
313
- self,
314
- content: Union[dict, Sequence[tuple[Any, Any]]],
315
- mode: Union[Literal["replace"], Literal["merge"]] = "replace",
316
- store_snapshot: bool = False,
317
- protect_llm_read_only_fields: bool = True,
318
- ) -> None:
319
- """
320
- Updates multiple values in the memory at once.
321
-
322
- Args:
323
- content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
324
- mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
325
- store_snapshot (bool): Whether to store a snapshot of the memory after the update.
326
- protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
327
-
328
- Raises:
329
- TypeError: If the content type is neither a dictionary nor a sequence of tuples.
330
- """
331
- if isinstance(content, dict):
332
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
333
- elif isinstance(content, Sequence):
334
- _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
335
- else:
336
- raise TypeError(f"Invalid content type `{type(content)}`!")
337
- for k, v in _list_content[:1]:
338
- await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
339
- for k, v in _list_content[1:]:
340
- await self.update(k, v, mode, False, protect_llm_read_only_fields)
341
-
342
596
  @lock_decorator
343
597
  async def add_watcher(self, key: str, callback: Callable) -> None:
344
- """
345
- Adds a callback function to be invoked when the value
346
- associated with the specified key in memory is updated.
347
-
348
- Args:
349
- key (str): The key for which the watcher is being registered.
350
- callback (Callable): A callable function that will be executed
351
- whenever the value associated with the specified key is updated.
352
-
353
- Notes:
354
- If the key does not already have any watchers, it will be
355
- initialized with an empty list before appending the callback.
356
- """
598
+ """添加值变更的监听器"""
357
599
  if key not in self.watchers:
358
600
  self.watchers[key] = []
359
601
  self.watchers[key].append(callback)
360
602
 
361
- @lock_decorator
362
603
  async def export(
363
604
  self,
364
605
  ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
@@ -369,12 +610,11 @@ class Memory:
369
610
  tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
370
611
  """
371
612
  return (
372
- await self._profile.export(),
373
- await self._state.export(),
374
- await self._dynamic.export(),
613
+ await self.profile.export(),
614
+ await self.state.export(),
615
+ await self.dynamic.export(),
375
616
  )
376
617
 
377
- @lock_decorator
378
618
  async def load(
379
619
  self,
380
620
  snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
@@ -390,41 +630,249 @@ class Memory:
390
630
  _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
391
631
  for _snapshot, _mem in zip(
392
632
  [_profile_snapshot, _state_snapshot, _dynamic_snapshot],
393
- [self._state, self._profile, self._dynamic],
633
+ [self.state, self.profile, self.dynamic],
394
634
  ):
395
635
  if _snapshot:
396
636
  await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
397
637
 
398
- @lock_decorator
399
- async def search(
400
- self, query: str, top_k: int = 3, filter: Optional[dict] = None
401
- ) -> str:
402
- """搜索相关记忆
638
+ class Memory:
639
+ """
640
+ A class to manage different types of memory (state, profile, dynamic).
641
+
642
+ Attributes:
643
+ _state (StateMemory): Stores state-related data.
644
+ _profile (ProfileMemory): Stores profile-related data.
645
+ _dynamic (DynamicMemory): Stores dynamically configured data.
646
+ """
647
+
648
+ def __init__(
649
+ self,
650
+ config: Optional[dict[Any, Any]] = None,
651
+ profile: Optional[dict[Any, Any]] = None,
652
+ base: Optional[dict[Any, Any]] = None,
653
+ activate_timestamp: bool = False,
654
+ embedding_model: Optional[Embeddings] = None,
655
+ faiss_query: Optional[FaissQuery] = None,
656
+ ) -> None:
657
+ """
658
+ Initializes the Memory with optional configuration.
403
659
 
404
660
  Args:
405
- query: 查询文本
406
- top_k: 返回最相关的记忆数量
407
- filter (dict, optional): 记忆的筛选条件,如 {"type":"dynamic", "key":"self_define_1",},默认为空
661
+ config (Optional[dict[Any, Any]], optional):
662
+ A configuration dictionary for dynamic memory. The dictionary format is:
663
+ - Key: The name of the dynamic memory field.
664
+ - Value: Can be one of two formats:
665
+ 1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
666
+ 2. A callable that returns the default value when invoked (useful for complex default values).
667
+ Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
668
+ Defaults to None.
669
+ profile (Optional[dict[Any, Any]], optional): profile attribute dict.
670
+ base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
671
+ motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
672
+ activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
673
+ embedding_model (Embeddings): The embedding model for memory search.
674
+ faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
675
+ """
676
+ self.watchers: dict[str, list[Callable]] = {}
677
+ self._lock = asyncio.Lock()
678
+ self._agent_id: int = -1
679
+ self._simulator = None
680
+ self._embedding_model = embedding_model
681
+ self._faiss_query = faiss_query
682
+ self._semantic_templates: dict[str, str] = {}
683
+ _dynamic_config: dict[Any, Any] = {}
684
+ _state_config: dict[Any, Any] = {}
685
+ _profile_config: dict[Any, Any] = {}
686
+ self._embedding_fields: dict[str, bool] = {}
687
+ self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
408
688
 
409
- Returns:
410
- str: 格式化的相关记忆文本
689
+ if config is not None:
690
+ for k, v in config.items():
691
+ try:
692
+ # 处理不同长度的配置元组
693
+ if isinstance(v, tuple):
694
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
695
+ _type, _value, enable_embedding, template = v
696
+ self._embedding_fields[k] = enable_embedding
697
+ self._semantic_templates[k] = template
698
+ elif len(v) == 3: # (_type, _value, enable_embedding)
699
+ _type, _value, enable_embedding = v
700
+ self._embedding_fields[k] = enable_embedding
701
+ else: # (_type, _value)
702
+ _type, _value = v
703
+ self._embedding_fields[k] = False
704
+ else:
705
+ _type = type(v)
706
+ _value = v
707
+ self._embedding_fields[k] = False
708
+
709
+ # 处理类型转换
710
+ try:
711
+ if isinstance(_type, type):
712
+ _value = _type(_value)
713
+ else:
714
+ if isinstance(_type, deque):
715
+ _type.extend(_value)
716
+ _value = deepcopy(_type)
717
+ else:
718
+ logger.warning(f"type `{_type}` is not supported!")
719
+ except TypeError as e:
720
+ logger.warning(f"Type conversion failed for key {k}: {e}")
721
+ except TypeError as e:
722
+ if isinstance(v, type):
723
+ _value = v()
724
+ else:
725
+ _value = v
726
+ self._embedding_fields[k] = False
727
+
728
+ if (
729
+ k in PROFILE_ATTRIBUTES
730
+ or k in STATE_ATTRIBUTES
731
+ or k == TIME_STAMP_KEY
732
+ ):
733
+ logger.warning(f"key `{k}` already declared in memory!")
734
+ continue
735
+
736
+ _dynamic_config[k] = deepcopy(_value)
737
+
738
+ # 初始化各类记忆
739
+ self._dynamic = DynamicMemory(
740
+ required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
741
+ )
742
+
743
+ if profile is not None:
744
+ for k, v in profile.items():
745
+ if k not in PROFILE_ATTRIBUTES:
746
+ logger.warning(f"key `{k}` is not a correct `profile` field!")
747
+ continue
748
+
749
+ try:
750
+ # 处理配置元组格式
751
+ if isinstance(v, tuple):
752
+ if len(v) == 4: # (_type, _value, enable_embedding, template)
753
+ _type, _value, enable_embedding, template = v
754
+ self._embedding_fields[k] = enable_embedding
755
+ self._semantic_templates[k] = template
756
+ elif len(v) == 3: # (_type, _value, enable_embedding)
757
+ _type, _value, enable_embedding = v
758
+ self._embedding_fields[k] = enable_embedding
759
+ else: # (_type, _value)
760
+ _type, _value = v
761
+ self._embedding_fields[k] = False
762
+
763
+ # 处理类型转换
764
+ try:
765
+ if isinstance(_type, type):
766
+ _value = _type(_value)
767
+ else:
768
+ if isinstance(_type, deque):
769
+ _type.extend(_value)
770
+ _value = deepcopy(_type)
771
+ else:
772
+ logger.warning(f"type `{_type}` is not supported!")
773
+ except TypeError as e:
774
+ logger.warning(f"Type conversion failed for key {k}: {e}")
775
+ else:
776
+ # 保持对简单键值对的兼容
777
+ _value = v
778
+ self._embedding_fields[k] = False
779
+ except TypeError as e:
780
+ if isinstance(v, type):
781
+ _value = v()
782
+ else:
783
+ _value = v
784
+ self._embedding_fields[k] = False
785
+
786
+ _profile_config[k] = deepcopy(_value)
787
+ self._profile = ProfileMemory(
788
+ msg=_profile_config, activate_timestamp=activate_timestamp
789
+ )
790
+
791
+ if base is not None:
792
+ for k, v in base.items():
793
+ if k not in STATE_ATTRIBUTES:
794
+ logger.warning(f"key `{k}` is not a correct `base` field!")
795
+ continue
796
+ _state_config[k] = v
797
+
798
+ self._state = StateMemory(
799
+ msg=_state_config, activate_timestamp=activate_timestamp
800
+ )
801
+
802
+ # 组合 StatusMemory,并传递 embedding_fields 信息
803
+ self._status = StatusMemory(
804
+ profile=self._profile,
805
+ state=self._state,
806
+ dynamic=self._dynamic
807
+ )
808
+ self._status.set_embedding_fields(self._embedding_fields)
809
+ self._status.set_search_components(self._faiss_query, self._embedding_model)
810
+
811
+ # 新增 StreamMemory
812
+ self._stream = StreamMemory()
813
+ self._stream.set_status_memory(self._status)
814
+ self._stream.set_search_components(self._faiss_query, self._embedding_model)
815
+
816
+ def set_search_components(
817
+ self,
818
+ faiss_query: FaissQuery,
819
+ embedding_model: Embeddings,
820
+ ):
821
+ self._embedding_model = embedding_model
822
+ self._faiss_query = faiss_query
823
+ self._stream.set_search_components(faiss_query, embedding_model)
824
+ self._status.set_search_components(faiss_query, embedding_model)
825
+
826
+ def set_agent_id(self, agent_id: int):
411
827
  """
412
- if not self._embedding_model:
413
- return "Embedding model not initialized"
414
- top_results: list[tuple[str, float, dict]] = (
415
- await self.faiss_query.similarity_search( # type:ignore
416
- query=query,
417
- agent_id=self.agent_id,
418
- k=top_k,
419
- return_score_type="similarity_score",
420
- filter=filter,
828
+ Set the FaissQuery of the agent.
829
+ """
830
+ self._agent_id = agent_id
831
+ self._stream.set_agent_id(agent_id)
832
+ self._status.set_agent_id(agent_id)
833
+
834
+ def set_simulator(self, simulator):
835
+ self._simulator = simulator
836
+ self._stream.set_simulator(simulator)
837
+ self._status.set_simulator(simulator)
838
+
839
+ @property
840
+ def status(self) -> StatusMemory:
841
+ return self._status
842
+
843
+ @property
844
+ def stream(self) -> StreamMemory:
845
+ return self._stream
846
+
847
+ @property
848
+ def embedding_model(
849
+ self,
850
+ ):
851
+ if self._embedding_model is None:
852
+ raise RuntimeError(
853
+ f"embedding_model before assignment, please `set_embedding_model` first!"
421
854
  )
422
- )
423
- # 格式化输出
424
- formatted_results = []
425
- for content, score, metadata in top_results:
426
- formatted_results.append(
427
- f"- [{metadata['type']}] {content} " f"(相关度: {score:.2f})"
855
+ return self._embedding_model
856
+
857
+ @property
858
+ def agent_id(
859
+ self,
860
+ ):
861
+ if self._agent_id < 0:
862
+ raise RuntimeError(
863
+ f"agent_id before assignment, please `set_agent_id` first!"
428
864
  )
865
+ return self._agent_id
429
866
 
430
- return "\n".join(formatted_results)
867
+ @property
868
+ def faiss_query(self) -> FaissQuery:
869
+ """FaissQuery"""
870
+ if self._faiss_query is None:
871
+ raise RuntimeError(
872
+ f"FaissQuery access before assignment, please `set_faiss_query` first!"
873
+ )
874
+ return self._faiss_query
875
+
876
+ async def initialize_embeddings(self):
877
+ """初始化embedding"""
878
+ await self._status.initialize_embeddings()