pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a53__cp311-cp311-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +48 -62
- pycityagent/agent/agent_base.py +66 -53
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +44 -56
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +204 -119
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/simulator.py +17 -12
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +720 -272
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +92 -99
- pycityagent/simulation/simulation.py +115 -40
- pycityagent/tools/tool.py +7 -9
- pycityagent/workflow/block.py +11 -4
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
pycityagent/pycityagent-sim
CHANGED
Binary file
|
@@ -21,7 +21,6 @@ from ..llm.llm import LLM
|
|
21
21
|
from ..llm.llmconfig import LLMConfig
|
22
22
|
from ..memory import FaissQuery, Memory
|
23
23
|
from ..message import Messager
|
24
|
-
from ..metrics import MlflowClient
|
25
24
|
from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
|
26
25
|
STATUS_SCHEMA, SURVEY_SCHEMA)
|
27
26
|
|
@@ -40,12 +39,10 @@ class AgentGroup:
|
|
40
39
|
],
|
41
40
|
config: dict,
|
42
41
|
exp_id: str | UUID,
|
43
|
-
exp_name: str,
|
44
42
|
enable_avro: bool,
|
45
43
|
avro_path: Path,
|
46
44
|
enable_pgsql: bool,
|
47
45
|
pgsql_writer: ray.ObjectRef,
|
48
|
-
mlflow_run_id: str,
|
49
46
|
embedding_model: Embeddings,
|
50
47
|
logging_level: int,
|
51
48
|
agent_config_file: Optional[Union[str, list[str]]] = None,
|
@@ -116,19 +113,6 @@ class AgentGroup:
|
|
116
113
|
else:
|
117
114
|
self.economy_client = None
|
118
115
|
|
119
|
-
# Mlflow
|
120
|
-
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
121
|
-
if _mlflow_config:
|
122
|
-
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
123
|
-
self.mlflow_client = MlflowClient(
|
124
|
-
config=_mlflow_config,
|
125
|
-
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
126
|
-
experiment_name=exp_name,
|
127
|
-
run_id=mlflow_run_id,
|
128
|
-
)
|
129
|
-
else:
|
130
|
-
self.mlflow_client = None
|
131
|
-
|
132
116
|
# set FaissQuery
|
133
117
|
if self.embedding_model is not None:
|
134
118
|
self.faiss_query = FaissQuery(
|
@@ -142,7 +126,11 @@ class AgentGroup:
|
|
142
126
|
for j in range(number_of_agents_i):
|
143
127
|
memory_config_function_group_i = memory_config_function_group[i]
|
144
128
|
extra_attributes, profile, base = memory_config_function_group_i()
|
145
|
-
memory = Memory(
|
129
|
+
memory = Memory(
|
130
|
+
config=extra_attributes,
|
131
|
+
profile=profile,
|
132
|
+
base=base
|
133
|
+
)
|
146
134
|
agent = agent_class_i(
|
147
135
|
name=f"{agent_class_i.__name__}_{i}",
|
148
136
|
memory=memory,
|
@@ -151,18 +139,12 @@ class AgentGroup:
|
|
151
139
|
simulator=self.simulator,
|
152
140
|
)
|
153
141
|
agent.set_exp_id(self.exp_id) # type: ignore
|
154
|
-
if self.mlflow_client is not None:
|
155
|
-
agent.set_mlflow_client(self.mlflow_client)
|
156
142
|
if self.messager is not None:
|
157
143
|
agent.set_messager(self.messager)
|
158
144
|
if self.enable_avro:
|
159
145
|
agent.set_avro_file(self.avro_file) # type: ignore
|
160
146
|
if self.enable_pgsql:
|
161
147
|
agent.set_pgsql_writer(self._pgsql_writer)
|
162
|
-
if self.faiss_query is not None:
|
163
|
-
agent.memory.set_faiss_query(self.faiss_query)
|
164
|
-
if self.embedding_model is not None:
|
165
|
-
agent.memory.set_embedding_model(self.embedding_model)
|
166
148
|
if self.agent_config_file is not None and self.agent_config_file[i]:
|
167
149
|
agent.load_from_file(self.agent_config_file[i])
|
168
150
|
self.agents.append(agent)
|
@@ -193,11 +175,21 @@ class AgentGroup:
|
|
193
175
|
self.message_dispatch_task.cancel() # type: ignore
|
194
176
|
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
|
195
177
|
|
178
|
+
async def insert_agents(self):
|
179
|
+
bind_tasks = []
|
180
|
+
for agent in self.agents:
|
181
|
+
bind_tasks.append(agent.bind_to_simulator()) # type: ignore
|
182
|
+
await asyncio.gather(*bind_tasks)
|
183
|
+
|
196
184
|
async def init_agents(self):
|
197
185
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
198
186
|
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
199
|
-
|
200
|
-
await
|
187
|
+
while True:
|
188
|
+
day = await self.simulator.get_simulator_day()
|
189
|
+
if day == 0:
|
190
|
+
break
|
191
|
+
await asyncio.sleep(1)
|
192
|
+
await self.insert_agents()
|
201
193
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
202
194
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
203
195
|
assert self.messager is not None
|
@@ -221,7 +213,7 @@ class AgentGroup:
|
|
221
213
|
with open(filename, "wb") as f:
|
222
214
|
profiles = []
|
223
215
|
for agent in self.agents:
|
224
|
-
profile = await agent.
|
216
|
+
profile = await agent.status.profile.export()
|
225
217
|
profile = profile[0]
|
226
218
|
profile["id"] = agent._uuid
|
227
219
|
profiles.append(profile)
|
@@ -252,7 +244,7 @@ class AgentGroup:
|
|
252
244
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
253
245
|
profiles: list[Any] = []
|
254
246
|
for agent in self.agents:
|
255
|
-
profile = await agent.
|
247
|
+
profile = await agent.status.profile.export()
|
256
248
|
profile = profile[0]
|
257
249
|
profile["id"] = agent._uuid
|
258
250
|
profiles.append(
|
@@ -271,7 +263,7 @@ class AgentGroup:
|
|
271
263
|
else:
|
272
264
|
profiles: list[Any] = []
|
273
265
|
for agent in self.agents:
|
274
|
-
profile = await agent.
|
266
|
+
profile = await agent.status.profile.export()
|
275
267
|
profile = profile[0]
|
276
268
|
profile["id"] = agent._uuid
|
277
269
|
profiles.append(
|
@@ -290,6 +282,16 @@ class AgentGroup:
|
|
290
282
|
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
291
283
|
profiles
|
292
284
|
)
|
285
|
+
if self.faiss_query is not None:
|
286
|
+
logger.debug(f"-----Initializing embeddings in AgentGroup {self._uuid} ...")
|
287
|
+
embedding_tasks = []
|
288
|
+
for agent in self.agents:
|
289
|
+
embedding_tasks.append(agent.memory.initialize_embeddings())
|
290
|
+
agent.memory.set_search_components(self.faiss_query, self.embedding_model)
|
291
|
+
agent.memory.set_simulator(self.simulator)
|
292
|
+
await asyncio.gather(*embedding_tasks)
|
293
|
+
logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
|
294
|
+
|
293
295
|
self.initialized = True
|
294
296
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
295
297
|
|
@@ -307,7 +309,7 @@ class AgentGroup:
|
|
307
309
|
if keys:
|
308
310
|
for key in keys:
|
309
311
|
assert values is not None
|
310
|
-
if not agent.
|
312
|
+
if not agent.status.get(key) == values[keys.index(key)]:
|
311
313
|
add = False
|
312
314
|
break
|
313
315
|
if add:
|
@@ -315,18 +317,21 @@ class AgentGroup:
|
|
315
317
|
elif keys:
|
316
318
|
for key in keys:
|
317
319
|
assert values is not None
|
318
|
-
if not agent.
|
320
|
+
if not agent.status.get(key) == values[keys.index(key)]:
|
319
321
|
add = False
|
320
322
|
break
|
321
323
|
if add:
|
322
324
|
filtered_uuids.append(agent._uuid)
|
323
325
|
return filtered_uuids
|
324
326
|
|
325
|
-
async def gather(self, content: str):
|
327
|
+
async def gather(self, content: str, target_agent_uuids: Optional[list[str]] = None):
|
326
328
|
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
327
329
|
results = {}
|
330
|
+
if target_agent_uuids is None:
|
331
|
+
target_agent_uuids = self.agent_uuids
|
328
332
|
for agent in self.agents:
|
329
|
-
|
333
|
+
if agent._uuid in target_agent_uuids:
|
334
|
+
results[agent._uuid] = await agent.status.get(content)
|
330
335
|
return results
|
331
336
|
|
332
337
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
@@ -334,7 +339,7 @@ class AgentGroup:
|
|
334
339
|
f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
|
335
340
|
)
|
336
341
|
agent = self.id2agent[target_agent_uuid]
|
337
|
-
await agent.
|
342
|
+
await agent.status.update(target_key, content)
|
338
343
|
|
339
344
|
async def message_dispatch(self):
|
340
345
|
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
@@ -394,7 +399,7 @@ class AgentGroup:
|
|
394
399
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
395
400
|
for agent in self.agents:
|
396
401
|
_date_time = datetime.now(timezone.utc)
|
397
|
-
position = await agent.
|
402
|
+
position = await agent.status.get("position")
|
398
403
|
x = position["xy_position"]["x"]
|
399
404
|
y = position["xy_position"]["y"]
|
400
405
|
lng, lat = self.projector(x, y, inverse=True)
|
@@ -404,8 +409,11 @@ class AgentGroup:
|
|
404
409
|
parent_id = position["lane_position"]["lane_id"]
|
405
410
|
else:
|
406
411
|
parent_id = -1
|
407
|
-
|
408
|
-
|
412
|
+
hunger_satisfaction = await agent.status.get("hunger_satisfaction")
|
413
|
+
energy_satisfaction = await agent.status.get("energy_satisfaction")
|
414
|
+
safety_satisfaction = await agent.status.get("safety_satisfaction")
|
415
|
+
social_satisfaction = await agent.status.get("social_satisfaction")
|
416
|
+
action = await agent.status.get("current_step")
|
409
417
|
action = action["intention"]
|
410
418
|
avro = {
|
411
419
|
"id": agent._uuid,
|
@@ -415,10 +423,10 @@ class AgentGroup:
|
|
415
423
|
"lat": lat,
|
416
424
|
"parent_id": parent_id,
|
417
425
|
"action": action,
|
418
|
-
"hungry":
|
419
|
-
"tired":
|
420
|
-
"safe":
|
421
|
-
"social":
|
426
|
+
"hungry": hunger_satisfaction,
|
427
|
+
"tired": energy_satisfaction,
|
428
|
+
"safe": safety_satisfaction,
|
429
|
+
"social": social_satisfaction,
|
422
430
|
"created_at": int(_date_time.timestamp() * 1000),
|
423
431
|
}
|
424
432
|
avros.append(avro)
|
@@ -429,54 +437,54 @@ class AgentGroup:
|
|
429
437
|
for agent in self.agents:
|
430
438
|
_date_time = datetime.now(timezone.utc)
|
431
439
|
try:
|
432
|
-
nominal_gdp = await agent.
|
440
|
+
nominal_gdp = await agent.status.get("nominal_gdp")
|
433
441
|
except:
|
434
442
|
nominal_gdp = []
|
435
443
|
try:
|
436
|
-
real_gdp = await agent.
|
444
|
+
real_gdp = await agent.status.get("real_gdp")
|
437
445
|
except:
|
438
446
|
real_gdp = []
|
439
447
|
try:
|
440
|
-
unemployment = await agent.
|
448
|
+
unemployment = await agent.status.get("unemployment")
|
441
449
|
except:
|
442
450
|
unemployment = []
|
443
451
|
try:
|
444
|
-
wages = await agent.
|
452
|
+
wages = await agent.status.get("wages")
|
445
453
|
except:
|
446
454
|
wages = []
|
447
455
|
try:
|
448
|
-
prices = await agent.
|
456
|
+
prices = await agent.status.get("prices")
|
449
457
|
except:
|
450
458
|
prices = []
|
451
459
|
try:
|
452
|
-
inventory = await agent.
|
460
|
+
inventory = await agent.status.get("inventory")
|
453
461
|
except:
|
454
462
|
inventory = 0
|
455
463
|
try:
|
456
|
-
price = await agent.
|
464
|
+
price = await agent.status.get("price")
|
457
465
|
except:
|
458
466
|
price = 0.0
|
459
467
|
try:
|
460
|
-
interest_rate = await agent.
|
468
|
+
interest_rate = await agent.status.get("interest_rate")
|
461
469
|
except:
|
462
470
|
interest_rate = 0.0
|
463
471
|
try:
|
464
|
-
bracket_cutoffs = await agent.
|
472
|
+
bracket_cutoffs = await agent.status.get("bracket_cutoffs")
|
465
473
|
except:
|
466
474
|
bracket_cutoffs = []
|
467
475
|
try:
|
468
|
-
bracket_rates = await agent.
|
476
|
+
bracket_rates = await agent.status.get("bracket_rates")
|
469
477
|
except:
|
470
478
|
bracket_rates = []
|
471
479
|
try:
|
472
|
-
employees = await agent.
|
480
|
+
employees = await agent.status.get("employees")
|
473
481
|
except:
|
474
482
|
employees = []
|
475
483
|
avro = {
|
476
484
|
"id": agent._uuid,
|
477
485
|
"day": _day,
|
478
486
|
"t": _t,
|
479
|
-
"type": await agent.
|
487
|
+
"type": await agent.status.get("type"),
|
480
488
|
"nominal_gdp": nominal_gdp,
|
481
489
|
"real_gdp": real_gdp,
|
482
490
|
"unemployment": unemployment,
|
@@ -519,7 +527,7 @@ class AgentGroup:
|
|
519
527
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
520
528
|
for agent in self.agents:
|
521
529
|
_date_time = datetime.now(timezone.utc)
|
522
|
-
position = await agent.
|
530
|
+
position = await agent.status.get("position")
|
523
531
|
x = position["xy_position"]["x"]
|
524
532
|
y = position["xy_position"]["y"]
|
525
533
|
lng, lat = self.projector(x, y, inverse=True)
|
@@ -528,10 +536,12 @@ class AgentGroup:
|
|
528
536
|
elif "lane_position" in position:
|
529
537
|
parent_id = position["lane_position"]["lane_id"]
|
530
538
|
else:
|
531
|
-
# BUG: 需要处理
|
532
539
|
parent_id = -1
|
533
|
-
|
534
|
-
|
540
|
+
hunger_satisfaction = await agent.status.get("hunger_satisfaction")
|
541
|
+
energy_satisfaction = await agent.status.get("energy_satisfaction")
|
542
|
+
safety_satisfaction = await agent.status.get("safety_satisfaction")
|
543
|
+
social_satisfaction = await agent.status.get("social_satisfaction")
|
544
|
+
action = await agent.status.get("current_step")
|
535
545
|
action = action["intention"]
|
536
546
|
_status_dict = {
|
537
547
|
"id": agent._uuid,
|
@@ -541,10 +551,10 @@ class AgentGroup:
|
|
541
551
|
"lat": lat,
|
542
552
|
"parent_id": parent_id,
|
543
553
|
"action": action,
|
544
|
-
"hungry":
|
545
|
-
"tired":
|
546
|
-
"safe":
|
547
|
-
"social":
|
554
|
+
"hungry": hunger_satisfaction,
|
555
|
+
"tired": energy_satisfaction,
|
556
|
+
"safe": safety_satisfaction,
|
557
|
+
"social": social_satisfaction,
|
548
558
|
"created_at": _date_time,
|
549
559
|
}
|
550
560
|
_statuses_time_list.append((_status_dict, _date_time))
|
@@ -552,54 +562,54 @@ class AgentGroup:
|
|
552
562
|
# institution
|
553
563
|
for agent in self.agents:
|
554
564
|
_date_time = datetime.now(timezone.utc)
|
555
|
-
position = await agent.
|
565
|
+
position = await agent.status.get("position")
|
556
566
|
x = position["xy_position"]["x"]
|
557
567
|
y = position["xy_position"]["y"]
|
558
568
|
lng, lat = self.projector(x, y, inverse=True)
|
559
569
|
# ATTENTION: no valid position for an institution
|
560
570
|
parent_id = -1
|
561
571
|
try:
|
562
|
-
nominal_gdp = await agent.
|
572
|
+
nominal_gdp = await agent.status.get("nominal_gdp")
|
563
573
|
except:
|
564
574
|
nominal_gdp = []
|
565
575
|
try:
|
566
|
-
real_gdp = await agent.
|
576
|
+
real_gdp = await agent.status.get("real_gdp")
|
567
577
|
except:
|
568
578
|
real_gdp = []
|
569
579
|
try:
|
570
|
-
unemployment = await agent.
|
580
|
+
unemployment = await agent.status.get("unemployment")
|
571
581
|
except:
|
572
582
|
unemployment = []
|
573
583
|
try:
|
574
|
-
wages = await agent.
|
584
|
+
wages = await agent.status.get("wages")
|
575
585
|
except:
|
576
586
|
wages = []
|
577
587
|
try:
|
578
|
-
prices = await agent.
|
588
|
+
prices = await agent.status.get("prices")
|
579
589
|
except:
|
580
590
|
prices = []
|
581
591
|
try:
|
582
|
-
inventory = await agent.
|
592
|
+
inventory = await agent.status.get("inventory")
|
583
593
|
except:
|
584
594
|
inventory = 0
|
585
595
|
try:
|
586
|
-
price = await agent.
|
596
|
+
price = await agent.status.get("price")
|
587
597
|
except:
|
588
598
|
price = 0.0
|
589
599
|
try:
|
590
|
-
interest_rate = await agent.
|
600
|
+
interest_rate = await agent.status.get("interest_rate")
|
591
601
|
except:
|
592
602
|
interest_rate = 0.0
|
593
603
|
try:
|
594
|
-
bracket_cutoffs = await agent.
|
604
|
+
bracket_cutoffs = await agent.status.get("bracket_cutoffs")
|
595
605
|
except:
|
596
606
|
bracket_cutoffs = []
|
597
607
|
try:
|
598
|
-
bracket_rates = await agent.
|
608
|
+
bracket_rates = await agent.status.get("bracket_rates")
|
599
609
|
except:
|
600
610
|
bracket_rates = []
|
601
611
|
try:
|
602
|
-
employees = await agent.
|
612
|
+
employees = await agent.status.get("employees")
|
603
613
|
except:
|
604
614
|
employees = []
|
605
615
|
_status_dict = {
|
@@ -610,7 +620,7 @@ class AgentGroup:
|
|
610
620
|
"lat": lat,
|
611
621
|
"parent_id": parent_id,
|
612
622
|
"action": "",
|
613
|
-
"type": await agent.
|
623
|
+
"type": await agent.status.get("type"),
|
614
624
|
"nominal_gdp": nominal_gdp,
|
615
625
|
"real_gdp": real_gdp,
|
616
626
|
"unemployment": unemployment,
|
@@ -653,35 +663,18 @@ class AgentGroup:
|
|
653
663
|
)
|
654
664
|
|
655
665
|
async def step(self):
|
656
|
-
if not self.initialized:
|
657
|
-
await self.init_agents()
|
658
|
-
|
659
|
-
tasks = [agent.run() for agent in self.agents]
|
660
|
-
await asyncio.gather(*tasks)
|
661
|
-
await self.save_status()
|
662
|
-
|
663
|
-
async def run(self, day: int = 1):
|
664
|
-
"""运行模拟器
|
665
|
-
|
666
|
-
Args:
|
667
|
-
day: 运行天数,默认为1天
|
668
|
-
"""
|
669
666
|
try:
|
670
|
-
|
671
|
-
|
672
|
-
start_time = int(start_time)
|
673
|
-
# 计算结束时间(秒)
|
674
|
-
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
675
|
-
|
676
|
-
while True:
|
677
|
-
current_time = await self.simulator.get_time()
|
678
|
-
current_time = int(current_time)
|
679
|
-
if current_time >= end_time:
|
680
|
-
break
|
681
|
-
await self.step()
|
682
|
-
|
667
|
+
tasks = [agent.run() for agent in self.agents]
|
668
|
+
await asyncio.gather(*tasks)
|
683
669
|
except Exception as e:
|
684
670
|
import traceback
|
671
|
+
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
672
|
+
raise RuntimeError(str(e)) from e
|
685
673
|
|
674
|
+
async def save(self, day: int, t: int):
|
675
|
+
try:
|
676
|
+
await self.save_status(day, t)
|
677
|
+
except Exception as e:
|
678
|
+
import traceback
|
686
679
|
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
687
680
|
raise RuntimeError(str(e)) from e
|