pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,27 +2,41 @@ import asyncio
2
2
  from datetime import datetime
3
3
  import json
4
4
  import logging
5
+ from pathlib import Path
5
6
  import uuid
6
7
  import fastavro
7
8
  import ray
8
9
  from uuid import UUID
9
- from pycityagent.agent import Agent, CitizenAgent
10
+ from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
10
11
  from pycityagent.economy.econ_client import EconomyClient
11
12
  from pycityagent.environment.simulator import Simulator
12
13
  from pycityagent.llm.llm import LLM
13
14
  from pycityagent.llm.llmconfig import LLMConfig
14
15
  from pycityagent.message import Messager
15
- from pycityagent.utils import STATUS_SCHEMA
16
+ from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
16
17
  from typing import Any
17
18
 
19
+ logger = logging.getLogger("pycityagent")
20
+
18
21
  @ray.remote
19
22
  class AgentGroup:
20
- def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, avro_file: dict):
23
+ def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int):
24
+ logger.setLevel(logging_level)
25
+ self._uuid = str(uuid.uuid4())
21
26
  self.agents = agents
22
27
  self.config = config
23
28
  self.exp_id = exp_id
24
- self.avro_file = avro_file
25
- self._uuid = uuid.uuid4()
29
+ self.enable_avro = enable_avro
30
+ self.avro_path = avro_path / f"{self._uuid}"
31
+ if enable_avro:
32
+ self.avro_path.mkdir(parents=True, exist_ok=True)
33
+ self.avro_file = {
34
+ "profile": self.avro_path / f"profile.avro",
35
+ "dialog": self.avro_path / f"dialog.avro",
36
+ "status": self.avro_path / f"status.avro",
37
+ "survey": self.avro_path / f"survey.avro",
38
+ }
39
+
26
40
  self.messager = Messager(
27
41
  hostname=config["simulator_request"]["mqtt"]["server"],
28
42
  port=config["simulator_request"]["mqtt"]["port"],
@@ -33,16 +47,16 @@ class AgentGroup:
33
47
  self.id2agent = {}
34
48
  # Step:1 prepare LLM client
35
49
  llmConfig = LLMConfig(config["llm_request"])
36
- logging.info("-----Creating LLM client in remote...")
50
+ logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
37
51
  self.llm = LLM(llmConfig)
38
52
 
39
53
  # Step:2 prepare Simulator
40
- logging.info("-----Creating Simulator in remote...")
54
+ logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
41
55
  self.simulator = Simulator(config["simulator_request"])
42
56
 
43
57
  # Step:3 prepare Economy client
44
58
  if "economy" in config["simulator_request"]:
45
- logging.info("-----Creating Economy client in remote...")
59
+ logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
46
60
  self.economy_client = EconomyClient(
47
61
  config["simulator_request"]["economy"]["server"]
48
62
  )
@@ -50,17 +64,22 @@ class AgentGroup:
50
64
  self.economy_client = None
51
65
 
52
66
  for agent in self.agents:
53
- agent.set_exp_id(self.exp_id)
67
+ agent.set_exp_id(self.exp_id) # type: ignore
54
68
  agent.set_llm_client(self.llm)
55
69
  agent.set_simulator(self.simulator)
56
70
  if self.economy_client is not None:
57
71
  agent.set_economy_client(self.economy_client)
58
72
  agent.set_messager(self.messager)
73
+ if self.enable_avro:
74
+ agent.set_avro_file(self.avro_file) # type: ignore
59
75
 
60
76
  async def init_agents(self):
77
+ logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
78
+ logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
61
79
  for agent in self.agents:
62
- await agent.bind_to_simulator()
80
+ await agent.bind_to_simulator() # type: ignore
63
81
  self.id2agent = {agent._uuid: agent for agent in self.agents}
82
+ logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
64
83
  await self.messager.connect()
65
84
  if self.messager.is_connected():
66
85
  await self.messager.start_listening()
@@ -74,27 +93,65 @@ class AgentGroup:
74
93
  await self.messager.subscribe(topic, agent)
75
94
  topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
76
95
  await self.messager.subscribe(topic, agent)
77
- self.initialized = True
78
96
  self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
79
-
97
+ if self.enable_avro:
98
+ logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
99
+ # profile
100
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
101
+ filename = self.avro_file["profile"]
102
+ with open(filename, "wb") as f:
103
+ profiles = []
104
+ for agent in self.agents:
105
+ profile = await agent.memory._profile.export()
106
+ profile = profile[0]
107
+ profile['id'] = agent._uuid
108
+ profiles.append(profile)
109
+ fastavro.writer(f, PROFILE_SCHEMA, profiles)
110
+
111
+ # dialog
112
+ filename = self.avro_file["dialog"]
113
+ with open(filename, "wb") as f:
114
+ dialogs = []
115
+ fastavro.writer(f, DIALOG_SCHEMA, dialogs)
116
+
117
+ # status
118
+ filename = self.avro_file["status"]
119
+ with open(filename, "wb") as f:
120
+ statuses = []
121
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
122
+ fastavro.writer(f, STATUS_SCHEMA, statuses)
123
+ else:
124
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
125
+
126
+ # survey
127
+ filename = self.avro_file["survey"]
128
+ with open(filename, "wb") as f:
129
+ surveys = []
130
+ fastavro.writer(f, SURVEY_SCHEMA, surveys)
131
+ self.initialized = True
132
+ logger.debug(f"-----AgentGroup {self._uuid} initialized")
133
+
80
134
  async def gather(self, content: str):
135
+ logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
81
136
  results = {}
82
137
  for agent in self.agents:
83
138
  results[agent._uuid] = await agent.memory.get(content)
84
139
  return results
85
140
 
86
141
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
142
+ logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
87
143
  agent = self.id2agent[target_agent_uuid]
88
144
  await agent.memory.update(target_key, content)
89
145
 
90
146
  async def message_dispatch(self):
147
+ logger.debug(f"-----Starting message dispatch for group {self._uuid}")
91
148
  while True:
92
149
  if not self.messager.is_connected():
93
- logging.warning("Messager is not connected. Skipping message processing.")
150
+ logger.warning("Messager is not connected. Skipping message processing.")
94
151
 
95
152
  # Step 1: 获取消息
96
153
  messages = await self.messager.fetch_messages()
97
- logging.info(f"Group {self._uuid} received {len(messages)} messages")
154
+ logger.info(f"Group {self._uuid} received {len(messages)} messages")
98
155
 
99
156
  # Step 2: 分发消息到对应的 Agent
100
157
  for message in messages:
@@ -109,8 +166,8 @@ class AgentGroup:
109
166
  # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
110
167
  _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
111
168
 
112
- if uuid.UUID(agent_uuid) in self.id2agent:
113
- agent = self.id2agent[uuid.UUID(agent_uuid)]
169
+ if agent_uuid in self.id2agent:
170
+ agent = self.id2agent[agent_uuid]
114
171
  # topic_type: agent-chat, user-chat, user-survey, gather
115
172
  if topic_type == "agent-chat":
116
173
  await agent.handle_agent_chat_message(payload)
@@ -123,46 +180,73 @@ class AgentGroup:
123
180
 
124
181
  await asyncio.sleep(0.5)
125
182
 
183
+ async def save_status(self):
184
+ if self.enable_avro:
185
+ logger.debug(f"-----Saving status for group {self._uuid}")
186
+ avros = []
187
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
188
+ for agent in self.agents:
189
+ position = await agent.memory.get("position")
190
+ lng = position["longlat_position"]["longitude"]
191
+ lat = position["longlat_position"]["latitude"]
192
+ if "aoi_position" in position:
193
+ parent_id = position["aoi_position"]["aoi_id"]
194
+ elif "lane_position" in position:
195
+ parent_id = position["lane_position"]["lane_id"]
196
+ else:
197
+ # BUG: 需要处理
198
+ parent_id = -1
199
+ needs = await agent.memory.get("needs")
200
+ action = await agent.memory.get("current_step")
201
+ action = action["intention"]
202
+ avro = {
203
+ "id": agent._uuid,
204
+ "day": await self.simulator.get_simulator_day(),
205
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
206
+ "lng": lng,
207
+ "lat": lat,
208
+ "parent_id": parent_id,
209
+ "action": action,
210
+ "hungry": needs["hungry"],
211
+ "tired": needs["tired"],
212
+ "safe": needs["safe"],
213
+ "social": needs["social"],
214
+ "created_at": int(datetime.now().timestamp() * 1000),
215
+ }
216
+ avros.append(avro)
217
+ with open(self.avro_file["status"], "a+b") as f:
218
+ fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
219
+ else:
220
+ for agent in self.agents:
221
+ avro = {
222
+ "id": agent._uuid,
223
+ "day": await self.simulator.get_simulator_day(),
224
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
225
+ "type": await agent.memory.get("type"),
226
+ "nominal_gdp": await agent.memory.get("nominal_gdp"),
227
+ "real_gdp": await agent.memory.get("real_gdp"),
228
+ "unemployment": await agent.memory.get("unemployment"),
229
+ "wages": await agent.memory.get("wages"),
230
+ "prices": await agent.memory.get("prices"),
231
+ "inventory": await agent.memory.get("inventory"),
232
+ "price": await agent.memory.get("price"),
233
+ "interest_rate": await agent.memory.get("interest_rate"),
234
+ "bracket_cutoffs": await agent.memory.get("bracket_cutoffs"),
235
+ "bracket_rates": await agent.memory.get("bracket_rates"),
236
+ "employees": await agent.memory.get("employees"),
237
+ "customers": await agent.memory.get("customers"),
238
+ }
239
+ avros.append(avro)
240
+ with open(self.avro_file["status"], "a+b") as f:
241
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
242
+
126
243
  async def step(self):
127
244
  if not self.initialized:
128
245
  await self.init_agents()
129
246
 
130
247
  tasks = [agent.run() for agent in self.agents]
131
248
  await asyncio.gather(*tasks)
132
- avros = []
133
- for agent in self.agents:
134
- if not issubclass(type(agent), CitizenAgent):
135
- continue
136
- position = await agent.memory.get("position")
137
- lng = position["longlat_position"]["longitude"]
138
- lat = position["longlat_position"]["latitude"]
139
- if "aoi_position" in position:
140
- parent_id = position["aoi_position"]["aoi_id"]
141
- elif "lane_position" in position:
142
- parent_id = position["lane_position"]["lane_id"]
143
- else:
144
- # BUG: 需要处理
145
- parent_id = -1
146
- needs = await agent.memory.get("needs")
147
- action = await agent.memory.get("current_step")
148
- action = action["intention"]
149
- avro = {
150
- "id": str(agent._uuid), # uuid as string
151
- "day": await self.simulator.get_simulator_day(),
152
- "t": await self.simulator.get_simulator_second_from_start_of_day(),
153
- "lng": lng,
154
- "lat": lat,
155
- "parent_id": parent_id,
156
- "action": action,
157
- "hungry": needs["hungry"],
158
- "tired": needs["tired"],
159
- "safe": needs["safe"],
160
- "social": needs["social"],
161
- "created_at": int(datetime.now().timestamp() * 1000),
162
- }
163
- avros.append(avro)
164
- with open(self.avro_file["status"], "a+b") as f:
165
- fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
249
+ await self.save_status()
166
250
 
167
251
  async def run(self, day: int = 1):
168
252
  """运行模拟器
@@ -185,5 +269,5 @@ class AgentGroup:
185
269
  await self.step()
186
270
 
187
271
  except Exception as e:
188
- logging.error(f"模拟器运行错误: {str(e)}")
272
+ logger.error(f"模拟器运行错误: {str(e)}")
189
273
  raise