pycityagent 2.0.0a52__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a53__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. pycityagent/agent/agent.py +48 -62
  2. pycityagent/agent/agent_base.py +66 -53
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +44 -56
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/nbsagent.py +6 -29
  17. pycityagent/cityagent/societyagent.py +204 -119
  18. pycityagent/environment/sim/client.py +10 -1
  19. pycityagent/environment/sim/clock_service.py +2 -2
  20. pycityagent/environment/sim/pause_service.py +61 -0
  21. pycityagent/environment/simulator.py +17 -12
  22. pycityagent/llm/embeddings.py +0 -24
  23. pycityagent/memory/faiss_query.py +29 -26
  24. pycityagent/memory/memory.py +720 -272
  25. pycityagent/pycityagent-sim +0 -0
  26. pycityagent/simulation/agentgroup.py +92 -99
  27. pycityagent/simulation/simulation.py +115 -40
  28. pycityagent/tools/tool.py +7 -9
  29. pycityagent/workflow/block.py +11 -4
  30. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +1 -1
  31. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
  32. pycityagent/cityagent/blocks/time_block.py +0 -116
  33. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
  34. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +0 -0
  35. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
  36. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
Binary file
@@ -21,7 +21,6 @@ from ..llm.llm import LLM
21
21
  from ..llm.llmconfig import LLMConfig
22
22
  from ..memory import FaissQuery, Memory
23
23
  from ..message import Messager
24
- from ..metrics import MlflowClient
25
24
  from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
26
25
  STATUS_SCHEMA, SURVEY_SCHEMA)
27
26
 
@@ -40,12 +39,10 @@ class AgentGroup:
40
39
  ],
41
40
  config: dict,
42
41
  exp_id: str | UUID,
43
- exp_name: str,
44
42
  enable_avro: bool,
45
43
  avro_path: Path,
46
44
  enable_pgsql: bool,
47
45
  pgsql_writer: ray.ObjectRef,
48
- mlflow_run_id: str,
49
46
  embedding_model: Embeddings,
50
47
  logging_level: int,
51
48
  agent_config_file: Optional[Union[str, list[str]]] = None,
@@ -116,19 +113,6 @@ class AgentGroup:
116
113
  else:
117
114
  self.economy_client = None
118
115
 
119
- # Mlflow
120
- _mlflow_config = config.get("metric_request", {}).get("mlflow")
121
- if _mlflow_config:
122
- logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
123
- self.mlflow_client = MlflowClient(
124
- config=_mlflow_config,
125
- mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
126
- experiment_name=exp_name,
127
- run_id=mlflow_run_id,
128
- )
129
- else:
130
- self.mlflow_client = None
131
-
132
116
  # set FaissQuery
133
117
  if self.embedding_model is not None:
134
118
  self.faiss_query = FaissQuery(
@@ -142,7 +126,11 @@ class AgentGroup:
142
126
  for j in range(number_of_agents_i):
143
127
  memory_config_function_group_i = memory_config_function_group[i]
144
128
  extra_attributes, profile, base = memory_config_function_group_i()
145
- memory = Memory(config=extra_attributes, profile=profile, base=base)
129
+ memory = Memory(
130
+ config=extra_attributes,
131
+ profile=profile,
132
+ base=base
133
+ )
146
134
  agent = agent_class_i(
147
135
  name=f"{agent_class_i.__name__}_{i}",
148
136
  memory=memory,
@@ -151,18 +139,12 @@ class AgentGroup:
151
139
  simulator=self.simulator,
152
140
  )
153
141
  agent.set_exp_id(self.exp_id) # type: ignore
154
- if self.mlflow_client is not None:
155
- agent.set_mlflow_client(self.mlflow_client)
156
142
  if self.messager is not None:
157
143
  agent.set_messager(self.messager)
158
144
  if self.enable_avro:
159
145
  agent.set_avro_file(self.avro_file) # type: ignore
160
146
  if self.enable_pgsql:
161
147
  agent.set_pgsql_writer(self._pgsql_writer)
162
- if self.faiss_query is not None:
163
- agent.memory.set_faiss_query(self.faiss_query)
164
- if self.embedding_model is not None:
165
- agent.memory.set_embedding_model(self.embedding_model)
166
148
  if self.agent_config_file is not None and self.agent_config_file[i]:
167
149
  agent.load_from_file(self.agent_config_file[i])
168
150
  self.agents.append(agent)
@@ -193,11 +175,21 @@ class AgentGroup:
193
175
  self.message_dispatch_task.cancel() # type: ignore
194
176
  await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
195
177
 
178
+ async def insert_agents(self):
179
+ bind_tasks = []
180
+ for agent in self.agents:
181
+ bind_tasks.append(agent.bind_to_simulator()) # type: ignore
182
+ await asyncio.gather(*bind_tasks)
183
+
196
184
  async def init_agents(self):
197
185
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
198
186
  logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
199
- for agent in self.agents:
200
- await agent.bind_to_simulator() # type: ignore
187
+ while True:
188
+ day = await self.simulator.get_simulator_day()
189
+ if day == 0:
190
+ break
191
+ await asyncio.sleep(1)
192
+ await self.insert_agents()
201
193
  self.id2agent = {agent._uuid: agent for agent in self.agents}
202
194
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
203
195
  assert self.messager is not None
@@ -221,7 +213,7 @@ class AgentGroup:
221
213
  with open(filename, "wb") as f:
222
214
  profiles = []
223
215
  for agent in self.agents:
224
- profile = await agent.memory._profile.export()
216
+ profile = await agent.status.profile.export()
225
217
  profile = profile[0]
226
218
  profile["id"] = agent._uuid
227
219
  profiles.append(profile)
@@ -252,7 +244,7 @@ class AgentGroup:
252
244
  if not issubclass(type(self.agents[0]), InstitutionAgent):
253
245
  profiles: list[Any] = []
254
246
  for agent in self.agents:
255
- profile = await agent.memory._profile.export()
247
+ profile = await agent.status.profile.export()
256
248
  profile = profile[0]
257
249
  profile["id"] = agent._uuid
258
250
  profiles.append(
@@ -271,7 +263,7 @@ class AgentGroup:
271
263
  else:
272
264
  profiles: list[Any] = []
273
265
  for agent in self.agents:
274
- profile = await agent.memory._profile.export()
266
+ profile = await agent.status.profile.export()
275
267
  profile = profile[0]
276
268
  profile["id"] = agent._uuid
277
269
  profiles.append(
@@ -290,6 +282,16 @@ class AgentGroup:
290
282
  await self._pgsql_writer.async_write_profile.remote( # type:ignore
291
283
  profiles
292
284
  )
285
+ if self.faiss_query is not None:
286
+ logger.debug(f"-----Initializing embeddings in AgentGroup {self._uuid} ...")
287
+ embedding_tasks = []
288
+ for agent in self.agents:
289
+ embedding_tasks.append(agent.memory.initialize_embeddings())
290
+ agent.memory.set_search_components(self.faiss_query, self.embedding_model)
291
+ agent.memory.set_simulator(self.simulator)
292
+ await asyncio.gather(*embedding_tasks)
293
+ logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
294
+
293
295
  self.initialized = True
294
296
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
295
297
 
@@ -307,7 +309,7 @@ class AgentGroup:
307
309
  if keys:
308
310
  for key in keys:
309
311
  assert values is not None
310
- if not agent.memory.get(key) == values[keys.index(key)]:
312
+ if not agent.status.get(key) == values[keys.index(key)]:
311
313
  add = False
312
314
  break
313
315
  if add:
@@ -315,18 +317,21 @@ class AgentGroup:
315
317
  elif keys:
316
318
  for key in keys:
317
319
  assert values is not None
318
- if not agent.memory.get(key) == values[keys.index(key)]:
320
+ if not agent.status.get(key) == values[keys.index(key)]:
319
321
  add = False
320
322
  break
321
323
  if add:
322
324
  filtered_uuids.append(agent._uuid)
323
325
  return filtered_uuids
324
326
 
325
- async def gather(self, content: str):
327
+ async def gather(self, content: str, target_agent_uuids: Optional[list[str]] = None):
326
328
  logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
327
329
  results = {}
330
+ if target_agent_uuids is None:
331
+ target_agent_uuids = self.agent_uuids
328
332
  for agent in self.agents:
329
- results[agent._uuid] = await agent.memory.get(content)
333
+ if agent._uuid in target_agent_uuids:
334
+ results[agent._uuid] = await agent.status.get(content)
330
335
  return results
331
336
 
332
337
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
@@ -334,7 +339,7 @@ class AgentGroup:
334
339
  f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
335
340
  )
336
341
  agent = self.id2agent[target_agent_uuid]
337
- await agent.memory.update(target_key, content)
342
+ await agent.status.update(target_key, content)
338
343
 
339
344
  async def message_dispatch(self):
340
345
  logger.debug(f"-----Starting message dispatch for group {self._uuid}")
@@ -394,7 +399,7 @@ class AgentGroup:
394
399
  if not issubclass(type(self.agents[0]), InstitutionAgent):
395
400
  for agent in self.agents:
396
401
  _date_time = datetime.now(timezone.utc)
397
- position = await agent.memory.get("position")
402
+ position = await agent.status.get("position")
398
403
  x = position["xy_position"]["x"]
399
404
  y = position["xy_position"]["y"]
400
405
  lng, lat = self.projector(x, y, inverse=True)
@@ -404,8 +409,11 @@ class AgentGroup:
404
409
  parent_id = position["lane_position"]["lane_id"]
405
410
  else:
406
411
  parent_id = -1
407
- needs = await agent.memory.get("needs")
408
- action = await agent.memory.get("current_step")
412
+ hunger_satisfaction = await agent.status.get("hunger_satisfaction")
413
+ energy_satisfaction = await agent.status.get("energy_satisfaction")
414
+ safety_satisfaction = await agent.status.get("safety_satisfaction")
415
+ social_satisfaction = await agent.status.get("social_satisfaction")
416
+ action = await agent.status.get("current_step")
409
417
  action = action["intention"]
410
418
  avro = {
411
419
  "id": agent._uuid,
@@ -415,10 +423,10 @@ class AgentGroup:
415
423
  "lat": lat,
416
424
  "parent_id": parent_id,
417
425
  "action": action,
418
- "hungry": needs["hungry"],
419
- "tired": needs["tired"],
420
- "safe": needs["safe"],
421
- "social": needs["social"],
426
+ "hungry": hunger_satisfaction,
427
+ "tired": energy_satisfaction,
428
+ "safe": safety_satisfaction,
429
+ "social": social_satisfaction,
422
430
  "created_at": int(_date_time.timestamp() * 1000),
423
431
  }
424
432
  avros.append(avro)
@@ -429,54 +437,54 @@ class AgentGroup:
429
437
  for agent in self.agents:
430
438
  _date_time = datetime.now(timezone.utc)
431
439
  try:
432
- nominal_gdp = await agent.memory.get("nominal_gdp")
440
+ nominal_gdp = await agent.status.get("nominal_gdp")
433
441
  except:
434
442
  nominal_gdp = []
435
443
  try:
436
- real_gdp = await agent.memory.get("real_gdp")
444
+ real_gdp = await agent.status.get("real_gdp")
437
445
  except:
438
446
  real_gdp = []
439
447
  try:
440
- unemployment = await agent.memory.get("unemployment")
448
+ unemployment = await agent.status.get("unemployment")
441
449
  except:
442
450
  unemployment = []
443
451
  try:
444
- wages = await agent.memory.get("wages")
452
+ wages = await agent.status.get("wages")
445
453
  except:
446
454
  wages = []
447
455
  try:
448
- prices = await agent.memory.get("prices")
456
+ prices = await agent.status.get("prices")
449
457
  except:
450
458
  prices = []
451
459
  try:
452
- inventory = await agent.memory.get("inventory")
460
+ inventory = await agent.status.get("inventory")
453
461
  except:
454
462
  inventory = 0
455
463
  try:
456
- price = await agent.memory.get("price")
464
+ price = await agent.status.get("price")
457
465
  except:
458
466
  price = 0.0
459
467
  try:
460
- interest_rate = await agent.memory.get("interest_rate")
468
+ interest_rate = await agent.status.get("interest_rate")
461
469
  except:
462
470
  interest_rate = 0.0
463
471
  try:
464
- bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
472
+ bracket_cutoffs = await agent.status.get("bracket_cutoffs")
465
473
  except:
466
474
  bracket_cutoffs = []
467
475
  try:
468
- bracket_rates = await agent.memory.get("bracket_rates")
476
+ bracket_rates = await agent.status.get("bracket_rates")
469
477
  except:
470
478
  bracket_rates = []
471
479
  try:
472
- employees = await agent.memory.get("employees")
480
+ employees = await agent.status.get("employees")
473
481
  except:
474
482
  employees = []
475
483
  avro = {
476
484
  "id": agent._uuid,
477
485
  "day": _day,
478
486
  "t": _t,
479
- "type": await agent.memory.get("type"),
487
+ "type": await agent.status.get("type"),
480
488
  "nominal_gdp": nominal_gdp,
481
489
  "real_gdp": real_gdp,
482
490
  "unemployment": unemployment,
@@ -519,7 +527,7 @@ class AgentGroup:
519
527
  if not issubclass(type(self.agents[0]), InstitutionAgent):
520
528
  for agent in self.agents:
521
529
  _date_time = datetime.now(timezone.utc)
522
- position = await agent.memory.get("position")
530
+ position = await agent.status.get("position")
523
531
  x = position["xy_position"]["x"]
524
532
  y = position["xy_position"]["y"]
525
533
  lng, lat = self.projector(x, y, inverse=True)
@@ -528,10 +536,12 @@ class AgentGroup:
528
536
  elif "lane_position" in position:
529
537
  parent_id = position["lane_position"]["lane_id"]
530
538
  else:
531
- # BUG: 需要处理
532
539
  parent_id = -1
533
- needs = await agent.memory.get("needs")
534
- action = await agent.memory.get("current_step")
540
+ hunger_satisfaction = await agent.status.get("hunger_satisfaction")
541
+ energy_satisfaction = await agent.status.get("energy_satisfaction")
542
+ safety_satisfaction = await agent.status.get("safety_satisfaction")
543
+ social_satisfaction = await agent.status.get("social_satisfaction")
544
+ action = await agent.status.get("current_step")
535
545
  action = action["intention"]
536
546
  _status_dict = {
537
547
  "id": agent._uuid,
@@ -541,10 +551,10 @@ class AgentGroup:
541
551
  "lat": lat,
542
552
  "parent_id": parent_id,
543
553
  "action": action,
544
- "hungry": needs["hungry"],
545
- "tired": needs["tired"],
546
- "safe": needs["safe"],
547
- "social": needs["social"],
554
+ "hungry": hunger_satisfaction,
555
+ "tired": energy_satisfaction,
556
+ "safe": safety_satisfaction,
557
+ "social": social_satisfaction,
548
558
  "created_at": _date_time,
549
559
  }
550
560
  _statuses_time_list.append((_status_dict, _date_time))
@@ -552,54 +562,54 @@ class AgentGroup:
552
562
  # institution
553
563
  for agent in self.agents:
554
564
  _date_time = datetime.now(timezone.utc)
555
- position = await agent.memory.get("position")
565
+ position = await agent.status.get("position")
556
566
  x = position["xy_position"]["x"]
557
567
  y = position["xy_position"]["y"]
558
568
  lng, lat = self.projector(x, y, inverse=True)
559
569
  # ATTENTION: no valid position for an institution
560
570
  parent_id = -1
561
571
  try:
562
- nominal_gdp = await agent.memory.get("nominal_gdp")
572
+ nominal_gdp = await agent.status.get("nominal_gdp")
563
573
  except:
564
574
  nominal_gdp = []
565
575
  try:
566
- real_gdp = await agent.memory.get("real_gdp")
576
+ real_gdp = await agent.status.get("real_gdp")
567
577
  except:
568
578
  real_gdp = []
569
579
  try:
570
- unemployment = await agent.memory.get("unemployment")
580
+ unemployment = await agent.status.get("unemployment")
571
581
  except:
572
582
  unemployment = []
573
583
  try:
574
- wages = await agent.memory.get("wages")
584
+ wages = await agent.status.get("wages")
575
585
  except:
576
586
  wages = []
577
587
  try:
578
- prices = await agent.memory.get("prices")
588
+ prices = await agent.status.get("prices")
579
589
  except:
580
590
  prices = []
581
591
  try:
582
- inventory = await agent.memory.get("inventory")
592
+ inventory = await agent.status.get("inventory")
583
593
  except:
584
594
  inventory = 0
585
595
  try:
586
- price = await agent.memory.get("price")
596
+ price = await agent.status.get("price")
587
597
  except:
588
598
  price = 0.0
589
599
  try:
590
- interest_rate = await agent.memory.get("interest_rate")
600
+ interest_rate = await agent.status.get("interest_rate")
591
601
  except:
592
602
  interest_rate = 0.0
593
603
  try:
594
- bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
604
+ bracket_cutoffs = await agent.status.get("bracket_cutoffs")
595
605
  except:
596
606
  bracket_cutoffs = []
597
607
  try:
598
- bracket_rates = await agent.memory.get("bracket_rates")
608
+ bracket_rates = await agent.status.get("bracket_rates")
599
609
  except:
600
610
  bracket_rates = []
601
611
  try:
602
- employees = await agent.memory.get("employees")
612
+ employees = await agent.status.get("employees")
603
613
  except:
604
614
  employees = []
605
615
  _status_dict = {
@@ -610,7 +620,7 @@ class AgentGroup:
610
620
  "lat": lat,
611
621
  "parent_id": parent_id,
612
622
  "action": "",
613
- "type": await agent.memory.get("type"),
623
+ "type": await agent.status.get("type"),
614
624
  "nominal_gdp": nominal_gdp,
615
625
  "real_gdp": real_gdp,
616
626
  "unemployment": unemployment,
@@ -653,35 +663,18 @@ class AgentGroup:
653
663
  )
654
664
 
655
665
  async def step(self):
656
- if not self.initialized:
657
- await self.init_agents()
658
-
659
- tasks = [agent.run() for agent in self.agents]
660
- await asyncio.gather(*tasks)
661
- await self.save_status()
662
-
663
- async def run(self, day: int = 1):
664
- """运行模拟器
665
-
666
- Args:
667
- day: 运行天数,默认为1天
668
- """
669
666
  try:
670
- # 获取开始时间
671
- start_time = await self.simulator.get_time()
672
- start_time = int(start_time)
673
- # 计算结束时间(秒)
674
- end_time = start_time + day * 24 * 3600 # 将天数转换为秒
675
-
676
- while True:
677
- current_time = await self.simulator.get_time()
678
- current_time = int(current_time)
679
- if current_time >= end_time:
680
- break
681
- await self.step()
682
-
667
+ tasks = [agent.run() for agent in self.agents]
668
+ await asyncio.gather(*tasks)
683
669
  except Exception as e:
684
670
  import traceback
671
+ logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
672
+ raise RuntimeError(str(e)) from e
685
673
 
674
+ async def save(self, day: int, t: int):
675
+ try:
676
+ await self.save_status(day, t)
677
+ except Exception as e:
678
+ import traceback
686
679
  logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
687
680
  raise RuntimeError(str(e)) from e