pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@ from langchain_core.embeddings import Embeddings
|
|
16
16
|
|
17
17
|
from ..agent import Agent, InstitutionAgent
|
18
18
|
from ..economy.econ_client import EconomyClient
|
19
|
-
from ..environment
|
19
|
+
from ..environment import Simulator
|
20
20
|
from ..llm.llm import LLM
|
21
21
|
from ..llm.llmconfig import LLMConfig
|
22
22
|
from ..memory import FaissQuery, Memory
|
@@ -39,12 +39,13 @@ class AgentGroup:
|
|
39
39
|
list[Callable[[], tuple[dict, dict, dict]]],
|
40
40
|
],
|
41
41
|
config: dict,
|
42
|
-
exp_id: str | UUID,
|
43
42
|
exp_name: str,
|
43
|
+
exp_id: str | UUID,
|
44
44
|
enable_avro: bool,
|
45
45
|
avro_path: Path,
|
46
46
|
enable_pgsql: bool,
|
47
47
|
pgsql_writer: ray.ObjectRef,
|
48
|
+
message_interceptor: ray.ObjectRef,
|
48
49
|
mlflow_run_id: str,
|
49
50
|
embedding_model: Embeddings,
|
50
51
|
logging_level: int,
|
@@ -80,6 +81,18 @@ class AgentGroup:
|
|
80
81
|
}
|
81
82
|
if self.enable_pgsql:
|
82
83
|
pass
|
84
|
+
# Mlflow
|
85
|
+
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
86
|
+
if _mlflow_config:
|
87
|
+
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
88
|
+
self.mlflow_client = MlflowClient(
|
89
|
+
config=_mlflow_config,
|
90
|
+
mlflow_run_name=f"{exp_name}_{1000*int(time.time())}",
|
91
|
+
experiment_name=exp_name,
|
92
|
+
run_id=mlflow_run_id,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
self.mlflow_client = None
|
83
96
|
|
84
97
|
# prepare Messager
|
85
98
|
if "mqtt" in config["simulator_request"]:
|
@@ -94,6 +107,7 @@ class AgentGroup:
|
|
94
107
|
|
95
108
|
self.message_dispatch_task = None
|
96
109
|
self._pgsql_writer = pgsql_writer
|
110
|
+
self._message_interceptor = message_interceptor
|
97
111
|
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
98
112
|
self.initialized = False
|
99
113
|
self.id2agent = {}
|
@@ -116,19 +130,6 @@ class AgentGroup:
|
|
116
130
|
else:
|
117
131
|
self.economy_client = None
|
118
132
|
|
119
|
-
# Mlflow
|
120
|
-
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
121
|
-
if _mlflow_config:
|
122
|
-
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
123
|
-
self.mlflow_client = MlflowClient(
|
124
|
-
config=_mlflow_config,
|
125
|
-
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
126
|
-
experiment_name=exp_name,
|
127
|
-
run_id=mlflow_run_id,
|
128
|
-
)
|
129
|
-
else:
|
130
|
-
self.mlflow_client = None
|
131
|
-
|
132
133
|
# set FaissQuery
|
133
134
|
if self.embedding_model is not None:
|
134
135
|
self.faiss_query = FaissQuery(
|
@@ -144,27 +145,25 @@ class AgentGroup:
|
|
144
145
|
extra_attributes, profile, base = memory_config_function_group_i()
|
145
146
|
memory = Memory(config=extra_attributes, profile=profile, base=base)
|
146
147
|
agent = agent_class_i(
|
147
|
-
name=f"{agent_class_i.__name__}_{i}",
|
148
|
+
name=f"{agent_class_i.__name__}_{i}", # type: ignore
|
148
149
|
memory=memory,
|
149
150
|
llm_client=self.llm,
|
150
151
|
economy_client=self.economy_client,
|
151
152
|
simulator=self.simulator,
|
152
153
|
)
|
153
154
|
agent.set_exp_id(self.exp_id) # type: ignore
|
154
|
-
if self.mlflow_client is not None:
|
155
|
-
agent.set_mlflow_client(self.mlflow_client)
|
156
155
|
if self.messager is not None:
|
157
156
|
agent.set_messager(self.messager)
|
157
|
+
if self.mlflow_client is not None:
|
158
|
+
agent.set_mlflow_client(self.mlflow_client) # type: ignore
|
158
159
|
if self.enable_avro:
|
159
160
|
agent.set_avro_file(self.avro_file) # type: ignore
|
160
161
|
if self.enable_pgsql:
|
161
162
|
agent.set_pgsql_writer(self._pgsql_writer)
|
162
|
-
if self.faiss_query is not None:
|
163
|
-
agent.memory.set_faiss_query(self.faiss_query)
|
164
|
-
if self.embedding_model is not None:
|
165
|
-
agent.memory.set_embedding_model(self.embedding_model)
|
166
163
|
if self.agent_config_file is not None and self.agent_config_file[i]:
|
167
164
|
agent.load_from_file(self.agent_config_file[i])
|
165
|
+
if self._message_interceptor is not None:
|
166
|
+
agent.set_message_interceptor(self._message_interceptor)
|
168
167
|
self.agents.append(agent)
|
169
168
|
self.id2agent[agent._uuid] = agent
|
170
169
|
|
@@ -193,11 +192,21 @@ class AgentGroup:
|
|
193
192
|
self.message_dispatch_task.cancel() # type: ignore
|
194
193
|
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
|
195
194
|
|
195
|
+
async def insert_agents(self):
|
196
|
+
bind_tasks = []
|
197
|
+
for agent in self.agents:
|
198
|
+
bind_tasks.append(agent.bind_to_simulator()) # type: ignore
|
199
|
+
await asyncio.gather(*bind_tasks)
|
200
|
+
|
196
201
|
async def init_agents(self):
|
197
202
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
198
203
|
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
199
|
-
|
200
|
-
await
|
204
|
+
while True:
|
205
|
+
day = await self.simulator.get_simulator_day()
|
206
|
+
if day == 0:
|
207
|
+
break
|
208
|
+
await asyncio.sleep(1)
|
209
|
+
await self.insert_agents()
|
201
210
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
202
211
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
203
212
|
assert self.messager is not None
|
@@ -221,7 +230,7 @@ class AgentGroup:
|
|
221
230
|
with open(filename, "wb") as f:
|
222
231
|
profiles = []
|
223
232
|
for agent in self.agents:
|
224
|
-
profile = await agent.
|
233
|
+
profile = await agent.status.profile.export()
|
225
234
|
profile = profile[0]
|
226
235
|
profile["id"] = agent._uuid
|
227
236
|
profiles.append(profile)
|
@@ -252,7 +261,7 @@ class AgentGroup:
|
|
252
261
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
253
262
|
profiles: list[Any] = []
|
254
263
|
for agent in self.agents:
|
255
|
-
profile = await agent.
|
264
|
+
profile = await agent.status.profile.export()
|
256
265
|
profile = profile[0]
|
257
266
|
profile["id"] = agent._uuid
|
258
267
|
profiles.append(
|
@@ -271,7 +280,7 @@ class AgentGroup:
|
|
271
280
|
else:
|
272
281
|
profiles: list[Any] = []
|
273
282
|
for agent in self.agents:
|
274
|
-
profile = await agent.
|
283
|
+
profile = await agent.status.profile.export()
|
275
284
|
profile = profile[0]
|
276
285
|
profile["id"] = agent._uuid
|
277
286
|
profiles.append(
|
@@ -290,6 +299,18 @@ class AgentGroup:
|
|
290
299
|
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
291
300
|
profiles
|
292
301
|
)
|
302
|
+
if self.faiss_query is not None:
|
303
|
+
logger.debug(f"-----Initializing embeddings in AgentGroup {self._uuid} ...")
|
304
|
+
embedding_tasks = []
|
305
|
+
for agent in self.agents:
|
306
|
+
embedding_tasks.append(agent.memory.initialize_embeddings())
|
307
|
+
agent.memory.set_search_components(
|
308
|
+
self.faiss_query, self.embedding_model
|
309
|
+
)
|
310
|
+
agent.memory.set_simulator(self.simulator)
|
311
|
+
await asyncio.gather(*embedding_tasks)
|
312
|
+
logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
|
313
|
+
|
293
314
|
self.initialized = True
|
294
315
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
295
316
|
|
@@ -307,7 +328,7 @@ class AgentGroup:
|
|
307
328
|
if keys:
|
308
329
|
for key in keys:
|
309
330
|
assert values is not None
|
310
|
-
if not agent.
|
331
|
+
if not agent.status.get(key) == values[keys.index(key)]:
|
311
332
|
add = False
|
312
333
|
break
|
313
334
|
if add:
|
@@ -315,18 +336,23 @@ class AgentGroup:
|
|
315
336
|
elif keys:
|
316
337
|
for key in keys:
|
317
338
|
assert values is not None
|
318
|
-
if not agent.
|
339
|
+
if not agent.status.get(key) == values[keys.index(key)]:
|
319
340
|
add = False
|
320
341
|
break
|
321
342
|
if add:
|
322
343
|
filtered_uuids.append(agent._uuid)
|
323
344
|
return filtered_uuids
|
324
345
|
|
325
|
-
async def gather(
|
346
|
+
async def gather(
|
347
|
+
self, content: str, target_agent_uuids: Optional[list[str]] = None
|
348
|
+
):
|
326
349
|
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
327
350
|
results = {}
|
351
|
+
if target_agent_uuids is None:
|
352
|
+
target_agent_uuids = self.agent_uuids
|
328
353
|
for agent in self.agents:
|
329
|
-
|
354
|
+
if agent._uuid in target_agent_uuids:
|
355
|
+
results[agent._uuid] = await agent.status.get(content)
|
330
356
|
return results
|
331
357
|
|
332
358
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
@@ -334,7 +360,7 @@ class AgentGroup:
|
|
334
360
|
f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
|
335
361
|
)
|
336
362
|
agent = self.id2agent[target_agent_uuid]
|
337
|
-
await agent.
|
363
|
+
await agent.status.update(target_key, content)
|
338
364
|
|
339
365
|
async def message_dispatch(self):
|
340
366
|
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
@@ -394,7 +420,7 @@ class AgentGroup:
|
|
394
420
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
395
421
|
for agent in self.agents:
|
396
422
|
_date_time = datetime.now(timezone.utc)
|
397
|
-
position = await agent.
|
423
|
+
position = await agent.status.get("position")
|
398
424
|
x = position["xy_position"]["x"]
|
399
425
|
y = position["xy_position"]["y"]
|
400
426
|
lng, lat = self.projector(x, y, inverse=True)
|
@@ -404,8 +430,11 @@ class AgentGroup:
|
|
404
430
|
parent_id = position["lane_position"]["lane_id"]
|
405
431
|
else:
|
406
432
|
parent_id = -1
|
407
|
-
|
408
|
-
|
433
|
+
hunger_satisfaction = await agent.status.get("hunger_satisfaction")
|
434
|
+
energy_satisfaction = await agent.status.get("energy_satisfaction")
|
435
|
+
safety_satisfaction = await agent.status.get("safety_satisfaction")
|
436
|
+
social_satisfaction = await agent.status.get("social_satisfaction")
|
437
|
+
action = await agent.status.get("current_step")
|
409
438
|
action = action["intention"]
|
410
439
|
avro = {
|
411
440
|
"id": agent._uuid,
|
@@ -415,10 +444,10 @@ class AgentGroup:
|
|
415
444
|
"lat": lat,
|
416
445
|
"parent_id": parent_id,
|
417
446
|
"action": action,
|
418
|
-
"hungry":
|
419
|
-
"tired":
|
420
|
-
"safe":
|
421
|
-
"social":
|
447
|
+
"hungry": hunger_satisfaction,
|
448
|
+
"tired": energy_satisfaction,
|
449
|
+
"safe": safety_satisfaction,
|
450
|
+
"social": social_satisfaction,
|
422
451
|
"created_at": int(_date_time.timestamp() * 1000),
|
423
452
|
}
|
424
453
|
avros.append(avro)
|
@@ -429,54 +458,54 @@ class AgentGroup:
|
|
429
458
|
for agent in self.agents:
|
430
459
|
_date_time = datetime.now(timezone.utc)
|
431
460
|
try:
|
432
|
-
nominal_gdp = await agent.
|
461
|
+
nominal_gdp = await agent.status.get("nominal_gdp")
|
433
462
|
except:
|
434
463
|
nominal_gdp = []
|
435
464
|
try:
|
436
|
-
real_gdp = await agent.
|
465
|
+
real_gdp = await agent.status.get("real_gdp")
|
437
466
|
except:
|
438
467
|
real_gdp = []
|
439
468
|
try:
|
440
|
-
unemployment = await agent.
|
469
|
+
unemployment = await agent.status.get("unemployment")
|
441
470
|
except:
|
442
471
|
unemployment = []
|
443
472
|
try:
|
444
|
-
wages = await agent.
|
473
|
+
wages = await agent.status.get("wages")
|
445
474
|
except:
|
446
475
|
wages = []
|
447
476
|
try:
|
448
|
-
prices = await agent.
|
477
|
+
prices = await agent.status.get("prices")
|
449
478
|
except:
|
450
479
|
prices = []
|
451
480
|
try:
|
452
|
-
inventory = await agent.
|
481
|
+
inventory = await agent.status.get("inventory")
|
453
482
|
except:
|
454
483
|
inventory = 0
|
455
484
|
try:
|
456
|
-
price = await agent.
|
485
|
+
price = await agent.status.get("price")
|
457
486
|
except:
|
458
487
|
price = 0.0
|
459
488
|
try:
|
460
|
-
interest_rate = await agent.
|
489
|
+
interest_rate = await agent.status.get("interest_rate")
|
461
490
|
except:
|
462
491
|
interest_rate = 0.0
|
463
492
|
try:
|
464
|
-
bracket_cutoffs = await agent.
|
493
|
+
bracket_cutoffs = await agent.status.get("bracket_cutoffs")
|
465
494
|
except:
|
466
495
|
bracket_cutoffs = []
|
467
496
|
try:
|
468
|
-
bracket_rates = await agent.
|
497
|
+
bracket_rates = await agent.status.get("bracket_rates")
|
469
498
|
except:
|
470
499
|
bracket_rates = []
|
471
500
|
try:
|
472
|
-
employees = await agent.
|
501
|
+
employees = await agent.status.get("employees")
|
473
502
|
except:
|
474
503
|
employees = []
|
475
504
|
avro = {
|
476
505
|
"id": agent._uuid,
|
477
506
|
"day": _day,
|
478
507
|
"t": _t,
|
479
|
-
"type": await agent.
|
508
|
+
"type": await agent.status.get("type"),
|
480
509
|
"nominal_gdp": nominal_gdp,
|
481
510
|
"real_gdp": real_gdp,
|
482
511
|
"unemployment": unemployment,
|
@@ -514,12 +543,17 @@ class AgentGroup:
|
|
514
543
|
]:
|
515
544
|
if key not in _status_dict:
|
516
545
|
_status_dict[key] = ""
|
546
|
+
for key in [
|
547
|
+
"friend_ids",
|
548
|
+
]:
|
549
|
+
if key not in _status_dict:
|
550
|
+
_status_dict[key] = []
|
517
551
|
_status_dict["created_at"] = _date_time
|
518
552
|
else:
|
519
553
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
520
554
|
for agent in self.agents:
|
521
555
|
_date_time = datetime.now(timezone.utc)
|
522
|
-
position = await agent.
|
556
|
+
position = await agent.status.get("position")
|
523
557
|
x = position["xy_position"]["x"]
|
524
558
|
y = position["xy_position"]["y"]
|
525
559
|
lng, lat = self.projector(x, y, inverse=True)
|
@@ -528,10 +562,21 @@ class AgentGroup:
|
|
528
562
|
elif "lane_position" in position:
|
529
563
|
parent_id = position["lane_position"]["lane_id"]
|
530
564
|
else:
|
531
|
-
# BUG: 需要处理
|
532
565
|
parent_id = -1
|
533
|
-
|
534
|
-
|
566
|
+
hunger_satisfaction = await agent.status.get(
|
567
|
+
"hunger_satisfaction"
|
568
|
+
)
|
569
|
+
energy_satisfaction = await agent.status.get(
|
570
|
+
"energy_satisfaction"
|
571
|
+
)
|
572
|
+
safety_satisfaction = await agent.status.get(
|
573
|
+
"safety_satisfaction"
|
574
|
+
)
|
575
|
+
social_satisfaction = await agent.status.get(
|
576
|
+
"social_satisfaction"
|
577
|
+
)
|
578
|
+
friend_ids = await agent.status.get("friends")
|
579
|
+
action = await agent.status.get("current_step")
|
535
580
|
action = action["intention"]
|
536
581
|
_status_dict = {
|
537
582
|
"id": agent._uuid,
|
@@ -540,11 +585,14 @@ class AgentGroup:
|
|
540
585
|
"lng": lng,
|
541
586
|
"lat": lat,
|
542
587
|
"parent_id": parent_id,
|
588
|
+
"friend_ids": [
|
589
|
+
str(_friend_id) for _friend_id in friend_ids
|
590
|
+
],
|
543
591
|
"action": action,
|
544
|
-
"hungry":
|
545
|
-
"tired":
|
546
|
-
"safe":
|
547
|
-
"social":
|
592
|
+
"hungry": hunger_satisfaction,
|
593
|
+
"tired": energy_satisfaction,
|
594
|
+
"safe": safety_satisfaction,
|
595
|
+
"social": social_satisfaction,
|
548
596
|
"created_at": _date_time,
|
549
597
|
}
|
550
598
|
_statuses_time_list.append((_status_dict, _date_time))
|
@@ -552,56 +600,60 @@ class AgentGroup:
|
|
552
600
|
# institution
|
553
601
|
for agent in self.agents:
|
554
602
|
_date_time = datetime.now(timezone.utc)
|
555
|
-
position = await agent.
|
603
|
+
position = await agent.status.get("position")
|
556
604
|
x = position["xy_position"]["x"]
|
557
605
|
y = position["xy_position"]["y"]
|
558
606
|
lng, lat = self.projector(x, y, inverse=True)
|
559
607
|
# ATTENTION: no valid position for an institution
|
560
608
|
parent_id = -1
|
561
609
|
try:
|
562
|
-
nominal_gdp = await agent.
|
610
|
+
nominal_gdp = await agent.status.get("nominal_gdp")
|
563
611
|
except:
|
564
612
|
nominal_gdp = []
|
565
613
|
try:
|
566
|
-
real_gdp = await agent.
|
614
|
+
real_gdp = await agent.status.get("real_gdp")
|
567
615
|
except:
|
568
616
|
real_gdp = []
|
569
617
|
try:
|
570
|
-
unemployment = await agent.
|
618
|
+
unemployment = await agent.status.get("unemployment")
|
571
619
|
except:
|
572
620
|
unemployment = []
|
573
621
|
try:
|
574
|
-
wages = await agent.
|
622
|
+
wages = await agent.status.get("wages")
|
575
623
|
except:
|
576
624
|
wages = []
|
577
625
|
try:
|
578
|
-
prices = await agent.
|
626
|
+
prices = await agent.status.get("prices")
|
579
627
|
except:
|
580
628
|
prices = []
|
581
629
|
try:
|
582
|
-
inventory = await agent.
|
630
|
+
inventory = await agent.status.get("inventory")
|
583
631
|
except:
|
584
632
|
inventory = 0
|
585
633
|
try:
|
586
|
-
price = await agent.
|
634
|
+
price = await agent.status.get("price")
|
587
635
|
except:
|
588
636
|
price = 0.0
|
589
637
|
try:
|
590
|
-
interest_rate = await agent.
|
638
|
+
interest_rate = await agent.status.get("interest_rate")
|
591
639
|
except:
|
592
640
|
interest_rate = 0.0
|
593
641
|
try:
|
594
|
-
bracket_cutoffs = await agent.
|
642
|
+
bracket_cutoffs = await agent.status.get("bracket_cutoffs")
|
595
643
|
except:
|
596
644
|
bracket_cutoffs = []
|
597
645
|
try:
|
598
|
-
bracket_rates = await agent.
|
646
|
+
bracket_rates = await agent.status.get("bracket_rates")
|
599
647
|
except:
|
600
648
|
bracket_rates = []
|
601
649
|
try:
|
602
|
-
employees = await agent.
|
650
|
+
employees = await agent.status.get("employees")
|
603
651
|
except:
|
604
652
|
employees = []
|
653
|
+
try:
|
654
|
+
friend_ids = await agent.status.get("friends")
|
655
|
+
except:
|
656
|
+
friend_ids = []
|
605
657
|
_status_dict = {
|
606
658
|
"id": agent._uuid,
|
607
659
|
"day": _day,
|
@@ -609,8 +661,11 @@ class AgentGroup:
|
|
609
661
|
"lng": lng,
|
610
662
|
"lat": lat,
|
611
663
|
"parent_id": parent_id,
|
664
|
+
"friend_ids": [
|
665
|
+
str(_friend_id) for _friend_id in friend_ids
|
666
|
+
],
|
612
667
|
"action": "",
|
613
|
-
"type": await agent.
|
668
|
+
"type": await agent.status.get("type"),
|
614
669
|
"nominal_gdp": nominal_gdp,
|
615
670
|
"real_gdp": real_gdp,
|
616
671
|
"unemployment": unemployment,
|
@@ -634,6 +689,7 @@ class AgentGroup:
|
|
634
689
|
"lng",
|
635
690
|
"lat",
|
636
691
|
"parent_id",
|
692
|
+
"friend_ids",
|
637
693
|
"action",
|
638
694
|
"created_at",
|
639
695
|
]
|
@@ -653,33 +709,18 @@ class AgentGroup:
|
|
653
709
|
)
|
654
710
|
|
655
711
|
async def step(self):
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
await self.save_status()
|
712
|
+
try:
|
713
|
+
tasks = [agent.run() for agent in self.agents]
|
714
|
+
await asyncio.gather(*tasks)
|
715
|
+
except Exception as e:
|
716
|
+
import traceback
|
662
717
|
|
663
|
-
|
664
|
-
|
718
|
+
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
719
|
+
raise RuntimeError(str(e)) from e
|
665
720
|
|
666
|
-
|
667
|
-
day: 运行天数,默认为1天
|
668
|
-
"""
|
721
|
+
async def save(self, day: int, t: int):
|
669
722
|
try:
|
670
|
-
|
671
|
-
start_time = await self.simulator.get_time()
|
672
|
-
start_time = int(start_time)
|
673
|
-
# 计算结束时间(秒)
|
674
|
-
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
675
|
-
|
676
|
-
while True:
|
677
|
-
current_time = await self.simulator.get_time()
|
678
|
-
current_time = int(current_time)
|
679
|
-
if current_time >= end_time:
|
680
|
-
break
|
681
|
-
await self.step()
|
682
|
-
|
723
|
+
await self.save_status(day, t)
|
683
724
|
except Exception as e:
|
684
725
|
import traceback
|
685
726
|
|