pycityagent 2.0.0a13__py3-none-any.whl → 2.0.0a15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,6 +5,7 @@ import logging
5
5
  import math
6
6
  from aiomqtt import Client
7
7
 
8
+ logger = logging.getLogger("pycityagent")
8
9
 
9
10
  class Messager:
10
11
  def __init__(
@@ -21,15 +22,15 @@ class Messager:
21
22
  try:
22
23
  await self.client.__aenter__()
23
24
  self.connected = True
24
- logging.info("Connected to MQTT Broker")
25
+ logger.info("Connected to MQTT Broker")
25
26
  except Exception as e:
26
27
  self.connected = False
27
- logging.error(f"Failed to connect to MQTT Broker: {e}")
28
+ logger.error(f"Failed to connect to MQTT Broker: {e}")
28
29
 
29
30
  async def disconnect(self):
30
31
  await self.client.__aexit__(None, None, None)
31
32
  self.connected = False
32
- logging.info("Disconnected from MQTT Broker")
33
+ logger.info("Disconnected from MQTT Broker")
33
34
 
34
35
  def is_connected(self):
35
36
  """检查是否成功连接到 Broker"""
@@ -37,13 +38,13 @@ class Messager:
37
38
 
38
39
  async def subscribe(self, topic, agent):
39
40
  if not self.is_connected():
40
- logging.error(
41
+ logger.error(
41
42
  f"Cannot subscribe to {topic} because not connected to the Broker."
42
43
  )
43
44
  return
44
45
  await self.client.subscribe(topic)
45
46
  self.subscribers[topic] = agent
46
- logging.info(f"Subscribed to {topic} for Agent {agent._uuid}")
47
+ logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
47
48
 
48
49
  async def receive_messages(self):
49
50
  """监听并将消息存入队列"""
@@ -61,11 +62,11 @@ class Messager:
61
62
  """通过 Messager 发送消息"""
62
63
  message = json.dumps(payload, default=str)
63
64
  await self.client.publish(topic, message)
64
- logging.info(f"Message sent to {topic}: {message}")
65
+ logger.info(f"Message sent to {topic}: {message}")
65
66
 
66
67
  async def start_listening(self):
67
68
  """启动消息监听任务"""
68
69
  if self.is_connected():
69
70
  asyncio.create_task(self.receive_messages())
70
71
  else:
71
- logging.error("Cannot start listening because not connected to the Broker.")
72
+ logger.error("Cannot start listening because not connected to the Broker.")
@@ -1,24 +1,42 @@
1
1
  import asyncio
2
+ from datetime import datetime
2
3
  import json
3
4
  import logging
5
+ from pathlib import Path
4
6
  import uuid
7
+ import fastavro
5
8
  import ray
6
9
  from uuid import UUID
7
- from pycityagent.agent import Agent
10
+ from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
8
11
  from pycityagent.economy.econ_client import EconomyClient
9
12
  from pycityagent.environment.simulator import Simulator
10
13
  from pycityagent.llm.llm import LLM
11
14
  from pycityagent.llm.llmconfig import LLMConfig
12
15
  from pycityagent.message import Messager
16
+ from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
13
17
  from typing import Any
14
18
 
19
+ logger = logging.getLogger("pycityagent")
20
+
15
21
  @ray.remote
16
22
  class AgentGroup:
17
- def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID):
23
+ def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
24
+ logger.setLevel(logging_level)
25
+ self._uuid = str(uuid.uuid4())
18
26
  self.agents = agents
19
27
  self.config = config
20
28
  self.exp_id = exp_id
21
- self._uuid = uuid.uuid4()
29
+ self.enable_avro = enable_avro
30
+ self.avro_path = avro_path / f"{self._uuid}"
31
+ if enable_avro:
32
+ self.avro_path.mkdir(parents=True, exist_ok=True)
33
+ self.avro_file = {
34
+ "profile": self.avro_path / f"profile.avro",
35
+ "dialog": self.avro_path / f"dialog.avro",
36
+ "status": self.avro_path / f"status.avro",
37
+ "survey": self.avro_path / f"survey.avro",
38
+ }
39
+
22
40
  self.messager = Messager(
23
41
  hostname=config["simulator_request"]["mqtt"]["server"],
24
42
  port=config["simulator_request"]["mqtt"]["port"],
@@ -29,16 +47,16 @@ class AgentGroup:
29
47
  self.id2agent = {}
30
48
  # Step:1 prepare LLM client
31
49
  llmConfig = LLMConfig(config["llm_request"])
32
- logging.info("-----Creating LLM client in remote...")
50
+ logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
33
51
  self.llm = LLM(llmConfig)
34
52
 
35
53
  # Step:2 prepare Simulator
36
- logging.info("-----Creating Simulator in remote...")
54
+ logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
37
55
  self.simulator = Simulator(config["simulator_request"])
38
56
 
39
57
  # Step:3 prepare Economy client
40
58
  if "economy" in config["simulator_request"]:
41
- logging.info("-----Creating Economy client in remote...")
59
+ logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
42
60
  self.economy_client = EconomyClient(
43
61
  config["simulator_request"]["economy"]["server"]
44
62
  )
@@ -52,11 +70,16 @@ class AgentGroup:
52
70
  if self.economy_client is not None:
53
71
  agent.set_economy_client(self.economy_client)
54
72
  agent.set_messager(self.messager)
73
+ if self.enable_avro:
74
+ agent.set_avro_file(self.avro_file)
55
75
 
56
76
  async def init_agents(self):
77
+ logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
78
+ logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
57
79
  for agent in self.agents:
58
80
  await agent.bind_to_simulator()
59
81
  self.id2agent = {agent._uuid: agent for agent in self.agents}
82
+ logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
60
83
  await self.messager.connect()
61
84
  if self.messager.is_connected():
62
85
  await self.messager.start_listening()
@@ -70,27 +93,65 @@ class AgentGroup:
70
93
  await self.messager.subscribe(topic, agent)
71
94
  topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
72
95
  await self.messager.subscribe(topic, agent)
73
- self.initialized = True
74
96
  self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
75
-
97
+ if self.enable_avro:
98
+ logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
99
+ # profile
100
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
101
+ filename = self.avro_file["profile"]
102
+ with open(filename, "wb") as f:
103
+ profiles = []
104
+ for agent in self.agents:
105
+ profile = await agent.memory._profile.export()
106
+ profile = profile[0]
107
+ profile['id'] = agent._uuid
108
+ profiles.append(profile)
109
+ fastavro.writer(f, PROFILE_SCHEMA, profiles)
110
+
111
+ # dialog
112
+ filename = self.avro_file["dialog"]
113
+ with open(filename, "wb") as f:
114
+ dialogs = []
115
+ fastavro.writer(f, DIALOG_SCHEMA, dialogs)
116
+
117
+ # status
118
+ filename = self.avro_file["status"]
119
+ with open(filename, "wb") as f:
120
+ statuses = []
121
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
122
+ fastavro.writer(f, STATUS_SCHEMA, statuses)
123
+ else:
124
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
125
+
126
+ # survey
127
+ filename = self.avro_file["survey"]
128
+ with open(filename, "wb") as f:
129
+ surveys = []
130
+ fastavro.writer(f, SURVEY_SCHEMA, surveys)
131
+ self.initialized = True
132
+ logger.debug(f"-----AgentGroup {self._uuid} initialized")
133
+
76
134
  async def gather(self, content: str):
135
+ logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
77
136
  results = {}
78
137
  for agent in self.agents:
79
138
  results[agent._uuid] = await agent.memory.get(content)
80
139
  return results
81
140
 
82
141
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
142
+ logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
83
143
  agent = self.id2agent[target_agent_uuid]
84
144
  await agent.memory.update(target_key, content)
85
145
 
86
146
  async def message_dispatch(self):
147
+ logger.debug(f"-----Starting message dispatch for group {self._uuid}")
87
148
  while True:
88
149
  if not self.messager.is_connected():
89
- logging.warning("Messager is not connected. Skipping message processing.")
150
+ logger.warning("Messager is not connected. Skipping message processing.")
90
151
 
91
152
  # Step 1: 获取消息
92
153
  messages = await self.messager.fetch_messages()
93
- logging.info(f"Group {self._uuid} received {len(messages)} messages")
154
+ logger.info(f"Group {self._uuid} received {len(messages)} messages")
94
155
 
95
156
  # Step 2: 分发消息到对应的 Agent
96
157
  for message in messages:
@@ -105,8 +166,8 @@ class AgentGroup:
105
166
  # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
106
167
  _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
107
168
 
108
- if uuid.UUID(agent_uuid) in self.id2agent:
109
- agent = self.id2agent[uuid.UUID(agent_uuid)]
169
+ if agent_uuid in self.id2agent:
170
+ agent = self.id2agent[agent_uuid]
110
171
  # topic_type: agent-chat, user-chat, user-survey, gather
111
172
  if topic_type == "agent-chat":
112
173
  await agent.handle_agent_chat_message(payload)
@@ -117,7 +178,67 @@ class AgentGroup:
117
178
  elif topic_type == "gather":
118
179
  await agent.handle_gather_message(payload)
119
180
 
120
- await asyncio.sleep(1)
181
+ await asyncio.sleep(0.5)
182
+
183
+ async def save_status(self):
184
+ if self.enable_avro:
185
+ logger.debug(f"-----Saving status for group {self._uuid}")
186
+ avros = []
187
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
188
+ for agent in self.agents:
189
+ position = await agent.memory.get("position")
190
+ lng = position["longlat_position"]["longitude"]
191
+ lat = position["longlat_position"]["latitude"]
192
+ if "aoi_position" in position:
193
+ parent_id = position["aoi_position"]["aoi_id"]
194
+ elif "lane_position" in position:
195
+ parent_id = position["lane_position"]["lane_id"]
196
+ else:
197
+ # BUG: 需要处理
198
+ parent_id = -1
199
+ needs = await agent.memory.get("needs")
200
+ action = await agent.memory.get("current_step")
201
+ action = action["intention"]
202
+ avro = {
203
+ "id": agent._uuid,
204
+ "day": await self.simulator.get_simulator_day(),
205
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
206
+ "lng": lng,
207
+ "lat": lat,
208
+ "parent_id": parent_id,
209
+ "action": action,
210
+ "hungry": needs["hungry"],
211
+ "tired": needs["tired"],
212
+ "safe": needs["safe"],
213
+ "social": needs["social"],
214
+ "created_at": int(datetime.now().timestamp() * 1000),
215
+ }
216
+ avros.append(avro)
217
+ with open(self.avro_file["status"], "a+b") as f:
218
+ fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
219
+ else:
220
+ for agent in self.agents:
221
+ avro = {
222
+ "id": agent._uuid,
223
+ "day": await self.simulator.get_simulator_day(),
224
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
225
+ "type": await agent.memory.get("type"),
226
+ "nominal_gdp": await agent.memory.get("nominal_gdp"),
227
+ "real_gdp": await agent.memory.get("real_gdp"),
228
+ "unemployment": await agent.memory.get("unemployment"),
229
+ "wages": await agent.memory.get("wages"),
230
+ "prices": await agent.memory.get("prices"),
231
+ "inventory": await agent.memory.get("inventory"),
232
+ "price": await agent.memory.get("price"),
233
+ "interest_rate": await agent.memory.get("interest_rate"),
234
+ "bracket_cutoffs": await agent.memory.get("bracket_cutoffs"),
235
+ "bracket_rates": await agent.memory.get("bracket_rates"),
236
+ "employees": await agent.memory.get("employees"),
237
+ "customers": await agent.memory.get("customers"),
238
+ }
239
+ avros.append(avro)
240
+ with open(self.avro_file["status"], "a+b") as f:
241
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
121
242
 
122
243
  async def step(self):
123
244
  if not self.initialized:
@@ -125,6 +246,7 @@ class AgentGroup:
125
246
 
126
247
  tasks = [agent.run() for agent in self.agents]
127
248
  await asyncio.gather(*tasks)
249
+ await self.save_status()
128
250
 
129
251
  async def run(self, day: int = 1):
130
252
  """运行模拟器
@@ -147,5 +269,5 @@ class AgentGroup:
147
269
  await self.step()
148
270
 
149
271
  except Exception as e:
150
- logging.error(f"模拟器运行错误: {str(e)}")
272
+ logger.error(f"模拟器运行错误: {str(e)}")
151
273
  raise