pycityagent 2.0.0a13__py3-none-any.whl → 2.0.0a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +164 -63
- pycityagent/economy/econ_client.py +2 -0
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/const.py +1 -0
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +136 -14
- pycityagent/simulation/simulation.py +212 -42
- pycityagent/survey/manager.py +58 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +7 -0
- pycityagent/utils/avro_schema.py +110 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/tool.py +0 -3
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/METADATA +3 -1
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/RECORD +20 -21
- pycityagent/simulation/interview.py +0 -40
- pycityagent/simulation/survey/manager.py +0 -68
- pycityagent/simulation/survey/models.py +0 -52
- pycityagent/simulation/ui/__init__.py +0 -3
- pycityagent/simulation/ui/interface.py +0 -602
- /pycityagent/{simulation/survey → survey}/__init__.py +0 -0
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/WHEEL +0 -0
pycityagent/message/messager.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5
5
|
import math
|
6
6
|
from aiomqtt import Client
|
7
7
|
|
8
|
+
logger = logging.getLogger("pycityagent")
|
8
9
|
|
9
10
|
class Messager:
|
10
11
|
def __init__(
|
@@ -21,15 +22,15 @@ class Messager:
|
|
21
22
|
try:
|
22
23
|
await self.client.__aenter__()
|
23
24
|
self.connected = True
|
24
|
-
|
25
|
+
logger.info("Connected to MQTT Broker")
|
25
26
|
except Exception as e:
|
26
27
|
self.connected = False
|
27
|
-
|
28
|
+
logger.error(f"Failed to connect to MQTT Broker: {e}")
|
28
29
|
|
29
30
|
async def disconnect(self):
|
30
31
|
await self.client.__aexit__(None, None, None)
|
31
32
|
self.connected = False
|
32
|
-
|
33
|
+
logger.info("Disconnected from MQTT Broker")
|
33
34
|
|
34
35
|
def is_connected(self):
|
35
36
|
"""检查是否成功连接到 Broker"""
|
@@ -37,13 +38,13 @@ class Messager:
|
|
37
38
|
|
38
39
|
async def subscribe(self, topic, agent):
|
39
40
|
if not self.is_connected():
|
40
|
-
|
41
|
+
logger.error(
|
41
42
|
f"Cannot subscribe to {topic} because not connected to the Broker."
|
42
43
|
)
|
43
44
|
return
|
44
45
|
await self.client.subscribe(topic)
|
45
46
|
self.subscribers[topic] = agent
|
46
|
-
|
47
|
+
logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
|
47
48
|
|
48
49
|
async def receive_messages(self):
|
49
50
|
"""监听并将消息存入队列"""
|
@@ -61,11 +62,11 @@ class Messager:
|
|
61
62
|
"""通过 Messager 发送消息"""
|
62
63
|
message = json.dumps(payload, default=str)
|
63
64
|
await self.client.publish(topic, message)
|
64
|
-
|
65
|
+
logger.info(f"Message sent to {topic}: {message}")
|
65
66
|
|
66
67
|
async def start_listening(self):
|
67
68
|
"""启动消息监听任务"""
|
68
69
|
if self.is_connected():
|
69
70
|
asyncio.create_task(self.receive_messages())
|
70
71
|
else:
|
71
|
-
|
72
|
+
logger.error("Cannot start listening because not connected to the Broker.")
|
@@ -1,24 +1,42 @@
|
|
1
1
|
import asyncio
|
2
|
+
from datetime import datetime
|
2
3
|
import json
|
3
4
|
import logging
|
5
|
+
from pathlib import Path
|
4
6
|
import uuid
|
7
|
+
import fastavro
|
5
8
|
import ray
|
6
9
|
from uuid import UUID
|
7
|
-
from pycityagent.agent import Agent
|
10
|
+
from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
|
8
11
|
from pycityagent.economy.econ_client import EconomyClient
|
9
12
|
from pycityagent.environment.simulator import Simulator
|
10
13
|
from pycityagent.llm.llm import LLM
|
11
14
|
from pycityagent.llm.llmconfig import LLMConfig
|
12
15
|
from pycityagent.message import Messager
|
16
|
+
from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
|
13
17
|
from typing import Any
|
14
18
|
|
19
|
+
logger = logging.getLogger("pycityagent")
|
20
|
+
|
15
21
|
@ray.remote
|
16
22
|
class AgentGroup:
|
17
|
-
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID):
|
23
|
+
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
|
24
|
+
logger.setLevel(logging_level)
|
25
|
+
self._uuid = str(uuid.uuid4())
|
18
26
|
self.agents = agents
|
19
27
|
self.config = config
|
20
28
|
self.exp_id = exp_id
|
21
|
-
self.
|
29
|
+
self.enable_avro = enable_avro
|
30
|
+
self.avro_path = avro_path / f"{self._uuid}"
|
31
|
+
if enable_avro:
|
32
|
+
self.avro_path.mkdir(parents=True, exist_ok=True)
|
33
|
+
self.avro_file = {
|
34
|
+
"profile": self.avro_path / f"profile.avro",
|
35
|
+
"dialog": self.avro_path / f"dialog.avro",
|
36
|
+
"status": self.avro_path / f"status.avro",
|
37
|
+
"survey": self.avro_path / f"survey.avro",
|
38
|
+
}
|
39
|
+
|
22
40
|
self.messager = Messager(
|
23
41
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
24
42
|
port=config["simulator_request"]["mqtt"]["port"],
|
@@ -29,16 +47,16 @@ class AgentGroup:
|
|
29
47
|
self.id2agent = {}
|
30
48
|
# Step:1 prepare LLM client
|
31
49
|
llmConfig = LLMConfig(config["llm_request"])
|
32
|
-
|
50
|
+
logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
|
33
51
|
self.llm = LLM(llmConfig)
|
34
52
|
|
35
53
|
# Step:2 prepare Simulator
|
36
|
-
|
54
|
+
logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
|
37
55
|
self.simulator = Simulator(config["simulator_request"])
|
38
56
|
|
39
57
|
# Step:3 prepare Economy client
|
40
58
|
if "economy" in config["simulator_request"]:
|
41
|
-
|
59
|
+
logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
|
42
60
|
self.economy_client = EconomyClient(
|
43
61
|
config["simulator_request"]["economy"]["server"]
|
44
62
|
)
|
@@ -52,11 +70,16 @@ class AgentGroup:
|
|
52
70
|
if self.economy_client is not None:
|
53
71
|
agent.set_economy_client(self.economy_client)
|
54
72
|
agent.set_messager(self.messager)
|
73
|
+
if self.enable_avro:
|
74
|
+
agent.set_avro_file(self.avro_file)
|
55
75
|
|
56
76
|
async def init_agents(self):
|
77
|
+
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
78
|
+
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
57
79
|
for agent in self.agents:
|
58
80
|
await agent.bind_to_simulator()
|
59
81
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
82
|
+
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
60
83
|
await self.messager.connect()
|
61
84
|
if self.messager.is_connected():
|
62
85
|
await self.messager.start_listening()
|
@@ -70,27 +93,65 @@ class AgentGroup:
|
|
70
93
|
await self.messager.subscribe(topic, agent)
|
71
94
|
topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
|
72
95
|
await self.messager.subscribe(topic, agent)
|
73
|
-
self.initialized = True
|
74
96
|
self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
|
75
|
-
|
97
|
+
if self.enable_avro:
|
98
|
+
logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
|
99
|
+
# profile
|
100
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
101
|
+
filename = self.avro_file["profile"]
|
102
|
+
with open(filename, "wb") as f:
|
103
|
+
profiles = []
|
104
|
+
for agent in self.agents:
|
105
|
+
profile = await agent.memory._profile.export()
|
106
|
+
profile = profile[0]
|
107
|
+
profile['id'] = agent._uuid
|
108
|
+
profiles.append(profile)
|
109
|
+
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
110
|
+
|
111
|
+
# dialog
|
112
|
+
filename = self.avro_file["dialog"]
|
113
|
+
with open(filename, "wb") as f:
|
114
|
+
dialogs = []
|
115
|
+
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
116
|
+
|
117
|
+
# status
|
118
|
+
filename = self.avro_file["status"]
|
119
|
+
with open(filename, "wb") as f:
|
120
|
+
statuses = []
|
121
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
122
|
+
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
123
|
+
else:
|
124
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
|
125
|
+
|
126
|
+
# survey
|
127
|
+
filename = self.avro_file["survey"]
|
128
|
+
with open(filename, "wb") as f:
|
129
|
+
surveys = []
|
130
|
+
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
131
|
+
self.initialized = True
|
132
|
+
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
133
|
+
|
76
134
|
async def gather(self, content: str):
|
135
|
+
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
77
136
|
results = {}
|
78
137
|
for agent in self.agents:
|
79
138
|
results[agent._uuid] = await agent.memory.get(content)
|
80
139
|
return results
|
81
140
|
|
82
141
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
142
|
+
logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
|
83
143
|
agent = self.id2agent[target_agent_uuid]
|
84
144
|
await agent.memory.update(target_key, content)
|
85
145
|
|
86
146
|
async def message_dispatch(self):
|
147
|
+
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
87
148
|
while True:
|
88
149
|
if not self.messager.is_connected():
|
89
|
-
|
150
|
+
logger.warning("Messager is not connected. Skipping message processing.")
|
90
151
|
|
91
152
|
# Step 1: 获取消息
|
92
153
|
messages = await self.messager.fetch_messages()
|
93
|
-
|
154
|
+
logger.info(f"Group {self._uuid} received {len(messages)} messages")
|
94
155
|
|
95
156
|
# Step 2: 分发消息到对应的 Agent
|
96
157
|
for message in messages:
|
@@ -105,8 +166,8 @@ class AgentGroup:
|
|
105
166
|
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
106
167
|
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
107
168
|
|
108
|
-
if
|
109
|
-
agent = self.id2agent[
|
169
|
+
if agent_uuid in self.id2agent:
|
170
|
+
agent = self.id2agent[agent_uuid]
|
110
171
|
# topic_type: agent-chat, user-chat, user-survey, gather
|
111
172
|
if topic_type == "agent-chat":
|
112
173
|
await agent.handle_agent_chat_message(payload)
|
@@ -117,7 +178,67 @@ class AgentGroup:
|
|
117
178
|
elif topic_type == "gather":
|
118
179
|
await agent.handle_gather_message(payload)
|
119
180
|
|
120
|
-
await asyncio.sleep(
|
181
|
+
await asyncio.sleep(0.5)
|
182
|
+
|
183
|
+
async def save_status(self):
|
184
|
+
if self.enable_avro:
|
185
|
+
logger.debug(f"-----Saving status for group {self._uuid}")
|
186
|
+
avros = []
|
187
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
188
|
+
for agent in self.agents:
|
189
|
+
position = await agent.memory.get("position")
|
190
|
+
lng = position["longlat_position"]["longitude"]
|
191
|
+
lat = position["longlat_position"]["latitude"]
|
192
|
+
if "aoi_position" in position:
|
193
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
194
|
+
elif "lane_position" in position:
|
195
|
+
parent_id = position["lane_position"]["lane_id"]
|
196
|
+
else:
|
197
|
+
# BUG: 需要处理
|
198
|
+
parent_id = -1
|
199
|
+
needs = await agent.memory.get("needs")
|
200
|
+
action = await agent.memory.get("current_step")
|
201
|
+
action = action["intention"]
|
202
|
+
avro = {
|
203
|
+
"id": agent._uuid,
|
204
|
+
"day": await self.simulator.get_simulator_day(),
|
205
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
206
|
+
"lng": lng,
|
207
|
+
"lat": lat,
|
208
|
+
"parent_id": parent_id,
|
209
|
+
"action": action,
|
210
|
+
"hungry": needs["hungry"],
|
211
|
+
"tired": needs["tired"],
|
212
|
+
"safe": needs["safe"],
|
213
|
+
"social": needs["social"],
|
214
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
215
|
+
}
|
216
|
+
avros.append(avro)
|
217
|
+
with open(self.avro_file["status"], "a+b") as f:
|
218
|
+
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
219
|
+
else:
|
220
|
+
for agent in self.agents:
|
221
|
+
avro = {
|
222
|
+
"id": agent._uuid,
|
223
|
+
"day": await self.simulator.get_simulator_day(),
|
224
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
225
|
+
"type": await agent.memory.get("type"),
|
226
|
+
"nominal_gdp": await agent.memory.get("nominal_gdp"),
|
227
|
+
"real_gdp": await agent.memory.get("real_gdp"),
|
228
|
+
"unemployment": await agent.memory.get("unemployment"),
|
229
|
+
"wages": await agent.memory.get("wages"),
|
230
|
+
"prices": await agent.memory.get("prices"),
|
231
|
+
"inventory": await agent.memory.get("inventory"),
|
232
|
+
"price": await agent.memory.get("price"),
|
233
|
+
"interest_rate": await agent.memory.get("interest_rate"),
|
234
|
+
"bracket_cutoffs": await agent.memory.get("bracket_cutoffs"),
|
235
|
+
"bracket_rates": await agent.memory.get("bracket_rates"),
|
236
|
+
"employees": await agent.memory.get("employees"),
|
237
|
+
"customers": await agent.memory.get("customers"),
|
238
|
+
}
|
239
|
+
avros.append(avro)
|
240
|
+
with open(self.avro_file["status"], "a+b") as f:
|
241
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
121
242
|
|
122
243
|
async def step(self):
|
123
244
|
if not self.initialized:
|
@@ -125,6 +246,7 @@ class AgentGroup:
|
|
125
246
|
|
126
247
|
tasks = [agent.run() for agent in self.agents]
|
127
248
|
await asyncio.gather(*tasks)
|
249
|
+
await self.save_status()
|
128
250
|
|
129
251
|
async def run(self, day: int = 1):
|
130
252
|
"""运行模拟器
|
@@ -147,5 +269,5 @@ class AgentGroup:
|
|
147
269
|
await self.step()
|
148
270
|
|
149
271
|
except Exception as e:
|
150
|
-
|
272
|
+
logger.error(f"模拟器运行错误: {str(e)}")
|
151
273
|
raise
|