pycityagent 2.0.0a65__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a67__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. pycityagent/agent/agent.py +157 -57
  2. pycityagent/agent/agent_base.py +316 -43
  3. pycityagent/cityagent/bankagent.py +49 -9
  4. pycityagent/cityagent/blocks/__init__.py +1 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +54 -31
  6. pycityagent/cityagent/blocks/dispatcher.py +22 -17
  7. pycityagent/cityagent/blocks/economy_block.py +46 -32
  8. pycityagent/cityagent/blocks/mobility_block.py +209 -105
  9. pycityagent/cityagent/blocks/needs_block.py +101 -54
  10. pycityagent/cityagent/blocks/other_block.py +42 -33
  11. pycityagent/cityagent/blocks/plan_block.py +59 -42
  12. pycityagent/cityagent/blocks/social_block.py +167 -126
  13. pycityagent/cityagent/blocks/utils.py +13 -6
  14. pycityagent/cityagent/firmagent.py +17 -35
  15. pycityagent/cityagent/governmentagent.py +3 -3
  16. pycityagent/cityagent/initial.py +79 -49
  17. pycityagent/cityagent/memory_config.py +123 -94
  18. pycityagent/cityagent/message_intercept.py +0 -4
  19. pycityagent/cityagent/metrics.py +41 -0
  20. pycityagent/cityagent/nbsagent.py +24 -36
  21. pycityagent/cityagent/societyagent.py +9 -4
  22. pycityagent/cli/wrapper.py +2 -2
  23. pycityagent/economy/econ_client.py +407 -81
  24. pycityagent/environment/__init__.py +0 -3
  25. pycityagent/environment/sim/__init__.py +0 -3
  26. pycityagent/environment/sim/aoi_service.py +2 -2
  27. pycityagent/environment/sim/client.py +3 -31
  28. pycityagent/environment/sim/clock_service.py +2 -2
  29. pycityagent/environment/sim/lane_service.py +8 -8
  30. pycityagent/environment/sim/light_service.py +8 -8
  31. pycityagent/environment/sim/pause_service.py +9 -10
  32. pycityagent/environment/sim/person_service.py +20 -20
  33. pycityagent/environment/sim/road_service.py +2 -2
  34. pycityagent/environment/sim/sim_env.py +21 -5
  35. pycityagent/environment/sim/social_service.py +4 -4
  36. pycityagent/environment/simulator.py +249 -27
  37. pycityagent/environment/utils/__init__.py +2 -2
  38. pycityagent/environment/utils/geojson.py +2 -2
  39. pycityagent/environment/utils/grpc.py +4 -4
  40. pycityagent/environment/utils/map_utils.py +2 -2
  41. pycityagent/llm/embeddings.py +147 -28
  42. pycityagent/llm/llm.py +178 -111
  43. pycityagent/llm/llmconfig.py +5 -0
  44. pycityagent/llm/utils.py +4 -0
  45. pycityagent/memory/__init__.py +0 -4
  46. pycityagent/memory/const.py +2 -2
  47. pycityagent/memory/faiss_query.py +140 -61
  48. pycityagent/memory/memory.py +394 -91
  49. pycityagent/memory/memory_base.py +140 -34
  50. pycityagent/memory/profile.py +13 -13
  51. pycityagent/memory/self_define.py +13 -13
  52. pycityagent/memory/state.py +14 -14
  53. pycityagent/message/message_interceptor.py +253 -3
  54. pycityagent/message/messager.py +133 -6
  55. pycityagent/metrics/mlflow_client.py +47 -4
  56. pycityagent/pycityagent-sim +0 -0
  57. pycityagent/pycityagent-ui +0 -0
  58. pycityagent/simulation/__init__.py +3 -2
  59. pycityagent/simulation/agentgroup.py +150 -54
  60. pycityagent/simulation/simulation.py +276 -66
  61. pycityagent/survey/manager.py +45 -3
  62. pycityagent/survey/models.py +42 -2
  63. pycityagent/tools/__init__.py +1 -2
  64. pycityagent/tools/tool.py +93 -69
  65. pycityagent/utils/avro_schema.py +2 -2
  66. pycityagent/utils/parsers/code_block_parser.py +1 -1
  67. pycityagent/utils/parsers/json_parser.py +2 -2
  68. pycityagent/utils/parsers/parser_base.py +2 -2
  69. pycityagent/workflow/block.py +64 -13
  70. pycityagent/workflow/prompt.py +31 -23
  71. pycityagent/workflow/trigger.py +91 -24
  72. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
  73. pycityagent-2.0.0a67.dist-info/RECORD +97 -0
  74. pycityagent/environment/interact/__init__.py +0 -0
  75. pycityagent/environment/interact/interact.py +0 -198
  76. pycityagent/environment/message/__init__.py +0 -0
  77. pycityagent/environment/sence/__init__.py +0 -0
  78. pycityagent/environment/sence/static.py +0 -416
  79. pycityagent/environment/sidecar/__init__.py +0 -8
  80. pycityagent/environment/sidecar/sidecarv2.py +0 -109
  81. pycityagent/environment/sim/economy_services.py +0 -192
  82. pycityagent/metrics/utils/const.py +0 -0
  83. pycityagent-2.0.0a65.dist-info/RECORD +0 -105
  84. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
  85. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
  86. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
  87. {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
@@ -13,19 +13,30 @@ import yaml
13
13
  from langchain_core.embeddings import Embeddings
14
14
 
15
15
  from ..agent import Agent, InstitutionAgent
16
- from ..cityagent import (BankAgent, FirmAgent, GovernmentAgent, NBSAgent,
17
- SocietyAgent, memory_config_bank, memory_config_firm,
18
- memory_config_government, memory_config_nbs,
19
- memory_config_societyagent)
16
+ from ..cityagent import BankAgent, FirmAgent, GovernmentAgent, NBSAgent, SocietyAgent
17
+ from ..cityagent.memory_config import (
18
+ memory_config_bank,
19
+ memory_config_firm,
20
+ memory_config_government,
21
+ memory_config_nbs,
22
+ memory_config_societyagent,
23
+ memory_config_init,
24
+ )
20
25
  from ..cityagent.initial import bind_agent_info, initialize_social_network
21
- from ..cityagent.message_intercept import (EdgeMessageBlock,
22
- MessageBlockListener,
23
- PointMessageBlock)
26
+ from ..cityagent.message_intercept import (
27
+ EdgeMessageBlock,
28
+ MessageBlockListener,
29
+ PointMessageBlock,
30
+ )
24
31
  from ..economy.econ_client import EconomyClient
25
32
  from ..environment import Simulator
26
33
  from ..llm import SimpleEmbedding
27
- from ..message import (MessageBlockBase, MessageBlockListenerBase,
28
- MessageInterceptor, Messager)
34
+ from ..message import (
35
+ MessageBlockBase,
36
+ MessageBlockListenerBase,
37
+ MessageInterceptor,
38
+ Messager,
39
+ )
29
40
  from ..metrics import init_mlflow_connection
30
41
  from ..metrics.mlflow_client import MlflowClient
31
42
  from ..survey import Survey
@@ -36,26 +47,55 @@ from .storage.pg import PgWriter, create_pg_tables
36
47
  logger = logging.getLogger("pycityagent")
37
48
 
38
49
 
50
+ __all__ = ["AgentSimulation"]
51
+
39
52
  class AgentSimulation:
40
- """城市智能体模拟器"""
53
+ """
54
+ A class to simulate a multi-agent system.
55
+
56
+ This simulation framework is designed to facilitate the creation and management of multiple agent types within an experiment.
57
+ It allows for the configuration of different agents, memory configurations, and metric extractors, as well as enabling institutional settings.
58
+
59
+ Attributes:
60
+ exp_id (str): A unique identifier for the current experiment.
61
+ agent_class (List[Type[Agent]]): A list of agent classes that will be instantiated in the simulation.
62
+ agent_config_file (Optional[dict]): Configuration file or dictionary for initializing agents.
63
+ logging_level (int): The level of logging to be used throughout the simulation.
64
+ default_memory_config_func (Dict[Type[Agent], Callable]): Dictionary mapping agent classes to their respective memory configuration functions.
65
+ """
41
66
 
42
67
  def __init__(
43
68
  self,
44
69
  config: dict,
45
70
  agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
46
71
  agent_config_file: Optional[dict] = None,
47
- metric_extractor: Optional[list[tuple[int, Callable]]] = None,
72
+ metric_extractors: Optional[list[tuple[int, Callable]]] = None,
48
73
  enable_institution: bool = True,
49
74
  agent_prefix: str = "agent_",
50
75
  exp_name: str = "default_experiment",
51
76
  logging_level: int = logging.WARNING,
52
77
  ):
53
78
  """
54
- Args:
55
- agent_class: 智能体类
56
- config: 配置
57
- agent_prefix: 智能体名称前缀
58
- exp_name: 实验名称
79
+ Initializes the AgentSimulation with the given parameters.
80
+
81
+ - **Description**:
82
+ - Sets up the simulation environment based on the provided configuration. Depending on the `enable_institution` flag,
83
+ it can include a predefined set of institutional agents. If specific agent classes are provided, those will be used instead.
84
+
85
+ - **Args**:
86
+ - `config` (dict): The main configuration dictionary for the simulation.
87
+ - `agent_class` (Union[None, Type[Agent], List[Type[Agent]]], optional):
88
+ Either a single agent class or a list of agent classes to instantiate. Defaults to None, which implies a default set of agents.
89
+ - `agent_config_file` (Optional[dict], optional): An optional configuration file or dictionary used to initialize agents. Defaults to None.
90
+ - `metric_extractors` (Optional[List[Tuple[int, Callable]]], optional):
91
+ A list of tuples containing intervals and callables for extracting metrics from the simulation. Defaults to None.
92
+ - `enable_institution` (bool, optional): Flag indicating whether institutional agents should be included in the simulation. Defaults to True.
93
+ - `agent_prefix` (str, optional): Prefix string for naming agents. Defaults to "agent_".
94
+ - `exp_name` (str, optional): The name of the experiment. Defaults to "default_experiment".
95
+ - `logging_level` (int, optional): Logging level to set for the simulation's logger. Defaults to logging.WARNING.
96
+
97
+ - **Returns**:
98
+ - None
59
99
  """
60
100
  self.exp_id = str(uuid.uuid4())
61
101
  if isinstance(agent_class, list):
@@ -166,11 +206,11 @@ class AgentSimulation:
166
206
  experiment_name=exp_name,
167
207
  run_id=mlflow_run_id,
168
208
  )
169
- self.metric_extractor = metric_extractor
209
+ self.metric_extractors = metric_extractors
170
210
  else:
171
211
  logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
172
212
  self.mlflow_client = None
173
- self.metric_extractor = None
213
+ self.metric_extractors = None
174
214
 
175
215
  # pg
176
216
  _pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
@@ -216,8 +256,9 @@ class AgentSimulation:
216
256
  - enable_institution: bool, default is True
217
257
  - agent_config:
218
258
  - agent_config_file: Optional[dict[type[Agent], str]]
259
+ - memory_config_init_func: Optional[Callable]
219
260
  - memory_config_func: Optional[dict[type[Agent], Callable]]
220
- - metric_extractor: Optional[list[tuple[int, Callable]]]
261
+ - metric_extractors: Optional[list[tuple[int, Callable]]]
221
262
  - init_func: Optional[list[Callable[AgentSimulation, None]]]
222
263
  - group_size: Optional[int]
223
264
  - embedding_model: Optional[EmbeddingModel]
@@ -258,28 +299,33 @@ class AgentSimulation:
258
299
  simulation = cls(
259
300
  config=simulation_config,
260
301
  agent_config_file=config["agent_config"].get("agent_config_file", None),
261
- metric_extractor=config["agent_config"].get("metric_extractor", None),
302
+ metric_extractors=config["agent_config"].get("metric_extractors", None),
262
303
  enable_institution=config.get("enable_institution", True),
263
304
  exp_name=config.get("exp_name", "default_experiment"),
264
305
  logging_level=config.get("logging_level", logging.WARNING),
265
306
  )
266
307
  environment = config.get(
267
- "environment",
308
+ "environment",
268
309
  {
269
- "weather": "The weather is normal",
270
- "crime": "The crime rate is low",
271
- "pollution": "The pollution level is low",
272
- "temperature": "The temperature is normal"
273
- }
310
+ "weather": "The weather is normal",
311
+ "crime": "The crime rate is low",
312
+ "pollution": "The pollution level is low",
313
+ "temperature": "The temperature is normal",
314
+ "day": "Workday"
315
+ },
274
316
  )
275
317
  simulation._simulator.set_environment(environment)
276
318
  logger.info("Initializing Agents...")
277
- agent_count = []
278
- agent_count.append(config["agent_config"]["number_of_citizen"])
279
- agent_count.append(config["agent_config"]["number_of_firm"])
280
- agent_count.append(config["agent_config"]["number_of_government"])
281
- agent_count.append(config["agent_config"]["number_of_bank"])
282
- agent_count.append(config["agent_config"]["number_of_nbs"])
319
+ agent_count = {
320
+ SocietyAgent: config["agent_config"].get("number_of_citizen", 0),
321
+ FirmAgent: config["agent_config"].get("number_of_firm", 0),
322
+ GovernmentAgent: config["agent_config"].get("number_of_government", 0),
323
+ BankAgent: config["agent_config"].get("number_of_bank", 0),
324
+ NBSAgent: config["agent_config"].get("number_of_nbs", 0),
325
+ }
326
+ if agent_count.get(SocietyAgent, 0) == 0:
327
+ raise ValueError("number_of_citizen is required")
328
+
283
329
  # support MessageInterceptor
284
330
  if "message_intercept" in config:
285
331
  _intercept_config = config["message_intercept"]
@@ -319,6 +365,9 @@ class AgentSimulation:
319
365
  "embedding_model", SimpleEmbedding()
320
366
  ),
321
367
  memory_config_func=config["agent_config"].get("memory_config_func", None),
368
+ memory_config_init_func=config["agent_config"].get(
369
+ "memory_config_init_func", None
370
+ ),
322
371
  **_message_intercept_kwargs,
323
372
  environment=environment,
324
373
  )
@@ -328,6 +377,9 @@ class AgentSimulation:
328
377
  ):
329
378
  await init_func(simulation)
330
379
  logger.info("Starting Simulation...")
380
+ llm_log_lists = []
381
+ mqtt_log_lists = []
382
+ simulator_log_lists = []
331
383
  for step in config["workflow"]:
332
384
  logger.info(
333
385
  f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
@@ -335,11 +387,17 @@ class AgentSimulation:
335
387
  if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
336
388
  raise ValueError(f"Invalid step type: {step['type']}")
337
389
  if step["type"] == "run":
338
- await simulation.run(step.get("days", 1))
390
+ llm_log_list, mqtt_log_list, simulator_log_list = await simulation.run(step.get("days", 1))
391
+ llm_log_lists.extend(llm_log_list)
392
+ mqtt_log_lists.extend(mqtt_log_list)
393
+ simulator_log_lists.extend(simulator_log_list)
339
394
  elif step["type"] == "step":
340
395
  times = step.get("times", 1)
341
396
  for _ in range(times):
342
- await simulation.step()
397
+ llm_log_list, mqtt_log_list, simulator_log_list = await simulation.step()
398
+ llm_log_lists.extend(llm_log_list)
399
+ mqtt_log_lists.extend(mqtt_log_list)
400
+ simulator_log_lists.extend(simulator_log_list)
343
401
  elif step["type"] == "pause":
344
402
  await simulation.pause_simulator()
345
403
  elif step["type"] == "resume":
@@ -347,7 +405,7 @@ class AgentSimulation:
347
405
  else:
348
406
  await step["func"](simulation)
349
407
  logger.info("Simulation finished")
350
-
408
+ return llm_log_lists, mqtt_log_lists, simulator_log_lists
351
409
  @property
352
410
  def enable_avro(
353
411
  self,
@@ -423,7 +481,7 @@ class AgentSimulation:
423
481
  async def _monitor_exp_status(self, stop_event: asyncio.Event):
424
482
  """监控实验状态并更新
425
483
 
426
- Args:
484
+ - **Args**:
427
485
  stop_event: 用于通知监控任务停止的事件
428
486
  """
429
487
  try:
@@ -465,7 +523,7 @@ class AgentSimulation:
465
523
 
466
524
  async def init_agents(
467
525
  self,
468
- agent_count: Union[int, list[int]],
526
+ agent_count: dict[type[Agent], int],
469
527
  group_size: int = 10000,
470
528
  pg_sql_writers: int = 32,
471
529
  message_interceptors: int = 1,
@@ -473,25 +531,44 @@ class AgentSimulation:
473
531
  social_black_list: Optional[list[tuple[str, str]]] = None,
474
532
  message_listener: Optional[MessageBlockListenerBase] = None,
475
533
  embedding_model: Embeddings = SimpleEmbedding(),
534
+ memory_config_init_func: Optional[Callable] = None,
476
535
  memory_config_func: Optional[dict[type[Agent], Callable]] = None,
477
536
  environment: Optional[dict[str, str]] = None,
478
537
  ) -> None:
479
- """初始化智能体
480
-
481
- Args:
482
- agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
483
- group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
484
- pg_sql_writers: 独立的PgSQL writer数量
485
- message_interceptors: message拦截器数量
486
- memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 每个元素表示一个智能体类创建的Memory配置函数
487
- environment: 环境变量,用于更新模拟器的环境变量
488
538
  """
489
- if not isinstance(agent_count, list):
490
- agent_count = [agent_count]
539
+ Initialize agents within the simulation.
540
+
541
+ - **Description**:
542
+ - Asynchronously initializes a specified number of agents for each provided agent class.
543
+ - Agents are grouped into independent Ray actors based on the `group_size`, with configurations for database writers,
544
+ message interceptors, and memory settings. Optionally updates the simulator's environment variables.
545
+
546
+ - **Args**:
547
+ - `agent_count` (dict[Type[Agent], int]): Dictionary mapping agent classes to the number of instances to create.
548
+ - `group_size` (int, optional): Number of agents per group, each group runs as an independent Ray actor. Defaults to 10000.
549
+ - `pg_sql_writers` (int, optional): Number of independent PgSQL writer processes. Defaults to 32.
550
+ - `message_interceptors` (int, optional): Number of message interceptor processes. Defaults to 1.
551
+ - `message_interceptor_blocks` (Optional[List[MessageBlockBase]], optional): List of message interception blocks. Defaults to None.
552
+ - `social_black_list` (Optional[List[Tuple[str, str]]], optional): List of tuples representing pairs of agents that should not communicate. Defaults to None.
553
+ - `message_listener` (Optional[MessageBlockListenerBase], optional): Listener for intercepted messages. Defaults to None.
554
+ - `embedding_model` (Embeddings, optional): Model used for generating embeddings for agents' memories. Defaults to SimpleEmbedding().
555
+ - `memory_config_init_func` (Optional[Callable], optional): Initialization function for setting up memory configuration. Defaults to None.
556
+ - `memory_config_func` (Optional[Dict[Type[Agent], Callable]], optional): Dictionary mapping agent classes to their memory configuration functions. Defaults to None.
557
+ - `environment` (Optional[Dict[str, str]], optional): Environment variables to update in the simulation. Defaults to None.
558
+
559
+ - **Raises**:
560
+ - `ValueError`: If the lengths of `agent_class` and `agent_count` do not match.
561
+
562
+ - **Returns**:
563
+ - `None`
564
+ """
565
+ self.agent_count = agent_count
491
566
 
492
567
  if len(self.agent_class) != len(agent_count):
493
- raise ValueError("agent_classagent_count的长度不一致")
568
+ raise ValueError("The length of agent_class and agent_count does not match")
494
569
 
570
+ if memory_config_init_func is not None:
571
+ await memory_config_init(self)
495
572
  if memory_config_func is None:
496
573
  memory_config_func = self.default_memory_config_func # type:ignore
497
574
 
@@ -505,7 +582,7 @@ class AgentSimulation:
505
582
  # 收集所有参数
506
583
  for i in range(len(self.agent_class)):
507
584
  agent_class = self.agent_class[i]
508
- agent_count_i = agent_count[i]
585
+ agent_count_i = agent_count[agent_class]
509
586
  assert memory_config_func is not None
510
587
  memory_config_func_i = memory_config_func.get(
511
588
  agent_class, self.default_memory_config_func[agent_class] # type:ignore
@@ -706,10 +783,32 @@ class AgentSimulation:
706
783
  )
707
784
  await self.messager.start_listening.remote() # type:ignore
708
785
 
786
+ agent_ids = set()
787
+ org_ids = set()
788
+ for group in self._groups.values():
789
+ ids = await group.get_economy_ids.remote()
790
+ agent_ids.update(ids[0])
791
+ org_ids.update(ids[1])
792
+ await self.economy_client.set_ids(agent_ids, org_ids)
793
+ for group in self._groups.values():
794
+ await group.set_economy_ids.remote(agent_ids, org_ids)
795
+
709
796
  async def gather(
710
797
  self, content: str, target_agent_uuids: Optional[list[str]] = None
711
798
  ):
712
- """收集智能体的特定信息"""
799
+ """
800
+ Collect specific information from agents.
801
+
802
+ - **Description**:
803
+ - Asynchronously gathers specified content from targeted agents within all groups.
804
+
805
+ - **Args**:
806
+ - `content` (str): The information to collect from the agents.
807
+ - `target_agent_uuids` (Optional[List[str]], optional): A list of agent UUIDs to target. Defaults to None, meaning all agents are targeted.
808
+
809
+ - **Returns**:
810
+ - Result of the gathering process as returned by each group's `gather` method.
811
+ """
713
812
  gather_tasks = []
714
813
  for group in self._groups.values():
715
814
  gather_tasks.append(group.gather.remote(content, target_agent_uuids))
@@ -721,7 +820,20 @@ class AgentSimulation:
721
820
  keys: Optional[list[str]] = None,
722
821
  values: Optional[list[Any]] = None,
723
822
  ) -> list[str]:
724
- """过滤出指定类型的智能体"""
823
+ """
824
+ Filter out agents of specified types or with matching key-value pairs.
825
+
826
+ - **Args**:
827
+ - `types` (Optional[List[Type[Agent]]], optional): Types of agents to filter for. Defaults to None.
828
+ - `keys` (Optional[List[str]], optional): Keys to match in agent attributes. Defaults to None.
829
+ - `values` (Optional[List[Any]], optional): Values corresponding to keys for matching. Defaults to None.
830
+
831
+ - **Raises**:
832
+ - `ValueError`: If neither types nor keys and values are provided, or if the lengths of keys and values do not match.
833
+
834
+ - **Returns**:
835
+ - `List[str]`: A list of filtered agent UUIDs.
836
+ """
725
837
  if not types and not keys and not values:
726
838
  return self._agent_uuids
727
839
  group_to_filter = []
@@ -744,12 +856,26 @@ class AgentSimulation:
744
856
  return filtered_uuids
745
857
 
746
858
  async def update_environment(self, key: str, value: str):
859
+ """
860
+ Update the environment variables for the simulation and all agent groups.
861
+
862
+ - **Args**:
863
+ - `key` (str): The environment variable key to update.
864
+ - `value` (str): The new value for the environment variable.
865
+ """
747
866
  self._simulator.update_environment(key, value)
748
867
  for group in self._groups.values():
749
868
  await group.update_environment.remote(key, value)
750
869
 
751
870
  async def update(self, target_agent_uuid: str, target_key: str, content: Any):
752
- """更新指定智能体的记忆"""
871
+ """
872
+ Update the memory of a specified agent.
873
+
874
+ - **Args**:
875
+ - `target_agent_uuid` (str): The UUID of the target agent to update.
876
+ - `target_key` (str): The key in the agent's memory to update.
877
+ - `content` (Any): The new content to set for the target key.
878
+ """
753
879
  group = self._agent_uuid2group[target_agent_uuid]
754
880
  await group.update.remote(target_agent_uuid, target_key, content)
755
881
 
@@ -760,13 +886,30 @@ class AgentSimulation:
760
886
  content: Any,
761
887
  mode: Literal["replace", "merge"] = "replace",
762
888
  ):
763
- """更新指定智能体的经济数据"""
889
+ """
890
+ Update economic data for a specified agent.
891
+
892
+ - **Args**:
893
+ - `target_agent_id` (int): The ID of the target agent whose economic data to update.
894
+ - `target_key` (str): The key in the agent's economic data to update.
895
+ - `content` (Any): The new content to set for the target key.
896
+ - `mode` (Literal["replace", "merge"], optional): Mode of updating the economic data. Defaults to "replace".
897
+ """
764
898
  await self.economy_client.update(
765
899
  id=target_agent_id, key=target_key, value=content, mode=mode
766
900
  )
767
901
 
768
902
  async def send_survey(self, survey: Survey, agent_uuids: list[str] = []):
769
- """发送问卷"""
903
+ """
904
+ Send a survey to specified agents.
905
+
906
+ - **Args**:
907
+ - `survey` (Survey): The survey object to send.
908
+ - `agent_uuids` (List[str], optional): List of agent UUIDs to receive the survey. Defaults to an empty list.
909
+
910
+ - **Returns**:
911
+ - None
912
+ """
770
913
  survey_dict = survey.to_dict()
771
914
  _date_time = datetime.now(timezone.utc)
772
915
  payload = {
@@ -791,7 +934,16 @@ class AgentSimulation:
791
934
  async def send_interview_message(
792
935
  self, content: str, agent_uuids: Union[str, list[str]]
793
936
  ):
794
- """发送采访消息"""
937
+ """
938
+ Send an interview message to specified agents.
939
+
940
+ - **Args**:
941
+ - `content` (str): The content of the message to send.
942
+ - `agent_uuids` (Union[str, List[str]]): A single UUID string or a list of UUID strings for the agents to receive the message.
943
+
944
+ - **Returns**:
945
+ - None
946
+ """
795
947
  _date_time = datetime.now(timezone.utc)
796
948
  payload = {
797
949
  "from": SURVEY_SENDER_UUID,
@@ -814,12 +966,37 @@ class AgentSimulation:
814
966
  await asyncio.sleep(3)
815
967
 
816
968
  async def extract_metric(self, metric_extractors: list[Callable]):
817
- """提取指标"""
969
+ """
970
+ Extract metrics using provided extractors.
971
+
972
+ - **Description**:
973
+ - Asynchronously applies each metric extractor function to the simulation to collect various metrics.
974
+
975
+ - **Args**:
976
+ - `metric_extractors` (List[Callable]): A list of callable functions that take the simulation instance as an argument and return a metric or perform some form of analysis.
977
+
978
+ - **Returns**:
979
+ - None
980
+ """
818
981
  for metric_extractor in metric_extractors:
819
982
  await metric_extractor(self)
820
983
 
821
984
  async def step(self):
822
- """Run one step, each agent execute one forward"""
985
+ """
986
+ Execute one step of the simulation where each agent performs its forward action.
987
+
988
+ - **Description**:
989
+ - Checks if new agents need to be inserted based on the current day of the simulation. If so, it inserts them.
990
+ - Executes the forward method for each agent group to advance the simulation by one step.
991
+ - Saves the state of all agent groups after the step has been completed.
992
+ - Optionally extracts metrics if the current step matches the interval specified for any metric extractors.
993
+
994
+ - **Raises**:
995
+ - `RuntimeError`: If there is an error during the execution of the step, it logs the error and rethrows it as a RuntimeError.
996
+
997
+ - **Returns**:
998
+ - None
999
+ """
823
1000
  try:
824
1001
  # check whether insert agents
825
1002
  simulator_day = await self._simulator.get_simulator_day()
@@ -837,11 +1014,21 @@ class AgentSimulation:
837
1014
  # step
838
1015
  simulator_day = await self._simulator.get_simulator_day()
839
1016
  simulator_time = int(await self._simulator.get_time())
840
- logger.info(f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}")
1017
+ logger.info(
1018
+ f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}"
1019
+ )
841
1020
  tasks = []
842
1021
  for group in self._groups.values():
843
1022
  tasks.append(group.step.remote())
844
- await asyncio.gather(*tasks)
1023
+ log_messages_groups = await asyncio.gather(*tasks)
1024
+ llm_log_list = []
1025
+ mqtt_log_list = []
1026
+ simulator_log_list = []
1027
+ for log_messages_group in log_messages_groups:
1028
+ llm_log_list.extend(log_messages_group['llm_log'])
1029
+ mqtt_log_list.extend(log_messages_group['mqtt_log'])
1030
+ simulator_log_list.extend(log_messages_group['simulator_log'])
1031
+
845
1032
  # save
846
1033
  simulator_day = await self._simulator.get_simulator_day()
847
1034
  simulator_time = int(await self._simulator.get_time())
@@ -850,14 +1037,15 @@ class AgentSimulation:
850
1037
  save_tasks.append(group.save.remote(simulator_day, simulator_time))
851
1038
  await asyncio.gather(*save_tasks)
852
1039
  self._total_steps += 1
853
- if self.metric_extractor is not None:
854
- print(f"total_steps: {self._total_steps}, excute metric")
1040
+ if self.metric_extractors is not None: # type:ignore
855
1041
  to_excute_metric = [
856
1042
  metric[1]
857
- for metric in self.metric_extractor
1043
+ for metric in self.metric_extractors # type:ignore
858
1044
  if self._total_steps % metric[0] == 0
859
1045
  ]
860
1046
  await self.extract_metric(to_excute_metric)
1047
+
1048
+ return llm_log_list, mqtt_log_list, simulator_log_list
861
1049
  except Exception as e:
862
1050
  import traceback
863
1051
 
@@ -868,7 +1056,26 @@ class AgentSimulation:
868
1056
  self,
869
1057
  day: int = 1,
870
1058
  ):
871
- """Run the simulation by days"""
1059
+ """
1060
+ Run the simulation for a specified number of days.
1061
+
1062
+ - **Args**:
1063
+ - `day` (int, optional): The number of days to run the simulation. Defaults to 1.
1064
+
1065
+ - **Description**:
1066
+ - Updates the experiment status to running and sets up monitoring for the experiment's status.
1067
+ - Runs the simulation loop until the end time, which is calculated based on the current time and the number of days to simulate.
1068
+ - After completing the simulation, updates the experiment status to finished, or to failed if an exception occurs.
1069
+
1070
+ - **Raises**:
1071
+ - `RuntimeError`: If there is an error during the simulation, it logs the error and updates the experiment status to failed before rethrowing the exception.
1072
+
1073
+ - **Returns**:
1074
+ - None
1075
+ """
1076
+ llm_log_lists = []
1077
+ mqtt_log_lists = []
1078
+ simulator_log_lists = []
872
1079
  try:
873
1080
  self._exp_info["num_day"] += day
874
1081
  await self._update_exp_status(1) # 更新状态为运行中
@@ -886,7 +1093,10 @@ class AgentSimulation:
886
1093
  current_time = await self._simulator.get_time()
887
1094
  if current_time >= end_time: # type:ignore
888
1095
  break
889
- await self.step()
1096
+ llm_log_list, mqtt_log_list, simulator_log_list = await self.step()
1097
+ llm_log_lists.extend(llm_log_list)
1098
+ mqtt_log_lists.extend(mqtt_log_list)
1099
+ simulator_log_lists.extend(simulator_log_list)
890
1100
  finally:
891
1101
  # 设置停止事件
892
1102
  stop_event.set()
@@ -895,7 +1105,7 @@ class AgentSimulation:
895
1105
 
896
1106
  # 运行成功后更新状态
897
1107
  await self._update_exp_status(2)
898
-
1108
+ return llm_log_lists, mqtt_log_lists, simulator_log_lists
899
1109
  except Exception as e:
900
1110
  error_msg = f"模拟器运行错误: {str(e)}"
901
1111
  logger.error(error_msg)
@@ -8,10 +8,33 @@ from .models import Page, Question, QuestionType, Survey
8
8
 
9
9
  class SurveyManager:
10
10
  def __init__(self):
11
+ """
12
+ Initializes a new instance of the SurveyManager class.
13
+
14
+ - **Description**:
15
+ - Manages the creation and retrieval of surveys. Uses an internal dictionary to store survey instances by their unique identifier.
16
+
17
+ - **Attributes**:
18
+ - `_surveys` (dict[str, Survey]): A dictionary mapping survey IDs to `Survey` instances.
19
+ """
11
20
  self._surveys: dict[str, Survey] = {}
12
21
 
13
22
  def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
14
- """创建新问卷"""
23
+ """
24
+ Create a new survey with specified title, description, and pages containing questions.
25
+
26
+ - **Description**:
27
+ - Generates a unique ID for the survey, converts the provided page and question data into `Page` and `Question` objects,
28
+ and adds the created survey to the internal storage.
29
+
30
+ - **Args**:
31
+ - `title` (str): The title of the survey.
32
+ - `description` (str): A brief description of what the survey is about.
33
+ - `pages` (list[dict]): A list of dictionaries where each dictionary contains page information including elements which are question definitions.
34
+
35
+ - **Returns**:
36
+ - `Survey`: An instance of the `Survey` class representing the newly created survey.
37
+ """
15
38
  survey_id = uuid.uuid4()
16
39
 
17
40
  # 转换页面和问题数据
@@ -46,9 +69,28 @@ class SurveyManager:
46
69
  return survey
47
70
 
48
71
  def get_survey(self, survey_id: str) -> Optional[Survey]:
49
- """获取指定问卷"""
72
+ """
73
+ Retrieve a specific survey by its unique identifier.
74
+
75
+ - **Description**:
76
+ - Searches for a survey within the internal storage using the provided survey ID.
77
+
78
+ - **Args**:
79
+ - `survey_id` (str): The unique identifier of the survey to retrieve.
80
+
81
+ - **Returns**:
82
+ - `Optional[Survey]`: An instance of the `Survey` class if found; otherwise, None.
83
+ """
50
84
  return self._surveys.get(survey_id)
51
85
 
52
86
  def get_all_surveys(self) -> list[Survey]:
53
- """获取所有问卷"""
87
+ """
88
+ Get all surveys that have been created.
89
+
90
+ - **Description**:
91
+ - Retrieves all stored surveys from the internal storage.
92
+
93
+ - **Returns**:
94
+ - `list[Survey]`: A list of `Survey` instances.
95
+ """
54
96
  return list(self._surveys.values())