pycityagent 2.0.0a66__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a67__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +157 -57
- pycityagent/agent/agent_base.py +316 -43
- pycityagent/cityagent/bankagent.py +49 -9
- pycityagent/cityagent/blocks/__init__.py +1 -2
- pycityagent/cityagent/blocks/cognition_block.py +54 -31
- pycityagent/cityagent/blocks/dispatcher.py +22 -17
- pycityagent/cityagent/blocks/economy_block.py +46 -32
- pycityagent/cityagent/blocks/mobility_block.py +130 -100
- pycityagent/cityagent/blocks/needs_block.py +101 -44
- pycityagent/cityagent/blocks/other_block.py +42 -33
- pycityagent/cityagent/blocks/plan_block.py +59 -42
- pycityagent/cityagent/blocks/social_block.py +167 -116
- pycityagent/cityagent/blocks/utils.py +13 -6
- pycityagent/cityagent/firmagent.py +17 -35
- pycityagent/cityagent/governmentagent.py +3 -3
- pycityagent/cityagent/initial.py +79 -44
- pycityagent/cityagent/memory_config.py +108 -88
- pycityagent/cityagent/message_intercept.py +0 -4
- pycityagent/cityagent/metrics.py +41 -0
- pycityagent/cityagent/nbsagent.py +24 -36
- pycityagent/cityagent/societyagent.py +7 -3
- pycityagent/cli/wrapper.py +2 -2
- pycityagent/economy/econ_client.py +407 -81
- pycityagent/environment/__init__.py +0 -3
- pycityagent/environment/sim/__init__.py +0 -3
- pycityagent/environment/sim/aoi_service.py +2 -2
- pycityagent/environment/sim/client.py +3 -31
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/lane_service.py +8 -8
- pycityagent/environment/sim/light_service.py +8 -8
- pycityagent/environment/sim/pause_service.py +9 -10
- pycityagent/environment/sim/person_service.py +20 -20
- pycityagent/environment/sim/road_service.py +2 -2
- pycityagent/environment/sim/sim_env.py +21 -5
- pycityagent/environment/sim/social_service.py +4 -4
- pycityagent/environment/simulator.py +249 -27
- pycityagent/environment/utils/__init__.py +2 -2
- pycityagent/environment/utils/geojson.py +2 -2
- pycityagent/environment/utils/grpc.py +4 -4
- pycityagent/environment/utils/map_utils.py +2 -2
- pycityagent/llm/embeddings.py +147 -28
- pycityagent/llm/llm.py +122 -77
- pycityagent/llm/llmconfig.py +5 -0
- pycityagent/llm/utils.py +4 -0
- pycityagent/memory/__init__.py +0 -4
- pycityagent/memory/const.py +2 -2
- pycityagent/memory/faiss_query.py +140 -61
- pycityagent/memory/memory.py +393 -90
- pycityagent/memory/memory_base.py +140 -34
- pycityagent/memory/profile.py +13 -13
- pycityagent/memory/self_define.py +13 -13
- pycityagent/memory/state.py +14 -14
- pycityagent/message/message_interceptor.py +253 -3
- pycityagent/message/messager.py +133 -6
- pycityagent/metrics/mlflow_client.py +47 -4
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +3 -2
- pycityagent/simulation/agentgroup.py +145 -52
- pycityagent/simulation/simulation.py +257 -62
- pycityagent/survey/manager.py +45 -3
- pycityagent/survey/models.py +42 -2
- pycityagent/tools/__init__.py +1 -2
- pycityagent/tools/tool.py +93 -69
- pycityagent/utils/avro_schema.py +2 -2
- pycityagent/utils/parsers/code_block_parser.py +1 -1
- pycityagent/utils/parsers/json_parser.py +2 -2
- pycityagent/utils/parsers/parser_base.py +2 -2
- pycityagent/workflow/block.py +64 -13
- pycityagent/workflow/prompt.py +31 -23
- pycityagent/workflow/trigger.py +91 -24
- {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
- pycityagent-2.0.0a67.dist-info/RECORD +97 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +0 -198
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +0 -416
- pycityagent/environment/sidecar/__init__.py +0 -8
- pycityagent/environment/sidecar/sidecarv2.py +0 -109
- pycityagent/environment/sim/economy_services.py +0 -192
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent-2.0.0a66.dist-info/RECORD +0 -105
- {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a66.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
@@ -13,20 +13,30 @@ import yaml
|
|
13
13
|
from langchain_core.embeddings import Embeddings
|
14
14
|
|
15
15
|
from ..agent import Agent, InstitutionAgent
|
16
|
-
from ..cityagent import
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
16
|
+
from ..cityagent import BankAgent, FirmAgent, GovernmentAgent, NBSAgent, SocietyAgent
|
17
|
+
from ..cityagent.memory_config import (
|
18
|
+
memory_config_bank,
|
19
|
+
memory_config_firm,
|
20
|
+
memory_config_government,
|
21
|
+
memory_config_nbs,
|
22
|
+
memory_config_societyagent,
|
23
|
+
memory_config_init,
|
24
|
+
)
|
21
25
|
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
22
|
-
from ..cityagent.message_intercept import (
|
23
|
-
|
24
|
-
|
26
|
+
from ..cityagent.message_intercept import (
|
27
|
+
EdgeMessageBlock,
|
28
|
+
MessageBlockListener,
|
29
|
+
PointMessageBlock,
|
30
|
+
)
|
25
31
|
from ..economy.econ_client import EconomyClient
|
26
32
|
from ..environment import Simulator
|
27
33
|
from ..llm import SimpleEmbedding
|
28
|
-
from ..message import (
|
29
|
-
|
34
|
+
from ..message import (
|
35
|
+
MessageBlockBase,
|
36
|
+
MessageBlockListenerBase,
|
37
|
+
MessageInterceptor,
|
38
|
+
Messager,
|
39
|
+
)
|
30
40
|
from ..metrics import init_mlflow_connection
|
31
41
|
from ..metrics.mlflow_client import MlflowClient
|
32
42
|
from ..survey import Survey
|
@@ -36,8 +46,23 @@ from .storage.pg import PgWriter, create_pg_tables
|
|
36
46
|
|
37
47
|
logger = logging.getLogger("pycityagent")
|
38
48
|
|
49
|
+
|
50
|
+
__all__ = ["AgentSimulation"]
|
51
|
+
|
39
52
|
class AgentSimulation:
|
40
|
-
"""
|
53
|
+
"""
|
54
|
+
A class to simulate a multi-agent system.
|
55
|
+
|
56
|
+
This simulation framework is designed to facilitate the creation and management of multiple agent types within an experiment.
|
57
|
+
It allows for the configuration of different agents, memory configurations, and metric extractors, as well as enabling institutional settings.
|
58
|
+
|
59
|
+
Attributes:
|
60
|
+
exp_id (str): A unique identifier for the current experiment.
|
61
|
+
agent_class (List[Type[Agent]]): A list of agent classes that will be instantiated in the simulation.
|
62
|
+
agent_config_file (Optional[dict]): Configuration file or dictionary for initializing agents.
|
63
|
+
logging_level (int): The level of logging to be used throughout the simulation.
|
64
|
+
default_memory_config_func (Dict[Type[Agent], Callable]): Dictionary mapping agent classes to their respective memory configuration functions.
|
65
|
+
"""
|
41
66
|
|
42
67
|
def __init__(
|
43
68
|
self,
|
@@ -51,15 +76,26 @@ class AgentSimulation:
|
|
51
76
|
logging_level: int = logging.WARNING,
|
52
77
|
):
|
53
78
|
"""
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
79
|
+
Initializes the AgentSimulation with the given parameters.
|
80
|
+
|
81
|
+
- **Description**:
|
82
|
+
- Sets up the simulation environment based on the provided configuration. Depending on the `enable_institution` flag,
|
83
|
+
it can include a predefined set of institutional agents. If specific agent classes are provided, those will be used instead.
|
84
|
+
|
85
|
+
- **Args**:
|
86
|
+
- `config` (dict): The main configuration dictionary for the simulation.
|
87
|
+
- `agent_class` (Union[None, Type[Agent], List[Type[Agent]]], optional):
|
88
|
+
Either a single agent class or a list of agent classes to instantiate. Defaults to None, which implies a default set of agents.
|
89
|
+
- `agent_config_file` (Optional[dict], optional): An optional configuration file or dictionary used to initialize agents. Defaults to None.
|
90
|
+
- `metric_extractors` (Optional[List[Tuple[int, Callable]]], optional):
|
91
|
+
A list of tuples containing intervals and callables for extracting metrics from the simulation. Defaults to None.
|
92
|
+
- `enable_institution` (bool, optional): Flag indicating whether institutional agents should be included in the simulation. Defaults to True.
|
93
|
+
- `agent_prefix` (str, optional): Prefix string for naming agents. Defaults to "agent_".
|
94
|
+
- `exp_name` (str, optional): The name of the experiment. Defaults to "default_experiment".
|
95
|
+
- `logging_level` (int, optional): Logging level to set for the simulation's logger. Defaults to logging.WARNING.
|
96
|
+
|
97
|
+
- **Returns**:
|
98
|
+
- None
|
63
99
|
"""
|
64
100
|
self.exp_id = str(uuid.uuid4())
|
65
101
|
if isinstance(agent_class, list):
|
@@ -170,8 +206,7 @@ class AgentSimulation:
|
|
170
206
|
experiment_name=exp_name,
|
171
207
|
run_id=mlflow_run_id,
|
172
208
|
)
|
173
|
-
|
174
|
-
self.metric_extractors = metric_extractors
|
209
|
+
self.metric_extractors = metric_extractors
|
175
210
|
else:
|
176
211
|
logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
|
177
212
|
self.mlflow_client = None
|
@@ -270,13 +305,14 @@ class AgentSimulation:
|
|
270
305
|
logging_level=config.get("logging_level", logging.WARNING),
|
271
306
|
)
|
272
307
|
environment = config.get(
|
273
|
-
"environment",
|
308
|
+
"environment",
|
274
309
|
{
|
275
|
-
"weather": "The weather is normal",
|
276
|
-
"crime": "The crime rate is low",
|
277
|
-
"pollution": "The pollution level is low",
|
278
|
-
"temperature": "The temperature is normal"
|
279
|
-
|
310
|
+
"weather": "The weather is normal",
|
311
|
+
"crime": "The crime rate is low",
|
312
|
+
"pollution": "The pollution level is low",
|
313
|
+
"temperature": "The temperature is normal",
|
314
|
+
"day": "Workday"
|
315
|
+
},
|
280
316
|
)
|
281
317
|
simulation._simulator.set_environment(environment)
|
282
318
|
logger.info("Initializing Agents...")
|
@@ -289,7 +325,7 @@ class AgentSimulation:
|
|
289
325
|
}
|
290
326
|
if agent_count.get(SocietyAgent, 0) == 0:
|
291
327
|
raise ValueError("number_of_citizen is required")
|
292
|
-
|
328
|
+
|
293
329
|
# support MessageInterceptor
|
294
330
|
if "message_intercept" in config:
|
295
331
|
_intercept_config = config["message_intercept"]
|
@@ -328,8 +364,10 @@ class AgentSimulation:
|
|
328
364
|
embedding_model=config["agent_config"].get(
|
329
365
|
"embedding_model", SimpleEmbedding()
|
330
366
|
),
|
331
|
-
memory_config_func=config["agent_config"].get("memory_config_func", None),
|
332
|
-
memory_config_init_func=config["agent_config"].get(
|
367
|
+
memory_config_func=config["agent_config"].get("memory_config_func", None),
|
368
|
+
memory_config_init_func=config["agent_config"].get(
|
369
|
+
"memory_config_init_func", None
|
370
|
+
),
|
333
371
|
**_message_intercept_kwargs,
|
334
372
|
environment=environment,
|
335
373
|
)
|
@@ -339,6 +377,9 @@ class AgentSimulation:
|
|
339
377
|
):
|
340
378
|
await init_func(simulation)
|
341
379
|
logger.info("Starting Simulation...")
|
380
|
+
llm_log_lists = []
|
381
|
+
mqtt_log_lists = []
|
382
|
+
simulator_log_lists = []
|
342
383
|
for step in config["workflow"]:
|
343
384
|
logger.info(
|
344
385
|
f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
|
@@ -346,11 +387,17 @@ class AgentSimulation:
|
|
346
387
|
if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
|
347
388
|
raise ValueError(f"Invalid step type: {step['type']}")
|
348
389
|
if step["type"] == "run":
|
349
|
-
await simulation.run(step.get("days", 1))
|
390
|
+
llm_log_list, mqtt_log_list, simulator_log_list = await simulation.run(step.get("days", 1))
|
391
|
+
llm_log_lists.extend(llm_log_list)
|
392
|
+
mqtt_log_lists.extend(mqtt_log_list)
|
393
|
+
simulator_log_lists.extend(simulator_log_list)
|
350
394
|
elif step["type"] == "step":
|
351
395
|
times = step.get("times", 1)
|
352
396
|
for _ in range(times):
|
353
|
-
await simulation.step()
|
397
|
+
llm_log_list, mqtt_log_list, simulator_log_list = await simulation.step()
|
398
|
+
llm_log_lists.extend(llm_log_list)
|
399
|
+
mqtt_log_lists.extend(mqtt_log_list)
|
400
|
+
simulator_log_lists.extend(simulator_log_list)
|
354
401
|
elif step["type"] == "pause":
|
355
402
|
await simulation.pause_simulator()
|
356
403
|
elif step["type"] == "resume":
|
@@ -358,7 +405,7 @@ class AgentSimulation:
|
|
358
405
|
else:
|
359
406
|
await step["func"](simulation)
|
360
407
|
logger.info("Simulation finished")
|
361
|
-
|
408
|
+
return llm_log_lists, mqtt_log_lists, simulator_log_lists
|
362
409
|
@property
|
363
410
|
def enable_avro(
|
364
411
|
self,
|
@@ -434,7 +481,7 @@ class AgentSimulation:
|
|
434
481
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
435
482
|
"""监控实验状态并更新
|
436
483
|
|
437
|
-
Args
|
484
|
+
- **Args**:
|
438
485
|
stop_event: 用于通知监控任务停止的事件
|
439
486
|
"""
|
440
487
|
try:
|
@@ -488,15 +535,32 @@ class AgentSimulation:
|
|
488
535
|
memory_config_func: Optional[dict[type[Agent], Callable]] = None,
|
489
536
|
environment: Optional[dict[str, str]] = None,
|
490
537
|
) -> None:
|
491
|
-
"""
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
538
|
+
"""
|
539
|
+
Initialize agents within the simulation.
|
540
|
+
|
541
|
+
- **Description**:
|
542
|
+
- Asynchronously initializes a specified number of agents for each provided agent class.
|
543
|
+
- Agents are grouped into independent Ray actors based on the `group_size`, with configurations for database writers,
|
544
|
+
message interceptors, and memory settings. Optionally updates the simulator's environment variables.
|
545
|
+
|
546
|
+
- **Args**:
|
547
|
+
- `agent_count` (dict[Type[Agent], int]): Dictionary mapping agent classes to the number of instances to create.
|
548
|
+
- `group_size` (int, optional): Number of agents per group, each group runs as an independent Ray actor. Defaults to 10000.
|
549
|
+
- `pg_sql_writers` (int, optional): Number of independent PgSQL writer processes. Defaults to 32.
|
550
|
+
- `message_interceptors` (int, optional): Number of message interceptor processes. Defaults to 1.
|
551
|
+
- `message_interceptor_blocks` (Optional[List[MessageBlockBase]], optional): List of message interception blocks. Defaults to None.
|
552
|
+
- `social_black_list` (Optional[List[Tuple[str, str]]], optional): List of tuples representing pairs of agents that should not communicate. Defaults to None.
|
553
|
+
- `message_listener` (Optional[MessageBlockListenerBase], optional): Listener for intercepted messages. Defaults to None.
|
554
|
+
- `embedding_model` (Embeddings, optional): Model used for generating embeddings for agents' memories. Defaults to SimpleEmbedding().
|
555
|
+
- `memory_config_init_func` (Optional[Callable], optional): Initialization function for setting up memory configuration. Defaults to None.
|
556
|
+
- `memory_config_func` (Optional[Dict[Type[Agent], Callable]], optional): Dictionary mapping agent classes to their memory configuration functions. Defaults to None.
|
557
|
+
- `environment` (Optional[Dict[str, str]], optional): Environment variables to update in the simulation. Defaults to None.
|
558
|
+
|
559
|
+
- **Raises**:
|
560
|
+
- `ValueError`: If the lengths of `agent_class` and `agent_count` do not match.
|
561
|
+
|
562
|
+
- **Returns**:
|
563
|
+
- `None`
|
500
564
|
"""
|
501
565
|
self.agent_count = agent_count
|
502
566
|
|
@@ -516,8 +580,6 @@ class AgentSimulation:
|
|
516
580
|
citizen_params = []
|
517
581
|
|
518
582
|
# 收集所有参数
|
519
|
-
print(self.agent_class)
|
520
|
-
print(agent_count)
|
521
583
|
for i in range(len(self.agent_class)):
|
522
584
|
agent_class = self.agent_class[i]
|
523
585
|
agent_count_i = agent_count[agent_class]
|
@@ -721,10 +783,32 @@ class AgentSimulation:
|
|
721
783
|
)
|
722
784
|
await self.messager.start_listening.remote() # type:ignore
|
723
785
|
|
786
|
+
agent_ids = set()
|
787
|
+
org_ids = set()
|
788
|
+
for group in self._groups.values():
|
789
|
+
ids = await group.get_economy_ids.remote()
|
790
|
+
agent_ids.update(ids[0])
|
791
|
+
org_ids.update(ids[1])
|
792
|
+
await self.economy_client.set_ids(agent_ids, org_ids)
|
793
|
+
for group in self._groups.values():
|
794
|
+
await group.set_economy_ids.remote(agent_ids, org_ids)
|
795
|
+
|
724
796
|
async def gather(
|
725
797
|
self, content: str, target_agent_uuids: Optional[list[str]] = None
|
726
798
|
):
|
727
|
-
"""
|
799
|
+
"""
|
800
|
+
Collect specific information from agents.
|
801
|
+
|
802
|
+
- **Description**:
|
803
|
+
- Asynchronously gathers specified content from targeted agents within all groups.
|
804
|
+
|
805
|
+
- **Args**:
|
806
|
+
- `content` (str): The information to collect from the agents.
|
807
|
+
- `target_agent_uuids` (Optional[List[str]], optional): A list of agent UUIDs to target. Defaults to None, meaning all agents are targeted.
|
808
|
+
|
809
|
+
- **Returns**:
|
810
|
+
- Result of the gathering process as returned by each group's `gather` method.
|
811
|
+
"""
|
728
812
|
gather_tasks = []
|
729
813
|
for group in self._groups.values():
|
730
814
|
gather_tasks.append(group.gather.remote(content, target_agent_uuids))
|
@@ -736,7 +820,20 @@ class AgentSimulation:
|
|
736
820
|
keys: Optional[list[str]] = None,
|
737
821
|
values: Optional[list[Any]] = None,
|
738
822
|
) -> list[str]:
|
739
|
-
"""
|
823
|
+
"""
|
824
|
+
Filter out agents of specified types or with matching key-value pairs.
|
825
|
+
|
826
|
+
- **Args**:
|
827
|
+
- `types` (Optional[List[Type[Agent]]], optional): Types of agents to filter for. Defaults to None.
|
828
|
+
- `keys` (Optional[List[str]], optional): Keys to match in agent attributes. Defaults to None.
|
829
|
+
- `values` (Optional[List[Any]], optional): Values corresponding to keys for matching. Defaults to None.
|
830
|
+
|
831
|
+
- **Raises**:
|
832
|
+
- `ValueError`: If neither types nor keys and values are provided, or if the lengths of keys and values do not match.
|
833
|
+
|
834
|
+
- **Returns**:
|
835
|
+
- `List[str]`: A list of filtered agent UUIDs.
|
836
|
+
"""
|
740
837
|
if not types and not keys and not values:
|
741
838
|
return self._agent_uuids
|
742
839
|
group_to_filter = []
|
@@ -759,12 +856,26 @@ class AgentSimulation:
|
|
759
856
|
return filtered_uuids
|
760
857
|
|
761
858
|
async def update_environment(self, key: str, value: str):
|
859
|
+
"""
|
860
|
+
Update the environment variables for the simulation and all agent groups.
|
861
|
+
|
862
|
+
- **Args**:
|
863
|
+
- `key` (str): The environment variable key to update.
|
864
|
+
- `value` (str): The new value for the environment variable.
|
865
|
+
"""
|
762
866
|
self._simulator.update_environment(key, value)
|
763
867
|
for group in self._groups.values():
|
764
868
|
await group.update_environment.remote(key, value)
|
765
869
|
|
766
870
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
767
|
-
"""
|
871
|
+
"""
|
872
|
+
Update the memory of a specified agent.
|
873
|
+
|
874
|
+
- **Args**:
|
875
|
+
- `target_agent_uuid` (str): The UUID of the target agent to update.
|
876
|
+
- `target_key` (str): The key in the agent's memory to update.
|
877
|
+
- `content` (Any): The new content to set for the target key.
|
878
|
+
"""
|
768
879
|
group = self._agent_uuid2group[target_agent_uuid]
|
769
880
|
await group.update.remote(target_agent_uuid, target_key, content)
|
770
881
|
|
@@ -775,13 +886,30 @@ class AgentSimulation:
|
|
775
886
|
content: Any,
|
776
887
|
mode: Literal["replace", "merge"] = "replace",
|
777
888
|
):
|
778
|
-
"""
|
889
|
+
"""
|
890
|
+
Update economic data for a specified agent.
|
891
|
+
|
892
|
+
- **Args**:
|
893
|
+
- `target_agent_id` (int): The ID of the target agent whose economic data to update.
|
894
|
+
- `target_key` (str): The key in the agent's economic data to update.
|
895
|
+
- `content` (Any): The new content to set for the target key.
|
896
|
+
- `mode` (Literal["replace", "merge"], optional): Mode of updating the economic data. Defaults to "replace".
|
897
|
+
"""
|
779
898
|
await self.economy_client.update(
|
780
899
|
id=target_agent_id, key=target_key, value=content, mode=mode
|
781
900
|
)
|
782
901
|
|
783
902
|
async def send_survey(self, survey: Survey, agent_uuids: list[str] = []):
|
784
|
-
"""
|
903
|
+
"""
|
904
|
+
Send a survey to specified agents.
|
905
|
+
|
906
|
+
- **Args**:
|
907
|
+
- `survey` (Survey): The survey object to send.
|
908
|
+
- `agent_uuids` (List[str], optional): List of agent UUIDs to receive the survey. Defaults to an empty list.
|
909
|
+
|
910
|
+
- **Returns**:
|
911
|
+
- None
|
912
|
+
"""
|
785
913
|
survey_dict = survey.to_dict()
|
786
914
|
_date_time = datetime.now(timezone.utc)
|
787
915
|
payload = {
|
@@ -806,7 +934,16 @@ class AgentSimulation:
|
|
806
934
|
async def send_interview_message(
|
807
935
|
self, content: str, agent_uuids: Union[str, list[str]]
|
808
936
|
):
|
809
|
-
"""
|
937
|
+
"""
|
938
|
+
Send an interview message to specified agents.
|
939
|
+
|
940
|
+
- **Args**:
|
941
|
+
- `content` (str): The content of the message to send.
|
942
|
+
- `agent_uuids` (Union[str, List[str]]): A single UUID string or a list of UUID strings for the agents to receive the message.
|
943
|
+
|
944
|
+
- **Returns**:
|
945
|
+
- None
|
946
|
+
"""
|
810
947
|
_date_time = datetime.now(timezone.utc)
|
811
948
|
payload = {
|
812
949
|
"from": SURVEY_SENDER_UUID,
|
@@ -829,12 +966,37 @@ class AgentSimulation:
|
|
829
966
|
await asyncio.sleep(3)
|
830
967
|
|
831
968
|
async def extract_metric(self, metric_extractors: list[Callable]):
|
832
|
-
"""
|
969
|
+
"""
|
970
|
+
Extract metrics using provided extractors.
|
971
|
+
|
972
|
+
- **Description**:
|
973
|
+
- Asynchronously applies each metric extractor function to the simulation to collect various metrics.
|
974
|
+
|
975
|
+
- **Args**:
|
976
|
+
- `metric_extractors` (List[Callable]): A list of callable functions that take the simulation instance as an argument and return a metric or perform some form of analysis.
|
977
|
+
|
978
|
+
- **Returns**:
|
979
|
+
- None
|
980
|
+
"""
|
833
981
|
for metric_extractor in metric_extractors:
|
834
982
|
await metric_extractor(self)
|
835
983
|
|
836
984
|
async def step(self):
|
837
|
-
"""
|
985
|
+
"""
|
986
|
+
Execute one step of the simulation where each agent performs its forward action.
|
987
|
+
|
988
|
+
- **Description**:
|
989
|
+
- Checks if new agents need to be inserted based on the current day of the simulation. If so, it inserts them.
|
990
|
+
- Executes the forward method for each agent group to advance the simulation by one step.
|
991
|
+
- Saves the state of all agent groups after the step has been completed.
|
992
|
+
- Optionally extracts metrics if the current step matches the interval specified for any metric extractors.
|
993
|
+
|
994
|
+
- **Raises**:
|
995
|
+
- `RuntimeError`: If there is an error during the execution of the step, it logs the error and rethrows it as a RuntimeError.
|
996
|
+
|
997
|
+
- **Returns**:
|
998
|
+
- None
|
999
|
+
"""
|
838
1000
|
try:
|
839
1001
|
# check whether insert agents
|
840
1002
|
simulator_day = await self._simulator.get_simulator_day()
|
@@ -852,11 +1014,21 @@ class AgentSimulation:
|
|
852
1014
|
# step
|
853
1015
|
simulator_day = await self._simulator.get_simulator_day()
|
854
1016
|
simulator_time = int(await self._simulator.get_time())
|
855
|
-
logger.info(
|
1017
|
+
logger.info(
|
1018
|
+
f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}"
|
1019
|
+
)
|
856
1020
|
tasks = []
|
857
1021
|
for group in self._groups.values():
|
858
1022
|
tasks.append(group.step.remote())
|
859
|
-
await asyncio.gather(*tasks)
|
1023
|
+
log_messages_groups = await asyncio.gather(*tasks)
|
1024
|
+
llm_log_list = []
|
1025
|
+
mqtt_log_list = []
|
1026
|
+
simulator_log_list = []
|
1027
|
+
for log_messages_group in log_messages_groups:
|
1028
|
+
llm_log_list.extend(log_messages_group['llm_log'])
|
1029
|
+
mqtt_log_list.extend(log_messages_group['mqtt_log'])
|
1030
|
+
simulator_log_list.extend(log_messages_group['simulator_log'])
|
1031
|
+
|
860
1032
|
# save
|
861
1033
|
simulator_day = await self._simulator.get_simulator_day()
|
862
1034
|
simulator_time = int(await self._simulator.get_time())
|
@@ -865,14 +1037,15 @@ class AgentSimulation:
|
|
865
1037
|
save_tasks.append(group.save.remote(simulator_day, simulator_time))
|
866
1038
|
await asyncio.gather(*save_tasks)
|
867
1039
|
self._total_steps += 1
|
868
|
-
if self.
|
869
|
-
print(f"total_steps: {self._total_steps}, excute metric")
|
1040
|
+
if self.metric_extractors is not None: # type:ignore
|
870
1041
|
to_excute_metric = [
|
871
1042
|
metric[1]
|
872
|
-
for metric in self.
|
1043
|
+
for metric in self.metric_extractors # type:ignore
|
873
1044
|
if self._total_steps % metric[0] == 0
|
874
1045
|
]
|
875
1046
|
await self.extract_metric(to_excute_metric)
|
1047
|
+
|
1048
|
+
return llm_log_list, mqtt_log_list, simulator_log_list
|
876
1049
|
except Exception as e:
|
877
1050
|
import traceback
|
878
1051
|
|
@@ -883,7 +1056,26 @@ class AgentSimulation:
|
|
883
1056
|
self,
|
884
1057
|
day: int = 1,
|
885
1058
|
):
|
886
|
-
"""
|
1059
|
+
"""
|
1060
|
+
Run the simulation for a specified number of days.
|
1061
|
+
|
1062
|
+
- **Args**:
|
1063
|
+
- `day` (int, optional): The number of days to run the simulation. Defaults to 1.
|
1064
|
+
|
1065
|
+
- **Description**:
|
1066
|
+
- Updates the experiment status to running and sets up monitoring for the experiment's status.
|
1067
|
+
- Runs the simulation loop until the end time, which is calculated based on the current time and the number of days to simulate.
|
1068
|
+
- After completing the simulation, updates the experiment status to finished, or to failed if an exception occurs.
|
1069
|
+
|
1070
|
+
- **Raises**:
|
1071
|
+
- `RuntimeError`: If there is an error during the simulation, it logs the error and updates the experiment status to failed before rethrowing the exception.
|
1072
|
+
|
1073
|
+
- **Returns**:
|
1074
|
+
- None
|
1075
|
+
"""
|
1076
|
+
llm_log_lists = []
|
1077
|
+
mqtt_log_lists = []
|
1078
|
+
simulator_log_lists = []
|
887
1079
|
try:
|
888
1080
|
self._exp_info["num_day"] += day
|
889
1081
|
await self._update_exp_status(1) # 更新状态为运行中
|
@@ -901,7 +1093,10 @@ class AgentSimulation:
|
|
901
1093
|
current_time = await self._simulator.get_time()
|
902
1094
|
if current_time >= end_time: # type:ignore
|
903
1095
|
break
|
904
|
-
await self.step()
|
1096
|
+
llm_log_list, mqtt_log_list, simulator_log_list = await self.step()
|
1097
|
+
llm_log_lists.extend(llm_log_list)
|
1098
|
+
mqtt_log_lists.extend(mqtt_log_list)
|
1099
|
+
simulator_log_lists.extend(simulator_log_list)
|
905
1100
|
finally:
|
906
1101
|
# 设置停止事件
|
907
1102
|
stop_event.set()
|
@@ -910,7 +1105,7 @@ class AgentSimulation:
|
|
910
1105
|
|
911
1106
|
# 运行成功后更新状态
|
912
1107
|
await self._update_exp_status(2)
|
913
|
-
|
1108
|
+
return llm_log_lists, mqtt_log_lists, simulator_log_lists
|
914
1109
|
except Exception as e:
|
915
1110
|
error_msg = f"模拟器运行错误: {str(e)}"
|
916
1111
|
logger.error(error_msg)
|
pycityagent/survey/manager.py
CHANGED
@@ -8,10 +8,33 @@ from .models import Page, Question, QuestionType, Survey
|
|
8
8
|
|
9
9
|
class SurveyManager:
|
10
10
|
def __init__(self):
|
11
|
+
"""
|
12
|
+
Initializes a new instance of the SurveyManager class.
|
13
|
+
|
14
|
+
- **Description**:
|
15
|
+
- Manages the creation and retrieval of surveys. Uses an internal dictionary to store survey instances by their unique identifier.
|
16
|
+
|
17
|
+
- **Attributes**:
|
18
|
+
- `_surveys` (dict[str, Survey]): A dictionary mapping survey IDs to `Survey` instances.
|
19
|
+
"""
|
11
20
|
self._surveys: dict[str, Survey] = {}
|
12
21
|
|
13
22
|
def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
|
14
|
-
"""
|
23
|
+
"""
|
24
|
+
Create a new survey with specified title, description, and pages containing questions.
|
25
|
+
|
26
|
+
- **Description**:
|
27
|
+
- Generates a unique ID for the survey, converts the provided page and question data into `Page` and `Question` objects,
|
28
|
+
and adds the created survey to the internal storage.
|
29
|
+
|
30
|
+
- **Args**:
|
31
|
+
- `title` (str): The title of the survey.
|
32
|
+
- `description` (str): A brief description of what the survey is about.
|
33
|
+
- `pages` (list[dict]): A list of dictionaries where each dictionary contains page information including elements which are question definitions.
|
34
|
+
|
35
|
+
- **Returns**:
|
36
|
+
- `Survey`: An instance of the `Survey` class representing the newly created survey.
|
37
|
+
"""
|
15
38
|
survey_id = uuid.uuid4()
|
16
39
|
|
17
40
|
# 转换页面和问题数据
|
@@ -46,9 +69,28 @@ class SurveyManager:
|
|
46
69
|
return survey
|
47
70
|
|
48
71
|
def get_survey(self, survey_id: str) -> Optional[Survey]:
|
49
|
-
"""
|
72
|
+
"""
|
73
|
+
Retrieve a specific survey by its unique identifier.
|
74
|
+
|
75
|
+
- **Description**:
|
76
|
+
- Searches for a survey within the internal storage using the provided survey ID.
|
77
|
+
|
78
|
+
- **Args**:
|
79
|
+
- `survey_id` (str): The unique identifier of the survey to retrieve.
|
80
|
+
|
81
|
+
- **Returns**:
|
82
|
+
- `Optional[Survey]`: An instance of the `Survey` class if found; otherwise, None.
|
83
|
+
"""
|
50
84
|
return self._surveys.get(survey_id)
|
51
85
|
|
52
86
|
def get_all_surveys(self) -> list[Survey]:
|
53
|
-
"""
|
87
|
+
"""
|
88
|
+
Get all surveys that have been created.
|
89
|
+
|
90
|
+
- **Description**:
|
91
|
+
- Retrieves all stored surveys from the internal storage.
|
92
|
+
|
93
|
+
- **Returns**:
|
94
|
+
- `list[Survey]`: A list of `Survey` instances.
|
95
|
+
"""
|
54
96
|
return list(self._surveys.values())
|
pycityagent/survey/models.py
CHANGED
@@ -57,6 +57,18 @@ class Page:
|
|
57
57
|
|
58
58
|
@dataclass
|
59
59
|
class Survey:
|
60
|
+
"""
|
61
|
+
Represents a survey with metadata and associated pages containing questions.
|
62
|
+
|
63
|
+
- **Attributes**:
|
64
|
+
- `id` (uuid.UUID): Unique identifier for the survey.
|
65
|
+
- `title` (str): Title of the survey.
|
66
|
+
- `description` (str): Description of the survey's purpose or content.
|
67
|
+
- `pages` (list[Page]): A list of `Page` objects, each containing a set of questions.
|
68
|
+
- `responses` (dict[str, dict], optional): Dictionary mapping response IDs to their data. Defaults to an empty dictionary.
|
69
|
+
- `created_at` (datetime, optional): Timestamp of when the survey was created. Defaults to the current time.
|
70
|
+
"""
|
71
|
+
|
60
72
|
id: uuid.UUID
|
61
73
|
title: str
|
62
74
|
description: str
|
@@ -65,6 +77,15 @@ class Survey:
|
|
65
77
|
created_at: datetime = field(default_factory=datetime.now)
|
66
78
|
|
67
79
|
def to_dict(self) -> dict:
|
80
|
+
"""
|
81
|
+
Convert the survey instance into a dictionary representation.
|
82
|
+
|
83
|
+
- **Description**:
|
84
|
+
- Creates a dictionary containing the survey's ID, title, description, and pages in a simplified format suitable for serialization.
|
85
|
+
|
86
|
+
- **Returns**:
|
87
|
+
- `dict`: Simplified dictionary representation of the survey.
|
88
|
+
"""
|
68
89
|
return {
|
69
90
|
"id": str(self.id),
|
70
91
|
"title": self.title,
|
@@ -74,7 +95,15 @@ class Survey:
|
|
74
95
|
}
|
75
96
|
|
76
97
|
def to_json(self) -> str:
|
77
|
-
"""
|
98
|
+
"""
|
99
|
+
Serialize the survey instance to a JSON string for MQTT transmission.
|
100
|
+
|
101
|
+
- **Description**:
|
102
|
+
- Converts the survey into a JSON string that includes all necessary information for reconstructing the survey object on another system.
|
103
|
+
|
104
|
+
- **Returns**:
|
105
|
+
- `str`: JSON string representing the survey.
|
106
|
+
"""
|
78
107
|
survey_dict = {
|
79
108
|
"id": str(self.id),
|
80
109
|
"title": self.title,
|
@@ -87,7 +116,18 @@ class Survey:
|
|
87
116
|
|
88
117
|
@classmethod
|
89
118
|
def from_json(cls, json_str: str) -> "Survey":
|
90
|
-
"""
|
119
|
+
"""
|
120
|
+
Deserialize a JSON string into a new Survey instance.
|
121
|
+
|
122
|
+
- **Description**:
|
123
|
+
- Parses a JSON string into a Python dictionary and uses it to create a new `Survey` instance.
|
124
|
+
|
125
|
+
- **Args**:
|
126
|
+
- `json_str` (str): JSON string representation of a survey.
|
127
|
+
|
128
|
+
- **Returns**:
|
129
|
+
- `Survey`: An instance of `Survey` initialized with the data from the JSON string.
|
130
|
+
"""
|
91
131
|
data = json.loads(json_str)
|
92
132
|
pages = [
|
93
133
|
Page(
|
pycityagent/tools/__init__.py
CHANGED