pycityagent 1.0.0__py3-none-any.whl → 2.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +7 -3
- pycityagent/agent.py +180 -284
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +307 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/interact.py +141 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/{brain → environment/sence}/static.py +1 -1
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +27 -0
- pycityagent/environment/sim/aoi_service.py +38 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +43 -0
- pycityagent/environment/sim/economy_services.py +191 -0
- pycityagent/environment/sim/lane_service.py +110 -0
- pycityagent/environment/sim/light_service.py +120 -0
- pycityagent/environment/sim/person_service.py +294 -0
- pycityagent/environment/sim/road_service.py +38 -0
- pycityagent/environment/sim/social_service.py +58 -0
- pycityagent/environment/simulator.py +369 -0
- pycityagent/environment/utils/__init__.py +8 -0
- pycityagent/environment/utils/geojson.py +26 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/protobuf.py +39 -0
- pycityagent/llm/__init__.py +6 -0
- pycityagent/llm/embedding.py +136 -0
- pycityagent/llm/llm.py +430 -0
- pycityagent/llm/llmconfig.py +15 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +11 -0
- pycityagent/memory/const.py +41 -0
- pycityagent/memory/memory.py +453 -0
- pycityagent/memory/memory_base.py +168 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +27 -0
- pycityagent/message/__init__.py +0 -0
- pycityagent/simulation/__init__.py +7 -0
- pycityagent/simulation/interview.py +36 -0
- pycityagent/simulation/simulation.py +286 -0
- pycityagent/simulation/survey/__init__.py +9 -0
- pycityagent/simulation/survey/manager.py +67 -0
- pycityagent/simulation/survey/models.py +49 -0
- pycityagent/simulation/ui/__init__.py +3 -0
- pycityagent/simulation/ui/interface.py +602 -0
- pycityagent/utils/__init__.py +0 -0
- pycityagent/utils/decorators.py +89 -0
- pycityagent/utils/parsers/__init__.py +12 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/workflow/__init__.py +22 -0
- pycityagent/workflow/block.py +137 -0
- pycityagent/workflow/prompt.py +72 -0
- pycityagent/workflow/tool.py +246 -0
- pycityagent/workflow/trigger.py +66 -0
- pycityagent-2.0.0a1.dist-info/METADATA +208 -0
- pycityagent-2.0.0a1.dist-info/RECORD +65 -0
- {pycityagent-1.0.0.dist-info → pycityagent-2.0.0a1.dist-info}/WHEEL +1 -2
- pycityagent/ac/__init__.py +0 -6
- pycityagent/ac/ac.py +0 -50
- pycityagent/ac/action.py +0 -14
- pycityagent/ac/controled.py +0 -13
- pycityagent/ac/converse.py +0 -31
- pycityagent/ac/idle.py +0 -17
- pycityagent/ac/shop.py +0 -80
- pycityagent/ac/trip.py +0 -37
- pycityagent/brain/__init__.py +0 -10
- pycityagent/brain/brain.py +0 -52
- pycityagent/brain/brainfc.py +0 -10
- pycityagent/brain/memory.py +0 -541
- pycityagent/brain/persistence/social.py +0 -1
- pycityagent/brain/persistence/spatial.py +0 -14
- pycityagent/brain/reason/shop.py +0 -37
- pycityagent/brain/reason/social.py +0 -148
- pycityagent/brain/reason/trip.py +0 -67
- pycityagent/brain/reason/user.py +0 -122
- pycityagent/brain/retrive/social.py +0 -6
- pycityagent/brain/scheduler.py +0 -408
- pycityagent/brain/sence.py +0 -375
- pycityagent/cc/__init__.py +0 -5
- pycityagent/cc/cc.py +0 -102
- pycityagent/cc/conve.py +0 -6
- pycityagent/cc/idle.py +0 -20
- pycityagent/cc/shop.py +0 -6
- pycityagent/cc/trip.py +0 -13
- pycityagent/cc/user.py +0 -13
- pycityagent/hubconnector/__init__.py +0 -3
- pycityagent/hubconnector/hubconnector.py +0 -137
- pycityagent/image/__init__.py +0 -3
- pycityagent/image/image.py +0 -158
- pycityagent/simulator.py +0 -161
- pycityagent/st/__init__.py +0 -4
- pycityagent/st/st.py +0 -96
- pycityagent/urbanllm/__init__.py +0 -3
- pycityagent/urbanllm/urbanllm.py +0 -132
- pycityagent-1.0.0.dist-info/LICENSE +0 -21
- pycityagent-1.0.0.dist-info/METADATA +0 -181
- pycityagent-1.0.0.dist-info/RECORD +0 -48
- pycityagent-1.0.0.dist-info/top_level.txt +0 -1
- /pycityagent/{brain/persistence/__init__.py → config.py} +0 -0
- /pycityagent/{brain/reason → environment/interact}/__init__.py +0 -0
- /pycityagent/{brain/retrive → environment/message}/__init__.py +0 -0
@@ -0,0 +1,453 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
from copy import deepcopy
|
4
|
+
from datetime import datetime
|
5
|
+
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
|
6
|
+
Tuple, Union)
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from pyparsing import deque
|
10
|
+
|
11
|
+
from ..utils.decorators import lock_decorator
|
12
|
+
from .const import *
|
13
|
+
from .profile import ProfileMemory
|
14
|
+
from .self_define import DynamicMemory
|
15
|
+
from .state import StateMemory
|
16
|
+
|
17
|
+
|
18
|
+
class Memory:
|
19
|
+
"""
|
20
|
+
A class to manage different types of memory (state, profile, dynamic).
|
21
|
+
|
22
|
+
Attributes:
|
23
|
+
_state (StateMemory): Stores state-related data.
|
24
|
+
_profile (ProfileMemory): Stores profile-related data.
|
25
|
+
_dynamic (DynamicMemory): Stores dynamically configured data.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
config: Optional[Dict[Any, Any]] = None,
|
31
|
+
profile: Optional[Dict[Any, Any]] = None,
|
32
|
+
base: Optional[Dict[Any, Any]] = None,
|
33
|
+
motion: Optional[Dict[Any, Any]] = None,
|
34
|
+
activate_timestamp: bool = False,
|
35
|
+
embedding_model: Any = None,
|
36
|
+
) -> None:
|
37
|
+
"""
|
38
|
+
Initializes the Memory with optional configuration.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
config (Optional[Dict[Any, Any]], optional):
|
42
|
+
A configuration dictionary for dynamic memory. The dictionary format is:
|
43
|
+
- Key: The name of the dynamic memory field.
|
44
|
+
- Value: Can be one of two formats:
|
45
|
+
1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
|
46
|
+
2. A callable that returns the default value when invoked (useful for complex default values).
|
47
|
+
Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
|
48
|
+
Defaults to None.
|
49
|
+
profile (Optional[Dict[Any, Any]], optional): profile attribute dict.
|
50
|
+
base (Optional[Dict[Any, Any]], optional): base attribute dict from City Simulator.
|
51
|
+
motion (Optional[Dict[Any, Any]], optional): motion attribute dict from City Simulator.
|
52
|
+
activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
|
53
|
+
embedding_model (Any): The embedding model for memory search.
|
54
|
+
"""
|
55
|
+
self.watchers: Dict[str, List[Callable]] = {}
|
56
|
+
self._lock = asyncio.Lock()
|
57
|
+
self.embedding_model = embedding_model
|
58
|
+
|
59
|
+
# 初始化embedding存储
|
60
|
+
self._embeddings = {"state": {}, "profile": {}, "dynamic": {}}
|
61
|
+
|
62
|
+
_dynamic_config: Dict[Any, Any] = {}
|
63
|
+
_state_config: Dict[Any, Any] = {}
|
64
|
+
_profile_config: Dict[Any, Any] = {}
|
65
|
+
# 记录哪些字段需要embedding
|
66
|
+
self._embedding_fields: Dict[str, bool] = {}
|
67
|
+
|
68
|
+
if config is not None:
|
69
|
+
for k, v in config.items():
|
70
|
+
try:
|
71
|
+
# 处理新的三元组格式
|
72
|
+
if isinstance(v, tuple) and len(v) == 3:
|
73
|
+
_type, _value, enable_embedding = v
|
74
|
+
self._embedding_fields[k] = enable_embedding
|
75
|
+
else:
|
76
|
+
_type, _value = v
|
77
|
+
self._embedding_fields[k] = False
|
78
|
+
|
79
|
+
try:
|
80
|
+
if isinstance(_type, type):
|
81
|
+
_value = _type(_value)
|
82
|
+
else:
|
83
|
+
if isinstance(_type, deque):
|
84
|
+
_type.extend(_value)
|
85
|
+
_value = deepcopy(_type)
|
86
|
+
else:
|
87
|
+
logging.warning(f"type `{_type}` is not supported!")
|
88
|
+
pass
|
89
|
+
except TypeError as e:
|
90
|
+
pass
|
91
|
+
except TypeError as e:
|
92
|
+
if isinstance(v, type):
|
93
|
+
_value = v()
|
94
|
+
else:
|
95
|
+
_value = v
|
96
|
+
self._embedding_fields[k] = False
|
97
|
+
|
98
|
+
if (
|
99
|
+
k in PROFILE_ATTRIBUTES
|
100
|
+
or k in STATE_ATTRIBUTES
|
101
|
+
or k == TIME_STAMP_KEY
|
102
|
+
):
|
103
|
+
logging.warning(f"key `{k}` already declared in memory!")
|
104
|
+
continue
|
105
|
+
|
106
|
+
_dynamic_config[k] = deepcopy(_value)
|
107
|
+
|
108
|
+
# 初始化各类记忆
|
109
|
+
self._dynamic = DynamicMemory(
|
110
|
+
required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
|
111
|
+
)
|
112
|
+
|
113
|
+
if profile is not None:
|
114
|
+
for k, v in profile.items():
|
115
|
+
if k not in PROFILE_ATTRIBUTES:
|
116
|
+
logging.warning(f"key `{k}` is not a correct `profile` field!")
|
117
|
+
continue
|
118
|
+
_profile_config[k] = v
|
119
|
+
if motion is not None:
|
120
|
+
for k, v in motion.items():
|
121
|
+
if k not in STATE_ATTRIBUTES:
|
122
|
+
logging.warning(f"key `{k}` is not a correct `motion` field!")
|
123
|
+
continue
|
124
|
+
_state_config[k] = v
|
125
|
+
if base is not None:
|
126
|
+
for k, v in base.items():
|
127
|
+
if k not in STATE_ATTRIBUTES:
|
128
|
+
logging.warning(f"key `{k}` is not a correct `base` field!")
|
129
|
+
continue
|
130
|
+
_state_config[k] = v
|
131
|
+
self._state = StateMemory(
|
132
|
+
msg=_state_config, activate_timestamp=activate_timestamp
|
133
|
+
)
|
134
|
+
self._profile = ProfileMemory(
|
135
|
+
msg=_profile_config, activate_timestamp=activate_timestamp
|
136
|
+
)
|
137
|
+
self.memories = [] # 存储记忆内容
|
138
|
+
self.embeddings = [] # 存储记忆的向量表示
|
139
|
+
|
140
|
+
@lock_decorator
|
141
|
+
async def get(
|
142
|
+
self,
|
143
|
+
key: Any,
|
144
|
+
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
145
|
+
) -> Any:
|
146
|
+
"""
|
147
|
+
Retrieves a value from memory based on the given key and access mode.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
key (Any): The key of the item to retrieve.
|
151
|
+
mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
Any: The value associated with the key.
|
155
|
+
|
156
|
+
Raises:
|
157
|
+
ValueError: If an invalid mode is provided.
|
158
|
+
KeyError: If the key is not found in any of the memory sections.
|
159
|
+
"""
|
160
|
+
if mode == "read only":
|
161
|
+
process_func = deepcopy
|
162
|
+
elif mode == "read and write":
|
163
|
+
process_func = lambda x: x
|
164
|
+
else:
|
165
|
+
raise ValueError(f"Invalid get mode `{mode}`!")
|
166
|
+
for _mem in [self._state, self._profile, self._dynamic]:
|
167
|
+
try:
|
168
|
+
value = await _mem.get(key)
|
169
|
+
return process_func(value)
|
170
|
+
except KeyError as e:
|
171
|
+
continue
|
172
|
+
raise KeyError(f"No attribute `{key}` in memories!")
|
173
|
+
|
174
|
+
@lock_decorator
|
175
|
+
async def update(
|
176
|
+
self,
|
177
|
+
key: Any,
|
178
|
+
value: Any,
|
179
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
180
|
+
store_snapshot: bool = False,
|
181
|
+
protect_llm_read_only_fields: bool = True,
|
182
|
+
) -> None:
|
183
|
+
"""更新记忆值并在必要时更新embedding"""
|
184
|
+
if protect_llm_read_only_fields:
|
185
|
+
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
186
|
+
logging.warning(f"Trying to write protected key `{key}`!")
|
187
|
+
return
|
188
|
+
for _mem in [self._state, self._profile, self._dynamic]:
|
189
|
+
try:
|
190
|
+
original_value = await _mem.get(key)
|
191
|
+
if mode == "replace":
|
192
|
+
await _mem.update(key, value, store_snapshot)
|
193
|
+
# 如果字段需要embedding,则更新embedding
|
194
|
+
if self.embedding_model and self._embedding_fields.get(key, False):
|
195
|
+
memory_type = self._get_memory_type(_mem)
|
196
|
+
self._embeddings[memory_type][key] = (
|
197
|
+
await self._generate_embedding(f"{key}: {str(value)}")
|
198
|
+
)
|
199
|
+
if key in self.watchers:
|
200
|
+
for callback in self.watchers[key]:
|
201
|
+
asyncio.create_task(callback())
|
202
|
+
elif mode == "merge":
|
203
|
+
if isinstance(original_value, set):
|
204
|
+
original_value.update(set(value))
|
205
|
+
elif isinstance(original_value, dict):
|
206
|
+
original_value.update(dict(value))
|
207
|
+
elif isinstance(original_value, list):
|
208
|
+
original_value.extend(list(value))
|
209
|
+
elif isinstance(original_value, deque):
|
210
|
+
original_value.extend(deque(value))
|
211
|
+
else:
|
212
|
+
logging.debug(
|
213
|
+
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
214
|
+
)
|
215
|
+
await _mem.update(key, value, store_snapshot)
|
216
|
+
if self.embedding_model and self._embedding_fields.get(key, False):
|
217
|
+
memory_type = self._get_memory_type(_mem)
|
218
|
+
self._embeddings[memory_type][key] = (
|
219
|
+
await self._generate_embedding(
|
220
|
+
f"{key}: {str(original_value)}"
|
221
|
+
)
|
222
|
+
)
|
223
|
+
if key in self.watchers:
|
224
|
+
for callback in self.watchers[key]:
|
225
|
+
asyncio.create_task(callback())
|
226
|
+
else:
|
227
|
+
raise ValueError(f"Invalid update mode `{mode}`!")
|
228
|
+
return
|
229
|
+
except KeyError:
|
230
|
+
continue
|
231
|
+
raise KeyError(f"No attribute `{key}` in memories!")
|
232
|
+
|
233
|
+
def _get_memory_type(self, mem: Any) -> str:
|
234
|
+
"""获取记忆类型"""
|
235
|
+
if mem is self._state:
|
236
|
+
return "state"
|
237
|
+
elif mem is self._profile:
|
238
|
+
return "profile"
|
239
|
+
else:
|
240
|
+
return "dynamic"
|
241
|
+
|
242
|
+
async def _generate_embedding(self, text: str) -> np.ndarray:
|
243
|
+
"""生成文本的向量表示
|
244
|
+
|
245
|
+
Args:
|
246
|
+
text: 输入文本
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
np.ndarray: 文本的向量表示
|
250
|
+
|
251
|
+
Raises:
|
252
|
+
ValueError: 如果embedding_model未初始化
|
253
|
+
"""
|
254
|
+
if not self.embedding_model:
|
255
|
+
raise RuntimeError("Embedding model not initialized")
|
256
|
+
|
257
|
+
return await self.embedding_model.embed(text)
|
258
|
+
|
259
|
+
async def search(self, query: str, top_k: int = 3) -> str:
|
260
|
+
"""搜索相关记忆
|
261
|
+
|
262
|
+
Args:
|
263
|
+
query: 查询文本
|
264
|
+
top_k: 返回最相关的记忆数量
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
str: 格式化的相关记忆文本
|
268
|
+
"""
|
269
|
+
if not self.embedding_model:
|
270
|
+
return "Embedding model not initialized"
|
271
|
+
|
272
|
+
query_embedding = await self._generate_embedding(query)
|
273
|
+
all_results = []
|
274
|
+
|
275
|
+
# 搜索所有记忆类型中启用了embedding的字段
|
276
|
+
for memory_type, embeddings in self._embeddings.items():
|
277
|
+
for key, embedding in embeddings.items():
|
278
|
+
similarity = self._cosine_similarity(query_embedding, embedding)
|
279
|
+
value = await self.get(key)
|
280
|
+
|
281
|
+
all_results.append(
|
282
|
+
{
|
283
|
+
"type": memory_type,
|
284
|
+
"key": key,
|
285
|
+
"content": f"{key}: {str(value)}",
|
286
|
+
"similarity": similarity,
|
287
|
+
}
|
288
|
+
)
|
289
|
+
|
290
|
+
# 按相似度排序
|
291
|
+
all_results.sort(key=lambda x: x["similarity"], reverse=True)
|
292
|
+
top_results = all_results[:top_k]
|
293
|
+
|
294
|
+
# 格式化输出
|
295
|
+
formatted_results = []
|
296
|
+
for result in top_results:
|
297
|
+
formatted_results.append(
|
298
|
+
f"- [{result['type']}] {result['content']} "
|
299
|
+
f"(相关度: {result['similarity']:.2f})"
|
300
|
+
)
|
301
|
+
|
302
|
+
return "\n".join(formatted_results)
|
303
|
+
|
304
|
+
async def update_batch(
|
305
|
+
self,
|
306
|
+
content: Union[Dict, Sequence[Tuple[Any, Any]]],
|
307
|
+
mode: Union[Literal["replace"], Literal["merge"]] = "replace",
|
308
|
+
store_snapshot: bool = False,
|
309
|
+
protect_llm_read_only_fields: bool = True,
|
310
|
+
) -> None:
|
311
|
+
"""
|
312
|
+
Updates multiple values in the memory at once.
|
313
|
+
|
314
|
+
Args:
|
315
|
+
content (Union[Dict, Sequence[Tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
|
316
|
+
mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
|
317
|
+
store_snapshot (bool): Whether to store a snapshot of the memory after the update.
|
318
|
+
protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
|
319
|
+
|
320
|
+
Raises:
|
321
|
+
TypeError: If the content type is neither a dictionary nor a sequence of tuples.
|
322
|
+
"""
|
323
|
+
if isinstance(content, dict):
|
324
|
+
_list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content.items()]
|
325
|
+
elif isinstance(content, Sequence):
|
326
|
+
_list_content: List[Tuple[Any, Any]] = [(k, v) for k, v in content]
|
327
|
+
else:
|
328
|
+
raise TypeError(f"Invalid content type `{type(content)}`!")
|
329
|
+
for k, v in _list_content[:1]:
|
330
|
+
await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
|
331
|
+
for k, v in _list_content[1:]:
|
332
|
+
await self.update(k, v, mode, False, protect_llm_read_only_fields)
|
333
|
+
|
334
|
+
@lock_decorator
|
335
|
+
async def add_watcher(self, key: str, callback: Callable) -> None:
|
336
|
+
"""
|
337
|
+
Adds a callback function to be invoked when the value
|
338
|
+
associated with the specified key in memory is updated.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
key (str): The key for which the watcher is being registered.
|
342
|
+
callback (Callable): A callable function that will be executed
|
343
|
+
whenever the value associated with the specified key is updated.
|
344
|
+
|
345
|
+
Notes:
|
346
|
+
If the key does not already have any watchers, it will be
|
347
|
+
initialized with an empty list before appending the callback.
|
348
|
+
"""
|
349
|
+
if key not in self.watchers:
|
350
|
+
self.watchers[key] = []
|
351
|
+
self.watchers[key].append(callback)
|
352
|
+
|
353
|
+
@lock_decorator
|
354
|
+
async def export(
|
355
|
+
self,
|
356
|
+
) -> Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]:
|
357
|
+
"""
|
358
|
+
Exports the current state of all memory sections.
|
359
|
+
|
360
|
+
Returns:
|
361
|
+
Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
|
362
|
+
"""
|
363
|
+
return (
|
364
|
+
await self._profile.export(),
|
365
|
+
await self._state.export(),
|
366
|
+
await self._dynamic.export(),
|
367
|
+
)
|
368
|
+
|
369
|
+
@lock_decorator
|
370
|
+
async def load(
|
371
|
+
self,
|
372
|
+
snapshots: Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]],
|
373
|
+
reset_memory: bool = True,
|
374
|
+
) -> None:
|
375
|
+
"""
|
376
|
+
Import the snapshot memories of all sections.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
snapshots (Tuple[Sequence[Dict], Sequence[Dict], Sequence[Dict]]): The exported snapshots.
|
380
|
+
reset_memory (bool): Whether to reset previous memory.
|
381
|
+
"""
|
382
|
+
_profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
|
383
|
+
for _snapshot, _mem in zip(
|
384
|
+
[_profile_snapshot, _state_snapshot, _dynamic_snapshot],
|
385
|
+
[self._state, self._profile, self._dynamic],
|
386
|
+
):
|
387
|
+
if _snapshot:
|
388
|
+
await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
|
389
|
+
|
390
|
+
@lock_decorator
|
391
|
+
async def get_top_k(
|
392
|
+
self,
|
393
|
+
key: Any,
|
394
|
+
metric: Callable[[Any], Any],
|
395
|
+
top_k: Optional[int] = None,
|
396
|
+
mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
|
397
|
+
preserve_order: bool = True,
|
398
|
+
) -> Any:
|
399
|
+
"""
|
400
|
+
Retrieves the top-k items from the memory based on the given key and metric.
|
401
|
+
|
402
|
+
Args:
|
403
|
+
key (Any): The key of the item to retrieve.
|
404
|
+
metric (Callable[[Any], Any]): A callable function that defines the metric for ranking the items.
|
405
|
+
top_k (Optional[int], optional): The number of top items to retrieve. Defaults to None (all items).
|
406
|
+
mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
|
407
|
+
preserve_order (bool): Whether preserve original order in output values.
|
408
|
+
|
409
|
+
Returns:
|
410
|
+
Any: The top-k items based on the specified metric.
|
411
|
+
|
412
|
+
Raises:
|
413
|
+
ValueError: If an invalid mode is provided.
|
414
|
+
KeyError: If the key is not found in any of the memory sections.
|
415
|
+
"""
|
416
|
+
if mode == "read only":
|
417
|
+
process_func = deepcopy
|
418
|
+
elif mode == "read and write":
|
419
|
+
process_func = lambda x: x
|
420
|
+
else:
|
421
|
+
raise ValueError(f"Invalid get mode `{mode}`!")
|
422
|
+
for _mem in [self._state, self._profile, self._dynamic]:
|
423
|
+
try:
|
424
|
+
value = await _mem.get_top_k(key, metric, top_k, preserve_order)
|
425
|
+
return process_func(value)
|
426
|
+
except KeyError as e:
|
427
|
+
continue
|
428
|
+
raise KeyError(f"No attribute `{key}` in memories!")
|
429
|
+
|
430
|
+
async def add(self, content: str, metadata: Optional[dict] = None) -> None:
|
431
|
+
"""添加新的记忆
|
432
|
+
|
433
|
+
Args:
|
434
|
+
content: 记忆内容
|
435
|
+
metadata: 相关元数据,如时间、地点等
|
436
|
+
"""
|
437
|
+
embedding = await self.embedding_model.embed(content)
|
438
|
+
self.memories.append(
|
439
|
+
{
|
440
|
+
"content": content,
|
441
|
+
"metadata": metadata or {},
|
442
|
+
"timestamp": datetime.now(),
|
443
|
+
"embedding": embedding,
|
444
|
+
}
|
445
|
+
)
|
446
|
+
self.embeddings.append(embedding)
|
447
|
+
|
448
|
+
def _cosine_similarity(self, v1: np.ndarray, v2: np.ndarray) -> float:
|
449
|
+
"""计算余弦相似度"""
|
450
|
+
dot_product = np.dot(v1, v2)
|
451
|
+
norm_v1 = np.linalg.norm(v1)
|
452
|
+
norm_v2 = np.linalg.norm(v2)
|
453
|
+
return dot_product / (norm_v1 * norm_v2)
|
@@ -0,0 +1,168 @@
|
|
1
|
+
"""
|
2
|
+
Base class of memory
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
import time
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from typing import (Any, Callable, Dict, Iterable, List, Optional, Sequence,
|
10
|
+
Tuple, Union)
|
11
|
+
|
12
|
+
from .const import *
|
13
|
+
|
14
|
+
|
15
|
+
class MemoryUnit:
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
content: Optional[Dict] = None,
|
19
|
+
required_attributes: Optional[Dict] = None,
|
20
|
+
activate_timestamp: bool = False,
|
21
|
+
) -> None:
|
22
|
+
self._content = {}
|
23
|
+
self._lock = asyncio.Lock()
|
24
|
+
self._activate_timestamp = activate_timestamp
|
25
|
+
if required_attributes is not None:
|
26
|
+
self._content.update(required_attributes)
|
27
|
+
if content is not None:
|
28
|
+
self._content.update(content)
|
29
|
+
if activate_timestamp and TIME_STAMP_KEY not in self._content:
|
30
|
+
self._content[TIME_STAMP_KEY] = time.time()
|
31
|
+
for _prop, _value in self._content.items():
|
32
|
+
self._set_attribute(_prop, _value)
|
33
|
+
|
34
|
+
def __getitem__(self, key: Any) -> Any:
|
35
|
+
return self._content[key]
|
36
|
+
|
37
|
+
def _create_property(self, property_name: str, property_value: Any):
|
38
|
+
|
39
|
+
def _getter(self):
|
40
|
+
return getattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", None)
|
41
|
+
|
42
|
+
def _setter(self, value):
|
43
|
+
setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", value)
|
44
|
+
|
45
|
+
setattr(self.__class__, property_name, property(_getter, _setter))
|
46
|
+
setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", property_value)
|
47
|
+
|
48
|
+
def _set_attribute(self, property_name: str, property_value: Any):
|
49
|
+
if not hasattr(self, f"{SELF_DEFINE_PREFIX}{property_name}"):
|
50
|
+
self._create_property(property_name, property_value)
|
51
|
+
else:
|
52
|
+
setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", property_value)
|
53
|
+
|
54
|
+
async def update(self, content: Dict) -> None:
|
55
|
+
await self._lock.acquire()
|
56
|
+
for k, v in content.items():
|
57
|
+
if k in self._content:
|
58
|
+
orig_v = self._content[k]
|
59
|
+
orig_type, new_type = type(orig_v), type(v)
|
60
|
+
if not orig_type == new_type:
|
61
|
+
logging.debug(
|
62
|
+
f"Type warning: The type of the value for key '{k}' is changing from `{orig_type.__name__}` to `{new_type.__name__}`!"
|
63
|
+
)
|
64
|
+
self._content.update(content)
|
65
|
+
for _prop, _value in self._content.items():
|
66
|
+
self._set_attribute(_prop, _value)
|
67
|
+
if self._activate_timestamp:
|
68
|
+
self._set_attribute(TIME_STAMP_KEY, time.time())
|
69
|
+
self._lock.release()
|
70
|
+
|
71
|
+
async def clear(self) -> None:
|
72
|
+
await self._lock.acquire()
|
73
|
+
self._content = {}
|
74
|
+
self._lock.release()
|
75
|
+
|
76
|
+
async def top_k_values(
|
77
|
+
self,
|
78
|
+
key: Any,
|
79
|
+
metric: Callable[[Any], Any],
|
80
|
+
top_k: Optional[int] = None,
|
81
|
+
preserve_order: bool = True,
|
82
|
+
) -> Union[Sequence[Any], Any]:
|
83
|
+
await self._lock.acquire()
|
84
|
+
values = self._content[key]
|
85
|
+
if not isinstance(values, Sequence):
|
86
|
+
logging.warning(
|
87
|
+
f"the value stored in key `{key}` is not `sequence`, return value `{values}` instead!"
|
88
|
+
)
|
89
|
+
return values
|
90
|
+
else:
|
91
|
+
_values_with_idx = [(i, v) for i, v in enumerate(values)]
|
92
|
+
_sorted_values_with_idx = sorted(
|
93
|
+
_values_with_idx, key=lambda i_v: -metric(i_v[1])
|
94
|
+
)
|
95
|
+
top_k = len(values) if top_k is None else top_k
|
96
|
+
if len(_sorted_values_with_idx) < top_k:
|
97
|
+
logging.debug(
|
98
|
+
f"Length of values {len(_sorted_values_with_idx)} is less than top_k {top_k}, returning all values."
|
99
|
+
)
|
100
|
+
self._lock.release()
|
101
|
+
if preserve_order:
|
102
|
+
return [
|
103
|
+
i_v[1]
|
104
|
+
for i_v in sorted(
|
105
|
+
_sorted_values_with_idx[:top_k], key=lambda i_v: i_v[0]
|
106
|
+
)
|
107
|
+
]
|
108
|
+
else:
|
109
|
+
return [i_v[1] for i_v in _sorted_values_with_idx[:top_k]]
|
110
|
+
|
111
|
+
async def dict_values(
|
112
|
+
self,
|
113
|
+
) -> Dict[Any, Any]:
|
114
|
+
return self._content
|
115
|
+
|
116
|
+
|
117
|
+
class MemoryBase(ABC):
|
118
|
+
|
119
|
+
def __init__(self) -> None:
|
120
|
+
self._memories: Dict[Any, Dict] = {}
|
121
|
+
self._lock = asyncio.Lock()
|
122
|
+
|
123
|
+
@abstractmethod
|
124
|
+
async def add(self, msg: Union[Any, Sequence[Any]]) -> None:
|
125
|
+
raise NotImplementedError
|
126
|
+
|
127
|
+
@abstractmethod
|
128
|
+
async def pop(self, index: int) -> Any:
|
129
|
+
pass
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
async def load(
|
133
|
+
self, snapshots: Union[Any, Sequence[Any]], reset_memory: bool = False
|
134
|
+
) -> None:
|
135
|
+
raise NotImplementedError
|
136
|
+
|
137
|
+
@abstractmethod
|
138
|
+
async def export(
|
139
|
+
self,
|
140
|
+
) -> Sequence[Any]:
|
141
|
+
raise NotImplementedError
|
142
|
+
|
143
|
+
@abstractmethod
|
144
|
+
async def reset(self) -> None:
|
145
|
+
raise NotImplementedError
|
146
|
+
|
147
|
+
def _fetch_recent_memory(self, recent_n: Optional[int] = None) -> Sequence[Any]:
|
148
|
+
_memories = self._memories
|
149
|
+
_list_units = list(_memories.keys())
|
150
|
+
if recent_n is None:
|
151
|
+
return _list_units
|
152
|
+
if len(_memories) < recent_n:
|
153
|
+
logging.debug(
|
154
|
+
f"Length of memory {len(_memories)} is less than recent_n {recent_n}, returning all available memories."
|
155
|
+
)
|
156
|
+
return _list_units[-recent_n:]
|
157
|
+
|
158
|
+
# interact
|
159
|
+
@abstractmethod
|
160
|
+
async def get(self, key: Any):
|
161
|
+
raise NotImplementedError
|
162
|
+
|
163
|
+
@abstractmethod
|
164
|
+
async def update(self, key: Any, value: Any, store_snapshot: bool):
|
165
|
+
raise NotImplementedError
|
166
|
+
|
167
|
+
def __getitem__(self, index: Any) -> Any:
|
168
|
+
return list(self._memories.keys())[index]
|