pycityagent 2.0.0a65__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a67__cp39-cp39-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +157 -57
- pycityagent/agent/agent_base.py +316 -43
- pycityagent/cityagent/bankagent.py +49 -9
- pycityagent/cityagent/blocks/__init__.py +1 -2
- pycityagent/cityagent/blocks/cognition_block.py +54 -31
- pycityagent/cityagent/blocks/dispatcher.py +22 -17
- pycityagent/cityagent/blocks/economy_block.py +46 -32
- pycityagent/cityagent/blocks/mobility_block.py +209 -105
- pycityagent/cityagent/blocks/needs_block.py +101 -54
- pycityagent/cityagent/blocks/other_block.py +42 -33
- pycityagent/cityagent/blocks/plan_block.py +59 -42
- pycityagent/cityagent/blocks/social_block.py +167 -126
- pycityagent/cityagent/blocks/utils.py +13 -6
- pycityagent/cityagent/firmagent.py +17 -35
- pycityagent/cityagent/governmentagent.py +3 -3
- pycityagent/cityagent/initial.py +79 -49
- pycityagent/cityagent/memory_config.py +123 -94
- pycityagent/cityagent/message_intercept.py +0 -4
- pycityagent/cityagent/metrics.py +41 -0
- pycityagent/cityagent/nbsagent.py +24 -36
- pycityagent/cityagent/societyagent.py +9 -4
- pycityagent/cli/wrapper.py +2 -2
- pycityagent/economy/econ_client.py +407 -81
- pycityagent/environment/__init__.py +0 -3
- pycityagent/environment/sim/__init__.py +0 -3
- pycityagent/environment/sim/aoi_service.py +2 -2
- pycityagent/environment/sim/client.py +3 -31
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/lane_service.py +8 -8
- pycityagent/environment/sim/light_service.py +8 -8
- pycityagent/environment/sim/pause_service.py +9 -10
- pycityagent/environment/sim/person_service.py +20 -20
- pycityagent/environment/sim/road_service.py +2 -2
- pycityagent/environment/sim/sim_env.py +21 -5
- pycityagent/environment/sim/social_service.py +4 -4
- pycityagent/environment/simulator.py +249 -27
- pycityagent/environment/utils/__init__.py +2 -2
- pycityagent/environment/utils/geojson.py +2 -2
- pycityagent/environment/utils/grpc.py +4 -4
- pycityagent/environment/utils/map_utils.py +2 -2
- pycityagent/llm/embeddings.py +147 -28
- pycityagent/llm/llm.py +178 -111
- pycityagent/llm/llmconfig.py +5 -0
- pycityagent/llm/utils.py +4 -0
- pycityagent/memory/__init__.py +0 -4
- pycityagent/memory/const.py +2 -2
- pycityagent/memory/faiss_query.py +140 -61
- pycityagent/memory/memory.py +394 -91
- pycityagent/memory/memory_base.py +140 -34
- pycityagent/memory/profile.py +13 -13
- pycityagent/memory/self_define.py +13 -13
- pycityagent/memory/state.py +14 -14
- pycityagent/message/message_interceptor.py +253 -3
- pycityagent/message/messager.py +133 -6
- pycityagent/metrics/mlflow_client.py +47 -4
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +3 -2
- pycityagent/simulation/agentgroup.py +150 -54
- pycityagent/simulation/simulation.py +276 -66
- pycityagent/survey/manager.py +45 -3
- pycityagent/survey/models.py +42 -2
- pycityagent/tools/__init__.py +1 -2
- pycityagent/tools/tool.py +93 -69
- pycityagent/utils/avro_schema.py +2 -2
- pycityagent/utils/parsers/code_block_parser.py +1 -1
- pycityagent/utils/parsers/json_parser.py +2 -2
- pycityagent/utils/parsers/parser_base.py +2 -2
- pycityagent/workflow/block.py +64 -13
- pycityagent/workflow/prompt.py +31 -23
- pycityagent/workflow/trigger.py +91 -24
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/METADATA +2 -2
- pycityagent-2.0.0a67.dist-info/RECORD +97 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +0 -198
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +0 -416
- pycityagent/environment/sidecar/__init__.py +0 -8
- pycityagent/environment/sidecar/sidecarv2.py +0 -109
- pycityagent/environment/sim/economy_services.py +0 -192
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent-2.0.0a65.dist-info/RECORD +0 -105
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a65.dist-info → pycityagent-2.0.0a67.dist-info}/top_level.txt +0 -0
@@ -13,19 +13,30 @@ import yaml
|
|
13
13
|
from langchain_core.embeddings import Embeddings
|
14
14
|
|
15
15
|
from ..agent import Agent, InstitutionAgent
|
16
|
-
from ..cityagent import
|
17
|
-
|
18
|
-
|
19
|
-
|
16
|
+
from ..cityagent import BankAgent, FirmAgent, GovernmentAgent, NBSAgent, SocietyAgent
|
17
|
+
from ..cityagent.memory_config import (
|
18
|
+
memory_config_bank,
|
19
|
+
memory_config_firm,
|
20
|
+
memory_config_government,
|
21
|
+
memory_config_nbs,
|
22
|
+
memory_config_societyagent,
|
23
|
+
memory_config_init,
|
24
|
+
)
|
20
25
|
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
21
|
-
from ..cityagent.message_intercept import (
|
22
|
-
|
23
|
-
|
26
|
+
from ..cityagent.message_intercept import (
|
27
|
+
EdgeMessageBlock,
|
28
|
+
MessageBlockListener,
|
29
|
+
PointMessageBlock,
|
30
|
+
)
|
24
31
|
from ..economy.econ_client import EconomyClient
|
25
32
|
from ..environment import Simulator
|
26
33
|
from ..llm import SimpleEmbedding
|
27
|
-
from ..message import (
|
28
|
-
|
34
|
+
from ..message import (
|
35
|
+
MessageBlockBase,
|
36
|
+
MessageBlockListenerBase,
|
37
|
+
MessageInterceptor,
|
38
|
+
Messager,
|
39
|
+
)
|
29
40
|
from ..metrics import init_mlflow_connection
|
30
41
|
from ..metrics.mlflow_client import MlflowClient
|
31
42
|
from ..survey import Survey
|
@@ -36,26 +47,55 @@ from .storage.pg import PgWriter, create_pg_tables
|
|
36
47
|
logger = logging.getLogger("pycityagent")
|
37
48
|
|
38
49
|
|
50
|
+
__all__ = ["AgentSimulation"]
|
51
|
+
|
39
52
|
class AgentSimulation:
|
40
|
-
"""
|
53
|
+
"""
|
54
|
+
A class to simulate a multi-agent system.
|
55
|
+
|
56
|
+
This simulation framework is designed to facilitate the creation and management of multiple agent types within an experiment.
|
57
|
+
It allows for the configuration of different agents, memory configurations, and metric extractors, as well as enabling institutional settings.
|
58
|
+
|
59
|
+
Attributes:
|
60
|
+
exp_id (str): A unique identifier for the current experiment.
|
61
|
+
agent_class (List[Type[Agent]]): A list of agent classes that will be instantiated in the simulation.
|
62
|
+
agent_config_file (Optional[dict]): Configuration file or dictionary for initializing agents.
|
63
|
+
logging_level (int): The level of logging to be used throughout the simulation.
|
64
|
+
default_memory_config_func (Dict[Type[Agent], Callable]): Dictionary mapping agent classes to their respective memory configuration functions.
|
65
|
+
"""
|
41
66
|
|
42
67
|
def __init__(
|
43
68
|
self,
|
44
69
|
config: dict,
|
45
70
|
agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
|
46
71
|
agent_config_file: Optional[dict] = None,
|
47
|
-
|
72
|
+
metric_extractors: Optional[list[tuple[int, Callable]]] = None,
|
48
73
|
enable_institution: bool = True,
|
49
74
|
agent_prefix: str = "agent_",
|
50
75
|
exp_name: str = "default_experiment",
|
51
76
|
logging_level: int = logging.WARNING,
|
52
77
|
):
|
53
78
|
"""
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
79
|
+
Initializes the AgentSimulation with the given parameters.
|
80
|
+
|
81
|
+
- **Description**:
|
82
|
+
- Sets up the simulation environment based on the provided configuration. Depending on the `enable_institution` flag,
|
83
|
+
it can include a predefined set of institutional agents. If specific agent classes are provided, those will be used instead.
|
84
|
+
|
85
|
+
- **Args**:
|
86
|
+
- `config` (dict): The main configuration dictionary for the simulation.
|
87
|
+
- `agent_class` (Union[None, Type[Agent], List[Type[Agent]]], optional):
|
88
|
+
Either a single agent class or a list of agent classes to instantiate. Defaults to None, which implies a default set of agents.
|
89
|
+
- `agent_config_file` (Optional[dict], optional): An optional configuration file or dictionary used to initialize agents. Defaults to None.
|
90
|
+
- `metric_extractors` (Optional[List[Tuple[int, Callable]]], optional):
|
91
|
+
A list of tuples containing intervals and callables for extracting metrics from the simulation. Defaults to None.
|
92
|
+
- `enable_institution` (bool, optional): Flag indicating whether institutional agents should be included in the simulation. Defaults to True.
|
93
|
+
- `agent_prefix` (str, optional): Prefix string for naming agents. Defaults to "agent_".
|
94
|
+
- `exp_name` (str, optional): The name of the experiment. Defaults to "default_experiment".
|
95
|
+
- `logging_level` (int, optional): Logging level to set for the simulation's logger. Defaults to logging.WARNING.
|
96
|
+
|
97
|
+
- **Returns**:
|
98
|
+
- None
|
59
99
|
"""
|
60
100
|
self.exp_id = str(uuid.uuid4())
|
61
101
|
if isinstance(agent_class, list):
|
@@ -166,11 +206,11 @@ class AgentSimulation:
|
|
166
206
|
experiment_name=exp_name,
|
167
207
|
run_id=mlflow_run_id,
|
168
208
|
)
|
169
|
-
self.
|
209
|
+
self.metric_extractors = metric_extractors
|
170
210
|
else:
|
171
211
|
logger.warning("Mlflow is not enabled, NO MLFLOW STORAGE")
|
172
212
|
self.mlflow_client = None
|
173
|
-
self.
|
213
|
+
self.metric_extractors = None
|
174
214
|
|
175
215
|
# pg
|
176
216
|
_pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
|
@@ -216,8 +256,9 @@ class AgentSimulation:
|
|
216
256
|
- enable_institution: bool, default is True
|
217
257
|
- agent_config:
|
218
258
|
- agent_config_file: Optional[dict[type[Agent], str]]
|
259
|
+
- memory_config_init_func: Optional[Callable]
|
219
260
|
- memory_config_func: Optional[dict[type[Agent], Callable]]
|
220
|
-
-
|
261
|
+
- metric_extractors: Optional[list[tuple[int, Callable]]]
|
221
262
|
- init_func: Optional[list[Callable[AgentSimulation, None]]]
|
222
263
|
- group_size: Optional[int]
|
223
264
|
- embedding_model: Optional[EmbeddingModel]
|
@@ -258,28 +299,33 @@ class AgentSimulation:
|
|
258
299
|
simulation = cls(
|
259
300
|
config=simulation_config,
|
260
301
|
agent_config_file=config["agent_config"].get("agent_config_file", None),
|
261
|
-
|
302
|
+
metric_extractors=config["agent_config"].get("metric_extractors", None),
|
262
303
|
enable_institution=config.get("enable_institution", True),
|
263
304
|
exp_name=config.get("exp_name", "default_experiment"),
|
264
305
|
logging_level=config.get("logging_level", logging.WARNING),
|
265
306
|
)
|
266
307
|
environment = config.get(
|
267
|
-
"environment",
|
308
|
+
"environment",
|
268
309
|
{
|
269
|
-
"weather": "The weather is normal",
|
270
|
-
"crime": "The crime rate is low",
|
271
|
-
"pollution": "The pollution level is low",
|
272
|
-
"temperature": "The temperature is normal"
|
273
|
-
|
310
|
+
"weather": "The weather is normal",
|
311
|
+
"crime": "The crime rate is low",
|
312
|
+
"pollution": "The pollution level is low",
|
313
|
+
"temperature": "The temperature is normal",
|
314
|
+
"day": "Workday"
|
315
|
+
},
|
274
316
|
)
|
275
317
|
simulation._simulator.set_environment(environment)
|
276
318
|
logger.info("Initializing Agents...")
|
277
|
-
agent_count =
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
319
|
+
agent_count = {
|
320
|
+
SocietyAgent: config["agent_config"].get("number_of_citizen", 0),
|
321
|
+
FirmAgent: config["agent_config"].get("number_of_firm", 0),
|
322
|
+
GovernmentAgent: config["agent_config"].get("number_of_government", 0),
|
323
|
+
BankAgent: config["agent_config"].get("number_of_bank", 0),
|
324
|
+
NBSAgent: config["agent_config"].get("number_of_nbs", 0),
|
325
|
+
}
|
326
|
+
if agent_count.get(SocietyAgent, 0) == 0:
|
327
|
+
raise ValueError("number_of_citizen is required")
|
328
|
+
|
283
329
|
# support MessageInterceptor
|
284
330
|
if "message_intercept" in config:
|
285
331
|
_intercept_config = config["message_intercept"]
|
@@ -319,6 +365,9 @@ class AgentSimulation:
|
|
319
365
|
"embedding_model", SimpleEmbedding()
|
320
366
|
),
|
321
367
|
memory_config_func=config["agent_config"].get("memory_config_func", None),
|
368
|
+
memory_config_init_func=config["agent_config"].get(
|
369
|
+
"memory_config_init_func", None
|
370
|
+
),
|
322
371
|
**_message_intercept_kwargs,
|
323
372
|
environment=environment,
|
324
373
|
)
|
@@ -328,6 +377,9 @@ class AgentSimulation:
|
|
328
377
|
):
|
329
378
|
await init_func(simulation)
|
330
379
|
logger.info("Starting Simulation...")
|
380
|
+
llm_log_lists = []
|
381
|
+
mqtt_log_lists = []
|
382
|
+
simulator_log_lists = []
|
331
383
|
for step in config["workflow"]:
|
332
384
|
logger.info(
|
333
385
|
f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
|
@@ -335,11 +387,17 @@ class AgentSimulation:
|
|
335
387
|
if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
|
336
388
|
raise ValueError(f"Invalid step type: {step['type']}")
|
337
389
|
if step["type"] == "run":
|
338
|
-
await simulation.run(step.get("days", 1))
|
390
|
+
llm_log_list, mqtt_log_list, simulator_log_list = await simulation.run(step.get("days", 1))
|
391
|
+
llm_log_lists.extend(llm_log_list)
|
392
|
+
mqtt_log_lists.extend(mqtt_log_list)
|
393
|
+
simulator_log_lists.extend(simulator_log_list)
|
339
394
|
elif step["type"] == "step":
|
340
395
|
times = step.get("times", 1)
|
341
396
|
for _ in range(times):
|
342
|
-
await simulation.step()
|
397
|
+
llm_log_list, mqtt_log_list, simulator_log_list = await simulation.step()
|
398
|
+
llm_log_lists.extend(llm_log_list)
|
399
|
+
mqtt_log_lists.extend(mqtt_log_list)
|
400
|
+
simulator_log_lists.extend(simulator_log_list)
|
343
401
|
elif step["type"] == "pause":
|
344
402
|
await simulation.pause_simulator()
|
345
403
|
elif step["type"] == "resume":
|
@@ -347,7 +405,7 @@ class AgentSimulation:
|
|
347
405
|
else:
|
348
406
|
await step["func"](simulation)
|
349
407
|
logger.info("Simulation finished")
|
350
|
-
|
408
|
+
return llm_log_lists, mqtt_log_lists, simulator_log_lists
|
351
409
|
@property
|
352
410
|
def enable_avro(
|
353
411
|
self,
|
@@ -423,7 +481,7 @@ class AgentSimulation:
|
|
423
481
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
424
482
|
"""监控实验状态并更新
|
425
483
|
|
426
|
-
Args
|
484
|
+
- **Args**:
|
427
485
|
stop_event: 用于通知监控任务停止的事件
|
428
486
|
"""
|
429
487
|
try:
|
@@ -465,7 +523,7 @@ class AgentSimulation:
|
|
465
523
|
|
466
524
|
async def init_agents(
|
467
525
|
self,
|
468
|
-
agent_count:
|
526
|
+
agent_count: dict[type[Agent], int],
|
469
527
|
group_size: int = 10000,
|
470
528
|
pg_sql_writers: int = 32,
|
471
529
|
message_interceptors: int = 1,
|
@@ -473,25 +531,44 @@ class AgentSimulation:
|
|
473
531
|
social_black_list: Optional[list[tuple[str, str]]] = None,
|
474
532
|
message_listener: Optional[MessageBlockListenerBase] = None,
|
475
533
|
embedding_model: Embeddings = SimpleEmbedding(),
|
534
|
+
memory_config_init_func: Optional[Callable] = None,
|
476
535
|
memory_config_func: Optional[dict[type[Agent], Callable]] = None,
|
477
536
|
environment: Optional[dict[str, str]] = None,
|
478
537
|
) -> None:
|
479
|
-
"""初始化智能体
|
480
|
-
|
481
|
-
Args:
|
482
|
-
agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
|
483
|
-
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
484
|
-
pg_sql_writers: 独立的PgSQL writer数量
|
485
|
-
message_interceptors: message拦截器数量
|
486
|
-
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 每个元素表示一个智能体类创建的Memory配置函数
|
487
|
-
environment: 环境变量,用于更新模拟器的环境变量
|
488
538
|
"""
|
489
|
-
|
490
|
-
|
539
|
+
Initialize agents within the simulation.
|
540
|
+
|
541
|
+
- **Description**:
|
542
|
+
- Asynchronously initializes a specified number of agents for each provided agent class.
|
543
|
+
- Agents are grouped into independent Ray actors based on the `group_size`, with configurations for database writers,
|
544
|
+
message interceptors, and memory settings. Optionally updates the simulator's environment variables.
|
545
|
+
|
546
|
+
- **Args**:
|
547
|
+
- `agent_count` (dict[Type[Agent], int]): Dictionary mapping agent classes to the number of instances to create.
|
548
|
+
- `group_size` (int, optional): Number of agents per group, each group runs as an independent Ray actor. Defaults to 10000.
|
549
|
+
- `pg_sql_writers` (int, optional): Number of independent PgSQL writer processes. Defaults to 32.
|
550
|
+
- `message_interceptors` (int, optional): Number of message interceptor processes. Defaults to 1.
|
551
|
+
- `message_interceptor_blocks` (Optional[List[MessageBlockBase]], optional): List of message interception blocks. Defaults to None.
|
552
|
+
- `social_black_list` (Optional[List[Tuple[str, str]]], optional): List of tuples representing pairs of agents that should not communicate. Defaults to None.
|
553
|
+
- `message_listener` (Optional[MessageBlockListenerBase], optional): Listener for intercepted messages. Defaults to None.
|
554
|
+
- `embedding_model` (Embeddings, optional): Model used for generating embeddings for agents' memories. Defaults to SimpleEmbedding().
|
555
|
+
- `memory_config_init_func` (Optional[Callable], optional): Initialization function for setting up memory configuration. Defaults to None.
|
556
|
+
- `memory_config_func` (Optional[Dict[Type[Agent], Callable]], optional): Dictionary mapping agent classes to their memory configuration functions. Defaults to None.
|
557
|
+
- `environment` (Optional[Dict[str, str]], optional): Environment variables to update in the simulation. Defaults to None.
|
558
|
+
|
559
|
+
- **Raises**:
|
560
|
+
- `ValueError`: If the lengths of `agent_class` and `agent_count` do not match.
|
561
|
+
|
562
|
+
- **Returns**:
|
563
|
+
- `None`
|
564
|
+
"""
|
565
|
+
self.agent_count = agent_count
|
491
566
|
|
492
567
|
if len(self.agent_class) != len(agent_count):
|
493
|
-
raise ValueError("agent_class
|
568
|
+
raise ValueError("The length of agent_class and agent_count does not match")
|
494
569
|
|
570
|
+
if memory_config_init_func is not None:
|
571
|
+
await memory_config_init(self)
|
495
572
|
if memory_config_func is None:
|
496
573
|
memory_config_func = self.default_memory_config_func # type:ignore
|
497
574
|
|
@@ -505,7 +582,7 @@ class AgentSimulation:
|
|
505
582
|
# 收集所有参数
|
506
583
|
for i in range(len(self.agent_class)):
|
507
584
|
agent_class = self.agent_class[i]
|
508
|
-
agent_count_i = agent_count[
|
585
|
+
agent_count_i = agent_count[agent_class]
|
509
586
|
assert memory_config_func is not None
|
510
587
|
memory_config_func_i = memory_config_func.get(
|
511
588
|
agent_class, self.default_memory_config_func[agent_class] # type:ignore
|
@@ -706,10 +783,32 @@ class AgentSimulation:
|
|
706
783
|
)
|
707
784
|
await self.messager.start_listening.remote() # type:ignore
|
708
785
|
|
786
|
+
agent_ids = set()
|
787
|
+
org_ids = set()
|
788
|
+
for group in self._groups.values():
|
789
|
+
ids = await group.get_economy_ids.remote()
|
790
|
+
agent_ids.update(ids[0])
|
791
|
+
org_ids.update(ids[1])
|
792
|
+
await self.economy_client.set_ids(agent_ids, org_ids)
|
793
|
+
for group in self._groups.values():
|
794
|
+
await group.set_economy_ids.remote(agent_ids, org_ids)
|
795
|
+
|
709
796
|
async def gather(
|
710
797
|
self, content: str, target_agent_uuids: Optional[list[str]] = None
|
711
798
|
):
|
712
|
-
"""
|
799
|
+
"""
|
800
|
+
Collect specific information from agents.
|
801
|
+
|
802
|
+
- **Description**:
|
803
|
+
- Asynchronously gathers specified content from targeted agents within all groups.
|
804
|
+
|
805
|
+
- **Args**:
|
806
|
+
- `content` (str): The information to collect from the agents.
|
807
|
+
- `target_agent_uuids` (Optional[List[str]], optional): A list of agent UUIDs to target. Defaults to None, meaning all agents are targeted.
|
808
|
+
|
809
|
+
- **Returns**:
|
810
|
+
- Result of the gathering process as returned by each group's `gather` method.
|
811
|
+
"""
|
713
812
|
gather_tasks = []
|
714
813
|
for group in self._groups.values():
|
715
814
|
gather_tasks.append(group.gather.remote(content, target_agent_uuids))
|
@@ -721,7 +820,20 @@ class AgentSimulation:
|
|
721
820
|
keys: Optional[list[str]] = None,
|
722
821
|
values: Optional[list[Any]] = None,
|
723
822
|
) -> list[str]:
|
724
|
-
"""
|
823
|
+
"""
|
824
|
+
Filter out agents of specified types or with matching key-value pairs.
|
825
|
+
|
826
|
+
- **Args**:
|
827
|
+
- `types` (Optional[List[Type[Agent]]], optional): Types of agents to filter for. Defaults to None.
|
828
|
+
- `keys` (Optional[List[str]], optional): Keys to match in agent attributes. Defaults to None.
|
829
|
+
- `values` (Optional[List[Any]], optional): Values corresponding to keys for matching. Defaults to None.
|
830
|
+
|
831
|
+
- **Raises**:
|
832
|
+
- `ValueError`: If neither types nor keys and values are provided, or if the lengths of keys and values do not match.
|
833
|
+
|
834
|
+
- **Returns**:
|
835
|
+
- `List[str]`: A list of filtered agent UUIDs.
|
836
|
+
"""
|
725
837
|
if not types and not keys and not values:
|
726
838
|
return self._agent_uuids
|
727
839
|
group_to_filter = []
|
@@ -744,12 +856,26 @@ class AgentSimulation:
|
|
744
856
|
return filtered_uuids
|
745
857
|
|
746
858
|
async def update_environment(self, key: str, value: str):
|
859
|
+
"""
|
860
|
+
Update the environment variables for the simulation and all agent groups.
|
861
|
+
|
862
|
+
- **Args**:
|
863
|
+
- `key` (str): The environment variable key to update.
|
864
|
+
- `value` (str): The new value for the environment variable.
|
865
|
+
"""
|
747
866
|
self._simulator.update_environment(key, value)
|
748
867
|
for group in self._groups.values():
|
749
868
|
await group.update_environment.remote(key, value)
|
750
869
|
|
751
870
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
752
|
-
"""
|
871
|
+
"""
|
872
|
+
Update the memory of a specified agent.
|
873
|
+
|
874
|
+
- **Args**:
|
875
|
+
- `target_agent_uuid` (str): The UUID of the target agent to update.
|
876
|
+
- `target_key` (str): The key in the agent's memory to update.
|
877
|
+
- `content` (Any): The new content to set for the target key.
|
878
|
+
"""
|
753
879
|
group = self._agent_uuid2group[target_agent_uuid]
|
754
880
|
await group.update.remote(target_agent_uuid, target_key, content)
|
755
881
|
|
@@ -760,13 +886,30 @@ class AgentSimulation:
|
|
760
886
|
content: Any,
|
761
887
|
mode: Literal["replace", "merge"] = "replace",
|
762
888
|
):
|
763
|
-
"""
|
889
|
+
"""
|
890
|
+
Update economic data for a specified agent.
|
891
|
+
|
892
|
+
- **Args**:
|
893
|
+
- `target_agent_id` (int): The ID of the target agent whose economic data to update.
|
894
|
+
- `target_key` (str): The key in the agent's economic data to update.
|
895
|
+
- `content` (Any): The new content to set for the target key.
|
896
|
+
- `mode` (Literal["replace", "merge"], optional): Mode of updating the economic data. Defaults to "replace".
|
897
|
+
"""
|
764
898
|
await self.economy_client.update(
|
765
899
|
id=target_agent_id, key=target_key, value=content, mode=mode
|
766
900
|
)
|
767
901
|
|
768
902
|
async def send_survey(self, survey: Survey, agent_uuids: list[str] = []):
|
769
|
-
"""
|
903
|
+
"""
|
904
|
+
Send a survey to specified agents.
|
905
|
+
|
906
|
+
- **Args**:
|
907
|
+
- `survey` (Survey): The survey object to send.
|
908
|
+
- `agent_uuids` (List[str], optional): List of agent UUIDs to receive the survey. Defaults to an empty list.
|
909
|
+
|
910
|
+
- **Returns**:
|
911
|
+
- None
|
912
|
+
"""
|
770
913
|
survey_dict = survey.to_dict()
|
771
914
|
_date_time = datetime.now(timezone.utc)
|
772
915
|
payload = {
|
@@ -791,7 +934,16 @@ class AgentSimulation:
|
|
791
934
|
async def send_interview_message(
|
792
935
|
self, content: str, agent_uuids: Union[str, list[str]]
|
793
936
|
):
|
794
|
-
"""
|
937
|
+
"""
|
938
|
+
Send an interview message to specified agents.
|
939
|
+
|
940
|
+
- **Args**:
|
941
|
+
- `content` (str): The content of the message to send.
|
942
|
+
- `agent_uuids` (Union[str, List[str]]): A single UUID string or a list of UUID strings for the agents to receive the message.
|
943
|
+
|
944
|
+
- **Returns**:
|
945
|
+
- None
|
946
|
+
"""
|
795
947
|
_date_time = datetime.now(timezone.utc)
|
796
948
|
payload = {
|
797
949
|
"from": SURVEY_SENDER_UUID,
|
@@ -814,12 +966,37 @@ class AgentSimulation:
|
|
814
966
|
await asyncio.sleep(3)
|
815
967
|
|
816
968
|
async def extract_metric(self, metric_extractors: list[Callable]):
|
817
|
-
"""
|
969
|
+
"""
|
970
|
+
Extract metrics using provided extractors.
|
971
|
+
|
972
|
+
- **Description**:
|
973
|
+
- Asynchronously applies each metric extractor function to the simulation to collect various metrics.
|
974
|
+
|
975
|
+
- **Args**:
|
976
|
+
- `metric_extractors` (List[Callable]): A list of callable functions that take the simulation instance as an argument and return a metric or perform some form of analysis.
|
977
|
+
|
978
|
+
- **Returns**:
|
979
|
+
- None
|
980
|
+
"""
|
818
981
|
for metric_extractor in metric_extractors:
|
819
982
|
await metric_extractor(self)
|
820
983
|
|
821
984
|
async def step(self):
|
822
|
-
"""
|
985
|
+
"""
|
986
|
+
Execute one step of the simulation where each agent performs its forward action.
|
987
|
+
|
988
|
+
- **Description**:
|
989
|
+
- Checks if new agents need to be inserted based on the current day of the simulation. If so, it inserts them.
|
990
|
+
- Executes the forward method for each agent group to advance the simulation by one step.
|
991
|
+
- Saves the state of all agent groups after the step has been completed.
|
992
|
+
- Optionally extracts metrics if the current step matches the interval specified for any metric extractors.
|
993
|
+
|
994
|
+
- **Raises**:
|
995
|
+
- `RuntimeError`: If there is an error during the execution of the step, it logs the error and rethrows it as a RuntimeError.
|
996
|
+
|
997
|
+
- **Returns**:
|
998
|
+
- None
|
999
|
+
"""
|
823
1000
|
try:
|
824
1001
|
# check whether insert agents
|
825
1002
|
simulator_day = await self._simulator.get_simulator_day()
|
@@ -837,11 +1014,21 @@ class AgentSimulation:
|
|
837
1014
|
# step
|
838
1015
|
simulator_day = await self._simulator.get_simulator_day()
|
839
1016
|
simulator_time = int(await self._simulator.get_time())
|
840
|
-
logger.info(
|
1017
|
+
logger.info(
|
1018
|
+
f"Start simulation day {simulator_day} at {simulator_time}, step {self._total_steps}"
|
1019
|
+
)
|
841
1020
|
tasks = []
|
842
1021
|
for group in self._groups.values():
|
843
1022
|
tasks.append(group.step.remote())
|
844
|
-
await asyncio.gather(*tasks)
|
1023
|
+
log_messages_groups = await asyncio.gather(*tasks)
|
1024
|
+
llm_log_list = []
|
1025
|
+
mqtt_log_list = []
|
1026
|
+
simulator_log_list = []
|
1027
|
+
for log_messages_group in log_messages_groups:
|
1028
|
+
llm_log_list.extend(log_messages_group['llm_log'])
|
1029
|
+
mqtt_log_list.extend(log_messages_group['mqtt_log'])
|
1030
|
+
simulator_log_list.extend(log_messages_group['simulator_log'])
|
1031
|
+
|
845
1032
|
# save
|
846
1033
|
simulator_day = await self._simulator.get_simulator_day()
|
847
1034
|
simulator_time = int(await self._simulator.get_time())
|
@@ -850,14 +1037,15 @@ class AgentSimulation:
|
|
850
1037
|
save_tasks.append(group.save.remote(simulator_day, simulator_time))
|
851
1038
|
await asyncio.gather(*save_tasks)
|
852
1039
|
self._total_steps += 1
|
853
|
-
if self.
|
854
|
-
print(f"total_steps: {self._total_steps}, excute metric")
|
1040
|
+
if self.metric_extractors is not None: # type:ignore
|
855
1041
|
to_excute_metric = [
|
856
1042
|
metric[1]
|
857
|
-
for metric in self.
|
1043
|
+
for metric in self.metric_extractors # type:ignore
|
858
1044
|
if self._total_steps % metric[0] == 0
|
859
1045
|
]
|
860
1046
|
await self.extract_metric(to_excute_metric)
|
1047
|
+
|
1048
|
+
return llm_log_list, mqtt_log_list, simulator_log_list
|
861
1049
|
except Exception as e:
|
862
1050
|
import traceback
|
863
1051
|
|
@@ -868,7 +1056,26 @@ class AgentSimulation:
|
|
868
1056
|
self,
|
869
1057
|
day: int = 1,
|
870
1058
|
):
|
871
|
-
"""
|
1059
|
+
"""
|
1060
|
+
Run the simulation for a specified number of days.
|
1061
|
+
|
1062
|
+
- **Args**:
|
1063
|
+
- `day` (int, optional): The number of days to run the simulation. Defaults to 1.
|
1064
|
+
|
1065
|
+
- **Description**:
|
1066
|
+
- Updates the experiment status to running and sets up monitoring for the experiment's status.
|
1067
|
+
- Runs the simulation loop until the end time, which is calculated based on the current time and the number of days to simulate.
|
1068
|
+
- After completing the simulation, updates the experiment status to finished, or to failed if an exception occurs.
|
1069
|
+
|
1070
|
+
- **Raises**:
|
1071
|
+
- `RuntimeError`: If there is an error during the simulation, it logs the error and updates the experiment status to failed before rethrowing the exception.
|
1072
|
+
|
1073
|
+
- **Returns**:
|
1074
|
+
- None
|
1075
|
+
"""
|
1076
|
+
llm_log_lists = []
|
1077
|
+
mqtt_log_lists = []
|
1078
|
+
simulator_log_lists = []
|
872
1079
|
try:
|
873
1080
|
self._exp_info["num_day"] += day
|
874
1081
|
await self._update_exp_status(1) # 更新状态为运行中
|
@@ -886,7 +1093,10 @@ class AgentSimulation:
|
|
886
1093
|
current_time = await self._simulator.get_time()
|
887
1094
|
if current_time >= end_time: # type:ignore
|
888
1095
|
break
|
889
|
-
await self.step()
|
1096
|
+
llm_log_list, mqtt_log_list, simulator_log_list = await self.step()
|
1097
|
+
llm_log_lists.extend(llm_log_list)
|
1098
|
+
mqtt_log_lists.extend(mqtt_log_list)
|
1099
|
+
simulator_log_lists.extend(simulator_log_list)
|
890
1100
|
finally:
|
891
1101
|
# 设置停止事件
|
892
1102
|
stop_event.set()
|
@@ -895,7 +1105,7 @@ class AgentSimulation:
|
|
895
1105
|
|
896
1106
|
# 运行成功后更新状态
|
897
1107
|
await self._update_exp_status(2)
|
898
|
-
|
1108
|
+
return llm_log_lists, mqtt_log_lists, simulator_log_lists
|
899
1109
|
except Exception as e:
|
900
1110
|
error_msg = f"模拟器运行错误: {str(e)}"
|
901
1111
|
logger.error(error_msg)
|
pycityagent/survey/manager.py
CHANGED
@@ -8,10 +8,33 @@ from .models import Page, Question, QuestionType, Survey
|
|
8
8
|
|
9
9
|
class SurveyManager:
|
10
10
|
def __init__(self):
|
11
|
+
"""
|
12
|
+
Initializes a new instance of the SurveyManager class.
|
13
|
+
|
14
|
+
- **Description**:
|
15
|
+
- Manages the creation and retrieval of surveys. Uses an internal dictionary to store survey instances by their unique identifier.
|
16
|
+
|
17
|
+
- **Attributes**:
|
18
|
+
- `_surveys` (dict[str, Survey]): A dictionary mapping survey IDs to `Survey` instances.
|
19
|
+
"""
|
11
20
|
self._surveys: dict[str, Survey] = {}
|
12
21
|
|
13
22
|
def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
|
14
|
-
"""
|
23
|
+
"""
|
24
|
+
Create a new survey with specified title, description, and pages containing questions.
|
25
|
+
|
26
|
+
- **Description**:
|
27
|
+
- Generates a unique ID for the survey, converts the provided page and question data into `Page` and `Question` objects,
|
28
|
+
and adds the created survey to the internal storage.
|
29
|
+
|
30
|
+
- **Args**:
|
31
|
+
- `title` (str): The title of the survey.
|
32
|
+
- `description` (str): A brief description of what the survey is about.
|
33
|
+
- `pages` (list[dict]): A list of dictionaries where each dictionary contains page information including elements which are question definitions.
|
34
|
+
|
35
|
+
- **Returns**:
|
36
|
+
- `Survey`: An instance of the `Survey` class representing the newly created survey.
|
37
|
+
"""
|
15
38
|
survey_id = uuid.uuid4()
|
16
39
|
|
17
40
|
# 转换页面和问题数据
|
@@ -46,9 +69,28 @@ class SurveyManager:
|
|
46
69
|
return survey
|
47
70
|
|
48
71
|
def get_survey(self, survey_id: str) -> Optional[Survey]:
|
49
|
-
"""
|
72
|
+
"""
|
73
|
+
Retrieve a specific survey by its unique identifier.
|
74
|
+
|
75
|
+
- **Description**:
|
76
|
+
- Searches for a survey within the internal storage using the provided survey ID.
|
77
|
+
|
78
|
+
- **Args**:
|
79
|
+
- `survey_id` (str): The unique identifier of the survey to retrieve.
|
80
|
+
|
81
|
+
- **Returns**:
|
82
|
+
- `Optional[Survey]`: An instance of the `Survey` class if found; otherwise, None.
|
83
|
+
"""
|
50
84
|
return self._surveys.get(survey_id)
|
51
85
|
|
52
86
|
def get_all_surveys(self) -> list[Survey]:
|
53
|
-
"""
|
87
|
+
"""
|
88
|
+
Get all surveys that have been created.
|
89
|
+
|
90
|
+
- **Description**:
|
91
|
+
- Retrieves all stored surveys from the internal storage.
|
92
|
+
|
93
|
+
- **Returns**:
|
94
|
+
- `list[Survey]`: A list of `Survey` instances.
|
95
|
+
"""
|
54
96
|
return list(self._surveys.values())
|