pycityagent 2.0.0a66__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a67__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. pycityagent/agent/agent.py +157 -57
  2. pycityagent/agent/agent_base.py +316 -43
  3. pycityagent/cityagent/bankagent.py +49 -9
  4. pycityagent/cityagent/blocks/__init__.py +1 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +54 -31
  6. pycityagent/cityagent/blocks/dispatcher.py +22 -17
  7. pycityagent/cityagent/blocks/economy_block.py +46 -32
  8. pycityagent/cityagent/blocks/mobility_block.py +130 -100
  9. pycityagent/cityagent/blocks/needs_block.py +101 -44
  10. pycityagent/cityagent/blocks/other_block.py +42 -33
  11. pycityagent/cityagent/blocks/plan_block.py +59 -42
  12. pycityagent/cityagent/blocks/social_block.py +167 -116
  13. pycityagent/cityagent/blocks/utils.py +13 -6
  14. pycityagent/cityagent/firmagent.py +17 -35
  15. pycityagent/cityagent/governmentagent.py +3 -3
  16. pycityagent/cityagent/initial.py +79 -44
  17. pycityagent/cityagent/memory_config.py +108 -88
  18. pycityagent/cityagent/message_intercept.py +0 -4
  19. pycityagent/cityagent/metrics.py +41 -0
  20. pycityagent/cityagent/nbsagent.py +24 -36
  21. pycityagent/cityagent/societyagent.py +7 -3
  22. pycityagent/cli/wrapper.py +2 -2
  23. pycityagent/economy/econ_client.py +407 -81
  24. pycityagent/environment/__init__.py +0 -3
  25. pycityagent/environment/sim/__init__.py +0 -3
  26. pycityagent/environment/sim/aoi_service.py +2 -2
  27. pycityagent/environment/sim/client.py +3 -31
  28. pycityagent/environment/sim/clock_service.py +2 -2
  29. pycityagent/environment/sim/lane_service.py +8 -8
  30. pycityagent/environment/sim/light_service.py +8 -8
  31. pycityagent/environment/sim/pause_service.py +9 -10
  32. pycityagent/environment/sim/person_service.py +20 -20
  33. pycityagent/environment/sim/road_service.py +2 -2
  34. pycityagent/environment/sim/sim_env.py +21 -5
  35. pycityagent/environment/sim/social_service.py +4 -4
  36. pycityagent/environment/simulator.py +249 -27
  37. pycityagent/environment/utils/__init__.py +2 -2
  38. pycityagent/environment/utils/geojson.py +2 -2
  39. pycityagent/environment/utils/grpc.py +4 -4
  40. pycityagent/environment/utils/map_utils.py +2 -2
  41. pycityagent/llm/embeddings.py +147 -28
  42. pycityagent/llm/llm.py +122 -77
  43. pycityagent/llm/llmconfig.py +5 -0
  44. pycityagent/llm/utils.py +4 -0
  45. pycityagent/memory/__init__.py +0 -4
  46. pycityagent/memory/const.py +2 -2
  47. pycityagent/memory/faiss_query.py +140 -61
  48. pycityagent/memory/memory.py +393 -90
  49. pycityagent/memory/memory_base.py +140 -34
  50. pycityagent/memory/profile.py +13 -13
  51. pycityagent/memory/self_define.py +13 -13
  52. pycityagent/memory/state.py +14 -14
  53. pycityagent/message/message_interceptor.py +253 -3
  54. pycityagent/message/messager.py +133 -6
  55. pycityagent/metrics/mlflow_client.py +47 -4
  56. pycityagent/pycityagent-sim +0 -0
  57. pycityagent/pycityagent-ui +0 -0
  58. pycityagent/simulation/__init__.py +3 -2
  59. pycityagent/simulation/agentgroup.py +145 -52
  60. pycityagent/simulation/simulation.py +257 -62
  61. pycityagent/survey/manager.py +45 -3
  62. pycityagent/survey/models.py +42 -2
  63. pycityagent/tools/__init__.py +1 -2
  64. pycityagent/tools/tool.py +93 -69
  65. pycityagent/utils/avro_schema.py +2 -2
  66. pycityagent/utils/parsers/code_block_parser.py +1 -1
  67. pycityagent/utils/parsers/json_parser.py +2 -2
  68. pycityagent/utils/parsers/parser_base.py +2 -2
  69. pycityagent/workflow/block.py +64 -13
  70. pycityagent/workflow/prompt.py +31 -23
  71. pycityagent/workflow/trigger.py +91 -24
  72. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
  73. pycityagent-2.0.0a67.dist-info/RECORD +97 -0
  74. pycityagent/environment/interact/__init__.py +0 -0
  75. pycityagent/environment/interact/interact.py +0 -198
  76. pycityagent/environment/message/__init__.py +0 -0
  77. pycityagent/environment/sence/__init__.py +0 -0
  78. pycityagent/environment/sence/static.py +0 -416
  79. pycityagent/environment/sidecar/__init__.py +0 -8
  80. pycityagent/environment/sidecar/sidecarv2.py +0 -109
  81. pycityagent/environment/sim/economy_services.py +0 -192
  82. pycityagent/metrics/utils/const.py +0 -0
  83. pycityagent-2.0.0a66.dist-info/RECORD +0 -105
  84. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
  85. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
  86. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
  87. {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
@@ -13,20 +13,30 @@ import yaml
13
13
  from langchain_core.embeddings import Embeddings
14
14
 
15
15
  from ..agent import Agent, InstitutionAgent
16
- from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
17
- SocietyAgent)
18
- from ..cityagent.memory_config import (memory_config_bank, memory_config_firm,
19
- memory_config_government, memory_config_nbs,
20
- memory_config_societyagent, memory_config_init)
16
+ from ..cityagent import BankAgent, FirmAgent, GovernmentAgent, NBSAgent, SocietyAgent
17
+ from ..cityagent.memory_config import (
18
+ memory_config_bank,
19
+ memory_config_firm,
20
+ memory_config_government,
21
+ memory_config_nbs,
22
+ memory_config_societyagent,
23
+ memory_config_init,
24
+ )
21
25
  from ..cityagent.initial import bind_agent_info, initialize_social_network
22
- from ..cityagent.message_intercept import (EdgeMessageBlock,
23
- MessageBlockListener,
24
- PointMessageBlock)
26
+ from ..cityagent.message_intercept import (
27
+ EdgeMessageBlock,
28
+ MessageBlockListener,
29
+ PointMessageBlock,
30
+ )
25
31
  from ..economy.econ_client import EconomyClient
26
32
  from ..environment import Simulator
27
33
  from ..llm import SimpleEmbedding
28
- from ..message import (MessageBlockBase, MessageBlockListenerBase,
29
- MessageInterceptor, Messager)
34
+ from ..message import (
35
+ MessageBlockBase,
36
+ MessageBlockListenerBase,
37
+ MessageInterceptor,
38
+ Messager,
39
+ )
30
40
  from ..metrics import init_mlflow_connection
31
41
  from ..metrics.mlflow_client import MlflowClient
32
42
  from ..survey import Survey
@@ -36,8 +46,23 @@ from .storage.pg import PgWriter, create_pg_tables
36
46
 
37
47
  logger = logging.getLogger("pycityagent")
38
48
 
49
+
50
+ __all__ = ["AgentSimulation"]
51
+
39
52
  class AgentSimulation:
40
- """Agent Simulation"""
53
+ """
54
+ A class to simulate a multi-agent system.
55
+
56
+ This simulation framework is designed to facilitate the creation and management of multiple agent types within an experiment.
57
+ It allows for the configuration of different agents, memory configurations, and metric extractors, as well as enabling institutional settings.
58
+
59
+ Attributes:
60
+ exp_id (str): A unique identifier for the current experiment.
61
+ agent_class (List[Type[Agent]]): A list of agent classes that will be instantiated in the simulation.
62
+ agent_config_file (Optional[dict]): Configuration file or dictionary for initializing agents.
63
+ logging_level (int): The level of logging to be used throughout the simulation.
64
+ default_memory_config_func (Dict[Type[Agent], Callable]): Dictionary mapping agent classes to their respective memory configuration functions.
65
+ """
41
66
 
42
67
  def __init__(
43
68
  self,
@@ -51,15 +76,26 @@ class AgentSimulation:
51
76
  logging_level: int = logging.WARNING,
52
77
  ):
53
78
  """
54
- Args:
55
- config: Configuration
56
- agent_class: Agent class
57
- agent_config_file: Agent configuration file
58
- metric_extractors: Metric extractor
59
- enable_institution: Whether to enable institution
60
- agent_prefix: Agent name prefix
61
- exp_name: Experiment name
62
- logging_level: Logging level
79
+ Initializes the AgentSimulation with the given parameters.
80
+
81
+ - **Description**:
82
+ - Sets up the simulation environment based on the provided configuration. Depending on the `enable_institution` flag,
83
+ it can include a predefined set of institutional agents. If specific agent classes are provided, those will be used instead.
84
+
85
+ - **Args**:
86
+ - `config` (dict): The main configuration dictionary for the simulation.
87
+ - `agent_class` (Union[None, Type[Agent], List[Type[Agent]]], optional):
88
+ Either a single agent class or a list of agent classes to instantiate. Defaults to None, which implies a default set of agents.
89
+ - `agent_config_file` (Optional[dict], optional): An optional configuration file or dictionary used to initialize agents. Defaults to None.
90
+ - `metric_extractors` (Optional[List[Tuple[int, Callable]]], optional):
91
+ A list of tuples containing intervals and callables for extracting metrics from the simulation. Defaults to None.
92
+ - `enable_institution` (bool, optional): Flag indicating whether institutional agents should be included in the simulation. Defaults to True.
93
+ - `agent_prefix` (str, optional): Prefix string for naming agents. Defaults to "agent_".
94
+ - `exp_name` (str, optional): The name of the experiment. Defaults to "default_experiment".
95
+ - `logging_level` (int, optional): Logging level to set for the simulation's logger. Defaults to logging.WARNING.
96
+
97
+ - **Returns**:
98
+ - None
63
99
  """
64
100
  self.exp_id = str(uuid.uuid4())
65
101
  if isinstance(agent_class, list):
@@ -170,8 +206,7 @@ class AgentSimulation:
170
206
  experiment_name=exp_name,
171
207
  run_id=mlflow_run_id,
172
208
  )
173
- if metric_extractors is not None:
174
- self.metric_extractors = metric_extractors
209
+ self.metric_extractors = metric_extractors
175
210
  else:
176
211
  logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
177
212
  self.mlflow_client = None
@@ -270,13 +305,14 @@ class AgentSimulation:
270
305
  logging_level=config.get("logging_level", logging.WARNING),
271
306
  )
272
307
  environment = config.get(
273
- "environment",
308
+ "environment",
274
309
  {
275
- "weather": "The weather is normal",
276
- "crime": "The crime rate is low",
277
- "pollution": "The pollution level is low",
278
- "temperature": "The temperature is normal"
279
- }
310
+ "weather": "The weather is normal",
311
+ "crime": "The crime rate is low",
312
+ "pollution": "The pollution level is low",
313
+ "temperature": "The temperature is normal",
314
+ "day": "Workday"
315
+ },
280
316
  )
281
317
  simulation._simulator.set_environment(environment)
282
318
  logger.info("Initializing Agents...")
@@ -289,7 +325,7 @@ class AgentSimulation:
289
325
  }
290
326
  if agent_count.get(SocietyAgent, 0) == 0:
291
327
  raise ValueError("number_of_citizen is required")
292
-
328
+
293
329
  # support MessageInterceptor
294
330
  if "message_intercept" in config:
295
331
  _intercept_config = config["message_intercept"]
@@ -328,8 +364,10 @@ class AgentSimulation:
328
364
  embedding_model=config["agent_config"].get(
329
365
  "embedding_model", SimpleEmbedding()
330
366
  ),
331
- memory_config_func=config["agent_config"].get("memory_config_func", None),
332
- memory_config_init_func=config["agent_config"].get("memory_config_init_func", None),
367
+ memory_config_func=config["agent_config"].get("memory_config_func", None),
368
+ memory_config_init_func=config["agent_config"].get(
369
+ "memory_config_init_func", None
370
+ ),
333
371
  **_message_intercept_kwargs,
334
372
  environment=environment,
335
373
  )
@@ -339,6 +377,9 @@ class AgentSimulation:
339
377
  ):
340
378
  await init_func(simulation)
341
379
  logger.info("Starting Simulation...")
380
+ llm_log_lists = []
381
+ mqtt_log_lists = []
382
+ simulator_log_lists = []
342
383
  for step in config["workflow"]:
343
384
  logger.info(
344
385
  f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
@@ -346,11 +387,17 @@ class AgentSimulation:
346
387
  if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
347
388
  raise ValueError(f"Invalid step type: {step['type']}")
348
389
  if step["type"] == "run":
349
- await simulation.run(step.get("days", 1))
390
+ llm_log_list, mqtt_log_list, simulator_log_list = await simulation.run(step.get("days", 1))
391
+ llm_log_lists.extend(llm_log_list)
392
+ mqtt_log_lists.extend(mqtt_log_list)
393
+ simulator_log_lists.extend(simulator_log_list)
350
394
  elif step["type"] == "step":
351
395
  times = step.get("times", 1)
352
396
  for _ in range(times):
353
- await simulation.step()
397
+ llm_log_list, mqtt_log_list, simulator_log_list = await simulation.step()
398
+ llm_log_lists.extend(llm_log_list)
399
+ mqtt_log_lists.extend(mqtt_log_list)
400
+ simulator_log_lists.extend(simulator_log_list)
354
401
  elif step["type"] == "pause":
355
402
  await simulation.pause_simulator()
356
403
  elif step["type"] == "resume":
@@ -358,7 +405,7 @@ class AgentSimulation:
358
405
  else:
359
406
  await step["func"](simulation)
360
407
  logger.info("Simulation finished")
361
-
408
+ return llm_log_lists, mqtt_log_lists, simulator_log_lists
362
409
  @property
363
410
  def enable_avro(
364
411
  self,
@@ -434,7 +481,7 @@ class AgentSimulation:
434
481
  async def _monitor_exp_status(self, stop_event: asyncio.Event):
435
482
  """监控实验状态并更新
436
483
 
437
- Args:
484
+ - **Args**:
438
485
  stop_event: 用于通知监控任务停止的事件
439
486
  """
440
487
  try:
@@ -488,15 +535,32 @@ class AgentSimulation:
488
535
  memory_config_func: Optional[dict[type[Agent], Callable]] = None,
489
536
  environment: Optional[dict[str, str]] = None,
490
537
  ) -> None:
491
- """初始化智能体
492
-
493
- Args:
494
- agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
495
- group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
496
- pg_sql_writers: 独立的PgSQL writer数量
497
- message_interceptors: message拦截器数量
498
- memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 每个元素表示一个智能体类创建的Memory配置函数
499
- environment: 环境变量,用于更新模拟器的环境变量
538
+ """
539
+ Initialize agents within the simulation.
540
+
541
+ - **Description**:
542
+ - Asynchronously initializes a specified number of agents for each provided agent class.
543
+ - Agents are grouped into independent Ray actors based on the `group_size`, with configurations for database writers,
544
+ message interceptors, and memory settings. Optionally updates the simulator's environment variables.
545
+
546
+ - **Args**:
547
+ - `agent_count` (dict[Type[Agent], int]): Dictionary mapping agent classes to the number of instances to create.
548
+ - `group_size` (int, optional): Number of agents per group, each group runs as an independent Ray actor. Defaults to 10000.
549
+ - `pg_sql_writers` (int, optional): Number of independent PgSQL writer processes. Defaults to 32.
550
+ - `message_interceptors` (int, optional): Number of message interceptor processes. Defaults to 1.
551
+ - `message_interceptor_blocks` (Optional[List[MessageBlockBase]], optional): List of message interception blocks. Defaults to None.
552
+ - `social_black_list` (Optional[List[Tuple[str, str]]], optional): List of tuples representing pairs of agents that should not communicate. Defaults to None.
553
+ - `message_listener` (Optional[MessageBlockListenerBase], optional): Listener for intercepted messages. Defaults to None.
554
+ - `embedding_model` (Embeddings, optional): Model used for generating embeddings for agents' memories. Defaults to SimpleEmbedding().
555
+ - `memory_config_init_func` (Optional[Callable], optional): Initialization function for setting up memory configuration. Defaults to None.
556
+ - `memory_config_func` (Optional[Dict[Type[Agent], Callable]], optional): Dictionary mapping agent classes to their memory configuration functions. Defaults to None.
557
+ - `environment` (Optional[Dict[str, str]], optional): Environment variables to update in the simulation. Defaults to None.
558
+
559
+ - **Raises**:
560
+ - `ValueError`: If the lengths of `agent_class` and `agent_count` do not match.
561
+
562
+ - **Returns**:
563
+ - `None`
500
564
  """
501
565
  self.agent_count = agent_count
502
566
 
@@ -516,8 +580,6 @@ class AgentSimulation:
516
580
  citizen_params = []
517
581
 
518
582
  # 收集所有参数
519
- print(self.agent_class)
520
- print(agent_count)
521
583
  for i in range(len(self.agent_class)):
522
584
  agent_class = self.agent_class[i]
523
585
  agent_count_i = agent_count[agent_class]
@@ -721,10 +783,32 @@ class AgentSimulation:
721
783
  )
722
784
  await self.messager.start_listening.remote() # type:ignore
723
785
 
786
+ agent_ids = set()
787
+ org_ids = set()
788
+ for group in self._groups.values():
789
+ ids = await group.get_economy_ids.remote()
790
+ agent_ids.update(ids[0])
791
+ org_ids.update(ids[1])
792
+ await self.economy_client.set_ids(agent_ids, org_ids)
793
+ for group in self._groups.values():
794
+ await group.set_economy_ids.remote(agent_ids, org_ids)
795
+
724
796
  async def gather(
725
797
  self, content: str, target_agent_uuids: Optional[list[str]] = None
726
798
  ):
727
- """收集智能体的特定信息"""
799
+ """
800
+ Collect specific information from agents.
801
+
802
+ - **Description**:
803
+ - Asynchronously gathers specified content from targeted agents within all groups.
804
+
805
+ - **Args**:
806
+ - `content` (str): The information to collect from the agents.
807
+ - `target_agent_uuids` (Optional[List[str]], optional): A list of agent UUIDs to target. Defaults to None, meaning all agents are targeted.
808
+
809
+ - **Returns**:
810
+ - Result of the gathering process as returned by each group's `gather` method.
811
+ """
728
812
  gather_tasks = []
729
813
  for group in self._groups.values():
730
814
  gather_tasks.append(group.gather.remote(content, target_agent_uuids))
@@ -736,7 +820,20 @@ class AgentSimulation:
736
820
  keys: Optional[list[str]] = None,
737
821
  values: Optional[list[Any]] = None,
738
822
  ) -> list[str]:
739
- """过滤出指定类型的智能体"""
823
+ """
824
+ Filter out agents of specified types or with matching key-value pairs.
825
+
826
+ - **Args**:
827
+ - `types` (Optional[List[Type[Agent]]], optional): Types of agents to filter for. Defaults to None.
828
+ - `keys` (Optional[List[str]], optional): Keys to match in agent attributes. Defaults to None.
829
+ - `values` (Optional[List[Any]], optional): Values corresponding to keys for matching. Defaults to None.
830
+
831
+ - **Raises**:
832
+ - `ValueError`: If neither types nor keys and values are provided, or if the lengths of keys and values do not match.
833
+
834
+ - **Returns**:
835
+ - `List[str]`: A list of filtered agent UUIDs.
836
+ """
740
837
  if not types and not keys and not values:
741
838
  return self._agent_uuids
742
839
  group_to_filter = []
@@ -759,12 +856,26 @@ class AgentSimulation:
759
856
  return filtered_uuids
760
857
 
761
858
  async def update_environment(self, key: str, value: str):
859
+ """
860
+ Update the environment variables for the simulation and all agent groups.
861
+
862
+ - **Args**:
863
+ - `key` (str): The environment variable key to update.
864
+ - `value` (str): The new value for the environment variable.
865
+ """
762
866
  self._simulator.update_environment(key, value)
763
867
  for group in self._groups.values():
764
868
  await group.update_environment.remote(key, value)
765
869
 
766
870
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
767
- """更新指定智能体的记忆"""
871
+ """
872
+ Update the memory of a specified agent.
873
+
874
+ - **Args**:
875
+ - `target_agent_uuid` (str): The UUID of the target agent to update.
876
+ - `target_key` (str): The key in the agent's memory to update.
877
+ - `content` (Any): The new content to set for the target key.
878
+ """
768
879
  group = self._agent_uuid2group[target_agent_uuid]
769
880
  await group.update.remote(target_agent_uuid, target_key, content)
770
881
 
@@ -775,13 +886,30 @@ class AgentSimulation:
775
886
  content: Any,
776
887
  mode: Literal["replace", "merge"] = "replace",
777
888
  ):
778
- """更新指定智能体的经济数据"""
889
+ """
890
+ Update economic data for a specified agent.
891
+
892
+ - **Args**:
893
+ - `target_agent_id` (int): The ID of the target agent whose economic data to update.
894
+ - `target_key` (str): The key in the agent's economic data to update.
895
+ - `content` (Any): The new content to set for the target key.
896
+ - `mode` (Literal["replace", "merge"], optional): Mode of updating the economic data. Defaults to "replace".
897
+ """
779
898
  await self.economy_client.update(
780
899
  id=target_agent_id, key=target_key, value=content, mode=mode
781
900
  )
782
901
 
783
902
  async def send_survey(self, survey: Survey, agent_uuids: list[str] = []):
784
- """发送问卷"""
903
+ """
904
+ Send a survey to specified agents.
905
+
906
+ - **Args**:
907
+ - `survey` (Survey): The survey object to send.
908
+ - `agent_uuids` (List[str], optional): List of agent UUIDs to receive the survey. Defaults to an empty list.
909
+
910
+ - **Returns**:
911
+ - None
912
+ """
785
913
  survey_dict = survey.to_dict()
786
914
  _date_time = datetime.now(timezone.utc)
787
915
  payload = {
@@ -806,7 +934,16 @@ class AgentSimulation:
806
934
  async def send_interview_message(
807
935
  self, content: str, agent_uuids: Union[str, list[str]]
808
936
  ):
809
- """发送采访消息"""
937
+ """
938
+ Send an interview message to specified agents.
939
+
940
+ - **Args**:
941
+ - `content` (str): The content of the message to send.
942
+ - `agent_uuids` (Union[str, List[str]]): A single UUID string or a list of UUID strings for the agents to receive the message.
943
+
944
+ - **Returns**:
945
+ - None
946
+ """
810
947
  _date_time = datetime.now(timezone.utc)
811
948
  payload = {
812
949
  "from": SURVEY_SENDER_UUID,
@@ -829,12 +966,37 @@ class AgentSimulation:
829
966
  await asyncio.sleep(3)
830
967
 
831
968
  async def extract_metric(self, metric_extractors: list[Callable]):
832
- """提取指标"""
969
+ """
970
+ Extract metrics using provided extractors.
971
+
972
+ - **Description**:
973
+ - Asynchronously applies each metric extractor function to the simulation to collect various metrics.
974
+
975
+ - **Args**:
976
+ - `metric_extractors` (List[Callable]): A list of callable functions that take the simulation instance as an argument and return a metric or perform some form of analysis.
977
+
978
+ - **Returns**:
979
+ - None
980
+ """
833
981
  for metric_extractor in metric_extractors:
834
982
  await metric_extractor(self)
835
983
 
836
984
  async def step(self):
837
- """Run one step, each agent execute one forward"""
985
+ """
986
+ Execute one step of the simulation where each agent performs its forward action.
987
+
988
+ - **Description**:
989
+ - Checks if new agents need to be inserted based on the current day of the simulation. If so, it inserts them.
990
+ - Executes the forward method for each agent group to advance the simulation by one step.
991
+ - Saves the state of all agent groups after the step has been completed.
992
+ - Optionally extracts metrics if the current step matches the interval specified for any metric extractors.
993
+
994
+ - **Raises**:
995
+ - `RuntimeError`: If there is an error during the execution of the step, it logs the error and rethrows it as a RuntimeError.
996
+
997
+ - **Returns**:
998
+ - None
999
+ """
838
1000
  try:
839
1001
  # check whether insert agents
840
1002
  simulator_day = await self._simulator.get_simulator_day()
@@ -852,11 +1014,21 @@ class AgentSimulation:
852
1014
  # step
853
1015
  simulator_day = await self._simulator.get_simulator_day()
854
1016
  simulator_time = int(await self._simulator.get_time())
855
- logger.info(f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}")
1017
+ logger.info(
1018
+ f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}"
1019
+ )
856
1020
  tasks = []
857
1021
  for group in self._groups.values():
858
1022
  tasks.append(group.step.remote())
859
- await asyncio.gather(*tasks)
1023
+ log_messages_groups = await asyncio.gather(*tasks)
1024
+ llm_log_list = []
1025
+ mqtt_log_list = []
1026
+ simulator_log_list = []
1027
+ for log_messages_group in log_messages_groups:
1028
+ llm_log_list.extend(log_messages_group['llm_log'])
1029
+ mqtt_log_list.extend(log_messages_group['mqtt_log'])
1030
+ simulator_log_list.extend(log_messages_group['simulator_log'])
1031
+
860
1032
  # save
861
1033
  simulator_day = await self._simulator.get_simulator_day()
862
1034
  simulator_time = int(await self._simulator.get_time())
@@ -865,14 +1037,15 @@ class AgentSimulation:
865
1037
  save_tasks.append(group.save.remote(simulator_day, simulator_time))
866
1038
  await asyncio.gather(*save_tasks)
867
1039
  self._total_steps += 1
868
- if self.metric_extractor is not None:
869
- print(f"total_steps: {self._total_steps}, excute metric")
1040
+ if self.metric_extractors is not None: # type:ignore
870
1041
  to_excute_metric = [
871
1042
  metric[1]
872
- for metric in self.metric_extractor
1043
+ for metric in self.metric_extractors # type:ignore
873
1044
  if self._total_steps % metric[0] == 0
874
1045
  ]
875
1046
  await self.extract_metric(to_excute_metric)
1047
+
1048
+ return llm_log_list, mqtt_log_list, simulator_log_list
876
1049
  except Exception as e:
877
1050
  import traceback
878
1051
 
@@ -883,7 +1056,26 @@ class AgentSimulation:
883
1056
  self,
884
1057
  day: int = 1,
885
1058
  ):
886
- """Run the simulation by days"""
1059
+ """
1060
+ Run the simulation for a specified number of days.
1061
+
1062
+ - **Args**:
1063
+ - `day` (int, optional): The number of days to run the simulation. Defaults to 1.
1064
+
1065
+ - **Description**:
1066
+ - Updates the experiment status to running and sets up monitoring for the experiment's status.
1067
+ - Runs the simulation loop until the end time, which is calculated based on the current time and the number of days to simulate.
1068
+ - After completing the simulation, updates the experiment status to finished, or to failed if an exception occurs.
1069
+
1070
+ - **Raises**:
1071
+ - `RuntimeError`: If there is an error during the simulation, it logs the error and updates the experiment status to failed before rethrowing the exception.
1072
+
1073
+ - **Returns**:
1074
+ - None
1075
+ """
1076
+ llm_log_lists = []
1077
+ mqtt_log_lists = []
1078
+ simulator_log_lists = []
887
1079
  try:
888
1080
  self._exp_info["num_day"] += day
889
1081
  await self._update_exp_status(1) # 更新状态为运行中
@@ -901,7 +1093,10 @@ class AgentSimulation:
901
1093
  current_time = await self._simulator.get_time()
902
1094
  if current_time >= end_time: # type:ignore
903
1095
  break
904
- await self.step()
1096
+ llm_log_list, mqtt_log_list, simulator_log_list = await self.step()
1097
+ llm_log_lists.extend(llm_log_list)
1098
+ mqtt_log_lists.extend(mqtt_log_list)
1099
+ simulator_log_lists.extend(simulator_log_list)
905
1100
  finally:
906
1101
  # 设置停止事件
907
1102
  stop_event.set()
@@ -910,7 +1105,7 @@ class AgentSimulation:
910
1105
 
911
1106
  # 运行成功后更新状态
912
1107
  await self._update_exp_status(2)
913
-
1108
+ return llm_log_lists, mqtt_log_lists, simulator_log_lists
914
1109
  except Exception as e:
915
1110
  error_msg = f"模拟器运行错误: {str(e)}"
916
1111
  logger.error(error_msg)
@@ -8,10 +8,33 @@ from .models import Page, Question, QuestionType, Survey
8
8
 
9
9
  class SurveyManager:
10
10
  def __init__(self):
11
+ """
12
+ Initializes a new instance of the SurveyManager class.
13
+
14
+ - **Description**:
15
+ - Manages the creation and retrieval of surveys. Uses an internal dictionary to store survey instances by their unique identifier.
16
+
17
+ - **Attributes**:
18
+ - `_surveys` (dict[str, Survey]): A dictionary mapping survey IDs to `Survey` instances.
19
+ """
11
20
  self._surveys: dict[str, Survey] = {}
12
21
 
13
22
  def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
14
- """创建新问卷"""
23
+ """
24
+ Create a new survey with specified title, description, and pages containing questions.
25
+
26
+ - **Description**:
27
+ - Generates a unique ID for the survey, converts the provided page and question data into `Page` and `Question` objects,
28
+ and adds the created survey to the internal storage.
29
+
30
+ - **Args**:
31
+ - `title` (str): The title of the survey.
32
+ - `description` (str): A brief description of what the survey is about.
33
+ - `pages` (list[dict]): A list of dictionaries where each dictionary contains page information including elements which are question definitions.
34
+
35
+ - **Returns**:
36
+ - `Survey`: An instance of the `Survey` class representing the newly created survey.
37
+ """
15
38
  survey_id = uuid.uuid4()
16
39
 
17
40
  # 转换页面和问题数据
@@ -46,9 +69,28 @@ class SurveyManager:
46
69
  return survey
47
70
 
48
71
  def get_survey(self, survey_id: str) -> Optional[Survey]:
49
- """获取指定问卷"""
72
+ """
73
+ Retrieve a specific survey by its unique identifier.
74
+
75
+ - **Description**:
76
+ - Searches for a survey within the internal storage using the provided survey ID.
77
+
78
+ - **Args**:
79
+ - `survey_id` (str): The unique identifier of the survey to retrieve.
80
+
81
+ - **Returns**:
82
+ - `Optional[Survey]`: An instance of the `Survey` class if found; otherwise, None.
83
+ """
50
84
  return self._surveys.get(survey_id)
51
85
 
52
86
  def get_all_surveys(self) -> list[Survey]:
53
- """获取所有问卷"""
87
+ """
88
+ Get all surveys that have been created.
89
+
90
+ - **Description**:
91
+ - Retrieves all stored surveys from the internal storage.
92
+
93
+ - **Returns**:
94
+ - `list[Survey]`: A list of `Survey` instances.
95
+ """
54
96
  return list(self._surveys.values())
@@ -57,6 +57,18 @@ class Page:
57
57
 
58
58
  @dataclass
59
59
  class Survey:
60
+ """
61
+ Represents a survey with metadata and associated pages containing questions.
62
+
63
+ - **Attributes**:
64
+ - `id` (uuid.UUID): Unique identifier for the survey.
65
+ - `title` (str): Title of the survey.
66
+ - `description` (str): Description of the survey's purpose or content.
67
+ - `pages` (list[Page]): A list of `Page` objects, each containing a set of questions.
68
+ - `responses` (dict[str, dict], optional): Dictionary mapping response IDs to their data. Defaults to an empty dictionary.
69
+ - `created_at` (datetime, optional): Timestamp of when the survey was created. Defaults to the current time.
70
+ """
71
+
60
72
  id: uuid.UUID
61
73
  title: str
62
74
  description: str
@@ -65,6 +77,15 @@ class Survey:
65
77
  created_at: datetime = field(default_factory=datetime.now)
66
78
 
67
79
  def to_dict(self) -> dict:
80
+ """
81
+ Convert the survey instance into a dictionary representation.
82
+
83
+ - **Description**:
84
+ - Creates a dictionary containing the survey's ID, title, description, and pages in a simplified format suitable for serialization.
85
+
86
+ - **Returns**:
87
+ - `dict`: Simplified dictionary representation of the survey.
88
+ """
68
89
  return {
69
90
  "id": str(self.id),
70
91
  "title": self.title,
@@ -74,7 +95,15 @@ class Survey:
74
95
  }
75
96
 
76
97
  def to_json(self) -> str:
77
- """Convert the survey to a JSON string for MQTT transmission"""
98
+ """
99
+ Serialize the survey instance to a JSON string for MQTT transmission.
100
+
101
+ - **Description**:
102
+ - Converts the survey into a JSON string that includes all necessary information for reconstructing the survey object on another system.
103
+
104
+ - **Returns**:
105
+ - `str`: JSON string representing the survey.
106
+ """
78
107
  survey_dict = {
79
108
  "id": str(self.id),
80
109
  "title": self.title,
@@ -87,7 +116,18 @@ class Survey:
87
116
 
88
117
  @classmethod
89
118
  def from_json(cls, json_str: str) -> "Survey":
90
- """Create a Survey instance from a JSON string"""
119
+ """
120
+ Deserialize a JSON string into a new Survey instance.
121
+
122
+ - **Description**:
123
+ - Parses a JSON string into a Python dictionary and uses it to create a new `Survey` instance.
124
+
125
+ - **Args**:
126
+ - `json_str` (str): JSON string representation of a survey.
127
+
128
+ - **Returns**:
129
+ - `Survey`: An instance of `Survey` initialized with the data from the JSON string.
130
+ """
91
131
  data = json.loads(json_str)
92
132
  pages = [
93
133
  Page(
@@ -1,8 +1,7 @@
1
- from .tool import (ExportMlflowMetrics, GetMap, ResetAgentPosition, SencePOI,
1
+ from .tool import (ExportMlflowMetrics, GetMap, ResetAgentPosition,
2
2
  Tool, UpdateWithSimulator)
3
3
 
4
4
  __all__ = [
5
- "SencePOI",
6
5
  "Tool",
7
6
  "ExportMlflowMetrics",
8
7
  "GetMap",