pycityagent 2.0.0a43__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a43.dist-info/METADATA +235 -0
- pycityagent-2.0.0a43.dist-info/RECORD +81 -0
- pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
@@ -0,0 +1,580 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import time
|
5
|
+
import uuid
|
6
|
+
from datetime import datetime, timezone
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any
|
9
|
+
from uuid import UUID
|
10
|
+
|
11
|
+
import fastavro
|
12
|
+
import pyproj
|
13
|
+
import ray
|
14
|
+
from langchain_core.embeddings import Embeddings
|
15
|
+
|
16
|
+
from ..agent import Agent, CitizenAgent, InstitutionAgent
|
17
|
+
from ..economy.econ_client import EconomyClient
|
18
|
+
from ..environment.simulator import Simulator
|
19
|
+
from ..llm.llm import LLM
|
20
|
+
from ..llm.llmconfig import LLMConfig
|
21
|
+
from ..memory import FaissQuery
|
22
|
+
from ..message import Messager
|
23
|
+
from ..metrics import MlflowClient
|
24
|
+
from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
|
25
|
+
STATUS_SCHEMA, SURVEY_SCHEMA)
|
26
|
+
|
27
|
+
logger = logging.getLogger("pycityagent")
|
28
|
+
|
29
|
+
|
30
|
+
@ray.remote
|
31
|
+
class AgentGroup:
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
agents: list[Agent],
|
35
|
+
config: dict,
|
36
|
+
exp_id: str | UUID,
|
37
|
+
exp_name: str,
|
38
|
+
enable_avro: bool,
|
39
|
+
avro_path: Path,
|
40
|
+
enable_pgsql: bool,
|
41
|
+
pgsql_writer: ray.ObjectRef,
|
42
|
+
mlflow_run_id: str,
|
43
|
+
embedding_model: Embeddings,
|
44
|
+
logging_level: int,
|
45
|
+
):
|
46
|
+
logger.setLevel(logging_level)
|
47
|
+
self._uuid = str(uuid.uuid4())
|
48
|
+
self.agents = agents
|
49
|
+
self.config = config
|
50
|
+
self.exp_id = exp_id
|
51
|
+
self.enable_avro = enable_avro
|
52
|
+
self.enable_pgsql = enable_pgsql
|
53
|
+
self.embedding_model = embedding_model
|
54
|
+
if enable_avro:
|
55
|
+
self.avro_path = avro_path / f"{self._uuid}"
|
56
|
+
self.avro_path.mkdir(parents=True, exist_ok=True)
|
57
|
+
self.avro_file = {
|
58
|
+
"profile": self.avro_path / f"profile.avro",
|
59
|
+
"dialog": self.avro_path / f"dialog.avro",
|
60
|
+
"status": self.avro_path / f"status.avro",
|
61
|
+
"survey": self.avro_path / f"survey.avro",
|
62
|
+
}
|
63
|
+
if self.enable_pgsql:
|
64
|
+
pass
|
65
|
+
|
66
|
+
self.messager = Messager.remote(
|
67
|
+
hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
|
68
|
+
port=config["simulator_request"]["mqtt"]["port"],
|
69
|
+
username=config["simulator_request"]["mqtt"].get("username", None),
|
70
|
+
password=config["simulator_request"]["mqtt"].get("password", None),
|
71
|
+
)
|
72
|
+
self.message_dispatch_task = None
|
73
|
+
self._pgsql_writer = pgsql_writer
|
74
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
75
|
+
self.initialized = False
|
76
|
+
self.id2agent = {}
|
77
|
+
# Step:1 prepare LLM client
|
78
|
+
llmConfig = LLMConfig(config["llm_request"])
|
79
|
+
logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
|
80
|
+
self.llm = LLM(llmConfig)
|
81
|
+
|
82
|
+
# Step:2 prepare Simulator
|
83
|
+
logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
|
84
|
+
self.simulator = Simulator(config["simulator_request"])
|
85
|
+
self.projector = pyproj.Proj(self.simulator.map.header["projection"])
|
86
|
+
|
87
|
+
# Step:3 prepare Economy client
|
88
|
+
if "economy" in config["simulator_request"]:
|
89
|
+
logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
|
90
|
+
self.economy_client = EconomyClient(
|
91
|
+
config["simulator_request"]["economy"]["server"]
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
self.economy_client = None
|
95
|
+
|
96
|
+
# Mlflow
|
97
|
+
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
98
|
+
if _mlflow_config:
|
99
|
+
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
100
|
+
self.mlflow_client = MlflowClient(
|
101
|
+
config=_mlflow_config,
|
102
|
+
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
103
|
+
experiment_name=exp_name,
|
104
|
+
run_id=mlflow_run_id,
|
105
|
+
)
|
106
|
+
else:
|
107
|
+
self.mlflow_client = None
|
108
|
+
|
109
|
+
# set FaissQuery
|
110
|
+
if self.embedding_model is not None:
|
111
|
+
self.faiss_query = FaissQuery(
|
112
|
+
embeddings=self.embedding_model,
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
self.faiss_query = None
|
116
|
+
for agent in self.agents:
|
117
|
+
agent.set_exp_id(self.exp_id) # type: ignore
|
118
|
+
agent.set_llm_client(self.llm)
|
119
|
+
agent.set_simulator(self.simulator)
|
120
|
+
if self.economy_client is not None:
|
121
|
+
agent.set_economy_client(self.economy_client)
|
122
|
+
if self.mlflow_client is not None:
|
123
|
+
agent.set_mlflow_client(self.mlflow_client)
|
124
|
+
agent.set_messager(self.messager)
|
125
|
+
if self.enable_avro:
|
126
|
+
agent.set_avro_file(self.avro_file) # type: ignore
|
127
|
+
if self.enable_pgsql:
|
128
|
+
agent.set_pgsql_writer(self._pgsql_writer)
|
129
|
+
# set memory.faiss_query
|
130
|
+
if self.faiss_query is not None:
|
131
|
+
agent.memory.set_faiss_query(self.faiss_query)
|
132
|
+
# set memory.embedding model
|
133
|
+
if self.embedding_model is not None:
|
134
|
+
agent.memory.set_embedding_model(self.embedding_model)
|
135
|
+
|
136
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
137
|
+
self.message_dispatch_task.cancel() # type: ignore
|
138
|
+
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
|
139
|
+
|
140
|
+
async def init_agents(self):
|
141
|
+
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
142
|
+
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
143
|
+
for agent in self.agents:
|
144
|
+
await agent.bind_to_simulator() # type: ignore
|
145
|
+
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
146
|
+
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
147
|
+
await self.messager.connect.remote()
|
148
|
+
if ray.get(self.messager.is_connected.remote()):
|
149
|
+
await self.messager.start_listening.remote()
|
150
|
+
topics = []
|
151
|
+
agents = []
|
152
|
+
for agent in self.agents:
|
153
|
+
agent.set_messager(self.messager)
|
154
|
+
topic = (f"exps/{self.exp_id}/agents/{agent._uuid}/#", 1)
|
155
|
+
topics.append(topic)
|
156
|
+
agents.append(agent.uuid)
|
157
|
+
await self.messager.subscribe.remote(topics, agents)
|
158
|
+
self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
|
159
|
+
if self.enable_avro:
|
160
|
+
logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
|
161
|
+
# profile
|
162
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
163
|
+
filename = self.avro_file["profile"]
|
164
|
+
with open(filename, "wb") as f:
|
165
|
+
profiles = []
|
166
|
+
for agent in self.agents:
|
167
|
+
profile = await agent.memory._profile.export()
|
168
|
+
profile = profile[0]
|
169
|
+
profile["id"] = agent._uuid
|
170
|
+
profiles.append(profile)
|
171
|
+
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
172
|
+
|
173
|
+
# dialog
|
174
|
+
filename = self.avro_file["dialog"]
|
175
|
+
with open(filename, "wb") as f:
|
176
|
+
dialogs = []
|
177
|
+
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
178
|
+
|
179
|
+
# status
|
180
|
+
filename = self.avro_file["status"]
|
181
|
+
with open(filename, "wb") as f:
|
182
|
+
statuses = []
|
183
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
184
|
+
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
185
|
+
else:
|
186
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
|
187
|
+
|
188
|
+
# survey
|
189
|
+
filename = self.avro_file["survey"]
|
190
|
+
with open(filename, "wb") as f:
|
191
|
+
surveys = []
|
192
|
+
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
193
|
+
|
194
|
+
if self.enable_pgsql:
|
195
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
196
|
+
profiles: list[Any] = []
|
197
|
+
for agent in self.agents:
|
198
|
+
profile = await agent.memory._profile.export()
|
199
|
+
profile = profile[0]
|
200
|
+
profile["id"] = agent._uuid
|
201
|
+
profiles.append(
|
202
|
+
(
|
203
|
+
agent._uuid,
|
204
|
+
profile.get("name", ""),
|
205
|
+
json.dumps(
|
206
|
+
{
|
207
|
+
k: v
|
208
|
+
for k, v in profile.items()
|
209
|
+
if k not in {"id", "name"}
|
210
|
+
}
|
211
|
+
),
|
212
|
+
)
|
213
|
+
)
|
214
|
+
else:
|
215
|
+
profiles: list[Any] = []
|
216
|
+
for agent in self.agents:
|
217
|
+
profile = await agent.memory._profile.export()
|
218
|
+
profile = profile[0]
|
219
|
+
profile["id"] = agent._uuid
|
220
|
+
profiles.append(
|
221
|
+
(
|
222
|
+
agent._uuid,
|
223
|
+
profile.get("name", ""),
|
224
|
+
json.dumps(
|
225
|
+
{
|
226
|
+
k: v
|
227
|
+
for k, v in profile.items()
|
228
|
+
if k not in {"id", "name"}
|
229
|
+
}
|
230
|
+
),
|
231
|
+
)
|
232
|
+
)
|
233
|
+
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
234
|
+
profiles
|
235
|
+
)
|
236
|
+
self.initialized = True
|
237
|
+
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
238
|
+
|
239
|
+
async def gather(self, content: str):
|
240
|
+
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
241
|
+
results = {}
|
242
|
+
for agent in self.agents:
|
243
|
+
results[agent._uuid] = await agent.memory.get(content)
|
244
|
+
return results
|
245
|
+
|
246
|
+
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
247
|
+
logger.debug(
|
248
|
+
f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
|
249
|
+
)
|
250
|
+
agent = self.id2agent[target_agent_uuid]
|
251
|
+
await agent.memory.update(target_key, content)
|
252
|
+
|
253
|
+
async def message_dispatch(self):
|
254
|
+
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
255
|
+
while True:
|
256
|
+
if not ray.get(self.messager.is_connected.remote()):
|
257
|
+
logger.warning(
|
258
|
+
"Messager is not connected. Skipping message processing."
|
259
|
+
)
|
260
|
+
break
|
261
|
+
|
262
|
+
# Step 1: 获取消息
|
263
|
+
messages = await self.messager.fetch_messages.remote()
|
264
|
+
logger.info(f"Group {self._uuid} received {len(messages)} messages")
|
265
|
+
|
266
|
+
# Step 2: 分发消息到对应的 Agent
|
267
|
+
for message in messages:
|
268
|
+
topic = message.topic.value
|
269
|
+
payload = message.payload
|
270
|
+
|
271
|
+
# 添加解码步骤,将bytes转换为str
|
272
|
+
if isinstance(payload, bytes):
|
273
|
+
payload = payload.decode("utf-8")
|
274
|
+
payload = json.loads(payload)
|
275
|
+
|
276
|
+
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
277
|
+
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
278
|
+
|
279
|
+
if agent_uuid in self.id2agent:
|
280
|
+
agent = self.id2agent[agent_uuid]
|
281
|
+
# topic_type: agent-chat, user-chat, user-survey, gather
|
282
|
+
if topic_type == "agent-chat":
|
283
|
+
await agent.handle_agent_chat_message(payload)
|
284
|
+
elif topic_type == "user-chat":
|
285
|
+
await agent.handle_user_chat_message(payload)
|
286
|
+
elif topic_type == "user-survey":
|
287
|
+
await agent.handle_user_survey_message(payload)
|
288
|
+
elif topic_type == "gather":
|
289
|
+
await agent.handle_gather_message(payload)
|
290
|
+
|
291
|
+
await asyncio.sleep(0.5)
|
292
|
+
|
293
|
+
async def save_status(self):
|
294
|
+
_statuses_time_list: list[tuple[dict, datetime]] = []
|
295
|
+
if self.enable_avro:
|
296
|
+
logger.debug(f"-----Saving status for group {self._uuid}")
|
297
|
+
avros = []
|
298
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
299
|
+
for agent in self.agents:
|
300
|
+
_date_time = datetime.now(timezone.utc)
|
301
|
+
position = await agent.memory.get("position")
|
302
|
+
x = position["xy_position"]["x"]
|
303
|
+
y = position["xy_position"]["y"]
|
304
|
+
lng, lat = self.projector(x, y, inverse=True)
|
305
|
+
if "aoi_position" in position:
|
306
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
307
|
+
elif "lane_position" in position:
|
308
|
+
parent_id = position["lane_position"]["lane_id"]
|
309
|
+
else:
|
310
|
+
parent_id = -1
|
311
|
+
needs = await agent.memory.get("needs")
|
312
|
+
action = await agent.memory.get("current_step")
|
313
|
+
action = action["intention"]
|
314
|
+
avro = {
|
315
|
+
"id": agent._uuid,
|
316
|
+
"day": await self.simulator.get_simulator_day(),
|
317
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
318
|
+
"lng": lng,
|
319
|
+
"lat": lat,
|
320
|
+
"parent_id": parent_id,
|
321
|
+
"action": action,
|
322
|
+
"hungry": needs["hungry"],
|
323
|
+
"tired": needs["tired"],
|
324
|
+
"safe": needs["safe"],
|
325
|
+
"social": needs["social"],
|
326
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
327
|
+
}
|
328
|
+
avros.append(avro)
|
329
|
+
_statuses_time_list.append((avro, _date_time))
|
330
|
+
with open(self.avro_file["status"], "a+b") as f:
|
331
|
+
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
332
|
+
else:
|
333
|
+
for agent in self.agents:
|
334
|
+
_date_time = datetime.now(timezone.utc)
|
335
|
+
try:
|
336
|
+
nominal_gdp = await agent.memory.get("nominal_gdp")
|
337
|
+
except:
|
338
|
+
nominal_gdp = []
|
339
|
+
try:
|
340
|
+
real_gdp = await agent.memory.get("real_gdp")
|
341
|
+
except:
|
342
|
+
real_gdp = []
|
343
|
+
try:
|
344
|
+
unemployment = await agent.memory.get("unemployment")
|
345
|
+
except:
|
346
|
+
unemployment = []
|
347
|
+
try:
|
348
|
+
wages = await agent.memory.get("wages")
|
349
|
+
except:
|
350
|
+
wages = []
|
351
|
+
try:
|
352
|
+
prices = await agent.memory.get("prices")
|
353
|
+
except:
|
354
|
+
prices = []
|
355
|
+
try:
|
356
|
+
inventory = await agent.memory.get("inventory")
|
357
|
+
except:
|
358
|
+
inventory = 0
|
359
|
+
try:
|
360
|
+
price = await agent.memory.get("price")
|
361
|
+
except:
|
362
|
+
price = 0.0
|
363
|
+
try:
|
364
|
+
interest_rate = await agent.memory.get("interest_rate")
|
365
|
+
except:
|
366
|
+
interest_rate = 0.0
|
367
|
+
try:
|
368
|
+
bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
|
369
|
+
except:
|
370
|
+
bracket_cutoffs = []
|
371
|
+
try:
|
372
|
+
bracket_rates = await agent.memory.get("bracket_rates")
|
373
|
+
except:
|
374
|
+
bracket_rates = []
|
375
|
+
try:
|
376
|
+
employees = await agent.memory.get("employees")
|
377
|
+
except:
|
378
|
+
employees = []
|
379
|
+
avro = {
|
380
|
+
"id": agent._uuid,
|
381
|
+
"day": await self.simulator.get_simulator_day(),
|
382
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
383
|
+
"type": await agent.memory.get("type"),
|
384
|
+
"nominal_gdp": nominal_gdp,
|
385
|
+
"real_gdp": real_gdp,
|
386
|
+
"unemployment": unemployment,
|
387
|
+
"wages": wages,
|
388
|
+
"prices": prices,
|
389
|
+
"inventory": inventory,
|
390
|
+
"price": price,
|
391
|
+
"interest_rate": interest_rate,
|
392
|
+
"bracket_cutoffs": bracket_cutoffs,
|
393
|
+
"bracket_rates": bracket_rates,
|
394
|
+
"employees": employees,
|
395
|
+
}
|
396
|
+
avros.append(avro)
|
397
|
+
_statuses_time_list.append((avro, _date_time))
|
398
|
+
with open(self.avro_file["status"], "a+b") as f:
|
399
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
400
|
+
if self.enable_pgsql:
|
401
|
+
# data already acquired from Avro part
|
402
|
+
if len(_statuses_time_list) > 0:
|
403
|
+
for _status_dict, _date_time in _statuses_time_list:
|
404
|
+
for key in ["lng", "lat", "parent_id"]:
|
405
|
+
if key not in _status_dict:
|
406
|
+
_status_dict[key] = -1
|
407
|
+
for key in [
|
408
|
+
"action",
|
409
|
+
]:
|
410
|
+
if key not in _status_dict:
|
411
|
+
_status_dict[key] = ""
|
412
|
+
_status_dict["created_at"] = _date_time
|
413
|
+
else:
|
414
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
415
|
+
for agent in self.agents:
|
416
|
+
_date_time = datetime.now(timezone.utc)
|
417
|
+
position = await agent.memory.get("position")
|
418
|
+
x = position["xy_position"]["x"]
|
419
|
+
y = position["xy_position"]["y"]
|
420
|
+
lng, lat = self.projector(x, y, inverse=True)
|
421
|
+
if "aoi_position" in position:
|
422
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
423
|
+
elif "lane_position" in position:
|
424
|
+
parent_id = position["lane_position"]["lane_id"]
|
425
|
+
else:
|
426
|
+
# BUG: 需要处理
|
427
|
+
parent_id = -1
|
428
|
+
needs = await agent.memory.get("needs")
|
429
|
+
action = await agent.memory.get("current_step")
|
430
|
+
action = action["intention"]
|
431
|
+
_status_dict = {
|
432
|
+
"id": agent._uuid,
|
433
|
+
"day": await self.simulator.get_simulator_day(),
|
434
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
435
|
+
"lng": lng,
|
436
|
+
"lat": lat,
|
437
|
+
"parent_id": parent_id,
|
438
|
+
"action": action,
|
439
|
+
"hungry": needs["hungry"],
|
440
|
+
"tired": needs["tired"],
|
441
|
+
"safe": needs["safe"],
|
442
|
+
"social": needs["social"],
|
443
|
+
"created_at": _date_time,
|
444
|
+
}
|
445
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
446
|
+
else:
|
447
|
+
# institution
|
448
|
+
for agent in self.agents:
|
449
|
+
_date_time = datetime.now(timezone.utc)
|
450
|
+
position = await agent.memory.get("position")
|
451
|
+
x = position["xy_position"]["x"]
|
452
|
+
y = position["xy_position"]["y"]
|
453
|
+
lng, lat = self.projector(x, y, inverse=True)
|
454
|
+
# ATTENTION: no valid position for an institution
|
455
|
+
parent_id = -1
|
456
|
+
try:
|
457
|
+
nominal_gdp = await agent.memory.get("nominal_gdp")
|
458
|
+
except:
|
459
|
+
nominal_gdp = []
|
460
|
+
try:
|
461
|
+
real_gdp = await agent.memory.get("real_gdp")
|
462
|
+
except:
|
463
|
+
real_gdp = []
|
464
|
+
try:
|
465
|
+
unemployment = await agent.memory.get("unemployment")
|
466
|
+
except:
|
467
|
+
unemployment = []
|
468
|
+
try:
|
469
|
+
wages = await agent.memory.get("wages")
|
470
|
+
except:
|
471
|
+
wages = []
|
472
|
+
try:
|
473
|
+
prices = await agent.memory.get("prices")
|
474
|
+
except:
|
475
|
+
prices = []
|
476
|
+
try:
|
477
|
+
inventory = await agent.memory.get("inventory")
|
478
|
+
except:
|
479
|
+
inventory = 0
|
480
|
+
try:
|
481
|
+
price = await agent.memory.get("price")
|
482
|
+
except:
|
483
|
+
price = 0.0
|
484
|
+
try:
|
485
|
+
interest_rate = await agent.memory.get("interest_rate")
|
486
|
+
except:
|
487
|
+
interest_rate = 0.0
|
488
|
+
try:
|
489
|
+
bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
|
490
|
+
except:
|
491
|
+
bracket_cutoffs = []
|
492
|
+
try:
|
493
|
+
bracket_rates = await agent.memory.get("bracket_rates")
|
494
|
+
except:
|
495
|
+
bracket_rates = []
|
496
|
+
try:
|
497
|
+
employees = await agent.memory.get("employees")
|
498
|
+
except:
|
499
|
+
employees = []
|
500
|
+
_status_dict = {
|
501
|
+
"id": agent._uuid,
|
502
|
+
"day": await self.simulator.get_simulator_day(),
|
503
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
504
|
+
"lng": lng,
|
505
|
+
"lat": lat,
|
506
|
+
"parent_id": parent_id,
|
507
|
+
"action": "",
|
508
|
+
"type": await agent.memory.get("type"),
|
509
|
+
"nominal_gdp": nominal_gdp,
|
510
|
+
"real_gdp": real_gdp,
|
511
|
+
"unemployment": unemployment,
|
512
|
+
"wages": wages,
|
513
|
+
"prices": prices,
|
514
|
+
"inventory": inventory,
|
515
|
+
"price": price,
|
516
|
+
"interest_rate": interest_rate,
|
517
|
+
"bracket_cutoffs": bracket_cutoffs,
|
518
|
+
"bracket_rates": bracket_rates,
|
519
|
+
"employees": employees,
|
520
|
+
"created_at": _date_time,
|
521
|
+
}
|
522
|
+
_statuses_time_list.append((_status_dict, _date_time))
|
523
|
+
to_update_statues: list[tuple] = []
|
524
|
+
for _status_dict, _ in _statuses_time_list:
|
525
|
+
BASIC_KEYS = [
|
526
|
+
"id",
|
527
|
+
"day",
|
528
|
+
"t",
|
529
|
+
"lng",
|
530
|
+
"lat",
|
531
|
+
"parent_id",
|
532
|
+
"action",
|
533
|
+
"created_at",
|
534
|
+
]
|
535
|
+
_data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
|
536
|
+
_other_dict = json.dumps(
|
537
|
+
{k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
|
538
|
+
)
|
539
|
+
_data.append(_other_dict)
|
540
|
+
_data.append(_status_dict["created_at"])
|
541
|
+
to_update_statues.append(tuple(_data))
|
542
|
+
if self._last_asyncio_pg_task is not None:
|
543
|
+
await self._last_asyncio_pg_task
|
544
|
+
self._last_asyncio_pg_task = (
|
545
|
+
self._pgsql_writer.async_write_status.remote( # type:ignore
|
546
|
+
to_update_statues
|
547
|
+
)
|
548
|
+
)
|
549
|
+
|
550
|
+
async def step(self):
|
551
|
+
if not self.initialized:
|
552
|
+
await self.init_agents()
|
553
|
+
|
554
|
+
tasks = [agent.run() for agent in self.agents]
|
555
|
+
await asyncio.gather(*tasks)
|
556
|
+
await self.save_status()
|
557
|
+
|
558
|
+
async def run(self, day: int = 1):
|
559
|
+
"""运行模拟器
|
560
|
+
|
561
|
+
Args:
|
562
|
+
day: 运行天数,默认为1天
|
563
|
+
"""
|
564
|
+
try:
|
565
|
+
# 获取开始时间
|
566
|
+
start_time = await self.simulator.get_time()
|
567
|
+
start_time = int(start_time)
|
568
|
+
# 计算结束时间(秒)
|
569
|
+
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
570
|
+
|
571
|
+
while True:
|
572
|
+
current_time = await self.simulator.get_time()
|
573
|
+
current_time = int(current_time)
|
574
|
+
if current_time >= end_time:
|
575
|
+
break
|
576
|
+
await self.step()
|
577
|
+
|
578
|
+
except Exception as e:
|
579
|
+
logger.error(f"模拟器运行错误: {str(e)}")
|
580
|
+
raise RuntimeError(str(e)) from e
|