pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. pycityagent/agent/agent.py +83 -62
  2. pycityagent/agent/agent_base.py +81 -54
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +45 -57
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/message_intercept.py +99 -0
  17. pycityagent/cityagent/nbsagent.py +6 -29
  18. pycityagent/cityagent/societyagent.py +325 -127
  19. pycityagent/cli/wrapper.py +4 -0
  20. pycityagent/economy/econ_client.py +0 -2
  21. pycityagent/environment/__init__.py +7 -1
  22. pycityagent/environment/sim/client.py +10 -1
  23. pycityagent/environment/sim/clock_service.py +2 -2
  24. pycityagent/environment/sim/pause_service.py +61 -0
  25. pycityagent/environment/sim/sim_env.py +34 -46
  26. pycityagent/environment/simulator.py +18 -14
  27. pycityagent/llm/embeddings.py +0 -24
  28. pycityagent/llm/llm.py +18 -10
  29. pycityagent/memory/faiss_query.py +29 -26
  30. pycityagent/memory/memory.py +733 -247
  31. pycityagent/message/__init__.py +8 -1
  32. pycityagent/message/message_interceptor.py +322 -0
  33. pycityagent/message/messager.py +42 -11
  34. pycityagent/pycityagent-sim +0 -0
  35. pycityagent/simulation/agentgroup.py +137 -96
  36. pycityagent/simulation/simulation.py +184 -38
  37. pycityagent/simulation/storage/pg.py +2 -2
  38. pycityagent/tools/tool.py +7 -9
  39. pycityagent/utils/__init__.py +7 -2
  40. pycityagent/utils/pg_query.py +1 -0
  41. pycityagent/utils/survey_util.py +26 -23
  42. pycityagent/workflow/block.py +14 -7
  43. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
  44. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
  45. pycityagent/cityagent/blocks/time_block.py +0 -116
  46. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
  47. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
  48. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
  49. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,15 @@ from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
18
18
  memory_config_government, memory_config_nbs,
19
19
  memory_config_societyagent)
20
20
  from ..cityagent.initial import bind_agent_info, initialize_social_network
21
- from ..environment.simulator import Simulator
22
- from ..llm import SimpleEmbedding
21
+ from ..environment import Simulator
22
+ from ..llm import LLM, LLMConfig, SimpleEmbedding
23
23
  from ..memory import Memory
24
- from ..message.messager import Messager
24
+ from ..message import (MessageBlockBase, MessageBlockListenerBase,
25
+ MessageInterceptor, Messager)
25
26
  from ..metrics import init_mlflow_connection
27
+ from ..metrics.mlflow_client import MlflowClient
26
28
  from ..survey import Survey
27
- from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
29
+ from ..utils import SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
28
30
  from .agentgroup import AgentGroup
29
31
  from .storage.pg import PgWriter, create_pg_tables
30
32
 
@@ -39,6 +41,7 @@ class AgentSimulation:
39
41
  config: dict,
40
42
  agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
41
43
  agent_config_file: Optional[dict] = None,
44
+ metric_extractor: Optional[list[tuple[int, Callable]]] = None,
42
45
  enable_economy: bool = True,
43
46
  agent_prefix: str = "agent_",
44
47
  exp_name: str = "default_experiment",
@@ -80,6 +83,15 @@ class AgentSimulation:
80
83
  self.config = config
81
84
  self.exp_name = exp_name
82
85
  self._simulator = Simulator(config["simulator_request"])
86
+ if enable_economy:
87
+ self._economy_env = self._simulator._sim_env
88
+ _req_dict = self.config["simulator_request"]
89
+ if "economy" in _req_dict:
90
+ _req_dict["economy"]["server"] = self._economy_env.sim_addr
91
+ else:
92
+ _req_dict["economy"] = {
93
+ "server": self._economy_env.sim_addr,
94
+ }
83
95
  self.agent_prefix = agent_prefix
84
96
  self._groups: dict[str, AgentGroup] = {} # type:ignore
85
97
  self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
@@ -89,6 +101,8 @@ class AgentSimulation:
89
101
  self._user_survey_topics: dict[str, str] = {}
90
102
  self._user_interview_topics: dict[str, str] = {}
91
103
  self._loop = asyncio.get_event_loop()
104
+ self._total_steps = 0
105
+ self._simulator_day = 0
92
106
  # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
93
107
 
94
108
  self._messager = Messager.remote(
@@ -102,6 +116,7 @@ class AgentSimulation:
102
116
  _storage_config: dict[str, Any] = config.get("storage", {})
103
117
  if _storage_config is None:
104
118
  _storage_config = {}
119
+
105
120
  # avro
106
121
  _avro_config: dict[str, Any] = _storage_config.get("avro", {})
107
122
  self._enable_avro = _avro_config.get("enabled", False)
@@ -112,6 +127,27 @@ class AgentSimulation:
112
127
  self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
113
128
  self._avro_path.mkdir(parents=True, exist_ok=True)
114
129
 
130
+ # mlflow
131
+ _mlflow_config: dict[str, Any] = config.get("metric_request", {}).get("mlflow")
132
+ mlflow_run_id, _ = init_mlflow_connection(
133
+ config=_mlflow_config,
134
+ mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
135
+ experiment_name=self.exp_name,
136
+ )
137
+ if _mlflow_config:
138
+ logger.info(f"-----Creating Mlflow client...")
139
+ self.mlflow_client = MlflowClient(
140
+ config=_mlflow_config,
141
+ mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
142
+ experiment_name=exp_name,
143
+ run_id=mlflow_run_id,
144
+ )
145
+ self.metric_extractor = metric_extractor
146
+ else:
147
+ logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
148
+ self.mlflow_client = None
149
+ self.metric_extractor = None
150
+
115
151
  # pg
116
152
  _pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
117
153
  self._enable_pgsql = _pgsql_config.get("enabled", False)
@@ -167,11 +203,11 @@ class AgentSimulation:
167
203
  - workflow:
168
204
  - list[Step]
169
205
  - Step:
170
- - type: str, "step", "run", "interview", "survey", "intervene"
206
+ - type: str, "step", "run", "interview", "survey", "intervene", "pause", "resume"
171
207
  - day: int if type is "run", else None
172
- - time: int if type is "step", else None
208
+ - times: int if type is "step", else None
173
209
  - description: Optional[str], description of the step
174
- - step_func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
210
+ - func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
175
211
  - logging_level: Optional[int]
176
212
  - exp_name: Optional[str]
177
213
  """
@@ -201,6 +237,7 @@ class AgentSimulation:
201
237
  agent_count.append(config["agent_config"]["number_of_government"])
202
238
  agent_count.append(config["agent_config"]["number_of_bank"])
203
239
  agent_count.append(config["agent_config"]["number_of_nbs"])
240
+ # TODO(yanjunbo): support MessageInterceptor
204
241
  await simulation.init_agents(
205
242
  agent_count=agent_count,
206
243
  group_size=config["agent_config"].get("group_size", 10000),
@@ -224,10 +261,15 @@ class AgentSimulation:
224
261
  if step["type"] == "run":
225
262
  await simulation.run(step.get("day", 1))
226
263
  elif step["type"] == "step":
227
- # await simulation.step(step.get("time", 1))
228
- await simulation.step()
264
+ times = step.get("times", 1)
265
+ for _ in range(times):
266
+ await simulation.step()
267
+ elif step["type"] == "pause":
268
+ await simulation.pause_simulator()
269
+ elif step["type"] == "resume":
270
+ await simulation.resume_simulator()
229
271
  else:
230
- await step["step_func"](simulation)
272
+ await step["func"](simulation)
231
273
  logger.info("Simulation finished")
232
274
 
233
275
  @property
@@ -261,9 +303,13 @@ class AgentSimulation:
261
303
  return self._agent_uuid2group
262
304
 
263
305
  @property
264
- def messager(self):
306
+ def messager(self) -> ray.ObjectRef:
265
307
  return self._messager
266
308
 
309
+ @property
310
+ def message_interceptor(self) -> ray.ObjectRef:
311
+ return self._message_interceptors[0] # type:ignore
312
+
267
313
  async def _save_exp_info(self) -> None:
268
314
  """异步保存实验信息到YAML文件"""
269
315
  try:
@@ -331,11 +377,21 @@ class AgentSimulation:
331
377
  # 如果没有发生异常且状态不是错误,则更新为完成
332
378
  await self._update_exp_status(2)
333
379
 
380
+ async def pause_simulator(self):
381
+ await self._simulator.pause()
382
+
383
+ async def resume_simulator(self):
384
+ await self._simulator.resume()
385
+
334
386
  async def init_agents(
335
387
  self,
336
388
  agent_count: Union[int, list[int]],
337
389
  group_size: int = 10000,
338
390
  pg_sql_writers: int = 32,
391
+ message_interceptors: int = 1,
392
+ message_interceptor_blocks: Optional[list[MessageBlockBase]] = None,
393
+ social_black_list: Optional[list[tuple[str, str]]] = None,
394
+ message_listener: Optional[MessageBlockListenerBase] = None,
339
395
  embedding_model: Embeddings = SimpleEmbedding(),
340
396
  memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
341
397
  ) -> None:
@@ -344,6 +400,8 @@ class AgentSimulation:
344
400
  Args:
345
401
  agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
346
402
  group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
403
+ pg_sql_writers: 独立的PgSQL writer数量
404
+ message_interceptors: message拦截器数量
347
405
  memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
348
406
  """
349
407
  if not isinstance(agent_count, list):
@@ -353,12 +411,12 @@ class AgentSimulation:
353
411
  raise ValueError("agent_class和agent_count的长度不一致")
354
412
 
355
413
  if memory_config_func is None:
356
- logger.warning(
357
- "memory_config_func is None, using default memory config function"
358
- )
359
414
  memory_config_func = self.default_memory_config_func
360
415
 
361
416
  elif not isinstance(memory_config_func, list):
417
+ logger.warning(
418
+ "memory_config_func is not a list, using specific memory config function"
419
+ )
362
420
  memory_config_func = [memory_config_func]
363
421
 
364
422
  if len(memory_config_func) != len(agent_count):
@@ -464,13 +522,12 @@ class AgentSimulation:
464
522
  config_files,
465
523
  )
466
524
  )
467
-
468
525
  # 初始化mlflow连接
469
526
  _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
470
527
  if _mlflow_config:
471
528
  mlflow_run_id, _ = init_mlflow_connection(
472
529
  config=_mlflow_config,
473
- mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
530
+ mlflow_run_name=f"{self.exp_name}_{1000*int(time.time())}",
474
531
  experiment_name=self.exp_name,
475
532
  )
476
533
  else:
@@ -489,7 +546,31 @@ class AgentSimulation:
489
546
  else:
490
547
  _num_workers = 1
491
548
  self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
492
-
549
+ # message interceptor
550
+ if message_listener is not None:
551
+ self._message_abort_listening_queue = _queue = ray.util.queue.Queue() # type: ignore
552
+ await message_listener.set_queue(_queue)
553
+ else:
554
+ self._message_abort_listening_queue = _queue = None
555
+ _interceptor_blocks = message_interceptor_blocks
556
+ _black_list = [] if social_black_list is None else social_black_list
557
+ _llm_config = self.config.get("llm_request", {})
558
+ if message_interceptor_blocks is not None:
559
+ _num_interceptors = min(1, message_interceptors)
560
+ self._message_interceptors = _interceptors = [
561
+ MessageInterceptor.remote(
562
+ _interceptor_blocks, # type:ignore
563
+ _black_list,
564
+ _llm_config,
565
+ _queue,
566
+ )
567
+ for _ in range(_num_interceptors)
568
+ ]
569
+ else:
570
+ _num_interceptors = 1
571
+ self._message_interceptors = _interceptors = [
572
+ None for _ in range(_num_interceptors)
573
+ ]
493
574
  creation_tasks = []
494
575
  for i, (
495
576
  agent_class,
@@ -504,13 +585,14 @@ class AgentSimulation:
504
585
  number_of_agents,
505
586
  memory_config_function_group,
506
587
  self.config,
507
- self.exp_id,
508
588
  self.exp_name,
589
+ self.exp_id,
509
590
  self.enable_avro,
510
591
  self.avro_path,
511
592
  self.enable_pgsql,
512
593
  _workers[i % _num_workers], # type:ignore
513
- mlflow_run_id, # type:ignore
594
+ self.message_interceptor,
595
+ mlflow_run_id,
514
596
  embedding_model,
515
597
  self.logging_level,
516
598
  config_file,
@@ -536,16 +618,24 @@ class AgentSimulation:
536
618
  self._type2group[agent_type].append(group)
537
619
 
538
620
  # 并行初始化所有组的agents
621
+ await self.resume_simulator()
539
622
  init_tasks = []
540
623
  for group in self._groups.values():
541
624
  init_tasks.append(group.init_agents.remote())
542
625
  ray.get(init_tasks)
626
+ await self.messager.connect.remote() # type:ignore
627
+ await self.messager.subscribe.remote( # type:ignore
628
+ [(f"exps/{self.exp_id}/user_payback", 1)], [self.exp_id]
629
+ )
630
+ await self.messager.start_listening.remote() # type:ignore
543
631
 
544
- async def gather(self, content: str):
632
+ async def gather(
633
+ self, content: str, target_agent_uuids: Optional[list[str]] = None
634
+ ):
545
635
  """收集智能体的特定信息"""
546
636
  gather_tasks = []
547
637
  for group in self._groups.values():
548
- gather_tasks.append(group.gather.remote(content))
638
+ gather_tasks.append(group.gather.remote(content, target_agent_uuids))
549
639
  return await asyncio.gather(*gather_tasks)
550
640
 
551
641
  async def filter(
@@ -585,13 +675,13 @@ class AgentSimulation:
585
675
  self, survey: Survey, agent_uuids: Optional[list[str]] = None
586
676
  ):
587
677
  """发送问卷"""
588
- await self.messager.connect()
678
+ await self.messager.connect.remote() # type:ignore
589
679
  survey_dict = survey.to_dict()
590
680
  if agent_uuids is None:
591
681
  agent_uuids = self._agent_uuids
592
682
  _date_time = datetime.now(timezone.utc)
593
683
  payload = {
594
- "from": "none",
684
+ "from": SURVEY_SENDER_UUID,
595
685
  "survey_id": survey_dict["id"],
596
686
  "timestamp": int(_date_time.timestamp() * 1000),
597
687
  "data": survey_dict,
@@ -599,16 +689,23 @@ class AgentSimulation:
599
689
  }
600
690
  for uuid in agent_uuids:
601
691
  topic = self._user_survey_topics[uuid]
602
- await self.messager.send_message(topic, payload)
692
+ await self.messager.send_message.remote(topic, payload) # type:ignore
693
+ remain_payback = len(agent_uuids)
694
+ while True:
695
+ messages = await self.messager.fetch_messages.remote() # type:ignore
696
+ logger.info(f"Received {len(messages)} payback messages [survey]")
697
+ remain_payback -= len(messages)
698
+ if remain_payback <= 0:
699
+ break
700
+ await asyncio.sleep(3)
603
701
 
604
702
  async def send_interview_message(
605
703
  self, content: str, agent_uuids: Union[str, list[str]]
606
704
  ):
607
705
  """发送采访消息"""
608
- await self.messager.connect()
609
706
  _date_time = datetime.now(timezone.utc)
610
707
  payload = {
611
- "from": "none",
708
+ "from": SURVEY_SENDER_UUID,
612
709
  "content": content,
613
710
  "timestamp": int(_date_time.timestamp() * 1000),
614
711
  "_date_time": _date_time,
@@ -617,24 +714,72 @@ class AgentSimulation:
617
714
  agent_uuids = [agent_uuids]
618
715
  for uuid in agent_uuids:
619
716
  topic = self._user_chat_topics[uuid]
620
- await self.messager.send_message(topic, payload)
717
+ await self.messager.send_message.remote(topic, payload) # type:ignore
718
+ remain_payback = len(agent_uuids)
719
+ while True:
720
+ messages = await self.messager.fetch_messages.remote() # type:ignore
721
+ logger.info(f"Received {len(messages)} payback messages [interview]")
722
+ remain_payback -= len(messages)
723
+ if remain_payback <= 0:
724
+ break
725
+ await asyncio.sleep(3)
726
+
727
+ async def extract_metric(self, metric_extractors: list[Callable]):
728
+ """提取指标"""
729
+ for metric_extractor in metric_extractors:
730
+ await metric_extractor(self)
621
731
 
622
732
  async def step(self):
623
- """运行一步, 即每个智能体执行一次forward"""
733
+ """Run one step, each agent execute one forward"""
624
734
  try:
735
+ # check whether insert agents
736
+ simulator_day = await self._simulator.get_simulator_day()
737
+ print(
738
+ f"simulator_day: {simulator_day}, self._simulator_day: {self._simulator_day}"
739
+ )
740
+ need_insert_agents = False
741
+ if simulator_day > self._simulator_day:
742
+ need_insert_agents = True
743
+ self._simulator_day = simulator_day
744
+ if need_insert_agents:
745
+ await self.resume_simulator()
746
+ insert_tasks = []
747
+ for group in self._groups.values():
748
+ insert_tasks.append(group.insert_agents.remote())
749
+ await asyncio.gather(*insert_tasks)
750
+
751
+ # step
625
752
  tasks = []
626
753
  for group in self._groups.values():
627
754
  tasks.append(group.step.remote())
628
755
  await asyncio.gather(*tasks)
756
+ # save
757
+ simulator_day = await self._simulator.get_simulator_day()
758
+ simulator_time = int(await self._simulator.get_time())
759
+ save_tasks = []
760
+ for group in self._groups.values():
761
+ save_tasks.append(group.save.remote(simulator_day, simulator_time))
762
+ await asyncio.gather(*save_tasks)
763
+ self._total_steps += 1
764
+ if self.metric_extractor is not None:
765
+ print(f"total_steps: {self._total_steps}, excute metric")
766
+ to_excute_metric = [
767
+ metric[1]
768
+ for metric in self.metric_extractor
769
+ if self._total_steps % metric[0] == 0
770
+ ]
771
+ await self.extract_metric(to_excute_metric)
629
772
  except Exception as e:
630
- logger.error(f"运行错误: {str(e)}")
631
- raise
773
+ import traceback
774
+
775
+ logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
776
+ raise RuntimeError(str(e)) from e
632
777
 
633
778
  async def run(
634
779
  self,
635
780
  day: int = 1,
636
781
  ):
637
- """运行模拟器"""
782
+ """Run the simulation by days"""
638
783
  try:
639
784
  self._exp_info["num_day"] += day
640
785
  await self._update_exp_status(1) # 更新状态为运行中
@@ -645,13 +790,14 @@ class AgentSimulation:
645
790
  monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
646
791
 
647
792
  try:
648
- for _ in range(day):
649
- tasks = []
650
- for group in self._groups.values():
651
- tasks.append(group.run.remote())
652
- # 等待所有group运行完成
653
- await asyncio.gather(*tasks)
654
-
793
+ end_time = (
794
+ await self._simulator.get_time() + day * 24 * 3600
795
+ ) # type:ignore
796
+ while True:
797
+ current_time = await self._simulator.get_time()
798
+ if current_time >= end_time: # type:ignore
799
+ break
800
+ await self.step()
655
801
  finally:
656
802
  # 设置停止事件
657
803
  stop_event.set()
@@ -71,11 +71,11 @@ class PgWriter:
71
71
 
72
72
  @lock_decorator
73
73
  async def async_write_status(self, rows: list[tuple]):
74
- _tuple_types = [str, int, float, float, float, int, str, str, None]
74
+ _tuple_types = [str, int, float, float, float, int, list, str, str, None]
75
75
  table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
76
76
  async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
77
77
  copy_sql = psycopg.sql.SQL(
78
- "COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
78
+ "COPY {} (id, day, t, lng, lat, parent_id, friend_ids, action, status, created_at) FROM STDIN"
79
79
  ).format(psycopg.sql.Identifier(table_name))
80
80
  _rows: list[Any] = []
81
81
  async with aconn.cursor() as cur:
pycityagent/tools/tool.py CHANGED
@@ -119,7 +119,7 @@ class SencePOI(Tool):
119
119
  if agent.memory is None or agent.simulator is None:
120
120
  raise ValueError("Memory or Simulator is not set.")
121
121
  if radius is None and category_prefix is None:
122
- position = await agent.memory.get("position")
122
+ position = await agent.status.get("position")
123
123
  resp = []
124
124
  for prefix in self.category_prefix:
125
125
  resp += agent.simulator.map.query_pois(
@@ -146,17 +146,15 @@ class UpdateWithSimulator(Tool):
146
146
  agent = self.agent
147
147
  if agent._simulator is None:
148
148
  return
149
- if not agent._has_bound_to_simulator:
150
- await agent._bind_to_simulator() # type: ignore
151
149
  simulator = agent.simulator
152
- memory = agent.memory
153
- person_id = await memory.get("id")
150
+ status = agent.status
151
+ person_id = await status.get("id")
154
152
  resp = await simulator.get_person(person_id)
155
153
  resp_dict = resp["person"]
156
154
  for k, v in resp_dict.get("motion", {}).items():
157
155
  try:
158
- await memory.get(k)
159
- await memory.update(
156
+ await status.get(k)
157
+ await status.update(
160
158
  k, v, mode="replace", protect_llm_read_only_fields=False
161
159
  )
162
160
  except KeyError as e:
@@ -183,9 +181,9 @@ class ResetAgentPosition(Tool):
183
181
  s: Optional[float] = None,
184
182
  ):
185
183
  agent = self.agent
186
- memory = agent.memory
184
+ status = agent.status
187
185
  await agent.simulator.reset_person_position(
188
- person_id=await memory.get("id"),
186
+ person_id=await status.get("id"),
189
187
  aoi_id=aoi_id,
190
188
  poi_id=poi_id,
191
189
  lane_id=lane_id,
@@ -1,11 +1,16 @@
1
1
  from .avro_schema import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA,
2
2
  PROFILE_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA)
3
3
  from .pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
4
- from .survey_util import process_survey_for_llm
4
+ from .survey_util import SURVEY_SENDER_UUID, process_survey_for_llm
5
5
 
6
6
  __all__ = [
7
- "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
7
+ "PROFILE_SCHEMA",
8
+ "DIALOG_SCHEMA",
9
+ "STATUS_SCHEMA",
10
+ "SURVEY_SCHEMA",
11
+ "INSTITUTION_STATUS_SCHEMA",
8
12
  "process_survey_for_llm",
9
13
  "TO_UPDATE_EXP_INFO_KEYS_AND_TYPES",
10
14
  "PGSQL_DICT",
15
+ "SURVEY_SENDER_UUID",
11
16
  ]
@@ -54,6 +54,7 @@ PGSQL_DICT: dict[str, list[Any]] = {
54
54
  lng DOUBLE PRECISION,
55
55
  lat DOUBLE PRECISION,
56
56
  parent_id INT4,
57
+ friend_ids UUID[],
57
58
  action TEXT,
58
59
  status JSONB,
59
60
  created_at TIMESTAMPTZ
@@ -8,40 +8,40 @@ Survey Description: {survey_dict['description']}
8
8
  Please answer each question in the following format:
9
9
 
10
10
  """
11
-
11
+
12
12
  question_count = 1
13
- for page in survey_dict['pages']:
14
- for question in page['elements']:
13
+ for page in survey_dict["pages"]:
14
+ for question in page["elements"]:
15
15
  prompt += f"Question {question_count}: {question['title']}\n"
16
-
16
+
17
17
  # 根据不同类型的问题生成不同的提示
18
- if question['type'] == 'radiogroup':
19
- prompt += "Options: " + ", ".join(question['choices']) + "\n"
18
+ if question["type"] == "radiogroup":
19
+ prompt += "Options: " + ", ".join(question["choices"]) + "\n"
20
20
  prompt += "Please select ONE option\n"
21
-
22
- elif question['type'] == 'checkbox':
23
- prompt += "Options: " + ", ".join(question['choices']) + "\n"
21
+
22
+ elif question["type"] == "checkbox":
23
+ prompt += "Options: " + ", ".join(question["choices"]) + "\n"
24
24
  prompt += "You can select MULTIPLE options\n"
25
-
26
- elif question['type'] == 'rating':
25
+
26
+ elif question["type"] == "rating":
27
27
  prompt += f"Rating range: {question.get('min_rating', 1)} - {question.get('max_rating', 5)}\n"
28
28
  prompt += "Please provide a rating within the range\n"
29
-
30
- elif question['type'] == 'matrix':
31
- prompt += "Rows: " + ", ".join(question['rows']) + "\n"
32
- prompt += "Columns: " + ", ".join(question['columns']) + "\n"
29
+
30
+ elif question["type"] == "matrix":
31
+ prompt += "Rows: " + ", ".join(question["rows"]) + "\n"
32
+ prompt += "Columns: " + ", ".join(question["columns"]) + "\n"
33
33
  prompt += "Please select ONE column option for EACH row\n"
34
-
35
- elif question['type'] == 'text':
34
+
35
+ elif question["type"] == "text":
36
36
  prompt += "Please provide a text response\n"
37
-
38
- elif question['type'] == 'boolean':
37
+
38
+ elif question["type"] == "boolean":
39
39
  prompt += "Options: Yes, No\n"
40
40
  prompt += "Please select either Yes or No\n"
41
-
41
+
42
42
  prompt += "\nAnswer: [Your response here]\n\n---\n\n"
43
43
  question_count += 1
44
-
44
+
45
45
  # 添加总结提示
46
46
  prompt += """Please ensure:
47
47
  1. All required questions are answered
@@ -49,5 +49,8 @@ Please answer each question in the following format:
49
49
  3. Answers are clear and specific
50
50
 
51
51
  Format your responses exactly as requested above."""
52
-
53
- return prompt
52
+
53
+ return prompt
54
+
55
+
56
+ SURVEY_SENDER_UUID = "none"
@@ -146,6 +146,7 @@ def trigger_class():
146
146
  class Block:
147
147
  configurable_fields: list[str] = []
148
148
  default_values: dict[str, Any] = {}
149
+ fields_description: dict[str, str] = {}
149
150
 
150
151
  def __init__(
151
152
  self,
@@ -173,14 +174,20 @@ class Block:
173
174
 
174
175
  @classmethod
175
176
  def export_class_config(cls) -> dict[str, str]:
176
- return {
177
- field: cls.default_values.get(field, "default_value")
178
- for field in cls.configurable_fields
179
- }
177
+ return (
178
+ {
179
+ field: cls.default_values.get(field, "default_value")
180
+ for field in cls.configurable_fields
181
+ },
182
+ {
183
+ field: cls.fields_description.get(field, "")
184
+ for field in cls.configurable_fields
185
+ }
186
+ )
180
187
 
181
188
  @classmethod
182
189
  def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
183
- instance = cls(name=config["name"])
190
+ instance = cls(name=config["name"]) # type: ignore
184
191
  assert isinstance(config["config"], dict)
185
192
  for field, value in config["config"].items():
186
193
  if field in cls.configurable_fields:
@@ -188,12 +195,12 @@ class Block:
188
195
 
189
196
  # 递归创建子Block
190
197
  for child_config in config.get("children", []):
191
- child_block = Block.import_config(child_config)
198
+ child_block = Block.import_config(child_config) # type: ignore
192
199
  setattr(instance, child_block.name.lower(), child_block)
193
200
 
194
201
  return instance
195
202
 
196
- def load_from_config(self, config: dict[str, list[Dict]]) -> None:
203
+ def load_from_config(self, config: dict[str, list[dict]]) -> None:
197
204
  """
198
205
  使用配置更新当前Block实例的参数,并递归更新子Block。
199
206
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pycityagent
3
- Version: 2.0.0a52
3
+ Version: 2.0.0a54
4
4
  Summary: LLM-based city environment agent building library
5
5
  Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
6
6
  License: MIT License
@@ -45,7 +45,7 @@ Requires-Dist: openai>=1.58.1
45
45
  Requires-Dist: Pillow<12.0.0,>=11.0.0
46
46
  Requires-Dist: protobuf<5.0.0,<=4.24.0
47
47
  Requires-Dist: pycitydata>=1.0.3
48
- Requires-Dist: pycityproto>=2.1.5
48
+ Requires-Dist: pycityproto>=2.2.0
49
49
  Requires-Dist: requests>=2.32.3
50
50
  Requires-Dist: Shapely>=2.0.6
51
51
  Requires-Dist: PyYAML>=6.0.2