pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +80 -63
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +135 -51
- pycityagent/simulation/simulation.py +206 -92
- pycityagent/survey/manager.py +10 -14
- pycityagent/survey/models.py +24 -24
- pycityagent/utils/__init__.py +2 -2
- pycityagent/utils/avro_schema.py +26 -1
- pycityagent/workflow/tool.py +1 -4
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/METADATA +3 -2
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/RECORD +16 -16
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/WHEEL +0 -0
@@ -2,27 +2,41 @@ import asyncio
|
|
2
2
|
from datetime import datetime
|
3
3
|
import json
|
4
4
|
import logging
|
5
|
+
from pathlib import Path
|
5
6
|
import uuid
|
6
7
|
import fastavro
|
7
8
|
import ray
|
8
9
|
from uuid import UUID
|
9
|
-
from pycityagent.agent import Agent, CitizenAgent
|
10
|
+
from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
|
10
11
|
from pycityagent.economy.econ_client import EconomyClient
|
11
12
|
from pycityagent.environment.simulator import Simulator
|
12
13
|
from pycityagent.llm.llm import LLM
|
13
14
|
from pycityagent.llm.llmconfig import LLMConfig
|
14
15
|
from pycityagent.message import Messager
|
15
|
-
from pycityagent.utils import STATUS_SCHEMA
|
16
|
+
from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
|
16
17
|
from typing import Any
|
17
18
|
|
19
|
+
logger = logging.getLogger("pycityagent")
|
20
|
+
|
18
21
|
@ray.remote
|
19
22
|
class AgentGroup:
|
20
|
-
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID,
|
23
|
+
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int):
|
24
|
+
logger.setLevel(logging_level)
|
25
|
+
self._uuid = str(uuid.uuid4())
|
21
26
|
self.agents = agents
|
22
27
|
self.config = config
|
23
28
|
self.exp_id = exp_id
|
24
|
-
self.
|
25
|
-
self.
|
29
|
+
self.enable_avro = enable_avro
|
30
|
+
self.avro_path = avro_path / f"{self._uuid}"
|
31
|
+
if enable_avro:
|
32
|
+
self.avro_path.mkdir(parents=True, exist_ok=True)
|
33
|
+
self.avro_file = {
|
34
|
+
"profile": self.avro_path / f"profile.avro",
|
35
|
+
"dialog": self.avro_path / f"dialog.avro",
|
36
|
+
"status": self.avro_path / f"status.avro",
|
37
|
+
"survey": self.avro_path / f"survey.avro",
|
38
|
+
}
|
39
|
+
|
26
40
|
self.messager = Messager(
|
27
41
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
28
42
|
port=config["simulator_request"]["mqtt"]["port"],
|
@@ -33,16 +47,16 @@ class AgentGroup:
|
|
33
47
|
self.id2agent = {}
|
34
48
|
# Step:1 prepare LLM client
|
35
49
|
llmConfig = LLMConfig(config["llm_request"])
|
36
|
-
|
50
|
+
logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
|
37
51
|
self.llm = LLM(llmConfig)
|
38
52
|
|
39
53
|
# Step:2 prepare Simulator
|
40
|
-
|
54
|
+
logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
|
41
55
|
self.simulator = Simulator(config["simulator_request"])
|
42
56
|
|
43
57
|
# Step:3 prepare Economy client
|
44
58
|
if "economy" in config["simulator_request"]:
|
45
|
-
|
59
|
+
logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
|
46
60
|
self.economy_client = EconomyClient(
|
47
61
|
config["simulator_request"]["economy"]["server"]
|
48
62
|
)
|
@@ -50,17 +64,22 @@ class AgentGroup:
|
|
50
64
|
self.economy_client = None
|
51
65
|
|
52
66
|
for agent in self.agents:
|
53
|
-
agent.set_exp_id(self.exp_id)
|
67
|
+
agent.set_exp_id(self.exp_id) # type: ignore
|
54
68
|
agent.set_llm_client(self.llm)
|
55
69
|
agent.set_simulator(self.simulator)
|
56
70
|
if self.economy_client is not None:
|
57
71
|
agent.set_economy_client(self.economy_client)
|
58
72
|
agent.set_messager(self.messager)
|
73
|
+
if self.enable_avro:
|
74
|
+
agent.set_avro_file(self.avro_file) # type: ignore
|
59
75
|
|
60
76
|
async def init_agents(self):
|
77
|
+
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
78
|
+
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
61
79
|
for agent in self.agents:
|
62
|
-
await agent.bind_to_simulator()
|
80
|
+
await agent.bind_to_simulator() # type: ignore
|
63
81
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
82
|
+
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
64
83
|
await self.messager.connect()
|
65
84
|
if self.messager.is_connected():
|
66
85
|
await self.messager.start_listening()
|
@@ -74,27 +93,65 @@ class AgentGroup:
|
|
74
93
|
await self.messager.subscribe(topic, agent)
|
75
94
|
topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
|
76
95
|
await self.messager.subscribe(topic, agent)
|
77
|
-
self.initialized = True
|
78
96
|
self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
|
79
|
-
|
97
|
+
if self.enable_avro:
|
98
|
+
logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
|
99
|
+
# profile
|
100
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
101
|
+
filename = self.avro_file["profile"]
|
102
|
+
with open(filename, "wb") as f:
|
103
|
+
profiles = []
|
104
|
+
for agent in self.agents:
|
105
|
+
profile = await agent.memory._profile.export()
|
106
|
+
profile = profile[0]
|
107
|
+
profile['id'] = agent._uuid
|
108
|
+
profiles.append(profile)
|
109
|
+
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
110
|
+
|
111
|
+
# dialog
|
112
|
+
filename = self.avro_file["dialog"]
|
113
|
+
with open(filename, "wb") as f:
|
114
|
+
dialogs = []
|
115
|
+
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
116
|
+
|
117
|
+
# status
|
118
|
+
filename = self.avro_file["status"]
|
119
|
+
with open(filename, "wb") as f:
|
120
|
+
statuses = []
|
121
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
122
|
+
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
123
|
+
else:
|
124
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
|
125
|
+
|
126
|
+
# survey
|
127
|
+
filename = self.avro_file["survey"]
|
128
|
+
with open(filename, "wb") as f:
|
129
|
+
surveys = []
|
130
|
+
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
131
|
+
self.initialized = True
|
132
|
+
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
133
|
+
|
80
134
|
async def gather(self, content: str):
|
135
|
+
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
81
136
|
results = {}
|
82
137
|
for agent in self.agents:
|
83
138
|
results[agent._uuid] = await agent.memory.get(content)
|
84
139
|
return results
|
85
140
|
|
86
141
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
142
|
+
logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
|
87
143
|
agent = self.id2agent[target_agent_uuid]
|
88
144
|
await agent.memory.update(target_key, content)
|
89
145
|
|
90
146
|
async def message_dispatch(self):
|
147
|
+
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
91
148
|
while True:
|
92
149
|
if not self.messager.is_connected():
|
93
|
-
|
150
|
+
logger.warning("Messager is not connected. Skipping message processing.")
|
94
151
|
|
95
152
|
# Step 1: 获取消息
|
96
153
|
messages = await self.messager.fetch_messages()
|
97
|
-
|
154
|
+
logger.info(f"Group {self._uuid} received {len(messages)} messages")
|
98
155
|
|
99
156
|
# Step 2: 分发消息到对应的 Agent
|
100
157
|
for message in messages:
|
@@ -109,8 +166,8 @@ class AgentGroup:
|
|
109
166
|
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
110
167
|
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
111
168
|
|
112
|
-
if
|
113
|
-
agent = self.id2agent[
|
169
|
+
if agent_uuid in self.id2agent:
|
170
|
+
agent = self.id2agent[agent_uuid]
|
114
171
|
# topic_type: agent-chat, user-chat, user-survey, gather
|
115
172
|
if topic_type == "agent-chat":
|
116
173
|
await agent.handle_agent_chat_message(payload)
|
@@ -123,46 +180,73 @@ class AgentGroup:
|
|
123
180
|
|
124
181
|
await asyncio.sleep(0.5)
|
125
182
|
|
183
|
+
async def save_status(self):
|
184
|
+
if self.enable_avro:
|
185
|
+
logger.debug(f"-----Saving status for group {self._uuid}")
|
186
|
+
avros = []
|
187
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
188
|
+
for agent in self.agents:
|
189
|
+
position = await agent.memory.get("position")
|
190
|
+
lng = position["longlat_position"]["longitude"]
|
191
|
+
lat = position["longlat_position"]["latitude"]
|
192
|
+
if "aoi_position" in position:
|
193
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
194
|
+
elif "lane_position" in position:
|
195
|
+
parent_id = position["lane_position"]["lane_id"]
|
196
|
+
else:
|
197
|
+
# BUG: 需要处理
|
198
|
+
parent_id = -1
|
199
|
+
needs = await agent.memory.get("needs")
|
200
|
+
action = await agent.memory.get("current_step")
|
201
|
+
action = action["intention"]
|
202
|
+
avro = {
|
203
|
+
"id": agent._uuid,
|
204
|
+
"day": await self.simulator.get_simulator_day(),
|
205
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
206
|
+
"lng": lng,
|
207
|
+
"lat": lat,
|
208
|
+
"parent_id": parent_id,
|
209
|
+
"action": action,
|
210
|
+
"hungry": needs["hungry"],
|
211
|
+
"tired": needs["tired"],
|
212
|
+
"safe": needs["safe"],
|
213
|
+
"social": needs["social"],
|
214
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
215
|
+
}
|
216
|
+
avros.append(avro)
|
217
|
+
with open(self.avro_file["status"], "a+b") as f:
|
218
|
+
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
219
|
+
else:
|
220
|
+
for agent in self.agents:
|
221
|
+
avro = {
|
222
|
+
"id": agent._uuid,
|
223
|
+
"day": await self.simulator.get_simulator_day(),
|
224
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
225
|
+
"type": await agent.memory.get("type"),
|
226
|
+
"nominal_gdp": await agent.memory.get("nominal_gdp"),
|
227
|
+
"real_gdp": await agent.memory.get("real_gdp"),
|
228
|
+
"unemployment": await agent.memory.get("unemployment"),
|
229
|
+
"wages": await agent.memory.get("wages"),
|
230
|
+
"prices": await agent.memory.get("prices"),
|
231
|
+
"inventory": await agent.memory.get("inventory"),
|
232
|
+
"price": await agent.memory.get("price"),
|
233
|
+
"interest_rate": await agent.memory.get("interest_rate"),
|
234
|
+
"bracket_cutoffs": await agent.memory.get("bracket_cutoffs"),
|
235
|
+
"bracket_rates": await agent.memory.get("bracket_rates"),
|
236
|
+
"employees": await agent.memory.get("employees"),
|
237
|
+
"customers": await agent.memory.get("customers"),
|
238
|
+
}
|
239
|
+
avros.append(avro)
|
240
|
+
with open(self.avro_file["status"], "a+b") as f:
|
241
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
242
|
+
|
126
243
|
async def step(self):
|
127
244
|
if not self.initialized:
|
128
245
|
await self.init_agents()
|
129
246
|
|
130
247
|
tasks = [agent.run() for agent in self.agents]
|
131
248
|
await asyncio.gather(*tasks)
|
132
|
-
|
133
|
-
for agent in self.agents:
|
134
|
-
if not issubclass(type(agent), CitizenAgent):
|
135
|
-
continue
|
136
|
-
position = await agent.memory.get("position")
|
137
|
-
lng = position["longlat_position"]["longitude"]
|
138
|
-
lat = position["longlat_position"]["latitude"]
|
139
|
-
if "aoi_position" in position:
|
140
|
-
parent_id = position["aoi_position"]["aoi_id"]
|
141
|
-
elif "lane_position" in position:
|
142
|
-
parent_id = position["lane_position"]["lane_id"]
|
143
|
-
else:
|
144
|
-
# BUG: 需要处理
|
145
|
-
parent_id = -1
|
146
|
-
needs = await agent.memory.get("needs")
|
147
|
-
action = await agent.memory.get("current_step")
|
148
|
-
action = action["intention"]
|
149
|
-
avro = {
|
150
|
-
"id": str(agent._uuid), # uuid as string
|
151
|
-
"day": await self.simulator.get_simulator_day(),
|
152
|
-
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
153
|
-
"lng": lng,
|
154
|
-
"lat": lat,
|
155
|
-
"parent_id": parent_id,
|
156
|
-
"action": action,
|
157
|
-
"hungry": needs["hungry"],
|
158
|
-
"tired": needs["tired"],
|
159
|
-
"safe": needs["safe"],
|
160
|
-
"social": needs["social"],
|
161
|
-
"created_at": int(datetime.now().timestamp() * 1000),
|
162
|
-
}
|
163
|
-
avros.append(avro)
|
164
|
-
with open(self.avro_file["status"], "a+b") as f:
|
165
|
-
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
249
|
+
await self.save_status()
|
166
250
|
|
167
251
|
async def run(self, day: int = 1):
|
168
252
|
"""运行模拟器
|
@@ -185,5 +269,5 @@ class AgentGroup:
|
|
185
269
|
await self.step()
|
186
270
|
|
187
271
|
except Exception as e:
|
188
|
-
|
272
|
+
logger.error(f"模拟器运行错误: {str(e)}")
|
189
273
|
raise
|