pycityagent 2.0.0a52__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a53__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +48 -62
- pycityagent/agent/agent_base.py +66 -53
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +44 -56
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +204 -119
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/simulator.py +17 -12
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +720 -272
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +92 -99
- pycityagent/simulation/simulation.py +115 -40
- pycityagent/tools/tool.py +7 -9
- pycityagent/workflow/block.py +11 -4
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
pycityagent/pycityagent-sim
CHANGED
Binary file
|
@@ -21,7 +21,6 @@ from ..llm.llm import LLM
|
|
21
21
|
from ..llm.llmconfig import LLMConfig
|
22
22
|
from ..memory import FaissQuery, Memory
|
23
23
|
from ..message import Messager
|
24
|
-
from ..metrics import MlflowClient
|
25
24
|
from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
|
26
25
|
STATUS_SCHEMA, SURVEY_SCHEMA)
|
27
26
|
|
@@ -40,12 +39,10 @@ class AgentGroup:
|
|
40
39
|
],
|
41
40
|
config: dict,
|
42
41
|
exp_id: str | UUID,
|
43
|
-
exp_name: str,
|
44
42
|
enable_avro: bool,
|
45
43
|
avro_path: Path,
|
46
44
|
enable_pgsql: bool,
|
47
45
|
pgsql_writer: ray.ObjectRef,
|
48
|
-
mlflow_run_id: str,
|
49
46
|
embedding_model: Embeddings,
|
50
47
|
logging_level: int,
|
51
48
|
agent_config_file: Optional[Union[str, list[str]]] = None,
|
@@ -116,19 +113,6 @@ class AgentGroup:
|
|
116
113
|
else:
|
117
114
|
self.economy_client = None
|
118
115
|
|
119
|
-
# Mlflow
|
120
|
-
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
121
|
-
if _mlflow_config:
|
122
|
-
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
123
|
-
self.mlflow_client = MlflowClient(
|
124
|
-
config=_mlflow_config,
|
125
|
-
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
126
|
-
experiment_name=exp_name,
|
127
|
-
run_id=mlflow_run_id,
|
128
|
-
)
|
129
|
-
else:
|
130
|
-
self.mlflow_client = None
|
131
|
-
|
132
116
|
# set FaissQuery
|
133
117
|
if self.embedding_model is not None:
|
134
118
|
self.faiss_query = FaissQuery(
|
@@ -142,7 +126,11 @@ class AgentGroup:
|
|
142
126
|
for j in range(number_of_agents_i):
|
143
127
|
memory_config_function_group_i = memory_config_function_group[i]
|
144
128
|
extra_attributes, profile, base = memory_config_function_group_i()
|
145
|
-
memory = Memory(
|
129
|
+
memory = Memory(
|
130
|
+
config=extra_attributes,
|
131
|
+
profile=profile,
|
132
|
+
base=base
|
133
|
+
)
|
146
134
|
agent = agent_class_i(
|
147
135
|
name=f"{agent_class_i.__name__}_{i}",
|
148
136
|
memory=memory,
|
@@ -151,18 +139,12 @@ class AgentGroup:
|
|
151
139
|
simulator=self.simulator,
|
152
140
|
)
|
153
141
|
agent.set_exp_id(self.exp_id) # type: ignore
|
154
|
-
if self.mlflow_client is not None:
|
155
|
-
agent.set_mlflow_client(self.mlflow_client)
|
156
142
|
if self.messager is not None:
|
157
143
|
agent.set_messager(self.messager)
|
158
144
|
if self.enable_avro:
|
159
145
|
agent.set_avro_file(self.avro_file) # type: ignore
|
160
146
|
if self.enable_pgsql:
|
161
147
|
agent.set_pgsql_writer(self._pgsql_writer)
|
162
|
-
if self.faiss_query is not None:
|
163
|
-
agent.memory.set_faiss_query(self.faiss_query)
|
164
|
-
if self.embedding_model is not None:
|
165
|
-
agent.memory.set_embedding_model(self.embedding_model)
|
166
148
|
if self.agent_config_file is not None and self.agent_config_file[i]:
|
167
149
|
agent.load_from_file(self.agent_config_file[i])
|
168
150
|
self.agents.append(agent)
|
@@ -193,11 +175,21 @@ class AgentGroup:
|
|
193
175
|
self.message_dispatch_task.cancel() # type: ignore
|
194
176
|
await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
|
195
177
|
|
178
|
+
async def insert_agents(self):
|
179
|
+
bind_tasks = []
|
180
|
+
for agent in self.agents:
|
181
|
+
bind_tasks.append(agent.bind_to_simulator()) # type: ignore
|
182
|
+
await asyncio.gather(*bind_tasks)
|
183
|
+
|
196
184
|
async def init_agents(self):
|
197
185
|
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
198
186
|
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
199
|
-
|
200
|
-
await
|
187
|
+
while True:
|
188
|
+
day = await self.simulator.get_simulator_day()
|
189
|
+
if day == 0:
|
190
|
+
break
|
191
|
+
await asyncio.sleep(1)
|
192
|
+
await self.insert_agents()
|
201
193
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
202
194
|
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
203
195
|
assert self.messager is not None
|
@@ -221,7 +213,7 @@ class AgentGroup:
|
|
221
213
|
with open(filename, "wb") as f:
|
222
214
|
profiles = []
|
223
215
|
for agent in self.agents:
|
224
|
-
profile = await agent.
|
216
|
+
profile = await agent.status.profile.export()
|
225
217
|
profile = profile[0]
|
226
218
|
profile["id"] = agent._uuid
|
227
219
|
profiles.append(profile)
|
@@ -252,7 +244,7 @@ class AgentGroup:
|
|
252
244
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
253
245
|
profiles: list[Any] = []
|
254
246
|
for agent in self.agents:
|
255
|
-
profile = await agent.
|
247
|
+
profile = await agent.status.profile.export()
|
256
248
|
profile = profile[0]
|
257
249
|
profile["id"] = agent._uuid
|
258
250
|
profiles.append(
|
@@ -271,7 +263,7 @@ class AgentGroup:
|
|
271
263
|
else:
|
272
264
|
profiles: list[Any] = []
|
273
265
|
for agent in self.agents:
|
274
|
-
profile = await agent.
|
266
|
+
profile = await agent.status.profile.export()
|
275
267
|
profile = profile[0]
|
276
268
|
profile["id"] = agent._uuid
|
277
269
|
profiles.append(
|
@@ -290,6 +282,16 @@ class AgentGroup:
|
|
290
282
|
await self._pgsql_writer.async_write_profile.remote( # type:ignore
|
291
283
|
profiles
|
292
284
|
)
|
285
|
+
if self.faiss_query is not None:
|
286
|
+
logger.debug(f"-----Initializing embeddings in AgentGroup {self._uuid} ...")
|
287
|
+
embedding_tasks = []
|
288
|
+
for agent in self.agents:
|
289
|
+
embedding_tasks.append(agent.memory.initialize_embeddings())
|
290
|
+
agent.memory.set_search_components(self.faiss_query, self.embedding_model)
|
291
|
+
agent.memory.set_simulator(self.simulator)
|
292
|
+
await asyncio.gather(*embedding_tasks)
|
293
|
+
logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
|
294
|
+
|
293
295
|
self.initialized = True
|
294
296
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
295
297
|
|
@@ -307,7 +309,7 @@ class AgentGroup:
|
|
307
309
|
if keys:
|
308
310
|
for key in keys:
|
309
311
|
assert values is not None
|
310
|
-
if not agent.
|
312
|
+
if not agent.status.get(key) == values[keys.index(key)]:
|
311
313
|
add = False
|
312
314
|
break
|
313
315
|
if add:
|
@@ -315,18 +317,21 @@ class AgentGroup:
|
|
315
317
|
elif keys:
|
316
318
|
for key in keys:
|
317
319
|
assert values is not None
|
318
|
-
if not agent.
|
320
|
+
if not agent.status.get(key) == values[keys.index(key)]:
|
319
321
|
add = False
|
320
322
|
break
|
321
323
|
if add:
|
322
324
|
filtered_uuids.append(agent._uuid)
|
323
325
|
return filtered_uuids
|
324
326
|
|
325
|
-
async def gather(self, content: str):
|
327
|
+
async def gather(self, content: str, target_agent_uuids: Optional[list[str]] = None):
|
326
328
|
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
327
329
|
results = {}
|
330
|
+
if target_agent_uuids is None:
|
331
|
+
target_agent_uuids = self.agent_uuids
|
328
332
|
for agent in self.agents:
|
329
|
-
|
333
|
+
if agent._uuid in target_agent_uuids:
|
334
|
+
results[agent._uuid] = await agent.status.get(content)
|
330
335
|
return results
|
331
336
|
|
332
337
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
@@ -334,7 +339,7 @@ class AgentGroup:
|
|
334
339
|
f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
|
335
340
|
)
|
336
341
|
agent = self.id2agent[target_agent_uuid]
|
337
|
-
await agent.
|
342
|
+
await agent.status.update(target_key, content)
|
338
343
|
|
339
344
|
async def message_dispatch(self):
|
340
345
|
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
@@ -394,7 +399,7 @@ class AgentGroup:
|
|
394
399
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
395
400
|
for agent in self.agents:
|
396
401
|
_date_time = datetime.now(timezone.utc)
|
397
|
-
position = await agent.
|
402
|
+
position = await agent.status.get("position")
|
398
403
|
x = position["xy_position"]["x"]
|
399
404
|
y = position["xy_position"]["y"]
|
400
405
|
lng, lat = self.projector(x, y, inverse=True)
|
@@ -404,8 +409,11 @@ class AgentGroup:
|
|
404
409
|
parent_id = position["lane_position"]["lane_id"]
|
405
410
|
else:
|
406
411
|
parent_id = -1
|
407
|
-
|
408
|
-
|
412
|
+
hunger_satisfaction = await agent.status.get("hunger_satisfaction")
|
413
|
+
energy_satisfaction = await agent.status.get("energy_satisfaction")
|
414
|
+
safety_satisfaction = await agent.status.get("safety_satisfaction")
|
415
|
+
social_satisfaction = await agent.status.get("social_satisfaction")
|
416
|
+
action = await agent.status.get("current_step")
|
409
417
|
action = action["intention"]
|
410
418
|
avro = {
|
411
419
|
"id": agent._uuid,
|
@@ -415,10 +423,10 @@ class AgentGroup:
|
|
415
423
|
"lat": lat,
|
416
424
|
"parent_id": parent_id,
|
417
425
|
"action": action,
|
418
|
-
"hungry":
|
419
|
-
"tired":
|
420
|
-
"safe":
|
421
|
-
"social":
|
426
|
+
"hungry": hunger_satisfaction,
|
427
|
+
"tired": energy_satisfaction,
|
428
|
+
"safe": safety_satisfaction,
|
429
|
+
"social": social_satisfaction,
|
422
430
|
"created_at": int(_date_time.timestamp() * 1000),
|
423
431
|
}
|
424
432
|
avros.append(avro)
|
@@ -429,54 +437,54 @@ class AgentGroup:
|
|
429
437
|
for agent in self.agents:
|
430
438
|
_date_time = datetime.now(timezone.utc)
|
431
439
|
try:
|
432
|
-
nominal_gdp = await agent.
|
440
|
+
nominal_gdp = await agent.status.get("nominal_gdp")
|
433
441
|
except:
|
434
442
|
nominal_gdp = []
|
435
443
|
try:
|
436
|
-
real_gdp = await agent.
|
444
|
+
real_gdp = await agent.status.get("real_gdp")
|
437
445
|
except:
|
438
446
|
real_gdp = []
|
439
447
|
try:
|
440
|
-
unemployment = await agent.
|
448
|
+
unemployment = await agent.status.get("unemployment")
|
441
449
|
except:
|
442
450
|
unemployment = []
|
443
451
|
try:
|
444
|
-
wages = await agent.
|
452
|
+
wages = await agent.status.get("wages")
|
445
453
|
except:
|
446
454
|
wages = []
|
447
455
|
try:
|
448
|
-
prices = await agent.
|
456
|
+
prices = await agent.status.get("prices")
|
449
457
|
except:
|
450
458
|
prices = []
|
451
459
|
try:
|
452
|
-
inventory = await agent.
|
460
|
+
inventory = await agent.status.get("inventory")
|
453
461
|
except:
|
454
462
|
inventory = 0
|
455
463
|
try:
|
456
|
-
price = await agent.
|
464
|
+
price = await agent.status.get("price")
|
457
465
|
except:
|
458
466
|
price = 0.0
|
459
467
|
try:
|
460
|
-
interest_rate = await agent.
|
468
|
+
interest_rate = await agent.status.get("interest_rate")
|
461
469
|
except:
|
462
470
|
interest_rate = 0.0
|
463
471
|
try:
|
464
|
-
bracket_cutoffs = await agent.
|
472
|
+
bracket_cutoffs = await agent.status.get("bracket_cutoffs")
|
465
473
|
except:
|
466
474
|
bracket_cutoffs = []
|
467
475
|
try:
|
468
|
-
bracket_rates = await agent.
|
476
|
+
bracket_rates = await agent.status.get("bracket_rates")
|
469
477
|
except:
|
470
478
|
bracket_rates = []
|
471
479
|
try:
|
472
|
-
employees = await agent.
|
480
|
+
employees = await agent.status.get("employees")
|
473
481
|
except:
|
474
482
|
employees = []
|
475
483
|
avro = {
|
476
484
|
"id": agent._uuid,
|
477
485
|
"day": _day,
|
478
486
|
"t": _t,
|
479
|
-
"type": await agent.
|
487
|
+
"type": await agent.status.get("type"),
|
480
488
|
"nominal_gdp": nominal_gdp,
|
481
489
|
"real_gdp": real_gdp,
|
482
490
|
"unemployment": unemployment,
|
@@ -519,7 +527,7 @@ class AgentGroup:
|
|
519
527
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
520
528
|
for agent in self.agents:
|
521
529
|
_date_time = datetime.now(timezone.utc)
|
522
|
-
position = await agent.
|
530
|
+
position = await agent.status.get("position")
|
523
531
|
x = position["xy_position"]["x"]
|
524
532
|
y = position["xy_position"]["y"]
|
525
533
|
lng, lat = self.projector(x, y, inverse=True)
|
@@ -528,10 +536,12 @@ class AgentGroup:
|
|
528
536
|
elif "lane_position" in position:
|
529
537
|
parent_id = position["lane_position"]["lane_id"]
|
530
538
|
else:
|
531
|
-
# BUG: 需要处理
|
532
539
|
parent_id = -1
|
533
|
-
|
534
|
-
|
540
|
+
hunger_satisfaction = await agent.status.get("hunger_satisfaction")
|
541
|
+
energy_satisfaction = await agent.status.get("energy_satisfaction")
|
542
|
+
safety_satisfaction = await agent.status.get("safety_satisfaction")
|
543
|
+
social_satisfaction = await agent.status.get("social_satisfaction")
|
544
|
+
action = await agent.status.get("current_step")
|
535
545
|
action = action["intention"]
|
536
546
|
_status_dict = {
|
537
547
|
"id": agent._uuid,
|
@@ -541,10 +551,10 @@ class AgentGroup:
|
|
541
551
|
"lat": lat,
|
542
552
|
"parent_id": parent_id,
|
543
553
|
"action": action,
|
544
|
-
"hungry":
|
545
|
-
"tired":
|
546
|
-
"safe":
|
547
|
-
"social":
|
554
|
+
"hungry": hunger_satisfaction,
|
555
|
+
"tired": energy_satisfaction,
|
556
|
+
"safe": safety_satisfaction,
|
557
|
+
"social": social_satisfaction,
|
548
558
|
"created_at": _date_time,
|
549
559
|
}
|
550
560
|
_statuses_time_list.append((_status_dict, _date_time))
|
@@ -552,54 +562,54 @@ class AgentGroup:
|
|
552
562
|
# institution
|
553
563
|
for agent in self.agents:
|
554
564
|
_date_time = datetime.now(timezone.utc)
|
555
|
-
position = await agent.
|
565
|
+
position = await agent.status.get("position")
|
556
566
|
x = position["xy_position"]["x"]
|
557
567
|
y = position["xy_position"]["y"]
|
558
568
|
lng, lat = self.projector(x, y, inverse=True)
|
559
569
|
# ATTENTION: no valid position for an institution
|
560
570
|
parent_id = -1
|
561
571
|
try:
|
562
|
-
nominal_gdp = await agent.
|
572
|
+
nominal_gdp = await agent.status.get("nominal_gdp")
|
563
573
|
except:
|
564
574
|
nominal_gdp = []
|
565
575
|
try:
|
566
|
-
real_gdp = await agent.
|
576
|
+
real_gdp = await agent.status.get("real_gdp")
|
567
577
|
except:
|
568
578
|
real_gdp = []
|
569
579
|
try:
|
570
|
-
unemployment = await agent.
|
580
|
+
unemployment = await agent.status.get("unemployment")
|
571
581
|
except:
|
572
582
|
unemployment = []
|
573
583
|
try:
|
574
|
-
wages = await agent.
|
584
|
+
wages = await agent.status.get("wages")
|
575
585
|
except:
|
576
586
|
wages = []
|
577
587
|
try:
|
578
|
-
prices = await agent.
|
588
|
+
prices = await agent.status.get("prices")
|
579
589
|
except:
|
580
590
|
prices = []
|
581
591
|
try:
|
582
|
-
inventory = await agent.
|
592
|
+
inventory = await agent.status.get("inventory")
|
583
593
|
except:
|
584
594
|
inventory = 0
|
585
595
|
try:
|
586
|
-
price = await agent.
|
596
|
+
price = await agent.status.get("price")
|
587
597
|
except:
|
588
598
|
price = 0.0
|
589
599
|
try:
|
590
|
-
interest_rate = await agent.
|
600
|
+
interest_rate = await agent.status.get("interest_rate")
|
591
601
|
except:
|
592
602
|
interest_rate = 0.0
|
593
603
|
try:
|
594
|
-
bracket_cutoffs = await agent.
|
604
|
+
bracket_cutoffs = await agent.status.get("bracket_cutoffs")
|
595
605
|
except:
|
596
606
|
bracket_cutoffs = []
|
597
607
|
try:
|
598
|
-
bracket_rates = await agent.
|
608
|
+
bracket_rates = await agent.status.get("bracket_rates")
|
599
609
|
except:
|
600
610
|
bracket_rates = []
|
601
611
|
try:
|
602
|
-
employees = await agent.
|
612
|
+
employees = await agent.status.get("employees")
|
603
613
|
except:
|
604
614
|
employees = []
|
605
615
|
_status_dict = {
|
@@ -610,7 +620,7 @@ class AgentGroup:
|
|
610
620
|
"lat": lat,
|
611
621
|
"parent_id": parent_id,
|
612
622
|
"action": "",
|
613
|
-
"type": await agent.
|
623
|
+
"type": await agent.status.get("type"),
|
614
624
|
"nominal_gdp": nominal_gdp,
|
615
625
|
"real_gdp": real_gdp,
|
616
626
|
"unemployment": unemployment,
|
@@ -653,35 +663,18 @@ class AgentGroup:
|
|
653
663
|
)
|
654
664
|
|
655
665
|
async def step(self):
|
656
|
-
if not self.initialized:
|
657
|
-
await self.init_agents()
|
658
|
-
|
659
|
-
tasks = [agent.run() for agent in self.agents]
|
660
|
-
await asyncio.gather(*tasks)
|
661
|
-
await self.save_status()
|
662
|
-
|
663
|
-
async def run(self, day: int = 1):
|
664
|
-
"""运行模拟器
|
665
|
-
|
666
|
-
Args:
|
667
|
-
day: 运行天数,默认为1天
|
668
|
-
"""
|
669
666
|
try:
|
670
|
-
|
671
|
-
|
672
|
-
start_time = int(start_time)
|
673
|
-
# 计算结束时间(秒)
|
674
|
-
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
675
|
-
|
676
|
-
while True:
|
677
|
-
current_time = await self.simulator.get_time()
|
678
|
-
current_time = int(current_time)
|
679
|
-
if current_time >= end_time:
|
680
|
-
break
|
681
|
-
await self.step()
|
682
|
-
|
667
|
+
tasks = [agent.run() for agent in self.agents]
|
668
|
+
await asyncio.gather(*tasks)
|
683
669
|
except Exception as e:
|
684
670
|
import traceback
|
671
|
+
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
672
|
+
raise RuntimeError(str(e)) from e
|
685
673
|
|
674
|
+
async def save(self, day: int, t: int):
|
675
|
+
try:
|
676
|
+
await self.save_status(day, t)
|
677
|
+
except Exception as e:
|
678
|
+
import traceback
|
686
679
|
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
687
680
|
raise RuntimeError(str(e)) from e
|