pycityagent 2.0.0a51__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a53__cp311-cp311-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (36) hide show
  1. pycityagent/agent/agent.py +48 -62
  2. pycityagent/agent/agent_base.py +66 -53
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +44 -56
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/nbsagent.py +6 -29
  17. pycityagent/cityagent/societyagent.py +204 -119
  18. pycityagent/environment/sim/client.py +10 -1
  19. pycityagent/environment/sim/clock_service.py +2 -2
  20. pycityagent/environment/sim/pause_service.py +61 -0
  21. pycityagent/environment/simulator.py +17 -12
  22. pycityagent/llm/embeddings.py +0 -24
  23. pycityagent/memory/faiss_query.py +29 -26
  24. pycityagent/memory/memory.py +720 -272
  25. pycityagent/pycityagent-sim +0 -0
  26. pycityagent/simulation/agentgroup.py +92 -99
  27. pycityagent/simulation/simulation.py +115 -40
  28. pycityagent/tools/tool.py +7 -10
  29. pycityagent/workflow/block.py +11 -4
  30. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/METADATA +2 -2
  31. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/RECORD +35 -35
  32. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/WHEEL +1 -1
  33. pycityagent/cityagent/blocks/time_block.py +0 -116
  34. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/LICENSE +0 -0
  35. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/entry_points.txt +0 -0
  36. {pycityagent-2.0.0a51.dist-info → pycityagent-2.0.0a53.dist-info}/top_level.txt +0 -0
Binary file
@@ -21,7 +21,6 @@ from ..llm.llm import LLM
21
21
  from ..llm.llmconfig import LLMConfig
22
22
  from ..memory import FaissQuery, Memory
23
23
  from ..message import Messager
24
- from ..metrics import MlflowClient
25
24
  from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
26
25
  STATUS_SCHEMA, SURVEY_SCHEMA)
27
26
 
@@ -40,12 +39,10 @@ class AgentGroup:
40
39
  ],
41
40
  config: dict,
42
41
  exp_id: str | UUID,
43
- exp_name: str,
44
42
  enable_avro: bool,
45
43
  avro_path: Path,
46
44
  enable_pgsql: bool,
47
45
  pgsql_writer: ray.ObjectRef,
48
- mlflow_run_id: str,
49
46
  embedding_model: Embeddings,
50
47
  logging_level: int,
51
48
  agent_config_file: Optional[Union[str, list[str]]] = None,
@@ -116,19 +113,6 @@ class AgentGroup:
116
113
  else:
117
114
  self.economy_client = None
118
115
 
119
- # Mlflow
120
- _mlflow_config = config.get("metric_request", {}).get("mlflow")
121
- if _mlflow_config:
122
- logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
123
- self.mlflow_client = MlflowClient(
124
- config=_mlflow_config,
125
- mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
126
- experiment_name=exp_name,
127
- run_id=mlflow_run_id,
128
- )
129
- else:
130
- self.mlflow_client = None
131
-
132
116
  # set FaissQuery
133
117
  if self.embedding_model is not None:
134
118
  self.faiss_query = FaissQuery(
@@ -142,7 +126,11 @@ class AgentGroup:
142
126
  for j in range(number_of_agents_i):
143
127
  memory_config_function_group_i = memory_config_function_group[i]
144
128
  extra_attributes, profile, base = memory_config_function_group_i()
145
- memory = Memory(config=extra_attributes, profile=profile, base=base)
129
+ memory = Memory(
130
+ config=extra_attributes,
131
+ profile=profile,
132
+ base=base
133
+ )
146
134
  agent = agent_class_i(
147
135
  name=f"{agent_class_i.__name__}_{i}",
148
136
  memory=memory,
@@ -151,18 +139,12 @@ class AgentGroup:
151
139
  simulator=self.simulator,
152
140
  )
153
141
  agent.set_exp_id(self.exp_id) # type: ignore
154
- if self.mlflow_client is not None:
155
- agent.set_mlflow_client(self.mlflow_client)
156
142
  if self.messager is not None:
157
143
  agent.set_messager(self.messager)
158
144
  if self.enable_avro:
159
145
  agent.set_avro_file(self.avro_file) # type: ignore
160
146
  if self.enable_pgsql:
161
147
  agent.set_pgsql_writer(self._pgsql_writer)
162
- if self.faiss_query is not None:
163
- agent.memory.set_faiss_query(self.faiss_query)
164
- if self.embedding_model is not None:
165
- agent.memory.set_embedding_model(self.embedding_model)
166
148
  if self.agent_config_file is not None and self.agent_config_file[i]:
167
149
  agent.load_from_file(self.agent_config_file[i])
168
150
  self.agents.append(agent)
@@ -193,11 +175,21 @@ class AgentGroup:
193
175
  self.message_dispatch_task.cancel() # type: ignore
194
176
  await asyncio.gather(self.message_dispatch_task, return_exceptions=True) # type: ignore
195
177
 
178
+ async def insert_agents(self):
179
+ bind_tasks = []
180
+ for agent in self.agents:
181
+ bind_tasks.append(agent.bind_to_simulator()) # type: ignore
182
+ await asyncio.gather(*bind_tasks)
183
+
196
184
  async def init_agents(self):
197
185
  logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
198
186
  logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
199
- for agent in self.agents:
200
- await agent.bind_to_simulator() # type: ignore
187
+ while True:
188
+ day = await self.simulator.get_simulator_day()
189
+ if day == 0:
190
+ break
191
+ await asyncio.sleep(1)
192
+ await self.insert_agents()
201
193
  self.id2agent = {agent._uuid: agent for agent in self.agents}
202
194
  logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
203
195
  assert self.messager is not None
@@ -221,7 +213,7 @@ class AgentGroup:
221
213
  with open(filename, "wb") as f:
222
214
  profiles = []
223
215
  for agent in self.agents:
224
- profile = await agent.memory._profile.export()
216
+ profile = await agent.status.profile.export()
225
217
  profile = profile[0]
226
218
  profile["id"] = agent._uuid
227
219
  profiles.append(profile)
@@ -252,7 +244,7 @@ class AgentGroup:
252
244
  if not issubclass(type(self.agents[0]), InstitutionAgent):
253
245
  profiles: list[Any] = []
254
246
  for agent in self.agents:
255
- profile = await agent.memory._profile.export()
247
+ profile = await agent.status.profile.export()
256
248
  profile = profile[0]
257
249
  profile["id"] = agent._uuid
258
250
  profiles.append(
@@ -271,7 +263,7 @@ class AgentGroup:
271
263
  else:
272
264
  profiles: list[Any] = []
273
265
  for agent in self.agents:
274
- profile = await agent.memory._profile.export()
266
+ profile = await agent.status.profile.export()
275
267
  profile = profile[0]
276
268
  profile["id"] = agent._uuid
277
269
  profiles.append(
@@ -290,6 +282,16 @@ class AgentGroup:
290
282
  await self._pgsql_writer.async_write_profile.remote( # type:ignore
291
283
  profiles
292
284
  )
285
+ if self.faiss_query is not None:
286
+ logger.debug(f"-----Initializing embeddings in AgentGroup {self._uuid} ...")
287
+ embedding_tasks = []
288
+ for agent in self.agents:
289
+ embedding_tasks.append(agent.memory.initialize_embeddings())
290
+ agent.memory.set_search_components(self.faiss_query, self.embedding_model)
291
+ agent.memory.set_simulator(self.simulator)
292
+ await asyncio.gather(*embedding_tasks)
293
+ logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
294
+
293
295
  self.initialized = True
294
296
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
295
297
 
@@ -307,7 +309,7 @@ class AgentGroup:
307
309
  if keys:
308
310
  for key in keys:
309
311
  assert values is not None
310
- if not agent.memory.get(key) == values[keys.index(key)]:
312
+ if not agent.status.get(key) == values[keys.index(key)]:
311
313
  add = False
312
314
  break
313
315
  if add:
@@ -315,18 +317,21 @@ class AgentGroup:
315
317
  elif keys:
316
318
  for key in keys:
317
319
  assert values is not None
318
- if not agent.memory.get(key) == values[keys.index(key)]:
320
+ if not agent.status.get(key) == values[keys.index(key)]:
319
321
  add = False
320
322
  break
321
323
  if add:
322
324
  filtered_uuids.append(agent._uuid)
323
325
  return filtered_uuids
324
326
 
325
- async def gather(self, content: str):
327
+ async def gather(self, content: str, target_agent_uuids: Optional[list[str]] = None):
326
328
  logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
327
329
  results = {}
330
+ if target_agent_uuids is None:
331
+ target_agent_uuids = self.agent_uuids
328
332
  for agent in self.agents:
329
- results[agent._uuid] = await agent.memory.get(content)
333
+ if agent._uuid in target_agent_uuids:
334
+ results[agent._uuid] = await agent.status.get(content)
330
335
  return results
331
336
 
332
337
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
@@ -334,7 +339,7 @@ class AgentGroup:
334
339
  f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}"
335
340
  )
336
341
  agent = self.id2agent[target_agent_uuid]
337
- await agent.memory.update(target_key, content)
342
+ await agent.status.update(target_key, content)
338
343
 
339
344
  async def message_dispatch(self):
340
345
  logger.debug(f"-----Starting message dispatch for group {self._uuid}")
@@ -394,7 +399,7 @@ class AgentGroup:
394
399
  if not issubclass(type(self.agents[0]), InstitutionAgent):
395
400
  for agent in self.agents:
396
401
  _date_time = datetime.now(timezone.utc)
397
- position = await agent.memory.get("position")
402
+ position = await agent.status.get("position")
398
403
  x = position["xy_position"]["x"]
399
404
  y = position["xy_position"]["y"]
400
405
  lng, lat = self.projector(x, y, inverse=True)
@@ -404,8 +409,11 @@ class AgentGroup:
404
409
  parent_id = position["lane_position"]["lane_id"]
405
410
  else:
406
411
  parent_id = -1
407
- needs = await agent.memory.get("needs")
408
- action = await agent.memory.get("current_step")
412
+ hunger_satisfaction = await agent.status.get("hunger_satisfaction")
413
+ energy_satisfaction = await agent.status.get("energy_satisfaction")
414
+ safety_satisfaction = await agent.status.get("safety_satisfaction")
415
+ social_satisfaction = await agent.status.get("social_satisfaction")
416
+ action = await agent.status.get("current_step")
409
417
  action = action["intention"]
410
418
  avro = {
411
419
  "id": agent._uuid,
@@ -415,10 +423,10 @@ class AgentGroup:
415
423
  "lat": lat,
416
424
  "parent_id": parent_id,
417
425
  "action": action,
418
- "hungry": needs["hungry"],
419
- "tired": needs["tired"],
420
- "safe": needs["safe"],
421
- "social": needs["social"],
426
+ "hungry": hunger_satisfaction,
427
+ "tired": energy_satisfaction,
428
+ "safe": safety_satisfaction,
429
+ "social": social_satisfaction,
422
430
  "created_at": int(_date_time.timestamp() * 1000),
423
431
  }
424
432
  avros.append(avro)
@@ -429,54 +437,54 @@ class AgentGroup:
429
437
  for agent in self.agents:
430
438
  _date_time = datetime.now(timezone.utc)
431
439
  try:
432
- nominal_gdp = await agent.memory.get("nominal_gdp")
440
+ nominal_gdp = await agent.status.get("nominal_gdp")
433
441
  except:
434
442
  nominal_gdp = []
435
443
  try:
436
- real_gdp = await agent.memory.get("real_gdp")
444
+ real_gdp = await agent.status.get("real_gdp")
437
445
  except:
438
446
  real_gdp = []
439
447
  try:
440
- unemployment = await agent.memory.get("unemployment")
448
+ unemployment = await agent.status.get("unemployment")
441
449
  except:
442
450
  unemployment = []
443
451
  try:
444
- wages = await agent.memory.get("wages")
452
+ wages = await agent.status.get("wages")
445
453
  except:
446
454
  wages = []
447
455
  try:
448
- prices = await agent.memory.get("prices")
456
+ prices = await agent.status.get("prices")
449
457
  except:
450
458
  prices = []
451
459
  try:
452
- inventory = await agent.memory.get("inventory")
460
+ inventory = await agent.status.get("inventory")
453
461
  except:
454
462
  inventory = 0
455
463
  try:
456
- price = await agent.memory.get("price")
464
+ price = await agent.status.get("price")
457
465
  except:
458
466
  price = 0.0
459
467
  try:
460
- interest_rate = await agent.memory.get("interest_rate")
468
+ interest_rate = await agent.status.get("interest_rate")
461
469
  except:
462
470
  interest_rate = 0.0
463
471
  try:
464
- bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
472
+ bracket_cutoffs = await agent.status.get("bracket_cutoffs")
465
473
  except:
466
474
  bracket_cutoffs = []
467
475
  try:
468
- bracket_rates = await agent.memory.get("bracket_rates")
476
+ bracket_rates = await agent.status.get("bracket_rates")
469
477
  except:
470
478
  bracket_rates = []
471
479
  try:
472
- employees = await agent.memory.get("employees")
480
+ employees = await agent.status.get("employees")
473
481
  except:
474
482
  employees = []
475
483
  avro = {
476
484
  "id": agent._uuid,
477
485
  "day": _day,
478
486
  "t": _t,
479
- "type": await agent.memory.get("type"),
487
+ "type": await agent.status.get("type"),
480
488
  "nominal_gdp": nominal_gdp,
481
489
  "real_gdp": real_gdp,
482
490
  "unemployment": unemployment,
@@ -519,7 +527,7 @@ class AgentGroup:
519
527
  if not issubclass(type(self.agents[0]), InstitutionAgent):
520
528
  for agent in self.agents:
521
529
  _date_time = datetime.now(timezone.utc)
522
- position = await agent.memory.get("position")
530
+ position = await agent.status.get("position")
523
531
  x = position["xy_position"]["x"]
524
532
  y = position["xy_position"]["y"]
525
533
  lng, lat = self.projector(x, y, inverse=True)
@@ -528,10 +536,12 @@ class AgentGroup:
528
536
  elif "lane_position" in position:
529
537
  parent_id = position["lane_position"]["lane_id"]
530
538
  else:
531
- # BUG: 需要处理
532
539
  parent_id = -1
533
- needs = await agent.memory.get("needs")
534
- action = await agent.memory.get("current_step")
540
+ hunger_satisfaction = await agent.status.get("hunger_satisfaction")
541
+ energy_satisfaction = await agent.status.get("energy_satisfaction")
542
+ safety_satisfaction = await agent.status.get("safety_satisfaction")
543
+ social_satisfaction = await agent.status.get("social_satisfaction")
544
+ action = await agent.status.get("current_step")
535
545
  action = action["intention"]
536
546
  _status_dict = {
537
547
  "id": agent._uuid,
@@ -541,10 +551,10 @@ class AgentGroup:
541
551
  "lat": lat,
542
552
  "parent_id": parent_id,
543
553
  "action": action,
544
- "hungry": needs["hungry"],
545
- "tired": needs["tired"],
546
- "safe": needs["safe"],
547
- "social": needs["social"],
554
+ "hungry": hunger_satisfaction,
555
+ "tired": energy_satisfaction,
556
+ "safe": safety_satisfaction,
557
+ "social": social_satisfaction,
548
558
  "created_at": _date_time,
549
559
  }
550
560
  _statuses_time_list.append((_status_dict, _date_time))
@@ -552,54 +562,54 @@ class AgentGroup:
552
562
  # institution
553
563
  for agent in self.agents:
554
564
  _date_time = datetime.now(timezone.utc)
555
- position = await agent.memory.get("position")
565
+ position = await agent.status.get("position")
556
566
  x = position["xy_position"]["x"]
557
567
  y = position["xy_position"]["y"]
558
568
  lng, lat = self.projector(x, y, inverse=True)
559
569
  # ATTENTION: no valid position for an institution
560
570
  parent_id = -1
561
571
  try:
562
- nominal_gdp = await agent.memory.get("nominal_gdp")
572
+ nominal_gdp = await agent.status.get("nominal_gdp")
563
573
  except:
564
574
  nominal_gdp = []
565
575
  try:
566
- real_gdp = await agent.memory.get("real_gdp")
576
+ real_gdp = await agent.status.get("real_gdp")
567
577
  except:
568
578
  real_gdp = []
569
579
  try:
570
- unemployment = await agent.memory.get("unemployment")
580
+ unemployment = await agent.status.get("unemployment")
571
581
  except:
572
582
  unemployment = []
573
583
  try:
574
- wages = await agent.memory.get("wages")
584
+ wages = await agent.status.get("wages")
575
585
  except:
576
586
  wages = []
577
587
  try:
578
- prices = await agent.memory.get("prices")
588
+ prices = await agent.status.get("prices")
579
589
  except:
580
590
  prices = []
581
591
  try:
582
- inventory = await agent.memory.get("inventory")
592
+ inventory = await agent.status.get("inventory")
583
593
  except:
584
594
  inventory = 0
585
595
  try:
586
- price = await agent.memory.get("price")
596
+ price = await agent.status.get("price")
587
597
  except:
588
598
  price = 0.0
589
599
  try:
590
- interest_rate = await agent.memory.get("interest_rate")
600
+ interest_rate = await agent.status.get("interest_rate")
591
601
  except:
592
602
  interest_rate = 0.0
593
603
  try:
594
- bracket_cutoffs = await agent.memory.get("bracket_cutoffs")
604
+ bracket_cutoffs = await agent.status.get("bracket_cutoffs")
595
605
  except:
596
606
  bracket_cutoffs = []
597
607
  try:
598
- bracket_rates = await agent.memory.get("bracket_rates")
608
+ bracket_rates = await agent.status.get("bracket_rates")
599
609
  except:
600
610
  bracket_rates = []
601
611
  try:
602
- employees = await agent.memory.get("employees")
612
+ employees = await agent.status.get("employees")
603
613
  except:
604
614
  employees = []
605
615
  _status_dict = {
@@ -610,7 +620,7 @@ class AgentGroup:
610
620
  "lat": lat,
611
621
  "parent_id": parent_id,
612
622
  "action": "",
613
- "type": await agent.memory.get("type"),
623
+ "type": await agent.status.get("type"),
614
624
  "nominal_gdp": nominal_gdp,
615
625
  "real_gdp": real_gdp,
616
626
  "unemployment": unemployment,
@@ -653,35 +663,18 @@ class AgentGroup:
653
663
  )
654
664
 
655
665
  async def step(self):
656
- if not self.initialized:
657
- await self.init_agents()
658
-
659
- tasks = [agent.run() for agent in self.agents]
660
- await asyncio.gather(*tasks)
661
- await self.save_status()
662
-
663
- async def run(self, day: int = 1):
664
- """运行模拟器
665
-
666
- Args:
667
- day: 运行天数,默认为1天
668
- """
669
666
  try:
670
- # 获取开始时间
671
- start_time = await self.simulator.get_time()
672
- start_time = int(start_time)
673
- # 计算结束时间(秒)
674
- end_time = start_time + day * 24 * 3600 # 将天数转换为秒
675
-
676
- while True:
677
- current_time = await self.simulator.get_time()
678
- current_time = int(current_time)
679
- if current_time >= end_time:
680
- break
681
- await self.step()
682
-
667
+ tasks = [agent.run() for agent in self.agents]
668
+ await asyncio.gather(*tasks)
683
669
  except Exception as e:
684
670
  import traceback
671
+ logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
672
+ raise RuntimeError(str(e)) from e
685
673
 
674
+ async def save(self, day: int, t: int):
675
+ try:
676
+ await self.save_status(day, t)
677
+ except Exception as e:
678
+ import traceback
686
679
  logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
687
680
  raise RuntimeError(str(e)) from e