pycityagent 2.0.0a43__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. pycityagent/__init__.py +23 -0
  2. pycityagent/agent.py +833 -0
  3. pycityagent/cli/wrapper.py +44 -0
  4. pycityagent/economy/__init__.py +5 -0
  5. pycityagent/economy/econ_client.py +355 -0
  6. pycityagent/environment/__init__.py +7 -0
  7. pycityagent/environment/interact/__init__.py +0 -0
  8. pycityagent/environment/interact/interact.py +198 -0
  9. pycityagent/environment/message/__init__.py +0 -0
  10. pycityagent/environment/sence/__init__.py +0 -0
  11. pycityagent/environment/sence/static.py +416 -0
  12. pycityagent/environment/sidecar/__init__.py +8 -0
  13. pycityagent/environment/sidecar/sidecarv2.py +109 -0
  14. pycityagent/environment/sim/__init__.py +29 -0
  15. pycityagent/environment/sim/aoi_service.py +39 -0
  16. pycityagent/environment/sim/client.py +126 -0
  17. pycityagent/environment/sim/clock_service.py +44 -0
  18. pycityagent/environment/sim/economy_services.py +192 -0
  19. pycityagent/environment/sim/lane_service.py +111 -0
  20. pycityagent/environment/sim/light_service.py +122 -0
  21. pycityagent/environment/sim/person_service.py +295 -0
  22. pycityagent/environment/sim/road_service.py +39 -0
  23. pycityagent/environment/sim/sim_env.py +145 -0
  24. pycityagent/environment/sim/social_service.py +59 -0
  25. pycityagent/environment/simulator.py +331 -0
  26. pycityagent/environment/utils/__init__.py +14 -0
  27. pycityagent/environment/utils/base64.py +16 -0
  28. pycityagent/environment/utils/const.py +244 -0
  29. pycityagent/environment/utils/geojson.py +24 -0
  30. pycityagent/environment/utils/grpc.py +57 -0
  31. pycityagent/environment/utils/map_utils.py +157 -0
  32. pycityagent/environment/utils/port.py +11 -0
  33. pycityagent/environment/utils/protobuf.py +41 -0
  34. pycityagent/llm/__init__.py +11 -0
  35. pycityagent/llm/embeddings.py +231 -0
  36. pycityagent/llm/llm.py +377 -0
  37. pycityagent/llm/llmconfig.py +13 -0
  38. pycityagent/llm/utils.py +6 -0
  39. pycityagent/memory/__init__.py +13 -0
  40. pycityagent/memory/const.py +43 -0
  41. pycityagent/memory/faiss_query.py +302 -0
  42. pycityagent/memory/memory.py +448 -0
  43. pycityagent/memory/memory_base.py +170 -0
  44. pycityagent/memory/profile.py +165 -0
  45. pycityagent/memory/self_define.py +165 -0
  46. pycityagent/memory/state.py +173 -0
  47. pycityagent/memory/utils.py +28 -0
  48. pycityagent/message/__init__.py +3 -0
  49. pycityagent/message/messager.py +88 -0
  50. pycityagent/metrics/__init__.py +6 -0
  51. pycityagent/metrics/mlflow_client.py +147 -0
  52. pycityagent/metrics/utils/const.py +0 -0
  53. pycityagent/pycityagent-sim +0 -0
  54. pycityagent/pycityagent-ui +0 -0
  55. pycityagent/simulation/__init__.py +8 -0
  56. pycityagent/simulation/agentgroup.py +580 -0
  57. pycityagent/simulation/simulation.py +634 -0
  58. pycityagent/simulation/storage/pg.py +184 -0
  59. pycityagent/survey/__init__.py +4 -0
  60. pycityagent/survey/manager.py +54 -0
  61. pycityagent/survey/models.py +120 -0
  62. pycityagent/utils/__init__.py +11 -0
  63. pycityagent/utils/avro_schema.py +109 -0
  64. pycityagent/utils/decorators.py +99 -0
  65. pycityagent/utils/parsers/__init__.py +13 -0
  66. pycityagent/utils/parsers/code_block_parser.py +37 -0
  67. pycityagent/utils/parsers/json_parser.py +86 -0
  68. pycityagent/utils/parsers/parser_base.py +60 -0
  69. pycityagent/utils/pg_query.py +92 -0
  70. pycityagent/utils/survey_util.py +53 -0
  71. pycityagent/workflow/__init__.py +26 -0
  72. pycityagent/workflow/block.py +211 -0
  73. pycityagent/workflow/prompt.py +79 -0
  74. pycityagent/workflow/tool.py +240 -0
  75. pycityagent/workflow/trigger.py +163 -0
  76. pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
  77. pycityagent-2.0.0a43.dist-info/METADATA +235 -0
  78. pycityagent-2.0.0a43.dist-info/RECORD +81 -0
  79. pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
  80. pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
  81. pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
@@ -0,0 +1,580 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import time
5
+ import uuid
6
+ from datetime import datetime, timezone
7
+ from pathlib import Path
8
+ from typing import Any
9
+ from uuid import UUID
10
+
11
+ import fastavro
12
+ import pyproj
13
+ import ray
14
+ from langchain_core.embeddings import Embeddings
15
+
16
+ from ..agent import Agent, CitizenAgent, InstitutionAgent
17
+ from ..economy.econ_client import EconomyClient
18
+ from ..environment.simulator import Simulator
19
+ from ..llm.llm import LLM
20
+ from ..llm.llmconfig import LLMConfig
21
+ from ..memory import FaissQuery
22
+ from ..message import Messager
23
+ from ..metrics import MlflowClient
24
+ from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
25
+ STATUS_SCHEMA, SURVEY_SCHEMA)
26
+
27
+ logger = logging.getLogger("pycityagent")
28
+
29
+
30
+ @ray.remote
31
+ class AgentGroup:
32
+ def __init__(
33
+ self,
34
+ agents: list[Agent],
35
+ config: dict,
36
+ exp_id: str | UUID,
37
+ exp_name: str,
38
+ enable_avro: bool,
39
+ avro_path: Path,
40
+ enable_pgsql: bool,
41
+ pgsql_writer: ray.ObjectRef,
42
+ mlflow_run_id: str,
43
+ embedding_model: Embeddings,
44
+ logging_level: int,
45
+ ):
46
+ logger.setLevel(logging_level)
47
+ self._uuid = str(uuid.uuid4())
48
+ self.agents = agents
49
+ self.config = config
50
+ self.exp_id = exp_id
51
+ self.enable_avro = enable_avro
52
+ self.enable_pgsql = enable_pgsql
53
+ self.embedding_model = embedding_model
54
+ if enable_avro:
55
+ self.avro_path = avro_path / f"{self._uuid}"
56
+ self.avro_path.mkdir(parents=True, exist_ok=True)
57
+ self.avro_file = {
58
+ "profile": self.avro_path / f"profile.avro",
59
+ "dialog": self.avro_path / f"dialog.avro",
60
+ "status": self.avro_path / f"status.avro",
61
+ "survey": self.avro_path / f"survey.avro",
62
+ }
63
+ if self.enable_pgsql:
64
+ pass
65
+
66
+ self.messager = Messager.remote(
67
+ hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
68
+ port=config["simulator_request"]["mqtt"]["port"],
69
+ username=config["simulator_request"]["mqtt"].get("username", None),
70
+ password=config["simulator_request"]["mqtt"].get("password", None),
71
+ )
72
+ self.message_dispatch_task = None
73
+ self._pgsql_writer = pgsql_writer
74
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
75
+ self.initialized = False
76
+ self.id2agent = {}
77
+ # Step:1 prepare LLM client
78
+ llmConfig = LLMConfig(config["llm_request"])
79
+ logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
80
+ self.llm = LLM(llmConfig)
81
+
82
+ # Step:2 prepare Simulator
83
+ logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
84
+ self.simulator = Simulator(config["simulator_request"])
85
+ self.projector = pyproj.Proj(self.simulator.map.header["projection"])
86
+
87
+ # Step:3 prepare Economy client
88
+ if "economy" in config["simulator_request"]:
89
+ logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
90
+ self.economy_client = EconomyClient(
91
+ config["simulator_request"]["economy"]["server"]
92
+ )
93
+ else:
94
+ self.economy_client = None
95
+
96
+ # Mlflow
97
+ _mlflow_config = config.get("metric_request", {}).get("mlflow")
98
+ if _mlflow_config:
99
+ logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
100
+ self.mlflow_client = MlflowClient(
101
+ config=_mlflow_config,
102
+ mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
103
+ experiment_name=exp_name,
104
+ run_id=mlflow_run_id,
105
+ )
106
+ else:
107
+ self.mlflow_client = None
108
+
109
+ # set FaissQuery
110
+ if self.embedding_model is not None:
111
+ self.faiss_query = FaissQuery(
112
+ embeddings=self.embedding_model,
113
+ )
114
+ else:
115
+ self.faiss_query = None
116
+ for agent in self.agents:
117
+ agent.set_exp_id(self.exp_id) # type: ignore
118
+ agent.set_llm_client(self.llm)
119
+ agent.set_simulator(self.simulator)
120
+ if self.economy_client is not None:
121
+ agent.set_economy_client(self.economy_client)
122
+ if self.mlflow_client is not None:
123
+ agent.set_mlflow_client(self.mlflow_client)
124
+ agent.set_messager(self.messager)
125
+ if self.enable_avro:
126
+ agent.set_avro_file(self.avro_file) # type: ignore
127
+ if self.enable_pgsql:
128
+ agent.set_pgsql_writer(self._pgsql_writer)
129
+ # set memory.faiss_query
130
+ if self.faiss_query is not None:
131
+ agent.memory.set_faiss_query(self.faiss_query)
132
+ # set memory.embedding model
133
+ if self.embedding_model is not None:
134
+ agent.memory.set_embedding_model(self.embedding_model)
135
+
136
+ async def __aexit__(self, exc_type, exc_value, traceback):
137
+ self.message_dispatch_task.cancel() # type: ignore
138
+ await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
139
+
140
+ async def init_agents(self):
141
+ logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
142
+ logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
143
+ for agent in self.agents:
144
+ await agent.bind_to_simulator() # type: ignore
145
+ self.id2agent = {agent._uuid: agent for agent in self.agents}
146
+ logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
147
+ await self.messager.connect.remote()
148
+ if ray.get(self.messager.is_connected.remote()):
149
+ await self.messager.start_listening.remote()
150
+ topics = []
151
+ agents = []
152
+ for agent in self.agents:
153
+ agent.set_messager(self.messager)
154
+ topic = (f"exps/{self.exp_id}/agents/{agent._uuid}/#", 1)
155
+ topics.append(topic)
156
+ agents.append(agent.uuid)
157
+ await self.messager.subscribe.remote(topics, agents)
158
+ self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
159
+ if self.enable_avro:
160
+ logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
161
+ # profile
162
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
163
+ filename = self.avro_file["profile"]
164
+ with open(filename, "wb") as f:
165
+ profiles = []
166
+ for agent in self.agents:
167
+ profile = await agent.memory._profile.export()
168
+ profile = profile[0]
169
+ profile["id"] = agent._uuid
170
+ profiles.append(profile)
171
+ fastavro.writer(f, PROFILE_SCHEMA, profiles)
172
+
173
+ # dialog
174
+ filename = self.avro_file["dialog"]
175
+ with open(filename, "wb") as f:
176
+ dialogs = []
177
+ fastavro.writer(f, DIALOG_SCHEMA, dialogs)
178
+
179
+ # status
180
+ filename = self.avro_file["status"]
181
+ with open(filename, "wb") as f:
182
+ statuses = []
183
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
184
+ fastavro.writer(f, STATUS_SCHEMA, statuses)
185
+ else:
186
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
187
+
188
+ # survey
189
+ filename = self.avro_file["survey"]
190
+ with open(filename, "wb") as f:
191
+ surveys = []
192
+ fastavro.writer(f, SURVEY_SCHEMA, surveys)
193
+
194
+ if self.enable_pgsql:
195
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
196
+ profiles: list[Any] = []
197
+ for agent in self.agents:
198
+ profile = await agent.memory._profile.export()
199
+ profile = profile[0]
200
+ profile["id"] = agent._uuid
201
+ profiles.append(
202
+ (
203
+ agent._uuid,
204
+ profile.get("name", ""),
205
+ json.dumps(
206
+ {
207
+ k: v
208
+ for k, v in profile.items()
209
+ if k not in {"id", "name"}
210
+ }
211
+ ),
212
+ )
213
+ )
214
+ else:
215
+ profiles: list[Any] = []
216
+ for agent in self.agents:
217
+ profile = await agent.memory._profile.export()
218
+ profile = profile[0]
219
+ profile["id"] = agent._uuid
220
+ profiles.append(
221
+ (
222
+ agent._uuid,
223
+ profile.get("name", ""),
224
+ json.dumps(
225
+ {
226
+ k: v
227
+ for k, v in profile.items()
228
+ if k not in {"id", "name"}
229
+ }
230
+ ),
231
+ )
232
+ )
233
+ await self._pgsql_writer.async_write_profile.remote( # type:ignore
234
+ profiles
235
+ )
236
+ self.initialized = True
237
+ logger.debug(f"-----AgentGroup {self._uuid} initialized")
238
+
239
+ async def gather(self, content: str):
240
+ logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
241
+ results = {}
242
+ for agent in self.agents:
243
+ results[agent._uuid] = await agent.memory.get(content)
244
+ return results
245
+
246
+ async def update(self, target_agent_uuid: str, target_key: str, content: Any):
247
+ logger.debug(
248
+ f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
249
+ )
250
+ agent = self.id2agent[target_agent_uuid]
251
+ await agent.memory.update(target_key, content)
252
+
253
+ async def message_dispatch(self):
254
+ logger.debug(f"-----Starting message dispatch for group {self._uuid}")
255
+ while True:
256
+ if not ray.get(self.messager.is_connected.remote()):
257
+ logger.warning(
258
+ "Messager is not connected. Skipping message processing."
259
+ )
260
+ break
261
+
262
+ # Step 1: 获取消息
263
+ messages = await self.messager.fetch_messages.remote()
264
+ logger.info(f"Group {self._uuid} received {len(messages)} messages")
265
+
266
+ # Step 2: 分发消息到对应的 Agent
267
+ for message in messages:
268
+ topic = message.topic.value
269
+ payload = message.payload
270
+
271
+ # 添加解码步骤,将bytes转换为str
272
+ if isinstance(payload, bytes):
273
+ payload = payload.decode("utf-8")
274
+ payload = json.loads(payload)
275
+
276
+ # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
277
+ _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
278
+
279
+ if agent_uuid in self.id2agent:
280
+ agent = self.id2agent[agent_uuid]
281
+ # topic_type: agent-chat, user-chat, user-survey, gather
282
+ if topic_type == "agent-chat":
283
+ await agent.handle_agent_chat_message(payload)
284
+ elif topic_type == "user-chat":
285
+ await agent.handle_user_chat_message(payload)
286
+ elif topic_type == "user-survey":
287
+ await agent.handle_user_survey_message(payload)
288
+ elif topic_type == "gather":
289
+ await agent.handle_gather_message(payload)
290
+
291
+ await asyncio.sleep(0.5)
292
+
293
+ async def save_status(self):
294
+ _statuses_time_list: list[tuple[dict, datetime]] = []
295
+ if self.enable_avro:
296
+ logger.debug(f"-----Saving status for group {self._uuid}")
297
+ avros = []
298
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
299
+ for agent in self.agents:
300
+ _date_time = datetime.now(timezone.utc)
301
+ position = await agent.memory.get("position")
302
+ x = position["xy_position"]["x"]
303
+ y = position["xy_position"]["y"]
304
+ lng, lat = self.projector(x, y, inverse=True)
305
+ if "aoi_position" in position:
306
+ parent_id = position["aoi_position"]["aoi_id"]
307
+ elif "lane_position" in position:
308
+ parent_id = position["lane_position"]["lane_id"]
309
+ else:
310
+ parent_id = -1
311
+ needs = await agent.memory.get("needs")
312
+ action = await agent.memory.get("current_step")
313
+ action = action["intention"]
314
+ avro = {
315
+ "id": agent._uuid,
316
+ "day": await self.simulator.get_simulator_day(),
317
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
318
+ "lng": lng,
319
+ "lat": lat,
320
+ "parent_id": parent_id,
321
+ "action": action,
322
+ "hungry": needs["hungry"],
323
+ "tired": needs["tired"],
324
+ "safe": needs["safe"],
325
+ "social": needs["social"],
326
+ "created_at": int(_date_time.timestamp() * 1000),
327
+ }
328
+ avros.append(avro)
329
+ _statuses_time_list.append((avro, _date_time))
330
+ with open(self.avro_file["status"], "a+b") as f:
331
+ fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
332
+ else:
333
+ for agent in self.agents:
334
+ _date_time = datetime.now(timezone.utc)
335
+ try:
336
+ nominal_gdp = await agent.memory.get("nominal_gdp")
337
+ except:
338
+ nominal_gdp = []
339
+ try:
340
+ real_gdp = await agent.memory.get("real_gdp")
341
+ except:
342
+ real_gdp = []
343
+ try:
344
+ unemployment = await agent.memory.get("unemployment")
345
+ except:
346
+ unemployment = []
347
+ try:
348
+ wages = await agent.memory.get("wages")
349
+ except:
350
+ wages = []
351
+ try:
352
+ prices = await agent.memory.get("prices")
353
+ except:
354
+ prices = []
355
+ try:
356
+ inventory = await agent.memory.get("inventory")
357
+ except:
358
+ inventory = 0
359
+ try:
360
+ price = await agent.memory.get("price")
361
+ except:
362
+ price = 0.0
363
+ try:
364
+ interest_rate = await agent.memory.get("interest_rate")
365
+ except:
366
+ interest_rate = 0.0
367
+ try:
368
+ bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
369
+ except:
370
+ bracket_cutoffs = []
371
+ try:
372
+ bracket_rates = await agent.memory.get("bracket_rates")
373
+ except:
374
+ bracket_rates = []
375
+ try:
376
+ employees = await agent.memory.get("employees")
377
+ except:
378
+ employees = []
379
+ avro = {
380
+ "id": agent._uuid,
381
+ "day": await self.simulator.get_simulator_day(),
382
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
383
+ "type": await agent.memory.get("type"),
384
+ "nominal_gdp": nominal_gdp,
385
+ "real_gdp": real_gdp,
386
+ "unemployment": unemployment,
387
+ "wages": wages,
388
+ "prices": prices,
389
+ "inventory": inventory,
390
+ "price": price,
391
+ "interest_rate": interest_rate,
392
+ "bracket_cutoffs": bracket_cutoffs,
393
+ "bracket_rates": bracket_rates,
394
+ "employees": employees,
395
+ }
396
+ avros.append(avro)
397
+ _statuses_time_list.append((avro, _date_time))
398
+ with open(self.avro_file["status"], "a+b") as f:
399
+ fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
400
+ if self.enable_pgsql:
401
+ # data already acquired from Avro part
402
+ if len(_statuses_time_list) > 0:
403
+ for _status_dict, _date_time in _statuses_time_list:
404
+ for key in ["lng", "lat", "parent_id"]:
405
+ if key not in _status_dict:
406
+ _status_dict[key] = -1
407
+ for key in [
408
+ "action",
409
+ ]:
410
+ if key not in _status_dict:
411
+ _status_dict[key] = ""
412
+ _status_dict["created_at"] = _date_time
413
+ else:
414
+ if not issubclass(type(self.agents[0]), InstitutionAgent):
415
+ for agent in self.agents:
416
+ _date_time = datetime.now(timezone.utc)
417
+ position = await agent.memory.get("position")
418
+ x = position["xy_position"]["x"]
419
+ y = position["xy_position"]["y"]
420
+ lng, lat = self.projector(x, y, inverse=True)
421
+ if "aoi_position" in position:
422
+ parent_id = position["aoi_position"]["aoi_id"]
423
+ elif "lane_position" in position:
424
+ parent_id = position["lane_position"]["lane_id"]
425
+ else:
426
+ # BUG: 需要处理
427
+ parent_id = -1
428
+ needs = await agent.memory.get("needs")
429
+ action = await agent.memory.get("current_step")
430
+ action = action["intention"]
431
+ _status_dict = {
432
+ "id": agent._uuid,
433
+ "day": await self.simulator.get_simulator_day(),
434
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
435
+ "lng": lng,
436
+ "lat": lat,
437
+ "parent_id": parent_id,
438
+ "action": action,
439
+ "hungry": needs["hungry"],
440
+ "tired": needs["tired"],
441
+ "safe": needs["safe"],
442
+ "social": needs["social"],
443
+ "created_at": _date_time,
444
+ }
445
+ _statuses_time_list.append((_status_dict, _date_time))
446
+ else:
447
+ # institution
448
+ for agent in self.agents:
449
+ _date_time = datetime.now(timezone.utc)
450
+ position = await agent.memory.get("position")
451
+ x = position["xy_position"]["x"]
452
+ y = position["xy_position"]["y"]
453
+ lng, lat = self.projector(x, y, inverse=True)
454
+ # ATTENTION: no valid position for an institution
455
+ parent_id = -1
456
+ try:
457
+ nominal_gdp = await agent.memory.get("nominal_gdp")
458
+ except:
459
+ nominal_gdp = []
460
+ try:
461
+ real_gdp = await agent.memory.get("real_gdp")
462
+ except:
463
+ real_gdp = []
464
+ try:
465
+ unemployment = await agent.memory.get("unemployment")
466
+ except:
467
+ unemployment = []
468
+ try:
469
+ wages = await agent.memory.get("wages")
470
+ except:
471
+ wages = []
472
+ try:
473
+ prices = await agent.memory.get("prices")
474
+ except:
475
+ prices = []
476
+ try:
477
+ inventory = await agent.memory.get("inventory")
478
+ except:
479
+ inventory = 0
480
+ try:
481
+ price = await agent.memory.get("price")
482
+ except:
483
+ price = 0.0
484
+ try:
485
+ interest_rate = await agent.memory.get("interest_rate")
486
+ except:
487
+ interest_rate = 0.0
488
+ try:
489
+ bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
490
+ except:
491
+ bracket_cutoffs = []
492
+ try:
493
+ bracket_rates = await agent.memory.get("bracket_rates")
494
+ except:
495
+ bracket_rates = []
496
+ try:
497
+ employees = await agent.memory.get("employees")
498
+ except:
499
+ employees = []
500
+ _status_dict = {
501
+ "id": agent._uuid,
502
+ "day": await self.simulator.get_simulator_day(),
503
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
504
+ "lng": lng,
505
+ "lat": lat,
506
+ "parent_id": parent_id,
507
+ "action": "",
508
+ "type": await agent.memory.get("type"),
509
+ "nominal_gdp": nominal_gdp,
510
+ "real_gdp": real_gdp,
511
+ "unemployment": unemployment,
512
+ "wages": wages,
513
+ "prices": prices,
514
+ "inventory": inventory,
515
+ "price": price,
516
+ "interest_rate": interest_rate,
517
+ "bracket_cutoffs": bracket_cutoffs,
518
+ "bracket_rates": bracket_rates,
519
+ "employees": employees,
520
+ "created_at": _date_time,
521
+ }
522
+ _statuses_time_list.append((_status_dict, _date_time))
523
+ to_update_statues: list[tuple] = []
524
+ for _status_dict, _ in _statuses_time_list:
525
+ BASIC_KEYS = [
526
+ "id",
527
+ "day",
528
+ "t",
529
+ "lng",
530
+ "lat",
531
+ "parent_id",
532
+ "action",
533
+ "created_at",
534
+ ]
535
+ _data = [_status_dict[k] for k in BASIC_KEYS if k != "created_at"]
536
+ _other_dict = json.dumps(
537
+ {k: v for k, v in _status_dict.items() if k not in BASIC_KEYS}
538
+ )
539
+ _data.append(_other_dict)
540
+ _data.append(_status_dict["created_at"])
541
+ to_update_statues.append(tuple(_data))
542
+ if self._last_asyncio_pg_task is not None:
543
+ await self._last_asyncio_pg_task
544
+ self._last_asyncio_pg_task = (
545
+ self._pgsql_writer.async_write_status.remote( # type:ignore
546
+ to_update_statues
547
+ )
548
+ )
549
+
550
+ async def step(self):
551
+ if not self.initialized:
552
+ await self.init_agents()
553
+
554
+ tasks = [agent.run() for agent in self.agents]
555
+ await asyncio.gather(*tasks)
556
+ await self.save_status()
557
+
558
+ async def run(self, day: int = 1):
559
+ """运行模拟器
560
+
561
+ Args:
562
+ day: 运行天数,默认为1天
563
+ """
564
+ try:
565
+ # 获取开始时间
566
+ start_time = await self.simulator.get_time()
567
+ start_time = int(start_time)
568
+ # 计算结束时间(秒)
569
+ end_time = start_time + day * 24 * 3600 # 将天数转换为秒
570
+
571
+ while True:
572
+ current_time = await self.simulator.get_time()
573
+ current_time = int(current_time)
574
+ if current_time >= end_time:
575
+ break
576
+ await self.step()
577
+
578
+ except Exception as e:
579
+ logger.error(f"模拟器运行错误: {str(e)}")
580
+ raise RuntimeError(str(e)) from e