pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (49) hide show
  1. pycityagent/agent/agent.py +83 -62
  2. pycityagent/agent/agent_base.py +81 -54
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +45 -57
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/message_intercept.py +99 -0
  17. pycityagent/cityagent/nbsagent.py +6 -29
  18. pycityagent/cityagent/societyagent.py +325 -127
  19. pycityagent/cli/wrapper.py +4 -0
  20. pycityagent/economy/econ_client.py +0 -2
  21. pycityagent/environment/__init__.py +7 -1
  22. pycityagent/environment/sim/client.py +10 -1
  23. pycityagent/environment/sim/clock_service.py +2 -2
  24. pycityagent/environment/sim/pause_service.py +61 -0
  25. pycityagent/environment/sim/sim_env.py +34 -46
  26. pycityagent/environment/simulator.py +18 -14
  27. pycityagent/llm/embeddings.py +0 -24
  28. pycityagent/llm/llm.py +18 -10
  29. pycityagent/memory/faiss_query.py +29 -26
  30. pycityagent/memory/memory.py +733 -247
  31. pycityagent/message/__init__.py +8 -1
  32. pycityagent/message/message_interceptor.py +322 -0
  33. pycityagent/message/messager.py +42 -11
  34. pycityagent/pycityagent-sim +0 -0
  35. pycityagent/simulation/agentgroup.py +137 -96
  36. pycityagent/simulation/simulation.py +184 -38
  37. pycityagent/simulation/storage/pg.py +2 -2
  38. pycityagent/tools/tool.py +7 -9
  39. pycityagent/utils/__init__.py +7 -2
  40. pycityagent/utils/pg_query.py +1 -0
  41. pycityagent/utils/survey_util.py +26 -23
  42. pycityagent/workflow/block.py +14 -7
  43. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
  44. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
  45. pycityagent/cityagent/blocks/time_block.py +0 -116
  46. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
  47. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
  48. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
  49. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -18,13 +18,15 @@ from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
18
18
  memory_config_government, memory_config_nbs,
19
19
  memory_config_societyagent)
20
20
  from ..cityagent.initial import bind_agent_info, initialize_social_network
21
- from ..environment.simulator import Simulator
22
- from ..llm import SimpleEmbedding
21
+ from ..environment import Simulator
22
+ from ..llm import LLM, LLMConfig, SimpleEmbedding
23
23
  from ..memory import Memory
24
- from ..message.messager import Messager
24
+ from ..message import (MessageBlockBase, MessageBlockListenerBase,
25
+ MessageInterceptor, Messager)
25
26
  from ..metrics import init_mlflow_connection
27
+ from ..metrics.mlflow_client import MlflowClient
26
28
  from ..survey import Survey
27
- from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
29
+ from ..utils import SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
28
30
  from .agentgroup import AgentGroup
29
31
  from .storage.pg import PgWriter, create_pg_tables
30
32
 
@@ -39,6 +41,7 @@ class AgentSimulation:
39
41
  config: dict,
40
42
  agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
41
43
  agent_config_file: Optional[dict] = None,
44
+ metric_extractor: Optional[list[tuple[int, Callable]]] = None,
42
45
  enable_economy: bool = True,
43
46
  agent_prefix: str = "agent_",
44
47
  exp_name: str = "default_experiment",
@@ -80,6 +83,15 @@ class AgentSimulation:
80
83
  self.config = config
81
84
  self.exp_name = exp_name
82
85
  self._simulator = Simulator(config["simulator_request"])
86
+ if enable_economy:
87
+ self._economy_env = self._simulator._sim_env
88
+ _req_dict = self.config["simulator_request"]
89
+ if "economy" in _req_dict:
90
+ _req_dict["economy"]["server"] = self._economy_env.sim_addr
91
+ else:
92
+ _req_dict["economy"] = {
93
+ "server": self._economy_env.sim_addr,
94
+ }
83
95
  self.agent_prefix = agent_prefix
84
96
  self._groups: dict[str, AgentGroup] = {} # type:ignore
85
97
  self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
@@ -89,6 +101,8 @@ class AgentSimulation:
89
101
  self._user_survey_topics: dict[str, str] = {}
90
102
  self._user_interview_topics: dict[str, str] = {}
91
103
  self._loop = asyncio.get_event_loop()
104
+ self._total_steps = 0
105
+ self._simulator_day = 0
92
106
  # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
93
107
 
94
108
  self._messager = Messager.remote(
@@ -102,6 +116,7 @@ class AgentSimulation:
102
116
  _storage_config: dict[str, Any] = config.get("storage", {})
103
117
  if _storage_config is None:
104
118
  _storage_config = {}
119
+
105
120
  # avro
106
121
  _avro_config: dict[str, Any] = _storage_config.get("avro", {})
107
122
  self._enable_avro = _avro_config.get("enabled", False)
@@ -112,6 +127,27 @@ class AgentSimulation:
112
127
  self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
113
128
  self._avro_path.mkdir(parents=True, exist_ok=True)
114
129
 
130
+ # mlflow
131
+ _mlflow_config: dict[str, Any] = config.get("metric_request", {}).get("mlflow")
132
+ mlflow_run_id, _ = init_mlflow_connection(
133
+ config=_mlflow_config,
134
+ mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
135
+ experiment_name=self.exp_name,
136
+ )
137
+ if _mlflow_config:
138
+ logger.info(f"-----Creating Mlflow client...")
139
+ self.mlflow_client = MlflowClient(
140
+ config=_mlflow_config,
141
+ mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
142
+ experiment_name=exp_name,
143
+ run_id=mlflow_run_id,
144
+ )
145
+ self.metric_extractor = metric_extractor
146
+ else:
147
+ logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
148
+ self.mlflow_client = None
149
+ self.metric_extractor = None
150
+
115
151
  # pg
116
152
  _pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
117
153
  self._enable_pgsql = _pgsql_config.get("enabled", False)
@@ -167,11 +203,11 @@ class AgentSimulation:
167
203
  - workflow:
168
204
  - list[Step]
169
205
  - Step:
170
- - type: str, "step", "run", "interview", "survey", "intervene"
206
+ - type: str, "step", "run", "interview", "survey", "intervene", "pause", "resume"
171
207
  - day: int if type is "run", else None
172
- - time: int if type is "step", else None
208
+ - times: int if type is "step", else None
173
209
  - description: Optional[str], description of the step
174
- - step_func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
210
+ - func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
175
211
  - logging_level: Optional[int]
176
212
  - exp_name: Optional[str]
177
213
  """
@@ -201,6 +237,7 @@ class AgentSimulation:
201
237
  agent_count.append(config["agent_config"]["number_of_government"])
202
238
  agent_count.append(config["agent_config"]["number_of_bank"])
203
239
  agent_count.append(config["agent_config"]["number_of_nbs"])
240
+ # TODO(yanjunbo): support MessageInterceptor
204
241
  await simulation.init_agents(
205
242
  agent_count=agent_count,
206
243
  group_size=config["agent_config"].get("group_size", 10000),
@@ -224,10 +261,15 @@ class AgentSimulation:
224
261
  if step["type"] == "run":
225
262
  await simulation.run(step.get("day", 1))
226
263
  elif step["type"] == "step":
227
- # await simulation.step(step.get("time", 1))
228
- await simulation.step()
264
+ times = step.get("times", 1)
265
+ for _ in range(times):
266
+ await simulation.step()
267
+ elif step["type"] == "pause":
268
+ await simulation.pause_simulator()
269
+ elif step["type"] == "resume":
270
+ await simulation.resume_simulator()
229
271
  else:
230
- await step["step_func"](simulation)
272
+ await step["func"](simulation)
231
273
  logger.info("Simulation finished")
232
274
 
233
275
  @property
@@ -261,9 +303,13 @@ class AgentSimulation:
261
303
  return self._agent_uuid2group
262
304
 
263
305
  @property
264
- def messager(self):
306
+ def messager(self) -> ray.ObjectRef:
265
307
  return self._messager
266
308
 
309
+ @property
310
+ def message_interceptor(self) -> ray.ObjectRef:
311
+ return self._message_interceptors[0] # type:ignore
312
+
267
313
  async def _save_exp_info(self) -> None:
268
314
  """异步保存实验信息到YAML文件"""
269
315
  try:
@@ -331,11 +377,21 @@ class AgentSimulation:
331
377
  # 如果没有发生异常且状态不是错误,则更新为完成
332
378
  await self._update_exp_status(2)
333
379
 
380
+ async def pause_simulator(self):
381
+ await self._simulator.pause()
382
+
383
+ async def resume_simulator(self):
384
+ await self._simulator.resume()
385
+
334
386
  async def init_agents(
335
387
  self,
336
388
  agent_count: Union[int, list[int]],
337
389
  group_size: int = 10000,
338
390
  pg_sql_writers: int = 32,
391
+ message_interceptors: int = 1,
392
+ message_interceptor_blocks: Optional[list[MessageBlockBase]] = None,
393
+ social_black_list: Optional[list[tuple[str, str]]] = None,
394
+ message_listener: Optional[MessageBlockListenerBase] = None,
339
395
  embedding_model: Embeddings = SimpleEmbedding(),
340
396
  memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
341
397
  ) -> None:
@@ -344,6 +400,8 @@ class AgentSimulation:
344
400
  Args:
345
401
  agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
346
402
  group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
403
+ pg_sql_writers: 独立的PgSQL writer数量
404
+ message_interceptors: message拦截器数量
347
405
  memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
348
406
  """
349
407
  if not isinstance(agent_count, list):
@@ -353,12 +411,12 @@ class AgentSimulation:
353
411
  raise ValueError("agent_class和agent_count的长度不一致")
354
412
 
355
413
  if memory_config_func is None:
356
- logger.warning(
357
- "memory_config_func is None, using default memory config function"
358
- )
359
414
  memory_config_func = self.default_memory_config_func
360
415
 
361
416
  elif not isinstance(memory_config_func, list):
417
+ logger.warning(
418
+ "memory_config_func is not a list, using specific memory config function"
419
+ )
362
420
  memory_config_func = [memory_config_func]
363
421
 
364
422
  if len(memory_config_func) != len(agent_count):
@@ -464,13 +522,12 @@ class AgentSimulation:
464
522
  config_files,
465
523
  )
466
524
  )
467
-
468
525
  # 初始化mlflow连接
469
526
  _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
470
527
  if _mlflow_config:
471
528
  mlflow_run_id, _ = init_mlflow_connection(
472
529
  config=_mlflow_config,
473
- mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
530
+ mlflow_run_name=f"{self.exp_name}_{1000*int(time.time())}",
474
531
  experiment_name=self.exp_name,
475
532
  )
476
533
  else:
@@ -489,7 +546,31 @@ class AgentSimulation:
489
546
  else:
490
547
  _num_workers = 1
491
548
  self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
492
-
549
+ # message interceptor
550
+ if message_listener is not None:
551
+ self._message_abort_listening_queue = _queue = ray.util.queue.Queue() # type: ignore
552
+ await message_listener.set_queue(_queue)
553
+ else:
554
+ self._message_abort_listening_queue = _queue = None
555
+ _interceptor_blocks = message_interceptor_blocks
556
+ _black_list = [] if social_black_list is None else social_black_list
557
+ _llm_config = self.config.get("llm_request", {})
558
+ if message_interceptor_blocks is not None:
559
+ _num_interceptors = min(1, message_interceptors)
560
+ self._message_interceptors = _interceptors = [
561
+ MessageInterceptor.remote(
562
+ _interceptor_blocks, # type:ignore
563
+ _black_list,
564
+ _llm_config,
565
+ _queue,
566
+ )
567
+ for _ in range(_num_interceptors)
568
+ ]
569
+ else:
570
+ _num_interceptors = 1
571
+ self._message_interceptors = _interceptors = [
572
+ None for _ in range(_num_interceptors)
573
+ ]
493
574
  creation_tasks = []
494
575
  for i, (
495
576
  agent_class,
@@ -504,13 +585,14 @@ class AgentSimulation:
504
585
  number_of_agents,
505
586
  memory_config_function_group,
506
587
  self.config,
507
- self.exp_id,
508
588
  self.exp_name,
589
+ self.exp_id,
509
590
  self.enable_avro,
510
591
  self.avro_path,
511
592
  self.enable_pgsql,
512
593
  _workers[i % _num_workers], # type:ignore
513
- mlflow_run_id, # type:ignore
594
+ self.message_interceptor,
595
+ mlflow_run_id,
514
596
  embedding_model,
515
597
  self.logging_level,
516
598
  config_file,
@@ -536,16 +618,24 @@ class AgentSimulation:
536
618
  self._type2group[agent_type].append(group)
537
619
 
538
620
  # 并行初始化所有组的agents
621
+ await self.resume_simulator()
539
622
  init_tasks = []
540
623
  for group in self._groups.values():
541
624
  init_tasks.append(group.init_agents.remote())
542
625
  ray.get(init_tasks)
626
+ await self.messager.connect.remote() # type:ignore
627
+ await self.messager.subscribe.remote( # type:ignore
628
+ [(f"exps/{self.exp_id}/user_payback", 1)], [self.exp_id]
629
+ )
630
+ await self.messager.start_listening.remote() # type:ignore
543
631
 
544
- async def gather(self, content: str):
632
+ async def gather(
633
+ self, content: str, target_agent_uuids: Optional[list[str]] = None
634
+ ):
545
635
  """收集智能体的特定信息"""
546
636
  gather_tasks = []
547
637
  for group in self._groups.values():
548
- gather_tasks.append(group.gather.remote(content))
638
+ gather_tasks.append(group.gather.remote(content, target_agent_uuids))
549
639
  return await asyncio.gather(*gather_tasks)
550
640
 
551
641
  async def filter(
@@ -585,13 +675,13 @@ class AgentSimulation:
585
675
  self, survey: Survey, agent_uuids: Optional[list[str]] = None
586
676
  ):
587
677
  """发送问卷"""
588
- await self.messager.connect()
678
+ await self.messager.connect.remote() # type:ignore
589
679
  survey_dict = survey.to_dict()
590
680
  if agent_uuids is None:
591
681
  agent_uuids = self._agent_uuids
592
682
  _date_time = datetime.now(timezone.utc)
593
683
  payload = {
594
- "from": "none",
684
+ "from": SURVEY_SENDER_UUID,
595
685
  "survey_id": survey_dict["id"],
596
686
  "timestamp": int(_date_time.timestamp() * 1000),
597
687
  "data": survey_dict,
@@ -599,16 +689,23 @@ class AgentSimulation:
599
689
  }
600
690
  for uuid in agent_uuids:
601
691
  topic = self._user_survey_topics[uuid]
602
- await self.messager.send_message(topic, payload)
692
+ await self.messager.send_message.remote(topic, payload) # type:ignore
693
+ remain_payback = len(agent_uuids)
694
+ while True:
695
+ messages = await self.messager.fetch_messages.remote() # type:ignore
696
+ logger.info(f"Received {len(messages)} payback messages [survey]")
697
+ remain_payback -= len(messages)
698
+ if remain_payback <= 0:
699
+ break
700
+ await asyncio.sleep(3)
603
701
 
604
702
  async def send_interview_message(
605
703
  self, content: str, agent_uuids: Union[str, list[str]]
606
704
  ):
607
705
  """发送采访消息"""
608
- await self.messager.connect()
609
706
  _date_time = datetime.now(timezone.utc)
610
707
  payload = {
611
- "from": "none",
708
+ "from": SURVEY_SENDER_UUID,
612
709
  "content": content,
613
710
  "timestamp": int(_date_time.timestamp() * 1000),
614
711
  "_date_time": _date_time,
@@ -617,24 +714,72 @@ class AgentSimulation:
617
714
  agent_uuids = [agent_uuids]
618
715
  for uuid in agent_uuids:
619
716
  topic = self._user_chat_topics[uuid]
620
- await self.messager.send_message(topic, payload)
717
+ await self.messager.send_message.remote(topic, payload) # type:ignore
718
+ remain_payback = len(agent_uuids)
719
+ while True:
720
+ messages = await self.messager.fetch_messages.remote() # type:ignore
721
+ logger.info(f"Received {len(messages)} payback messages [interview]")
722
+ remain_payback -= len(messages)
723
+ if remain_payback <= 0:
724
+ break
725
+ await asyncio.sleep(3)
726
+
727
+ async def extract_metric(self, metric_extractors: list[Callable]):
728
+ """提取指标"""
729
+ for metric_extractor in metric_extractors:
730
+ await metric_extractor(self)
621
731
 
622
732
  async def step(self):
623
- """运行一步, 即每个智能体执行一次forward"""
733
+ """Run one step, each agent execute one forward"""
624
734
  try:
735
+ # check whether insert agents
736
+ simulator_day = await self._simulator.get_simulator_day()
737
+ print(
738
+ f"simulator_day: {simulator_day}, self._simulator_day: {self._simulator_day}"
739
+ )
740
+ need_insert_agents = False
741
+ if simulator_day > self._simulator_day:
742
+ need_insert_agents = True
743
+ self._simulator_day = simulator_day
744
+ if need_insert_agents:
745
+ await self.resume_simulator()
746
+ insert_tasks = []
747
+ for group in self._groups.values():
748
+ insert_tasks.append(group.insert_agents.remote())
749
+ await asyncio.gather(*insert_tasks)
750
+
751
+ # step
625
752
  tasks = []
626
753
  for group in self._groups.values():
627
754
  tasks.append(group.step.remote())
628
755
  await asyncio.gather(*tasks)
756
+ # save
757
+ simulator_day = await self._simulator.get_simulator_day()
758
+ simulator_time = int(await self._simulator.get_time())
759
+ save_tasks = []
760
+ for group in self._groups.values():
761
+ save_tasks.append(group.save.remote(simulator_day, simulator_time))
762
+ await asyncio.gather(*save_tasks)
763
+ self._total_steps += 1
764
+ if self.metric_extractor is not None:
765
+ print(f"total_steps: {self._total_steps}, excute metric")
766
+ to_excute_metric = [
767
+ metric[1]
768
+ for metric in self.metric_extractor
769
+ if self._total_steps % metric[0] == 0
770
+ ]
771
+ await self.extract_metric(to_excute_metric)
629
772
  except Exception as e:
630
- logger.error(f"运行错误: {str(e)}")
631
- raise
773
+ import traceback
774
+
775
+ logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
776
+ raise RuntimeError(str(e)) from e
632
777
 
633
778
  async def run(
634
779
  self,
635
780
  day: int = 1,
636
781
  ):
637
- """运行模拟器"""
782
+ """Run the simulation by days"""
638
783
  try:
639
784
  self._exp_info["num_day"] += day
640
785
  await self._update_exp_status(1) # 更新状态为运行中
@@ -645,13 +790,14 @@ class AgentSimulation:
645
790
  monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
646
791
 
647
792
  try:
648
- for _ in range(day):
649
- tasks = []
650
- for group in self._groups.values():
651
- tasks.append(group.run.remote())
652
- # 等待所有group运行完成
653
- await asyncio.gather(*tasks)
654
-
793
+ end_time = (
794
+ await self._simulator.get_time() + day * 24 * 3600
795
+ ) # type:ignore
796
+ while True:
797
+ current_time = await self._simulator.get_time()
798
+ if current_time >= end_time: # type:ignore
799
+ break
800
+ await self.step()
655
801
  finally:
656
802
  # 设置停止事件
657
803
  stop_event.set()
@@ -71,11 +71,11 @@ class PgWriter:
71
71
 
72
72
  @lock_decorator
73
73
  async def async_write_status(self, rows: list[tuple]):
74
- _tuple_types = [str, int, float, float, float, int, str, str, None]
74
+ _tuple_types = [str, int, float, float, float, int, list, str, str, None]
75
75
  table_name = f"socialcity_{self.exp_id.replace('-', '_')}_agent_status"
76
76
  async with await psycopg.AsyncConnection.connect(self._dsn) as aconn:
77
77
  copy_sql = psycopg.sql.SQL(
78
- "COPY {} (id, day, t, lng, lat, parent_id, action, status, created_at) FROM STDIN"
78
+ "COPY {} (id, day, t, lng, lat, parent_id, friend_ids, action, status, created_at) FROM STDIN"
79
79
  ).format(psycopg.sql.Identifier(table_name))
80
80
  _rows: list[Any] = []
81
81
  async with aconn.cursor() as cur:
pycityagent/tools/tool.py CHANGED
@@ -119,7 +119,7 @@ class SencePOI(Tool):
119
119
  if agent.memory is None or agent.simulator is None:
120
120
  raise ValueError("Memory or Simulator is not set.")
121
121
  if radius is None and category_prefix is None:
122
- position = await agent.memory.get("position")
122
+ position = await agent.status.get("position")
123
123
  resp = []
124
124
  for prefix in self.category_prefix:
125
125
  resp += agent.simulator.map.query_pois(
@@ -146,17 +146,15 @@ class UpdateWithSimulator(Tool):
146
146
  agent = self.agent
147
147
  if agent._simulator is None:
148
148
  return
149
- if not agent._has_bound_to_simulator:
150
- await agent._bind_to_simulator() # type: ignore
151
149
  simulator = agent.simulator
152
- memory = agent.memory
153
- person_id = await memory.get("id")
150
+ status = agent.status
151
+ person_id = await status.get("id")
154
152
  resp = await simulator.get_person(person_id)
155
153
  resp_dict = resp["person"]
156
154
  for k, v in resp_dict.get("motion", {}).items():
157
155
  try:
158
- await memory.get(k)
159
- await memory.update(
156
+ await status.get(k)
157
+ await status.update(
160
158
  k, v, mode="replace", protect_llm_read_only_fields=False
161
159
  )
162
160
  except KeyError as e:
@@ -183,9 +181,9 @@ class ResetAgentPosition(Tool):
183
181
  s: Optional[float] = None,
184
182
  ):
185
183
  agent = self.agent
186
- memory = agent.memory
184
+ status = agent.status
187
185
  await agent.simulator.reset_person_position(
188
- person_id=await memory.get("id"),
186
+ person_id=await status.get("id"),
189
187
  aoi_id=aoi_id,
190
188
  poi_id=poi_id,
191
189
  lane_id=lane_id,
@@ -1,11 +1,16 @@
1
1
  from .avro_schema import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA,
2
2
  PROFILE_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA)
3
3
  from .pg_query import PGSQL_DICT, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
4
- from .survey_util import process_survey_for_llm
4
+ from .survey_util import SURVEY_SENDER_UUID, process_survey_for_llm
5
5
 
6
6
  __all__ = [
7
- "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
7
+ "PROFILE_SCHEMA",
8
+ "DIALOG_SCHEMA",
9
+ "STATUS_SCHEMA",
10
+ "SURVEY_SCHEMA",
11
+ "INSTITUTION_STATUS_SCHEMA",
8
12
  "process_survey_for_llm",
9
13
  "TO_UPDATE_EXP_INFO_KEYS_AND_TYPES",
10
14
  "PGSQL_DICT",
15
+ "SURVEY_SENDER_UUID",
11
16
  ]
@@ -54,6 +54,7 @@ PGSQL_DICT: dict[str, list[Any]] = {
54
54
  lng DOUBLE PRECISION,
55
55
  lat DOUBLE PRECISION,
56
56
  parent_id INT4,
57
+ friend_ids UUID[],
57
58
  action TEXT,
58
59
  status JSONB,
59
60
  created_at TIMESTAMPTZ
@@ -8,40 +8,40 @@ Survey Description: {survey_dict['description']}
8
8
  Please answer each question in the following format:
9
9
 
10
10
  """
11
-
11
+
12
12
  question_count = 1
13
- for page in survey_dict['pages']:
14
- for question in page['elements']:
13
+ for page in survey_dict["pages"]:
14
+ for question in page["elements"]:
15
15
  prompt += f"Question {question_count}: {question['title']}\n"
16
-
16
+
17
17
  # 根据不同类型的问题生成不同的提示
18
- if question['type'] == 'radiogroup':
19
- prompt += "Options: " + ", ".join(question['choices']) + "\n"
18
+ if question["type"] == "radiogroup":
19
+ prompt += "Options: " + ", ".join(question["choices"]) + "\n"
20
20
  prompt += "Please select ONE option\n"
21
-
22
- elif question['type'] == 'checkbox':
23
- prompt += "Options: " + ", ".join(question['choices']) + "\n"
21
+
22
+ elif question["type"] == "checkbox":
23
+ prompt += "Options: " + ", ".join(question["choices"]) + "\n"
24
24
  prompt += "You can select MULTIPLE options\n"
25
-
26
- elif question['type'] == 'rating':
25
+
26
+ elif question["type"] == "rating":
27
27
  prompt += f"Rating range: {question.get('min_rating', 1)} - {question.get('max_rating', 5)}\n"
28
28
  prompt += "Please provide a rating within the range\n"
29
-
30
- elif question['type'] == 'matrix':
31
- prompt += "Rows: " + ", ".join(question['rows']) + "\n"
32
- prompt += "Columns: " + ", ".join(question['columns']) + "\n"
29
+
30
+ elif question["type"] == "matrix":
31
+ prompt += "Rows: " + ", ".join(question["rows"]) + "\n"
32
+ prompt += "Columns: " + ", ".join(question["columns"]) + "\n"
33
33
  prompt += "Please select ONE column option for EACH row\n"
34
-
35
- elif question['type'] == 'text':
34
+
35
+ elif question["type"] == "text":
36
36
  prompt += "Please provide a text response\n"
37
-
38
- elif question['type'] == 'boolean':
37
+
38
+ elif question["type"] == "boolean":
39
39
  prompt += "Options: Yes, No\n"
40
40
  prompt += "Please select either Yes or No\n"
41
-
41
+
42
42
  prompt += "\nAnswer: [Your response here]\n\n---\n\n"
43
43
  question_count += 1
44
-
44
+
45
45
  # 添加总结提示
46
46
  prompt += """Please ensure:
47
47
  1. All required questions are answered
@@ -49,5 +49,8 @@ Please answer each question in the following format:
49
49
  3. Answers are clear and specific
50
50
 
51
51
  Format your responses exactly as requested above."""
52
-
53
- return prompt
52
+
53
+ return prompt
54
+
55
+
56
+ SURVEY_SENDER_UUID = "none"
@@ -146,6 +146,7 @@ def trigger_class():
146
146
  class Block:
147
147
  configurable_fields: list[str] = []
148
148
  default_values: dict[str, Any] = {}
149
+ fields_description: dict[str, str] = {}
149
150
 
150
151
  def __init__(
151
152
  self,
@@ -173,14 +174,20 @@ class Block:
173
174
 
174
175
  @classmethod
175
176
  def export_class_config(cls) -> dict[str, str]:
176
- return {
177
- field: cls.default_values.get(field, "default_value")
178
- for field in cls.configurable_fields
179
- }
177
+ return (
178
+ {
179
+ field: cls.default_values.get(field, "default_value")
180
+ for field in cls.configurable_fields
181
+ },
182
+ {
183
+ field: cls.fields_description.get(field, "")
184
+ for field in cls.configurable_fields
185
+ }
186
+ )
180
187
 
181
188
  @classmethod
182
189
  def import_config(cls, config: dict[str, Union[str, dict]]) -> Block:
183
- instance = cls(name=config["name"])
190
+ instance = cls(name=config["name"]) # type: ignore
184
191
  assert isinstance(config["config"], dict)
185
192
  for field, value in config["config"].items():
186
193
  if field in cls.configurable_fields:
@@ -188,12 +195,12 @@ class Block:
188
195
 
189
196
  # 递归创建子Block
190
197
  for child_config in config.get("children", []):
191
- child_block = Block.import_config(child_config)
198
+ child_block = Block.import_config(child_config) # type: ignore
192
199
  setattr(instance, child_block.name.lower(), child_block)
193
200
 
194
201
  return instance
195
202
 
196
- def load_from_config(self, config: dict[str, list[Dict]]) -> None:
203
+ def load_from_config(self, config: dict[str, list[dict]]) -> None:
197
204
  """
198
205
  使用配置更新当前Block实例的参数,并递归更新子Block。
199
206
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: pycityagent
3
- Version: 2.0.0a52
3
+ Version: 2.0.0a54
4
4
  Summary: LLM-based city environment agent building library
5
5
  Author-email: Yuwei Yan <pinkgranite86@gmail.com>, Junbo Yan <yanjb20thu@gmali.com>, Jun Zhang <zhangjun990222@gmali.com>
6
6
  License: MIT License
@@ -45,7 +45,7 @@ Requires-Dist: openai>=1.58.1
45
45
  Requires-Dist: Pillow<12.0.0,>=11.0.0
46
46
  Requires-Dist: protobuf<5.0.0,<=4.24.0
47
47
  Requires-Dist: pycitydata>=1.0.3
48
- Requires-Dist: pycityproto>=2.1.5
48
+ Requires-Dist: pycityproto>=2.2.0
49
49
  Requires-Dist: requests>=2.32.3
50
50
  Requires-Dist: Shapely>=2.0.6
51
51
  Requires-Dist: PyYAML>=6.0.2