pycityagent 2.0.0a34__py3-none-any.whl → 2.0.0a36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -687,6 +687,7 @@ class InstitutionAgent(Agent):
687
687
  # TODO: More general id generation
688
688
  _id = random.randint(100000, 999999)
689
689
  self._agent_id = _id
690
+ self.memory.set_agent_id(_id)
690
691
  await self.memory.update("id", _id, protect_llm_read_only_fields=False)
691
692
  try:
692
693
  await self._economy_client.remove_orgs([self._agent_id])
@@ -14,7 +14,6 @@ from pycityproto.city.person.v2 import person_pb2 as person_pb2
14
14
  from pycityproto.city.person.v2 import person_service_pb2 as person_service
15
15
  from pymongo import MongoClient
16
16
  from shapely.geometry import Point
17
- from shapely.strtree import STRtree
18
17
 
19
18
  from .sim import CityClient, ControlSimEnv
20
19
  from .utils.const import *
@@ -87,12 +86,6 @@ class Simulator:
87
86
  - Simulator map object
88
87
  """
89
88
 
90
- self.pois_matrix: dict[str, list[list[list]]] = {}
91
- """
92
- pois的基于区块的划分——方便快速粗略地查询poi
93
- 通过Simulator.set_pois_matrix()初始化
94
- """
95
-
96
89
  self.time: int = 0
97
90
  """
98
91
  - 模拟城市当前时间
@@ -102,12 +95,11 @@ class Simulator:
102
95
  self.map_x_gap = None
103
96
  self.map_y_gap = None
104
97
  self._bbox: tuple[float, float, float, float] = (-1, -1, -1, -1)
105
- self.poi_matrix_centers = []
106
98
  self._lock = asyncio.Lock()
107
99
  # poi id dict
108
- self.poi_id_2_aoi_id: dict[int, int] = {}
109
- # poi STRtree
110
- self.set_poi_tree()
100
+ self.poi_id_2_aoi_id: dict[int, int] = {
101
+ poi["id"]: poi["aoi_id"] for _, poi in self.map.pois.items()
102
+ }
111
103
 
112
104
  # * Agent相关
113
105
  def find_agents_by_area(self, req: dict, status=None):
@@ -137,35 +129,21 @@ class Simulator:
137
129
  resp.motions = motions # type: ignore
138
130
  return resp
139
131
 
140
- def set_poi_tree(
141
- self,
142
- ):
143
- """
144
- 初始化pois_tree
145
- """
146
- poi_geos = []
147
- tree_id_2_poi_and_catg: dict[int, tuple[dict, str]] = {}
148
- for tree_id, poi in enumerate(self.map.pois.values()):
149
- tree_id_2_poi_and_catg[tree_id] = (poi, poi["category"])
150
- poi_geos.append(Point([poi["position"][k] for k in ["x", "y"]]))
151
- self.poi_id_2_aoi_id[poi["id"]] = poi["aoi_id"]
152
- self.tree_id_2_poi_and_catg = tree_id_2_poi_and_catg
153
- self.pois_tree = STRtree(poi_geos)
154
-
155
132
  def get_poi_categories(
156
133
  self,
157
134
  center: Optional[Union[tuple[float, float], Point]] = None,
158
135
  radius: Optional[float] = None,
159
136
  ) -> list[str]:
160
- if center is not None and radius is not None:
161
- if not isinstance(center, Point):
162
- center = Point(center)
163
- indices = self.pois_tree.query(center.buffer(radius))
164
- else:
165
- indices = list(self.tree_id_2_poi_and_catg.keys())
166
- categories = []
167
- for index in indices:
168
- _, catg = self.tree_id_2_poi_and_catg[index]
137
+ categories: list[str] = []
138
+ if center is None:
139
+ center = (0, 0)
140
+ _pois: list[dict] = self.map.query_pois( # type:ignore
141
+ center=center,
142
+ radius=radius,
143
+ return_distance=False,
144
+ )
145
+ for poi in _pois:
146
+ catg = poi["category"]
169
147
  categories.append(catg.split("|")[-1])
170
148
  return list(set(categories))
171
149
 
@@ -327,7 +305,7 @@ class Simulator:
327
305
  center: Union[tuple[float, float], Point],
328
306
  radius: float,
329
307
  poi_type: Union[str, list[str]],
330
- ):
308
+ ) -> list[dict]:
331
309
  if isinstance(poi_type, str):
332
310
  poi_type = [poi_type]
333
311
  transformed_poi_type = []
@@ -337,14 +315,16 @@ class Simulator:
337
315
  else:
338
316
  transformed_poi_type += self.poi_cate[t]
339
317
  poi_type_set = set(transformed_poi_type)
340
- if not isinstance(center, Point):
341
- center = Point(center)
342
318
  # 获取半径内的poi
343
- indices = self.pois_tree.query(center.buffer(radius))
319
+ _pois: list[dict] = self.map.query_pois( # type:ignore
320
+ center=center,
321
+ radius=radius,
322
+ return_distance=False,
323
+ )
344
324
  # 过滤掉不满足类别前缀的poi
345
325
  pois = []
346
- for index in indices:
347
- poi, catg = self.tree_id_2_poi_and_catg[index]
326
+ for poi in _pois:
327
+ catg = poi["category"]
348
328
  if catg.split("|")[-1] not in poi_type_set:
349
329
  continue
350
330
  pois.append(poi)
@@ -1,8 +1,8 @@
1
1
  import asyncio
2
- from collections import defaultdict
3
2
  import json
4
3
  import logging
5
4
  import math
5
+ from typing import Any, List, Union
6
6
  from aiomqtt import Client
7
7
 
8
8
  logger = logging.getLogger("pycityagent")
@@ -17,6 +17,10 @@ class Messager:
17
17
  self.connected = False # 是否已连接标志
18
18
  self.message_queue = asyncio.Queue() # 用于存储接收到的消息
19
19
  self.subscribers = {} # 订阅者信息,topic -> Agent 映射
20
+ self.receive_messages_task = None
21
+
22
+ async def __aexit__(self, exc_type, exc_value, traceback):
23
+ await self.stop()
20
24
 
21
25
  async def connect(self):
22
26
  try:
@@ -36,15 +40,20 @@ class Messager:
36
40
  """检查是否成功连接到 Broker"""
37
41
  return self.connected
38
42
 
39
- async def subscribe(self, topic, agent):
43
+ async def subscribe(self, topics: Union[str, List[str]], agents: Union[Any, List[Any]]):
40
44
  if not self.is_connected():
41
45
  logger.error(
42
- f"Cannot subscribe to {topic} because not connected to the Broker."
46
+ f"Cannot subscribe to {topics} because not connected to the Broker."
43
47
  )
44
48
  return
45
- await self.client.subscribe(topic)
46
- self.subscribers[topic] = agent
47
- logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
49
+ if not isinstance(topics, list):
50
+ topics = [topics]
51
+ if not isinstance(agents, list):
52
+ agents = [agents]
53
+ for topic, agent in zip(topics, agents):
54
+ self.subscribers[topic] = agent
55
+ logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
56
+ await self.client.subscribe(topics, qos=1)
48
57
 
49
58
  async def receive_messages(self):
50
59
  """监听并将消息存入队列"""
@@ -61,12 +70,17 @@ class Messager:
61
70
  async def send_message(self, topic: str, payload: dict):
62
71
  """通过 Messager 发送消息"""
63
72
  message = json.dumps(payload, default=str)
64
- await self.client.publish(topic, message)
73
+ await self.client.publish(topic=topic, payload=message, qos=1)
65
74
  logger.info(f"Message sent to {topic}: {message}")
66
75
 
67
76
  async def start_listening(self):
68
77
  """启动消息监听任务"""
69
78
  if self.is_connected():
70
- asyncio.create_task(self.receive_messages())
79
+ self.receive_messages_task = asyncio.create_task(self.receive_messages())
71
80
  else:
72
81
  logger.error("Cannot start listening because not connected to the Broker.")
82
+
83
+ async def stop(self):
84
+ self.receive_messages_task.cancel()
85
+ await asyncio.gather(self.receive_messages_task, return_exceptions=True)
86
+ await self.disconnect()
@@ -69,6 +69,7 @@ class AgentGroup:
69
69
  username=config["simulator_request"]["mqtt"].get("username", None),
70
70
  password=config["simulator_request"]["mqtt"].get("password", None),
71
71
  )
72
+ self.message_dispatch_task = None
72
73
  self._pgsql_writer = pgsql_writer
73
74
  self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
74
75
  self.initialized = False
@@ -132,6 +133,10 @@ class AgentGroup:
132
133
  if self.embedding_model is not None:
133
134
  agent.memory.set_embedding_model(self.embedding_model)
134
135
 
136
+ async def __aexit__(self, exc_type, exc_value, traceback):
137
+ self.message_dispatch_task.cancel()
138
+ await asyncio.gather(self.message_dispatch_task, return_exceptions=True)
139
+
135
140
  async def init_agents(self):
136
141
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
137
142
  logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
@@ -142,16 +147,14 @@ class AgentGroup:
142
147
  await self.messager.connect()
143
148
  if self.messager.is_connected():
144
149
  await self.messager.start_listening()
150
+ topics = []
151
+ agents = []
145
152
  for agent in self.agents:
146
153
  agent.set_messager(self.messager)
147
- topic = f"exps/{self.exp_id}/agents/{agent._uuid}/agent-chat"
148
- await self.messager.subscribe(topic, agent)
149
- topic = f"exps/{self.exp_id}/agents/{agent._uuid}/user-chat"
150
- await self.messager.subscribe(topic, agent)
151
- topic = f"exps/{self.exp_id}/agents/{agent._uuid}/user-survey"
152
- await self.messager.subscribe(topic, agent)
153
- topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
154
- await self.messager.subscribe(topic, agent)
154
+ topic = (f"exps/{self.exp_id}/agents/{agent._uuid}/#", 1)
155
+ topics.append(topic)
156
+ agents.append(agent)
157
+ await self.messager.subscribe(topics, agents)
155
158
  self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
156
159
  if self.enable_avro:
157
160
  logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
@@ -225,6 +228,7 @@ class AgentGroup:
225
228
  logger.warning(
226
229
  "Messager is not connected. Skipping message processing."
227
230
  )
231
+ break
228
232
 
229
233
  # Step 1: 获取消息
230
234
  messages = await self.messager.fetch_messages()
@@ -74,7 +74,6 @@ class AgentSimulation:
74
74
  username=config["simulator_request"]["mqtt"].get("username", None),
75
75
  password=config["simulator_request"]["mqtt"].get("password", None),
76
76
  )
77
- asyncio.create_task(self._messager.connect())
78
77
 
79
78
  # storage
80
79
  _storage_config: dict[str, Any] = config.get("storage", {})
@@ -202,6 +201,7 @@ class AgentSimulation:
202
201
  group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
203
202
  memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
204
203
  """
204
+ await self._messager.connect()
205
205
  if not isinstance(agent_count, list):
206
206
  agent_count = [agent_count]
207
207
 
@@ -535,13 +535,6 @@ class AgentSimulation:
535
535
  try:
536
536
  if self.enable_pgsql:
537
537
  worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
538
- # if self._last_asyncio_pg_task is not None:
539
- # await self._last_asyncio_pg_task
540
- # self._last_asyncio_pg_task = (
541
- # worker.async_update_exp_info.remote( # type:ignore
542
- # pg_exp_info
543
- # )
544
- # )
545
538
  pg_exp_info = {k: v for k, v in self._exp_info.items()}
546
539
  pg_exp_info["created_at"] = self._exp_created_time
547
540
  pg_exp_info["updated_at"] = self._exp_updated_time
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pycityagent
3
- Version: 2.0.0a34
3
+ Version: 2.0.0a36
4
4
  Summary: LLM-based城市环境agent构建库
5
5
  License: MIT
6
6
  Author: Yuwei Yan
@@ -29,7 +29,7 @@ Requires-Dist: langchain-community (>=0.3.13,<0.4.0)
29
29
  Requires-Dist: langchain-core (>=0.3.28,<0.4.0)
30
30
  Requires-Dist: matplotlib (==3.8.3)
31
31
  Requires-Dist: mlflow (>=2.19.0,<3.0.0)
32
- Requires-Dist: mosstool (==1.0.24)
32
+ Requires-Dist: mosstool (>=1.3.0,<2.0.0)
33
33
  Requires-Dist: networkx (==3.2.1)
34
34
  Requires-Dist: numpy (>=1.20.0,<2.0.0)
35
35
  Requires-Dist: openai (>=1.58.1,<2.0.0)
@@ -37,7 +37,7 @@ Requires-Dist: pandavro (>=1.8.0,<2.0.0)
37
37
  Requires-Dist: poetry (>=1.2.2)
38
38
  Requires-Dist: protobuf (<=4.24.0)
39
39
  Requires-Dist: psycopg[binary] (>=3.2.3,<4.0.0)
40
- Requires-Dist: pycitydata (==1.0.0)
40
+ Requires-Dist: pycitydata (>=1.0.3,<2.0.0)
41
41
  Requires-Dist: pycityproto (>=2.1.5,<3.0.0)
42
42
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
43
43
  Requires-Dist: ray (>=2.40.0,<3.0.0)
@@ -1,5 +1,5 @@
1
1
  pycityagent/__init__.py,sha256=fv0mzNGbHBF6m550yYqnuUpB8iQPWS-7EatYRK7DO4s,693
2
- pycityagent/agent.py,sha256=l8Oa95_K5JBWKzvZmbQe_QM_E_vaG-YstuuR55kgC6Y,29005
2
+ pycityagent/agent.py,sha256=r1uG4ib6fAmqzZXqGUiwswYdIOS8ybTFdbVqJWjhwuI,29047
3
3
  pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
4
4
  pycityagent/economy/econ_client.py,sha256=GuHK9ZBnhqW3Z7F8ViDJn_iN73yOBbbwFyJv1wLEBDk,12211
5
5
  pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
@@ -21,7 +21,7 @@ pycityagent/environment/sim/person_service.py,sha256=5r1F2Itn7dKJ2U4hSLovrk5p4qy
21
21
  pycityagent/environment/sim/road_service.py,sha256=bKyn3_me0sGmaJVyF6eNeFbdU-9C1yWsa9L7pieDJzg,1285
22
22
  pycityagent/environment/sim/sim_env.py,sha256=HI1LcS_FotDKQ6vBnx0e49prXSABOfA20aU9KM-ZkCY,4625
23
23
  pycityagent/environment/sim/social_service.py,sha256=9EFJAwVdUuUQkNkFRn9qZRDfD1brh2fqkvasnXUEBhQ,2014
24
- pycityagent/environment/simulator.py,sha256=KVfwSwVGXPqUHQGyD9jv_RXRgGal2k7NloUFVdmWE8I,12943
24
+ pycityagent/environment/simulator.py,sha256=1OUfODDzM4EN6Lw_Wzq4KeQb-EpcUBioZYc9fxfSPn0,12070
25
25
  pycityagent/environment/utils/__init__.py,sha256=1m4Q1EfGvNpUsa1bgQzzCyWhfkpElnskNImjjFD3Znc,237
26
26
  pycityagent/environment/utils/base64.py,sha256=hoREzQo3FXMN79pqQLO2jgsDEvudciomyKii7MWljAM,374
27
27
  pycityagent/environment/utils/const.py,sha256=1LqxnYJ8FSmq37fN5kIFlWLwycEDzFa8SFS-8plrFlU,5396
@@ -45,13 +45,13 @@ pycityagent/memory/self_define.py,sha256=vpZ6CIxR2grNXEIOScdpsSc59FBg0mOKelwQuTE
45
45
  pycityagent/memory/state.py,sha256=TYItiyDtehMEQaSBN7PpNrnNxdDM5jGppr9R9Ufv3kA,5134
46
46
  pycityagent/memory/utils.py,sha256=oJWLdPeJy_jcdKcDTo9JAH9kDZhqjoQhhv_zT9qWC0w,877
47
47
  pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
48
- pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSncfU,2541
48
+ pycityagent/message/messager.py,sha256=gz-EZOGakgwQH8ZKabAGr1pT43E8B9z-s4dbNx-mxr4,3167
49
49
  pycityagent/metrics/__init__.py,sha256=X08PaBbGVAd7_PRGLREXWxaqm7nS82WBQpD1zvQzcqc,128
50
50
  pycityagent/metrics/mlflow_client.py,sha256=g_tHxWkWTDijtbGL74-HmiYzWVKb1y8-w12QrY9jL30,4449
51
51
  pycityagent/metrics/utils/const.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
52
52
  pycityagent/simulation/__init__.py,sha256=P5czbcg2d8S0nbbnsQXFIhwzO4CennAhZM8OmKvAeYw,194
53
- pycityagent/simulation/agentgroup.py,sha256=QwVbgqKYp42_wRS3K6C6E7Aq8cYBacabHmXIKRaxYyw,23955
54
- pycityagent/simulation/simulation.py,sha256=9kkdgXSEOAN8wiewVFyORksti4IdVNU0opObV6ZYa9k,23344
53
+ pycityagent/simulation/agentgroup.py,sha256=mZXznSf7VEHNjU0KM5TuPB-Z09DqMZ6VeQib7ciTLpo,23914
54
+ pycityagent/simulation/simulation.py,sha256=OVflF_Z_kZfE6c4z_6tHTGD53jq89Hewg8Yc5KflRI0,23008
55
55
  pycityagent/simulation/storage/pg.py,sha256=qGrYzJIAzjv8-d3-cle0rY0AN6XB6MgnHkFLBoLmKWU,7251
56
56
  pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
57
57
  pycityagent/survey/manager.py,sha256=S5IkwTdelsdtZETChRcfCEczzwSrry_Fly9MY4s3rbk,1681
@@ -70,6 +70,6 @@ pycityagent/workflow/block.py,sha256=C2aWdVRffb3LknP955GvPcBMsm3VPXN9ZuAtCgITFTo
70
70
  pycityagent/workflow/prompt.py,sha256=6jI0Rq54JLv3-IXqZLYug62vse10wTI83xvf4ZX42nk,2929
71
71
  pycityagent/workflow/tool.py,sha256=xADxhNgVsjNiMxlhdwn3xGUstFOkLEG8P67ez8VmwSI,8555
72
72
  pycityagent/workflow/trigger.py,sha256=Df-MOBEDWBbM-v0dFLQLXteLsipymT4n8vqexmK2GiQ,5643
73
- pycityagent-2.0.0a34.dist-info/METADATA,sha256=za2dplTaxTwxYovvgF_HioNdjMXvmpTdg6v4Nx8Q2vI,8033
74
- pycityagent-2.0.0a34.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
75
- pycityagent-2.0.0a34.dist-info/RECORD,,
73
+ pycityagent-2.0.0a36.dist-info/METADATA,sha256=aJH2tcaEccrMdJRwjgNmRS8bYla-E_TTMh1bgevmHEM,8046
74
+ pycityagent-2.0.0a36.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
75
+ pycityagent-2.0.0a36.dist-info/RECORD,,