pycityagent 2.0.0a42__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a42.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a42.dist-info/METADATA +235 -0
- pycityagent-2.0.0a42.dist-info/RECORD +81 -0
- pycityagent-2.0.0a42.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a42.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a42.dist-info/top_level.txt +3 -0
| @@ -0,0 +1,448 @@ | |
| 1 | 
            +
            import asyncio
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            from collections import defaultdict
         | 
| 4 | 
            +
            from collections.abc import Callable, Sequence
         | 
| 5 | 
            +
            from copy import deepcopy
         | 
| 6 | 
            +
            from datetime import datetime
         | 
| 7 | 
            +
            from typing import Any, Literal, Optional, Union
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from langchain_core.embeddings import Embeddings
         | 
| 11 | 
            +
            from pyparsing import deque
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from ..utils.decorators import lock_decorator
         | 
| 14 | 
            +
            from .const import *
         | 
| 15 | 
            +
            from .faiss_query import FaissQuery
         | 
| 16 | 
            +
            from .profile import ProfileMemory
         | 
| 17 | 
            +
            from .self_define import DynamicMemory
         | 
| 18 | 
            +
            from .state import StateMemory
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            logger = logging.getLogger("pycityagent")
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class Memory:
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
                A class to manage different types of memory (state, profile, dynamic).
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                Attributes:
         | 
| 28 | 
            +
                    _state (StateMemory): Stores state-related data.
         | 
| 29 | 
            +
                    _profile (ProfileMemory): Stores profile-related data.
         | 
| 30 | 
            +
                    _dynamic (DynamicMemory): Stores dynamically configured data.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def __init__(
         | 
| 34 | 
            +
                    self,
         | 
| 35 | 
            +
                    config: Optional[dict[Any, Any]] = None,
         | 
| 36 | 
            +
                    profile: Optional[dict[Any, Any]] = None,
         | 
| 37 | 
            +
                    base: Optional[dict[Any, Any]] = None,
         | 
| 38 | 
            +
                    motion: Optional[dict[Any, Any]] = None,
         | 
| 39 | 
            +
                    activate_timestamp: bool = False,
         | 
| 40 | 
            +
                    embedding_model: Optional[Embeddings] = None,
         | 
| 41 | 
            +
                    faiss_query: Optional[FaissQuery] = None,
         | 
| 42 | 
            +
                ) -> None:
         | 
| 43 | 
            +
                    """
         | 
| 44 | 
            +
                    Initializes the Memory with optional configuration.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    Args:
         | 
| 47 | 
            +
                        config (Optional[dict[Any, Any]], optional):
         | 
| 48 | 
            +
                            A configuration dictionary for dynamic memory. The dictionary format is:
         | 
| 49 | 
            +
                            - Key: The name of the dynamic memory field.
         | 
| 50 | 
            +
                            - Value: Can be one of two formats:
         | 
| 51 | 
            +
                                1. A tuple where the first element is a variable type (e.g., int, str, etc.), and the second element is the default value for this field.
         | 
| 52 | 
            +
                                2. A callable that returns the default value when invoked (useful for complex default values).
         | 
| 53 | 
            +
                            Note: If a key in `config` overlaps with predefined attributes in `PROFILE_ATTRIBUTES` or `STATE_ATTRIBUTES`, a warning will be logged, and the key will be ignored.
         | 
| 54 | 
            +
                            Defaults to None.
         | 
| 55 | 
            +
                        profile (Optional[dict[Any, Any]], optional): profile attribute dict.
         | 
| 56 | 
            +
                        base (Optional[dict[Any, Any]], optional): base attribute dict from City Simulator.
         | 
| 57 | 
            +
                        motion (Optional[dict[Any, Any]], optional): motion attribute dict from City Simulator.
         | 
| 58 | 
            +
                        activate_timestamp (bool): Whether activate timestamp storage in MemoryUnit
         | 
| 59 | 
            +
                        embedding_model (Embeddings): The embedding model for memory search.
         | 
| 60 | 
            +
                        faiss_query (FaissQuery): The faiss_query of the agent. Defaults to None.
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    self.watchers: dict[str, list[Callable]] = {}
         | 
| 63 | 
            +
                    self._lock = asyncio.Lock()
         | 
| 64 | 
            +
                    self._agent_id: int = -1
         | 
| 65 | 
            +
                    self._embedding_model = embedding_model
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    _dynamic_config: dict[Any, Any] = {}
         | 
| 68 | 
            +
                    _state_config: dict[Any, Any] = {}
         | 
| 69 | 
            +
                    _profile_config: dict[Any, Any] = {}
         | 
| 70 | 
            +
                    # 记录哪些字段需要embedding
         | 
| 71 | 
            +
                    self._embedding_fields: dict[str, bool] = {}
         | 
| 72 | 
            +
                    self._embedding_field_to_doc_id: dict[Any, str] = defaultdict(str)
         | 
| 73 | 
            +
                    self._faiss_query = faiss_query
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    if config is not None:
         | 
| 76 | 
            +
                        for k, v in config.items():
         | 
| 77 | 
            +
                            try:
         | 
| 78 | 
            +
                                # 处理新的三元组格式
         | 
| 79 | 
            +
                                if isinstance(v, tuple) and len(v) == 3:
         | 
| 80 | 
            +
                                    _type, _value, enable_embedding = v
         | 
| 81 | 
            +
                                    self._embedding_fields[k] = enable_embedding
         | 
| 82 | 
            +
                                else:
         | 
| 83 | 
            +
                                    _type, _value = v
         | 
| 84 | 
            +
                                    self._embedding_fields[k] = False
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                                try:
         | 
| 87 | 
            +
                                    if isinstance(_type, type):
         | 
| 88 | 
            +
                                        _value = _type(_value)
         | 
| 89 | 
            +
                                    else:
         | 
| 90 | 
            +
                                        if isinstance(_type, deque):
         | 
| 91 | 
            +
                                            _type.extend(_value)
         | 
| 92 | 
            +
                                            _value = deepcopy(_type)
         | 
| 93 | 
            +
                                        else:
         | 
| 94 | 
            +
                                            logger.warning(f"type `{_type}` is not supported!")
         | 
| 95 | 
            +
                                            pass
         | 
| 96 | 
            +
                                except TypeError as e:
         | 
| 97 | 
            +
                                    pass
         | 
| 98 | 
            +
                            except TypeError as e:
         | 
| 99 | 
            +
                                if isinstance(v, type):
         | 
| 100 | 
            +
                                    _value = v()
         | 
| 101 | 
            +
                                else:
         | 
| 102 | 
            +
                                    _value = v
         | 
| 103 | 
            +
                                self._embedding_fields[k] = False
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                            if (
         | 
| 106 | 
            +
                                k in PROFILE_ATTRIBUTES
         | 
| 107 | 
            +
                                or k in STATE_ATTRIBUTES
         | 
| 108 | 
            +
                                or k == TIME_STAMP_KEY
         | 
| 109 | 
            +
                            ):
         | 
| 110 | 
            +
                                logger.warning(f"key `{k}` already declared in memory!")
         | 
| 111 | 
            +
                                continue
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                            _dynamic_config[k] = deepcopy(_value)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # 初始化各类记忆
         | 
| 116 | 
            +
                    self._dynamic = DynamicMemory(
         | 
| 117 | 
            +
                        required_attributes=_dynamic_config, activate_timestamp=activate_timestamp
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                    if profile is not None:
         | 
| 121 | 
            +
                        for k, v in profile.items():
         | 
| 122 | 
            +
                            if k not in PROFILE_ATTRIBUTES:
         | 
| 123 | 
            +
                                logger.warning(f"key `{k}` is not a correct `profile` field!")
         | 
| 124 | 
            +
                                continue
         | 
| 125 | 
            +
                            _profile_config[k] = v
         | 
| 126 | 
            +
                    if motion is not None:
         | 
| 127 | 
            +
                        for k, v in motion.items():
         | 
| 128 | 
            +
                            if k not in STATE_ATTRIBUTES:
         | 
| 129 | 
            +
                                logger.warning(f"key `{k}` is not a correct `motion` field!")
         | 
| 130 | 
            +
                                continue
         | 
| 131 | 
            +
                            _state_config[k] = v
         | 
| 132 | 
            +
                    if base is not None:
         | 
| 133 | 
            +
                        for k, v in base.items():
         | 
| 134 | 
            +
                            if k not in STATE_ATTRIBUTES:
         | 
| 135 | 
            +
                                logger.warning(f"key `{k}` is not a correct `base` field!")
         | 
| 136 | 
            +
                                continue
         | 
| 137 | 
            +
                            _state_config[k] = v
         | 
| 138 | 
            +
                    self._state = StateMemory(
         | 
| 139 | 
            +
                        msg=_state_config, activate_timestamp=activate_timestamp
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    self._profile = ProfileMemory(
         | 
| 142 | 
            +
                        msg=_profile_config, activate_timestamp=activate_timestamp
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                    # self.memories = []  # 存储记忆内容
         | 
| 145 | 
            +
                    # self.embeddings = []  # 存储记忆的向量表示
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def set_embedding_model(
         | 
| 148 | 
            +
                    self,
         | 
| 149 | 
            +
                    embedding_model: Embeddings,
         | 
| 150 | 
            +
                ):
         | 
| 151 | 
            +
                    self._embedding_model = embedding_model
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                @property
         | 
| 154 | 
            +
                def embedding_model(
         | 
| 155 | 
            +
                    self,
         | 
| 156 | 
            +
                ):
         | 
| 157 | 
            +
                    if self._embedding_model is None:
         | 
| 158 | 
            +
                        raise RuntimeError(
         | 
| 159 | 
            +
                            f"embedding_model before assignment, please `set_embedding_model` first!"
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                    return self._embedding_model
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def set_faiss_query(self, faiss_query: FaissQuery):
         | 
| 164 | 
            +
                    """
         | 
| 165 | 
            +
                    Set the FaissQuery of the agent.
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    self._faiss_query = faiss_query
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                @property
         | 
| 170 | 
            +
                def agent_id(
         | 
| 171 | 
            +
                    self,
         | 
| 172 | 
            +
                ):
         | 
| 173 | 
            +
                    if self._agent_id < 0:
         | 
| 174 | 
            +
                        raise RuntimeError(
         | 
| 175 | 
            +
                            f"agent_id before assignment, please `set_agent_id` first!"
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
                    return self._agent_id
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                def set_agent_id(self, agent_id: int):
         | 
| 180 | 
            +
                    """
         | 
| 181 | 
            +
                    Set the FaissQuery of the agent.
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    self._agent_id = agent_id
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                @property
         | 
| 186 | 
            +
                def faiss_query(self) -> FaissQuery:
         | 
| 187 | 
            +
                    """FaissQuery"""
         | 
| 188 | 
            +
                    if self._faiss_query is None:
         | 
| 189 | 
            +
                        raise RuntimeError(
         | 
| 190 | 
            +
                            f"FaissQuery access before assignment, please `set_faiss_query` first!"
         | 
| 191 | 
            +
                        )
         | 
| 192 | 
            +
                    return self._faiss_query
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                @lock_decorator
         | 
| 195 | 
            +
                async def get(
         | 
| 196 | 
            +
                    self,
         | 
| 197 | 
            +
                    key: Any,
         | 
| 198 | 
            +
                    mode: Union[Literal["read only"], Literal["read and write"]] = "read only",
         | 
| 199 | 
            +
                ) -> Any:
         | 
| 200 | 
            +
                    """
         | 
| 201 | 
            +
                    Retrieves a value from memory based on the given key and access mode.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    Args:
         | 
| 204 | 
            +
                        key (Any): The key of the item to retrieve.
         | 
| 205 | 
            +
                        mode (Union[Literal["read only"], Literal["read and write"]], optional): Access mode for the item. Defaults to "read only".
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    Returns:
         | 
| 208 | 
            +
                        Any: The value associated with the key.
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    Raises:
         | 
| 211 | 
            +
                        ValueError: If an invalid mode is provided.
         | 
| 212 | 
            +
                        KeyError: If the key is not found in any of the memory sections.
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    if mode == "read only":
         | 
| 215 | 
            +
                        process_func = deepcopy
         | 
| 216 | 
            +
                    elif mode == "read and write":
         | 
| 217 | 
            +
                        process_func = lambda x: x
         | 
| 218 | 
            +
                    else:
         | 
| 219 | 
            +
                        raise ValueError(f"Invalid get mode `{mode}`!")
         | 
| 220 | 
            +
                    for _mem in [self._state, self._profile, self._dynamic]:
         | 
| 221 | 
            +
                        try:
         | 
| 222 | 
            +
                            value = await _mem.get(key)
         | 
| 223 | 
            +
                            return process_func(value)
         | 
| 224 | 
            +
                        except KeyError as e:
         | 
| 225 | 
            +
                            continue
         | 
| 226 | 
            +
                    raise KeyError(f"No attribute `{key}` in memories!")
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                @lock_decorator
         | 
| 229 | 
            +
                async def update(
         | 
| 230 | 
            +
                    self,
         | 
| 231 | 
            +
                    key: Any,
         | 
| 232 | 
            +
                    value: Any,
         | 
| 233 | 
            +
                    mode: Union[Literal["replace"], Literal["merge"]] = "replace",
         | 
| 234 | 
            +
                    store_snapshot: bool = False,
         | 
| 235 | 
            +
                    protect_llm_read_only_fields: bool = True,
         | 
| 236 | 
            +
                ) -> None:
         | 
| 237 | 
            +
                    """更新记忆值并在必要时更新embedding"""
         | 
| 238 | 
            +
                    if protect_llm_read_only_fields:
         | 
| 239 | 
            +
                        if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
         | 
| 240 | 
            +
                            logger.warning(f"Trying to write protected key `{key}`!")
         | 
| 241 | 
            +
                            return
         | 
| 242 | 
            +
                    for _mem in [self._state, self._profile, self._dynamic]:
         | 
| 243 | 
            +
                        try:
         | 
| 244 | 
            +
                            original_value = await _mem.get(key)
         | 
| 245 | 
            +
                            if mode == "replace":
         | 
| 246 | 
            +
                                await _mem.update(key, value, store_snapshot)
         | 
| 247 | 
            +
                                # 如果字段需要embedding,则更新embedding
         | 
| 248 | 
            +
                                if self._embedding_fields.get(key, False) and self.embedding_model:
         | 
| 249 | 
            +
                                    memory_type = self._get_memory_type(_mem)
         | 
| 250 | 
            +
                                    # 覆盖更新删除原vector
         | 
| 251 | 
            +
                                    orig_doc_id = self._embedding_field_to_doc_id[key]
         | 
| 252 | 
            +
                                    if orig_doc_id:
         | 
| 253 | 
            +
                                        await self.faiss_query.delete_documents(
         | 
| 254 | 
            +
                                            to_delete_ids=[orig_doc_id],
         | 
| 255 | 
            +
                                        )
         | 
| 256 | 
            +
                                    doc_ids: list[str] = await self.faiss_query.add_documents(
         | 
| 257 | 
            +
                                        agent_id=self.agent_id,
         | 
| 258 | 
            +
                                        documents=f"{key}: {str(value)}",
         | 
| 259 | 
            +
                                        extra_tags={
         | 
| 260 | 
            +
                                            "type": memory_type,
         | 
| 261 | 
            +
                                            "key": key,
         | 
| 262 | 
            +
                                        },
         | 
| 263 | 
            +
                                    )
         | 
| 264 | 
            +
                                    self._embedding_field_to_doc_id[key] = doc_ids[0]
         | 
| 265 | 
            +
                                if key in self.watchers:
         | 
| 266 | 
            +
                                    for callback in self.watchers[key]:
         | 
| 267 | 
            +
                                        asyncio.create_task(callback())
         | 
| 268 | 
            +
                            elif mode == "merge":
         | 
| 269 | 
            +
                                if isinstance(original_value, set):
         | 
| 270 | 
            +
                                    original_value.update(set(value))
         | 
| 271 | 
            +
                                elif isinstance(original_value, dict):
         | 
| 272 | 
            +
                                    original_value.update(dict(value))
         | 
| 273 | 
            +
                                elif isinstance(original_value, list):
         | 
| 274 | 
            +
                                    original_value.extend(list(value))
         | 
| 275 | 
            +
                                elif isinstance(original_value, deque):
         | 
| 276 | 
            +
                                    original_value.extend(deque(value))
         | 
| 277 | 
            +
                                else:
         | 
| 278 | 
            +
                                    logger.debug(
         | 
| 279 | 
            +
                                        f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
         | 
| 280 | 
            +
                                    )
         | 
| 281 | 
            +
                                    await _mem.update(key, value, store_snapshot)
         | 
| 282 | 
            +
                                if self._embedding_fields.get(key, False) and self.embedding_model:
         | 
| 283 | 
            +
                                    memory_type = self._get_memory_type(_mem)
         | 
| 284 | 
            +
                                    doc_ids = await self.faiss_query.add_documents(
         | 
| 285 | 
            +
                                        agent_id=self.agent_id,
         | 
| 286 | 
            +
                                        documents=f"{key}: {str(original_value)}",
         | 
| 287 | 
            +
                                        extra_tags={
         | 
| 288 | 
            +
                                            "type": memory_type,
         | 
| 289 | 
            +
                                            "key": key,
         | 
| 290 | 
            +
                                        },
         | 
| 291 | 
            +
                                    )
         | 
| 292 | 
            +
                                    self._embedding_field_to_doc_id[key] = doc_ids[0]
         | 
| 293 | 
            +
                                if key in self.watchers:
         | 
| 294 | 
            +
                                    for callback in self.watchers[key]:
         | 
| 295 | 
            +
                                        asyncio.create_task(callback())
         | 
| 296 | 
            +
                            else:
         | 
| 297 | 
            +
                                raise ValueError(f"Invalid update mode `{mode}`!")
         | 
| 298 | 
            +
                            return
         | 
| 299 | 
            +
                        except KeyError:
         | 
| 300 | 
            +
                            continue
         | 
| 301 | 
            +
                    raise KeyError(f"No attribute `{key}` in memories!")
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def _get_memory_type(self, mem: Any) -> str:
         | 
| 304 | 
            +
                    """获取记忆类型"""
         | 
| 305 | 
            +
                    if mem is self._state:
         | 
| 306 | 
            +
                        return "state"
         | 
| 307 | 
            +
                    elif mem is self._profile:
         | 
| 308 | 
            +
                        return "profile"
         | 
| 309 | 
            +
                    else:
         | 
| 310 | 
            +
                        return "dynamic"
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                async def update_batch(
         | 
| 313 | 
            +
                    self,
         | 
| 314 | 
            +
                    content: Union[dict, Sequence[tuple[Any, Any]]],
         | 
| 315 | 
            +
                    mode: Union[Literal["replace"], Literal["merge"]] = "replace",
         | 
| 316 | 
            +
                    store_snapshot: bool = False,
         | 
| 317 | 
            +
                    protect_llm_read_only_fields: bool = True,
         | 
| 318 | 
            +
                ) -> None:
         | 
| 319 | 
            +
                    """
         | 
| 320 | 
            +
                    Updates multiple values in the memory at once.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    Args:
         | 
| 323 | 
            +
                        content (Union[dict, Sequence[tuple[Any, Any]]]): A dictionary or sequence of tuples containing the keys and values to update.
         | 
| 324 | 
            +
                        mode (Union[Literal["replace"], Literal["merge"]], optional): Update mode. Defaults to "replace".
         | 
| 325 | 
            +
                        store_snapshot (bool): Whether to store a snapshot of the memory after the update.
         | 
| 326 | 
            +
                        protect_llm_read_only_fields (bool): Whether to protect non-self define fields from being updated.
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    Raises:
         | 
| 329 | 
            +
                        TypeError: If the content type is neither a dictionary nor a sequence of tuples.
         | 
| 330 | 
            +
                    """
         | 
| 331 | 
            +
                    if isinstance(content, dict):
         | 
| 332 | 
            +
                        _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content.items()]
         | 
| 333 | 
            +
                    elif isinstance(content, Sequence):
         | 
| 334 | 
            +
                        _list_content: list[tuple[Any, Any]] = [(k, v) for k, v in content]
         | 
| 335 | 
            +
                    else:
         | 
| 336 | 
            +
                        raise TypeError(f"Invalid content type `{type(content)}`!")
         | 
| 337 | 
            +
                    for k, v in _list_content[:1]:
         | 
| 338 | 
            +
                        await self.update(k, v, mode, store_snapshot, protect_llm_read_only_fields)
         | 
| 339 | 
            +
                    for k, v in _list_content[1:]:
         | 
| 340 | 
            +
                        await self.update(k, v, mode, False, protect_llm_read_only_fields)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                @lock_decorator
         | 
| 343 | 
            +
                async def add_watcher(self, key: str, callback: Callable) -> None:
         | 
| 344 | 
            +
                    """
         | 
| 345 | 
            +
                    Adds a callback function to be invoked when the value
         | 
| 346 | 
            +
                    associated with the specified key in memory is updated.
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    Args:
         | 
| 349 | 
            +
                        key (str): The key for which the watcher is being registered.
         | 
| 350 | 
            +
                        callback (Callable): A callable function that will be executed
         | 
| 351 | 
            +
                        whenever the value associated with the specified key is updated.
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    Notes:
         | 
| 354 | 
            +
                        If the key does not already have any watchers, it will be
         | 
| 355 | 
            +
                        initialized with an empty list before appending the callback.
         | 
| 356 | 
            +
                    """
         | 
| 357 | 
            +
                    if key not in self.watchers:
         | 
| 358 | 
            +
                        self.watchers[key] = []
         | 
| 359 | 
            +
                    self.watchers[key].append(callback)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                @lock_decorator
         | 
| 362 | 
            +
                async def export(
         | 
| 363 | 
            +
                    self,
         | 
| 364 | 
            +
                ) -> tuple[Sequence[dict], Sequence[dict], Sequence[dict]]:
         | 
| 365 | 
            +
                    """
         | 
| 366 | 
            +
                    Exports the current state of all memory sections.
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    Returns:
         | 
| 369 | 
            +
                        tuple[Sequence[dict], Sequence[dict], Sequence[dict]]: A tuple containing the exported data of profile, state, and dynamic memory sections.
         | 
| 370 | 
            +
                    """
         | 
| 371 | 
            +
                    return (
         | 
| 372 | 
            +
                        await self._profile.export(),
         | 
| 373 | 
            +
                        await self._state.export(),
         | 
| 374 | 
            +
                        await self._dynamic.export(),
         | 
| 375 | 
            +
                    )
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                @lock_decorator
         | 
| 378 | 
            +
                async def load(
         | 
| 379 | 
            +
                    self,
         | 
| 380 | 
            +
                    snapshots: tuple[Sequence[dict], Sequence[dict], Sequence[dict]],
         | 
| 381 | 
            +
                    reset_memory: bool = True,
         | 
| 382 | 
            +
                ) -> None:
         | 
| 383 | 
            +
                    """
         | 
| 384 | 
            +
                    Import the snapshot memories of all sections.
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                    Args:
         | 
| 387 | 
            +
                        snapshots (tuple[Sequence[dict], Sequence[dict], Sequence[dict]]): The exported snapshots.
         | 
| 388 | 
            +
                        reset_memory (bool): Whether to reset previous memory.
         | 
| 389 | 
            +
                    """
         | 
| 390 | 
            +
                    _profile_snapshot, _state_snapshot, _dynamic_snapshot = snapshots
         | 
| 391 | 
            +
                    for _snapshot, _mem in zip(
         | 
| 392 | 
            +
                        [_profile_snapshot, _state_snapshot, _dynamic_snapshot],
         | 
| 393 | 
            +
                        [self._state, self._profile, self._dynamic],
         | 
| 394 | 
            +
                    ):
         | 
| 395 | 
            +
                        if _snapshot:
         | 
| 396 | 
            +
                            await _mem.load(snapshots=_snapshot, reset_memory=reset_memory)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                # async def add(self, content: str, metadata: Optional[dict] = None) -> None:
         | 
| 399 | 
            +
                #     """添加新的记忆
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                #     Args:
         | 
| 402 | 
            +
                #         content: 记忆内容
         | 
| 403 | 
            +
                #         metadata: 相关元数据,如时间、地点等
         | 
| 404 | 
            +
                #     """
         | 
| 405 | 
            +
                #     embedding = await self.embedding_model.aembed_query(content)
         | 
| 406 | 
            +
                #     self.memories.append(
         | 
| 407 | 
            +
                #         {
         | 
| 408 | 
            +
                #             "content": content,
         | 
| 409 | 
            +
                #             "metadata": metadata or {},
         | 
| 410 | 
            +
                #             "timestamp": datetime.now(),
         | 
| 411 | 
            +
                #             "embedding": embedding,
         | 
| 412 | 
            +
                #         }
         | 
| 413 | 
            +
                #     )
         | 
| 414 | 
            +
                #     self.embeddings.append(embedding)
         | 
| 415 | 
            +
             | 
| 416 | 
            +
                @lock_decorator
         | 
| 417 | 
            +
                async def search(
         | 
| 418 | 
            +
                    self, query: str, top_k: int = 3, filter: Optional[dict] = None
         | 
| 419 | 
            +
                ) -> str:
         | 
| 420 | 
            +
                    """搜索相关记忆
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    Args:
         | 
| 423 | 
            +
                        query: 查询文本
         | 
| 424 | 
            +
                        top_k: 返回最相关的记忆数量
         | 
| 425 | 
            +
                        filter (dict, optional): 记忆的筛选条件,如 {"type":"dynamic", "key":"self_define_1",},默认为空
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    Returns:
         | 
| 428 | 
            +
                        str: 格式化的相关记忆文本
         | 
| 429 | 
            +
                    """
         | 
| 430 | 
            +
                    if not self._embedding_model:
         | 
| 431 | 
            +
                        return "Embedding model not initialized"
         | 
| 432 | 
            +
                    top_results: list[tuple[str, float, dict]] = (
         | 
| 433 | 
            +
                        await self.faiss_query.similarity_search(  # type:ignore
         | 
| 434 | 
            +
                            query=query,
         | 
| 435 | 
            +
                            agent_id=self.agent_id,
         | 
| 436 | 
            +
                            k=top_k,
         | 
| 437 | 
            +
                            return_score_type="similarity_score",
         | 
| 438 | 
            +
                            filter=filter,
         | 
| 439 | 
            +
                        )
         | 
| 440 | 
            +
                    )
         | 
| 441 | 
            +
                    # 格式化输出
         | 
| 442 | 
            +
                    formatted_results = []
         | 
| 443 | 
            +
                    for content, score, metadata in top_results:
         | 
| 444 | 
            +
                        formatted_results.append(
         | 
| 445 | 
            +
                            f"- [{metadata['type']}] {content} " f"(相关度: {score:.2f})"
         | 
| 446 | 
            +
                        )
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    return "\n".join(formatted_results)
         | 
| @@ -0,0 +1,170 @@ | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Base class of memory
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import asyncio
         | 
| 6 | 
            +
            import logging
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from abc import ABC, abstractmethod
         | 
| 9 | 
            +
            from collections.abc import Callable, Sequence
         | 
| 10 | 
            +
            from typing import Any, Optional, Union
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from .const import *
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            logger = logging.getLogger("pycityagent")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class MemoryUnit:
         | 
| 18 | 
            +
                def __init__(
         | 
| 19 | 
            +
                    self,
         | 
| 20 | 
            +
                    content: Optional[dict] = None,
         | 
| 21 | 
            +
                    required_attributes: Optional[dict] = None,
         | 
| 22 | 
            +
                    activate_timestamp: bool = False,
         | 
| 23 | 
            +
                ) -> None:
         | 
| 24 | 
            +
                    self._content = {}
         | 
| 25 | 
            +
                    self._lock = asyncio.Lock()
         | 
| 26 | 
            +
                    self._activate_timestamp = activate_timestamp
         | 
| 27 | 
            +
                    if required_attributes is not None:
         | 
| 28 | 
            +
                        self._content.update(required_attributes)
         | 
| 29 | 
            +
                    if content is not None:
         | 
| 30 | 
            +
                        self._content.update(content)
         | 
| 31 | 
            +
                    if activate_timestamp and TIME_STAMP_KEY not in self._content:
         | 
| 32 | 
            +
                        self._content[TIME_STAMP_KEY] = time.time()
         | 
| 33 | 
            +
                    for _prop, _value in self._content.items():
         | 
| 34 | 
            +
                        self._set_attribute(_prop, _value)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def __getitem__(self, key: Any) -> Any:
         | 
| 37 | 
            +
                    return self._content[key]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def _create_property(self, property_name: str, property_value: Any):
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    def _getter(self):
         | 
| 42 | 
            +
                        return getattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", None)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    def _setter(self, value):
         | 
| 45 | 
            +
                        setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", value)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    setattr(self.__class__, property_name, property(_getter, _setter))
         | 
| 48 | 
            +
                    setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", property_value)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def _set_attribute(self, property_name: str, property_value: Any):
         | 
| 51 | 
            +
                    if not hasattr(self, f"{SELF_DEFINE_PREFIX}{property_name}"):
         | 
| 52 | 
            +
                        self._create_property(property_name, property_value)
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        setattr(self, f"{SELF_DEFINE_PREFIX}{property_name}", property_value)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                async def update(self, content: dict) -> None:
         | 
| 57 | 
            +
                    await self._lock.acquire()
         | 
| 58 | 
            +
                    for k, v in content.items():
         | 
| 59 | 
            +
                        if k in self._content:
         | 
| 60 | 
            +
                            orig_v = self._content[k]
         | 
| 61 | 
            +
                            orig_type, new_type = type(orig_v), type(v)
         | 
| 62 | 
            +
                            if not orig_type == new_type:
         | 
| 63 | 
            +
                                logger.debug(
         | 
| 64 | 
            +
                                    f"Type warning: The type of the value for key '{k}' is changing from `{orig_type.__name__}` to `{new_type.__name__}`!"
         | 
| 65 | 
            +
                                )
         | 
| 66 | 
            +
                    self._content.update(content)
         | 
| 67 | 
            +
                    for _prop, _value in self._content.items():
         | 
| 68 | 
            +
                        self._set_attribute(_prop, _value)
         | 
| 69 | 
            +
                    if self._activate_timestamp:
         | 
| 70 | 
            +
                        self._set_attribute(TIME_STAMP_KEY, time.time())
         | 
| 71 | 
            +
                    self._lock.release()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                async def clear(self) -> None:
         | 
| 74 | 
            +
                    await self._lock.acquire()
         | 
| 75 | 
            +
                    self._content = {}
         | 
| 76 | 
            +
                    self._lock.release()
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                async def top_k_values(
         | 
| 79 | 
            +
                    self,
         | 
| 80 | 
            +
                    key: Any,
         | 
| 81 | 
            +
                    metric: Callable[[Any], Any],
         | 
| 82 | 
            +
                    top_k: Optional[int] = None,
         | 
| 83 | 
            +
                    preserve_order: bool = True,
         | 
| 84 | 
            +
                ) -> Union[Sequence[Any], Any]:
         | 
| 85 | 
            +
                    await self._lock.acquire()
         | 
| 86 | 
            +
                    values = self._content[key]
         | 
| 87 | 
            +
                    if not isinstance(values, Sequence):
         | 
| 88 | 
            +
                        logger.warning(
         | 
| 89 | 
            +
                            f"the value stored in key `{key}` is not `sequence`, return value `{values}` instead!"
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
                        return values
         | 
| 92 | 
            +
                    else:
         | 
| 93 | 
            +
                        _values_with_idx = [(i, v) for i, v in enumerate(values)]
         | 
| 94 | 
            +
                        _sorted_values_with_idx = sorted(
         | 
| 95 | 
            +
                            _values_with_idx, key=lambda i_v: -metric(i_v[1])
         | 
| 96 | 
            +
                        )
         | 
| 97 | 
            +
                        top_k = len(values) if top_k is None else top_k
         | 
| 98 | 
            +
                        if len(_sorted_values_with_idx) < top_k:
         | 
| 99 | 
            +
                            logger.debug(
         | 
| 100 | 
            +
                                f"Length of values {len(_sorted_values_with_idx)} is less than top_k {top_k}, returning all values."
         | 
| 101 | 
            +
                            )
         | 
| 102 | 
            +
                        self._lock.release()
         | 
| 103 | 
            +
                        if preserve_order:
         | 
| 104 | 
            +
                            return [
         | 
| 105 | 
            +
                                i_v[1]
         | 
| 106 | 
            +
                                for i_v in sorted(
         | 
| 107 | 
            +
                                    _sorted_values_with_idx[:top_k], key=lambda i_v: i_v[0]
         | 
| 108 | 
            +
                                )
         | 
| 109 | 
            +
                            ]
         | 
| 110 | 
            +
                        else:
         | 
| 111 | 
            +
                            return [i_v[1] for i_v in _sorted_values_with_idx[:top_k]]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                async def dict_values(
         | 
| 114 | 
            +
                    self,
         | 
| 115 | 
            +
                ) -> dict[Any, Any]:
         | 
| 116 | 
            +
                    return self._content
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            class MemoryBase(ABC):
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def __init__(self) -> None:
         | 
| 122 | 
            +
                    self._memories: dict[Any, dict] = {}
         | 
| 123 | 
            +
                    self._lock = asyncio.Lock()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                @abstractmethod
         | 
| 126 | 
            +
                async def add(self, msg: Union[Any, Sequence[Any]]) -> None:
         | 
| 127 | 
            +
                    raise NotImplementedError
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @abstractmethod
         | 
| 130 | 
            +
                async def pop(self, index: int) -> Any:
         | 
| 131 | 
            +
                    pass
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                @abstractmethod
         | 
| 134 | 
            +
                async def load(
         | 
| 135 | 
            +
                    self, snapshots: Union[Any, Sequence[Any]], reset_memory: bool = False
         | 
| 136 | 
            +
                ) -> None:
         | 
| 137 | 
            +
                    raise NotImplementedError
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                @abstractmethod
         | 
| 140 | 
            +
                async def export(
         | 
| 141 | 
            +
                    self,
         | 
| 142 | 
            +
                ) -> Sequence[Any]:
         | 
| 143 | 
            +
                    raise NotImplementedError
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                @abstractmethod
         | 
| 146 | 
            +
                async def reset(self) -> None:
         | 
| 147 | 
            +
                    raise NotImplementedError
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def _fetch_recent_memory(self, recent_n: Optional[int] = None) -> Sequence[Any]:
         | 
| 150 | 
            +
                    _memories = self._memories
         | 
| 151 | 
            +
                    _list_units = list(_memories.keys())
         | 
| 152 | 
            +
                    if recent_n is None:
         | 
| 153 | 
            +
                        return _list_units
         | 
| 154 | 
            +
                    if len(_memories) < recent_n:
         | 
| 155 | 
            +
                        logger.debug(
         | 
| 156 | 
            +
                            f"Length of memory {len(_memories)} is less than recent_n {recent_n}, returning all available memories."
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
                    return _list_units[-recent_n:]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                # interact
         | 
| 161 | 
            +
                @abstractmethod
         | 
| 162 | 
            +
                async def get(self, key: Any):
         | 
| 163 | 
            +
                    raise NotImplementedError
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                @abstractmethod
         | 
| 166 | 
            +
                async def update(self, key: Any, value: Any, store_snapshot: bool):
         | 
| 167 | 
            +
                    raise NotImplementedError
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def __getitem__(self, index: Any) -> Any:
         | 
| 170 | 
            +
                    return list(self._memories.keys())[index]
         |