pycityagent 2.0.0a39__py3-none-any.whl → 2.0.0a41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +1 -1
- pycityagent/message/messager.py +14 -12
- pycityagent/simulation/agentgroup.py +8 -8
- pycityagent/simulation/simulation.py +4 -4
- pycityagent/simulation/storage/pg.py +6 -1
- {pycityagent-2.0.0a39.dist-info → pycityagent-2.0.0a41.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a39.dist-info → pycityagent-2.0.0a41.dist-info}/RECORD +8 -8
- {pycityagent-2.0.0a39.dist-info → pycityagent-2.0.0a41.dist-info}/WHEEL +0 -0
    
        pycityagent/agent.py
    CHANGED
    
    | @@ -464,7 +464,7 @@ class Agent(ABC): | |
| 464 464 | 
             
                    if self._messager is None:
         | 
| 465 465 | 
             
                        raise RuntimeError("Messager is not set")
         | 
| 466 466 | 
             
                    topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
         | 
| 467 | 
            -
                    await self._messager.send_message(topic, payload)
         | 
| 467 | 
            +
                    await self._messager.send_message.remote(topic, payload)
         | 
| 468 468 |  | 
| 469 469 | 
             
                async def send_message_to_agent(
         | 
| 470 470 | 
             
                    self, to_agent_uuid: str, content: str, type: str = "social"
         | 
    
        pycityagent/message/messager.py
    CHANGED
    
    | @@ -4,32 +4,37 @@ import logging | |
| 4 4 | 
             
            import math
         | 
| 5 5 | 
             
            from typing import Any, List, Union
         | 
| 6 6 | 
             
            from aiomqtt import Client
         | 
| 7 | 
            +
            import ray
         | 
| 7 8 |  | 
| 8 9 | 
             
            logger = logging.getLogger("pycityagent")
         | 
| 9 10 |  | 
| 11 | 
            +
            @ray.remote
         | 
| 10 12 | 
             
            class Messager:
         | 
| 11 13 | 
             
                def __init__(
         | 
| 12 | 
            -
                    self, hostname:str, port:int=1883, username=None, password=None, timeout= | 
| 14 | 
            +
                    self, hostname:str, port:int=1883, username=None, password=None, timeout=60
         | 
| 13 15 | 
             
                ):
         | 
| 14 16 | 
             
                    self.client = Client(
         | 
| 15 17 | 
             
                        hostname, port=port, username=username, password=password, timeout=timeout
         | 
| 16 18 | 
             
                    )
         | 
| 17 19 | 
             
                    self.connected = False  # 是否已连接标志
         | 
| 18 20 | 
             
                    self.message_queue = asyncio.Queue()  # 用于存储接收到的消息
         | 
| 19 | 
            -
                    self.subscribers = {}  # 订阅者信息,topic -> Agent 映射
         | 
| 20 21 | 
             
                    self.receive_messages_task = None
         | 
| 21 22 |  | 
| 22 23 | 
             
                async def __aexit__(self, exc_type, exc_value, traceback):
         | 
| 23 24 | 
             
                    await self.stop()
         | 
| 24 25 |  | 
| 25 26 | 
             
                async def connect(self):
         | 
| 26 | 
            -
                     | 
| 27 | 
            -
                         | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
                         | 
| 27 | 
            +
                    for i in range(3):
         | 
| 28 | 
            +
                        try:
         | 
| 29 | 
            +
                            await self.client.__aenter__()
         | 
| 30 | 
            +
                            self.connected = True
         | 
| 31 | 
            +
                            logger.info("Connected to MQTT Broker")
         | 
| 32 | 
            +
                            return
         | 
| 33 | 
            +
                        except Exception as e:
         | 
| 34 | 
            +
                            logger.error(f"Attempt {i+1}: Failed to connect to MQTT Broker: {e}")
         | 
| 35 | 
            +
                            await asyncio.sleep(10)
         | 
| 36 | 
            +
                    self.connected = False
         | 
| 37 | 
            +
                    logger.error("All connection attempts failed.")
         | 
| 33 38 |  | 
| 34 39 | 
             
                async def disconnect(self):
         | 
| 35 40 | 
             
                    await self.client.__aexit__(None, None, None)
         | 
| @@ -50,9 +55,6 @@ class Messager: | |
| 50 55 | 
             
                        topics = [topics]
         | 
| 51 56 | 
             
                    if not isinstance(agents, list):
         | 
| 52 57 | 
             
                        agents = [agents]
         | 
| 53 | 
            -
                    for topic, agent in zip(topics, agents):
         | 
| 54 | 
            -
                        self.subscribers[topic] = agent
         | 
| 55 | 
            -
                        logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
         | 
| 56 58 | 
             
                    await self.client.subscribe(topics, qos=1)
         | 
| 57 59 |  | 
| 58 60 | 
             
                async def receive_messages(self):
         | 
| @@ -63,7 +63,7 @@ class AgentGroup: | |
| 63 63 | 
             
                    if self.enable_pgsql:
         | 
| 64 64 | 
             
                        pass
         | 
| 65 65 |  | 
| 66 | 
            -
                    self.messager = Messager(
         | 
| 66 | 
            +
                    self.messager = Messager.remote(
         | 
| 67 67 | 
             
                        hostname=config["simulator_request"]["mqtt"]["server"],
         | 
| 68 68 | 
             
                        port=config["simulator_request"]["mqtt"]["port"],
         | 
| 69 69 | 
             
                        username=config["simulator_request"]["mqtt"].get("username", None),
         | 
| @@ -144,17 +144,17 @@ class AgentGroup: | |
| 144 144 | 
             
                        await agent.bind_to_simulator()  # type: ignore
         | 
| 145 145 | 
             
                    self.id2agent = {agent._uuid: agent for agent in self.agents}
         | 
| 146 146 | 
             
                    logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
         | 
| 147 | 
            -
                    await self.messager.connect()
         | 
| 148 | 
            -
                    if self.messager.is_connected():
         | 
| 149 | 
            -
                        await self.messager.start_listening()
         | 
| 147 | 
            +
                    await self.messager.connect.remote()
         | 
| 148 | 
            +
                    if ray.get(self.messager.is_connected.remote()):
         | 
| 149 | 
            +
                        await self.messager.start_listening.remote()
         | 
| 150 150 | 
             
                        topics = []
         | 
| 151 151 | 
             
                        agents = []
         | 
| 152 152 | 
             
                        for agent in self.agents:
         | 
| 153 153 | 
             
                            agent.set_messager(self.messager)
         | 
| 154 154 | 
             
                            topic = (f"exps/{self.exp_id}/agents/{agent._uuid}/#", 1)
         | 
| 155 155 | 
             
                            topics.append(topic)
         | 
| 156 | 
            -
                            agents.append(agent)
         | 
| 157 | 
            -
                        await self.messager.subscribe(topics, agents)
         | 
| 156 | 
            +
                            agents.append(agent.uuid)
         | 
| 157 | 
            +
                        await self.messager.subscribe.remote(topics, agents)
         | 
| 158 158 | 
             
                    self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
         | 
| 159 159 | 
             
                    if self.enable_avro:
         | 
| 160 160 | 
             
                        logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
         | 
| @@ -253,14 +253,14 @@ class AgentGroup: | |
| 253 253 | 
             
                async def message_dispatch(self):
         | 
| 254 254 | 
             
                    logger.debug(f"-----Starting message dispatch for group {self._uuid}")
         | 
| 255 255 | 
             
                    while True:
         | 
| 256 | 
            -
                        if not self.messager.is_connected():
         | 
| 256 | 
            +
                        if not ray.get(self.messager.is_connected.remote()):
         | 
| 257 257 | 
             
                            logger.warning(
         | 
| 258 258 | 
             
                                "Messager is not connected. Skipping message processing."
         | 
| 259 259 | 
             
                            )
         | 
| 260 260 | 
             
                            break
         | 
| 261 261 |  | 
| 262 262 | 
             
                        # Step 1: 获取消息
         | 
| 263 | 
            -
                        messages = await self.messager.fetch_messages()
         | 
| 263 | 
            +
                        messages = await self.messager.fetch_messages.remote()
         | 
| 264 264 | 
             
                        logger.info(f"Group {self._uuid} received {len(messages)} messages")
         | 
| 265 265 |  | 
| 266 266 | 
             
                        # Step 2: 分发消息到对应的 Agent
         | 
| @@ -69,7 +69,7 @@ class AgentSimulation: | |
| 69 69 | 
             
                    self._loop = asyncio.get_event_loop()
         | 
| 70 70 | 
             
                    # self._last_asyncio_pg_task = None  # 将SQL写入的IO隐藏到计算任务后
         | 
| 71 71 |  | 
| 72 | 
            -
                    self._messager = Messager(
         | 
| 72 | 
            +
                    self._messager = Messager.remote(
         | 
| 73 73 | 
             
                        hostname=config["simulator_request"]["mqtt"]["server"],
         | 
| 74 74 | 
             
                        port=config["simulator_request"]["mqtt"]["port"],
         | 
| 75 75 | 
             
                        username=config["simulator_request"]["mqtt"].get("username", None),
         | 
| @@ -206,7 +206,7 @@ class AgentSimulation: | |
| 206 206 | 
             
                        group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
         | 
| 207 207 | 
             
                        memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
         | 
| 208 208 | 
             
                    """
         | 
| 209 | 
            -
                    await self._messager.connect()
         | 
| 209 | 
            +
                    await self._messager.connect.remote()
         | 
| 210 210 | 
             
                    if not isinstance(agent_count, list):
         | 
| 211 211 | 
             
                        agent_count = [agent_count]
         | 
| 212 212 |  | 
| @@ -499,7 +499,7 @@ class AgentSimulation: | |
| 499 499 | 
             
                    }
         | 
| 500 500 | 
             
                    for uuid in agent_uuids:
         | 
| 501 501 | 
             
                        topic = self._user_survey_topics[uuid]
         | 
| 502 | 
            -
                        await self._messager.send_message(topic, payload)
         | 
| 502 | 
            +
                        await self._messager.send_message.remote(topic, payload)
         | 
| 503 503 |  | 
| 504 504 | 
             
                async def send_interview_message(
         | 
| 505 505 | 
             
                    self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
         | 
| @@ -516,7 +516,7 @@ class AgentSimulation: | |
| 516 516 | 
             
                        agent_uuids = [agent_uuids]
         | 
| 517 517 | 
             
                    for uuid in agent_uuids:
         | 
| 518 518 | 
             
                        topic = self._user_chat_topics[uuid]
         | 
| 519 | 
            -
                        await self._messager.send_message(topic, payload)
         | 
| 519 | 
            +
                        await self._messager.send_message.remote(topic, payload)
         | 
| 520 520 |  | 
| 521 521 | 
             
                async def step(self):
         | 
| 522 522 | 
             
                    """运行一步, 即每个智能体执行一次forward"""
         | 
| @@ -139,7 +139,12 @@ class PgWriter: | |
| 139 139 | 
             
                            exec_str = "SELECT * FROM {table_name} WHERE id=%s".format(
         | 
| 140 140 | 
             
                                table_name=table_name
         | 
| 141 141 | 
             
                            ), (self.exp_id,)
         | 
| 142 | 
            -
                            await cur.execute( | 
| 142 | 
            +
                            await cur.execute(
         | 
| 143 | 
            +
                                "SELECT * FROM {table_name} WHERE id=%s".format(
         | 
| 144 | 
            +
                                    table_name=table_name
         | 
| 145 | 
            +
                                ),
         | 
| 146 | 
            +
                                (self.exp_id,),
         | 
| 147 | 
            +
                            )  # type:ignore
         | 
| 143 148 | 
             
                            logger.debug(f"table:{table_name} sql: {exec_str}")
         | 
| 144 149 | 
             
                            record_exists = await cur.fetchall()
         | 
| 145 150 | 
             
                            if record_exists:
         | 
| @@ -1,5 +1,5 @@ | |
| 1 1 | 
             
            pycityagent/__init__.py,sha256=fv0mzNGbHBF6m550yYqnuUpB8iQPWS-7EatYRK7DO4s,693
         | 
| 2 | 
            -
            pycityagent/agent.py,sha256= | 
| 2 | 
            +
            pycityagent/agent.py,sha256=KO8yJiVOZTWlYAd57gLEgbSNzQ26cw27k1NT8pAFHNY,29917
         | 
| 3 3 | 
             
            pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
         | 
| 4 4 | 
             
            pycityagent/economy/econ_client.py,sha256=GuHK9ZBnhqW3Z7F8ViDJn_iN73yOBbbwFyJv1wLEBDk,12211
         | 
| 5 5 | 
             
            pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
         | 
| @@ -45,14 +45,14 @@ pycityagent/memory/self_define.py,sha256=vpZ6CIxR2grNXEIOScdpsSc59FBg0mOKelwQuTE | |
| 45 45 | 
             
            pycityagent/memory/state.py,sha256=TYItiyDtehMEQaSBN7PpNrnNxdDM5jGppr9R9Ufv3kA,5134
         | 
| 46 46 | 
             
            pycityagent/memory/utils.py,sha256=oJWLdPeJy_jcdKcDTo9JAH9kDZhqjoQhhv_zT9qWC0w,877
         | 
| 47 47 | 
             
            pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
         | 
| 48 | 
            -
            pycityagent/message/messager.py,sha256= | 
| 48 | 
            +
            pycityagent/message/messager.py,sha256=ePu1LDZZBDMhxoVoX4-LGTcizCVsOuw4T1GRBXHdM7E,3125
         | 
| 49 49 | 
             
            pycityagent/metrics/__init__.py,sha256=X08PaBbGVAd7_PRGLREXWxaqm7nS82WBQpD1zvQzcqc,128
         | 
| 50 50 | 
             
            pycityagent/metrics/mlflow_client.py,sha256=g_tHxWkWTDijtbGL74-HmiYzWVKb1y8-w12QrY9jL30,4449
         | 
| 51 51 | 
             
            pycityagent/metrics/utils/const.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 52 52 | 
             
            pycityagent/simulation/__init__.py,sha256=P5czbcg2d8S0nbbnsQXFIhwzO4CennAhZM8OmKvAeYw,194
         | 
| 53 | 
            -
            pycityagent/simulation/agentgroup.py,sha256= | 
| 54 | 
            -
            pycityagent/simulation/simulation.py,sha256= | 
| 55 | 
            -
            pycityagent/simulation/storage/pg.py,sha256= | 
| 53 | 
            +
            pycityagent/simulation/agentgroup.py,sha256=3UIMbA-CpoHxtxkFWkzEpI6LMclD3CwvxS2WjpqjaWE,25356
         | 
| 54 | 
            +
            pycityagent/simulation/simulation.py,sha256=LGWcooTa5n3YVUeChv3M1T6keUJK1a4Ddo-g32WK1o0,23274
         | 
| 55 | 
            +
            pycityagent/simulation/storage/pg.py,sha256=jjdYvqKmDms3weqALqOO__gxWF-Z0YcqeD85XbP4Qks,8455
         | 
| 56 56 | 
             
            pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
         | 
| 57 57 | 
             
            pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
         | 
| 58 58 | 
             
            pycityagent/survey/models.py,sha256=YE50UUt5qJ0O_lIUsSY6XFCGUTkJVNu_L1gAhaCJ2fs,3546
         | 
| @@ -70,6 +70,6 @@ pycityagent/workflow/block.py,sha256=C2aWdVRffb3LknP955GvPcBMsm3VPXN9ZuAtCgITFTo | |
| 70 70 | 
             
            pycityagent/workflow/prompt.py,sha256=6jI0Rq54JLv3-IXqZLYug62vse10wTI83xvf4ZX42nk,2929
         | 
| 71 71 | 
             
            pycityagent/workflow/tool.py,sha256=xADxhNgVsjNiMxlhdwn3xGUstFOkLEG8P67ez8VmwSI,8555
         | 
| 72 72 | 
             
            pycityagent/workflow/trigger.py,sha256=Df-MOBEDWBbM-v0dFLQLXteLsipymT4n8vqexmK2GiQ,5643
         | 
| 73 | 
            -
            pycityagent-2.0. | 
| 74 | 
            -
            pycityagent-2.0. | 
| 75 | 
            -
            pycityagent-2.0. | 
| 73 | 
            +
            pycityagent-2.0.0a41.dist-info/METADATA,sha256=jR-wJCsCj1qthzVzdEcXZOrFpG_v7TXyZ8Sgh-2yTHU,8046
         | 
| 74 | 
            +
            pycityagent-2.0.0a41.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
         | 
| 75 | 
            +
            pycityagent-2.0.0a41.dist-info/RECORD,,
         | 
| 
            File without changes
         |