pycityagent 2.0.0a94__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a96__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +5 -5
- pycityagent/agent/agent_base.py +1 -6
- pycityagent/cityagent/__init__.py +6 -5
- pycityagent/cityagent/bankagent.py +2 -2
- pycityagent/cityagent/blocks/__init__.py +4 -4
- pycityagent/cityagent/blocks/cognition_block.py +7 -4
- pycityagent/cityagent/blocks/economy_block.py +227 -135
- pycityagent/cityagent/blocks/mobility_block.py +70 -27
- pycityagent/cityagent/blocks/needs_block.py +11 -12
- pycityagent/cityagent/blocks/other_block.py +2 -2
- pycityagent/cityagent/blocks/plan_block.py +22 -24
- pycityagent/cityagent/blocks/social_block.py +15 -17
- pycityagent/cityagent/blocks/utils.py +3 -2
- pycityagent/cityagent/firmagent.py +1 -1
- pycityagent/cityagent/governmentagent.py +1 -1
- pycityagent/cityagent/initial.py +1 -1
- pycityagent/cityagent/memory_config.py +0 -1
- pycityagent/cityagent/message_intercept.py +7 -8
- pycityagent/cityagent/nbsagent.py +1 -1
- pycityagent/cityagent/societyagent.py +1 -2
- pycityagent/configs/__init__.py +18 -0
- pycityagent/configs/exp_config.py +202 -0
- pycityagent/configs/sim_config.py +251 -0
- pycityagent/configs/utils.py +17 -0
- pycityagent/environment/__init__.py +2 -0
- pycityagent/{economy → environment/economy}/econ_client.py +14 -32
- pycityagent/environment/sim/sim_env.py +17 -24
- pycityagent/environment/simulator.py +36 -113
- pycityagent/llm/__init__.py +1 -2
- pycityagent/llm/llm.py +54 -167
- pycityagent/memory/memory.py +13 -12
- pycityagent/message/message_interceptor.py +5 -4
- pycityagent/message/messager.py +3 -5
- pycityagent/metrics/__init__.py +1 -1
- pycityagent/metrics/mlflow_client.py +20 -17
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +18 -20
- pycityagent/simulation/simulation.py +157 -210
- pycityagent/survey/manager.py +0 -2
- pycityagent/utils/__init__.py +3 -0
- pycityagent/utils/config_const.py +20 -0
- pycityagent/workflow/__init__.py +1 -2
- pycityagent/workflow/block.py +0 -3
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/METADATA +7 -24
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/RECORD +50 -46
- pycityagent/llm/llmconfig.py +0 -18
- /pycityagent/{economy → environment/economy}/__init__.py +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a94.dist-info → pycityagent-2.0.0a96.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
+
import inspect
|
2
3
|
import json
|
3
4
|
import logging
|
4
5
|
import time
|
@@ -6,7 +7,7 @@ import uuid
|
|
6
7
|
from collections.abc import Callable
|
7
8
|
from datetime import datetime, timezone
|
8
9
|
from pathlib import Path
|
9
|
-
from typing import Any, Literal, Optional, Type, Union
|
10
|
+
from typing import Any, Literal, Optional, Type, Union, cast
|
10
11
|
|
11
12
|
import ray
|
12
13
|
import yaml
|
@@ -23,19 +24,21 @@ from ..cityagent.memory_config import (memory_config_bank, memory_config_firm,
|
|
23
24
|
from ..cityagent.message_intercept import (EdgeMessageBlock,
|
24
25
|
MessageBlockListener,
|
25
26
|
PointMessageBlock)
|
26
|
-
from ..
|
27
|
-
from ..environment import Simulator
|
27
|
+
from ..configs import ExpConfig, SimConfig
|
28
|
+
from ..environment import EconomyClient, Simulator
|
28
29
|
from ..llm import SimpleEmbedding
|
29
30
|
from ..message import (MessageBlockBase, MessageBlockListenerBase,
|
30
31
|
MessageInterceptor, Messager)
|
31
32
|
from ..metrics import init_mlflow_connection
|
32
33
|
from ..metrics.mlflow_client import MlflowClient
|
33
34
|
from ..survey import Survey
|
34
|
-
from ..utils import SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
35
|
+
from ..utils import (SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES,
|
36
|
+
WorkflowType)
|
35
37
|
from .agentgroup import AgentGroup
|
36
38
|
from .storage.pg import PgWriter, create_pg_tables
|
37
39
|
|
38
40
|
logger = logging.getLogger("pycityagent")
|
41
|
+
ExpConfig.model_rebuild() # rebuild the schema due to circular import
|
39
42
|
|
40
43
|
|
41
44
|
__all__ = ["AgentSimulation"]
|
@@ -58,9 +61,9 @@ class AgentSimulation:
|
|
58
61
|
|
59
62
|
def __init__(
|
60
63
|
self,
|
61
|
-
config:
|
64
|
+
config: SimConfig,
|
62
65
|
agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
|
63
|
-
|
66
|
+
agent_class_configs: Optional[dict] = None,
|
64
67
|
metric_extractors: Optional[list[tuple[int, Callable]]] = None,
|
65
68
|
enable_institution: bool = True,
|
66
69
|
agent_prefix: str = "agent_",
|
@@ -75,10 +78,10 @@ class AgentSimulation:
|
|
75
78
|
it can include a predefined set of institutional agents. If specific agent classes are provided, those will be used instead.
|
76
79
|
|
77
80
|
- **Args**:
|
78
|
-
- `config` (
|
81
|
+
- `config` (SimConfig): The main configuration for the simulation.
|
79
82
|
- `agent_class` (Union[None, Type[Agent], List[Type[Agent]]], optional):
|
80
83
|
Either a single agent class or a list of agent classes to instantiate. Defaults to None, which implies a default set of agents.
|
81
|
-
- `
|
84
|
+
- `agent_class_configs` (Optional[dict], optional): An optional configuration dict used to initialize agents. Defaults to None.
|
82
85
|
- `metric_extractors` (Optional[List[Tuple[int, Callable]]], optional):
|
83
86
|
A list of tuples containing intervals and callables for extracting metrics from the simulation. Defaults to None.
|
84
87
|
- `enable_institution` (bool, optional): Flag indicating whether institutional agents should be included in the simulation. Defaults to True.
|
@@ -115,41 +118,18 @@ class AgentSimulation:
|
|
115
118
|
}
|
116
119
|
else:
|
117
120
|
self.agent_class = [agent_class]
|
118
|
-
self.
|
121
|
+
self.agent_class_configs = agent_class_configs
|
119
122
|
self.logging_level = logging_level
|
120
123
|
self.config = config
|
121
124
|
self.exp_name = exp_name
|
122
|
-
|
123
|
-
if "server" in _simulator_config:
|
124
|
-
raise ValueError(f"Passing Traffic Simulation address is not supported!")
|
125
|
-
simulator = Simulator(config["simulator_request"], create_map=True)
|
125
|
+
simulator = Simulator(config, create_map=True)
|
126
126
|
self._simulator = simulator
|
127
127
|
self._map_ref = self._simulator.map
|
128
128
|
server_addr = self._simulator.get_server_addr()
|
129
|
-
config
|
130
|
-
self._economy_client = EconomyClient(
|
131
|
-
config["simulator_request"]["simulator"]["server"]
|
132
|
-
)
|
129
|
+
config.SetServerAddress(server_addr)
|
130
|
+
self._economy_client = EconomyClient(server_addr)
|
133
131
|
if enable_institution:
|
134
132
|
self._economy_addr = economy_addr = server_addr
|
135
|
-
if economy_addr is None:
|
136
|
-
raise ValueError(
|
137
|
-
f"`simulator` not provided in `simulator_request`, thus unable to activate economy!"
|
138
|
-
)
|
139
|
-
_req_dict: dict = self.config["simulator_request"]
|
140
|
-
if "economy" in _req_dict:
|
141
|
-
if _req_dict["economy"] is None:
|
142
|
-
_req_dict["economy"] = {}
|
143
|
-
if "server" in _req_dict["economy"]:
|
144
|
-
raise ValueError(
|
145
|
-
f"Passing Economy Simulation address is not supported!"
|
146
|
-
)
|
147
|
-
else:
|
148
|
-
_req_dict["economy"]["server"] = economy_addr
|
149
|
-
else:
|
150
|
-
_req_dict["economy"] = {
|
151
|
-
"server": economy_addr,
|
152
|
-
}
|
153
133
|
self.agent_prefix = agent_prefix
|
154
134
|
self._groups: dict[str, AgentGroup] = {} # type:ignore
|
155
135
|
self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
|
@@ -161,42 +141,45 @@ class AgentSimulation:
|
|
161
141
|
self._loop = asyncio.get_event_loop()
|
162
142
|
self._total_steps = 0
|
163
143
|
self._simulator_day = 0
|
164
|
-
# self._last_asyncio_pg_task = None #
|
144
|
+
# self._last_asyncio_pg_task = None # hide SQL write IO to calculation task
|
165
145
|
|
146
|
+
mqtt_config = config.prop_mqtt
|
166
147
|
self._messager = Messager.remote(
|
167
|
-
hostname=
|
168
|
-
port=
|
169
|
-
username=
|
170
|
-
password=
|
148
|
+
hostname=mqtt_config.server, # type:ignore
|
149
|
+
port=mqtt_config.port,
|
150
|
+
username=mqtt_config.username,
|
151
|
+
password=mqtt_config.password,
|
171
152
|
)
|
172
153
|
|
173
154
|
# storage
|
174
|
-
_storage_config: dict[str, Any] = config.get("storage", {})
|
175
|
-
if _storage_config is None:
|
176
|
-
_storage_config = {}
|
177
155
|
|
178
156
|
# avro
|
179
|
-
|
180
|
-
|
181
|
-
|
157
|
+
avro_config = config.prop_avro_config
|
158
|
+
if avro_config is not None:
|
159
|
+
self._enable_avro: bool = avro_config.enabled # type:ignore
|
160
|
+
if not self._enable_avro:
|
161
|
+
self._avro_path = None
|
162
|
+
logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
|
163
|
+
else:
|
164
|
+
self._avro_path = Path(avro_config.path) / f"{self.exp_id}"
|
165
|
+
self._avro_path.mkdir(parents=True, exist_ok=True)
|
166
|
+
else:
|
182
167
|
self._avro_path = None
|
183
168
|
logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
|
184
|
-
|
185
|
-
self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
|
186
|
-
self._avro_path.mkdir(parents=True, exist_ok=True)
|
169
|
+
self._enable_avro = False
|
187
170
|
|
188
171
|
# mlflow
|
189
|
-
|
190
|
-
if
|
172
|
+
metric_config = config.prop_metric_request
|
173
|
+
if metric_config is not None and metric_config.mlflow is not None:
|
191
174
|
logger.info(f"-----Creating Mlflow client...")
|
192
175
|
mlflow_run_id, _ = init_mlflow_connection(
|
193
|
-
config=
|
176
|
+
config=metric_config.mlflow,
|
194
177
|
experiment_uuid=self.exp_id,
|
195
178
|
mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
|
196
179
|
experiment_name=self.exp_name,
|
197
180
|
)
|
198
181
|
self.mlflow_client = MlflowClient(
|
199
|
-
config=
|
182
|
+
config=metric_config.mlflow,
|
200
183
|
experiment_uuid=self.exp_id,
|
201
184
|
mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
|
202
185
|
experiment_name=exp_name,
|
@@ -209,35 +192,36 @@ class AgentSimulation:
|
|
209
192
|
self.metric_extractors = None
|
210
193
|
|
211
194
|
# pg
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
195
|
+
pgsql_config = config.prop_postgre_sql_config
|
196
|
+
if pgsql_config is not None:
|
197
|
+
self._enable_pgsql: bool = pgsql_config.enabled # type:ignore
|
198
|
+
if not self._enable_pgsql:
|
199
|
+
logger.warning(
|
200
|
+
"PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE"
|
201
|
+
)
|
202
|
+
self._pgsql_dsn = ""
|
203
|
+
else:
|
204
|
+
self._pgsql_dsn = pgsql_config.dsn
|
217
205
|
else:
|
218
|
-
self.
|
219
|
-
_pgsql_config["data_source_name"]
|
220
|
-
if "data_source_name" in _pgsql_config
|
221
|
-
else _pgsql_config["dsn"]
|
222
|
-
)
|
206
|
+
self._enable_pgsql = False
|
223
207
|
|
224
|
-
#
|
208
|
+
# add experiment info related properties
|
225
209
|
self._exp_created_time = datetime.now(timezone.utc)
|
226
210
|
self._exp_updated_time = datetime.now(timezone.utc)
|
227
211
|
self._exp_info = {
|
228
212
|
"id": self.exp_id,
|
229
213
|
"name": exp_name,
|
230
|
-
"num_day": 0, #
|
214
|
+
"num_day": 0, # will be updated in run method
|
231
215
|
"status": 0,
|
232
216
|
"cur_day": 0,
|
233
217
|
"cur_t": 0.0,
|
234
|
-
"config":
|
218
|
+
"config": str(config.model_dump()),
|
235
219
|
"error": "",
|
236
220
|
"created_at": self._exp_created_time.isoformat(),
|
237
221
|
"updated_at": self._exp_updated_time.isoformat(),
|
238
222
|
}
|
239
223
|
|
240
|
-
#
|
224
|
+
# create async task to save experiment info
|
241
225
|
if self._enable_avro:
|
242
226
|
assert self._avro_path is not None
|
243
227
|
self._exp_info_file = self._avro_path / "experiment_info.yaml"
|
@@ -245,7 +229,7 @@ class AgentSimulation:
|
|
245
229
|
yaml.dump(self._exp_info, f)
|
246
230
|
|
247
231
|
@classmethod
|
248
|
-
async def run_from_config(cls, config:
|
232
|
+
async def run_from_config(cls, config: ExpConfig, sim_config: SimConfig):
|
249
233
|
"""Directly run from config file
|
250
234
|
Basic config file should contain:
|
251
235
|
- simulation_config: file_path
|
@@ -280,130 +264,101 @@ class AgentSimulation:
|
|
280
264
|
- logging_level: Optional[int]
|
281
265
|
- exp_name: Optional[str]
|
282
266
|
"""
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
if "agent_config" not in config:
|
287
|
-
raise ValueError("agent_config is required")
|
288
|
-
if "workflow" not in config:
|
289
|
-
raise ValueError("workflow is required")
|
290
|
-
import yaml
|
291
|
-
|
292
|
-
logger.info("Loading config file...")
|
293
|
-
with open(config["simulation_config"], "r") as f:
|
294
|
-
simulation_config = yaml.safe_load(f)
|
267
|
+
|
268
|
+
agent_config = config.prop_agent_config
|
269
|
+
|
295
270
|
logger.info("Creating AgentSimulation Task...")
|
296
271
|
simulation = cls(
|
297
|
-
config=
|
298
|
-
|
299
|
-
metric_extractors=config
|
300
|
-
enable_institution=
|
301
|
-
exp_name=config.
|
302
|
-
logging_level=config.
|
303
|
-
)
|
304
|
-
environment = config.get(
|
305
|
-
"environment",
|
306
|
-
{
|
307
|
-
"weather": "The weather is normal",
|
308
|
-
"crime": "The crime rate is low",
|
309
|
-
"pollution": "The pollution level is low",
|
310
|
-
"temperature": "The temperature is normal",
|
311
|
-
"day": "Workday",
|
312
|
-
},
|
272
|
+
config=sim_config,
|
273
|
+
agent_class_configs=agent_config.agent_class_configs,
|
274
|
+
metric_extractors=config.prop_metric_extractors,
|
275
|
+
enable_institution=agent_config.enable_institution,
|
276
|
+
exp_name=config.exp_name,
|
277
|
+
logging_level=config.logging_level,
|
313
278
|
)
|
279
|
+
environment = config.prop_environment.model_dump()
|
314
280
|
simulation._simulator.set_environment(environment)
|
315
281
|
logger.info("Initializing Agents...")
|
316
|
-
agent_count = {
|
317
|
-
SocietyAgent:
|
318
|
-
FirmAgent:
|
319
|
-
GovernmentAgent:
|
320
|
-
BankAgent:
|
321
|
-
NBSAgent:
|
282
|
+
agent_count: dict[type[Agent], int] = {
|
283
|
+
SocietyAgent: agent_config.number_of_citizen,
|
284
|
+
FirmAgent: agent_config.number_of_firm,
|
285
|
+
GovernmentAgent: agent_config.number_of_government,
|
286
|
+
BankAgent: agent_config.number_of_bank,
|
287
|
+
NBSAgent: agent_config.number_of_nbs,
|
322
288
|
}
|
323
289
|
if agent_count.get(SocietyAgent, 0) == 0:
|
324
290
|
raise ValueError("number_of_citizen is required")
|
325
291
|
|
326
292
|
# support MessageInterceptor
|
327
|
-
if
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
_kwargs = {
|
342
|
-
k: v
|
343
|
-
for k, v in _intercept_config.items()
|
344
|
-
if k
|
345
|
-
in {
|
346
|
-
"max_violation_time",
|
347
|
-
}
|
348
|
-
}
|
349
|
-
_interceptor_blocks = [EdgeMessageBlock(**_kwargs)]
|
293
|
+
if config.message_intercept is not None:
|
294
|
+
intercept_config = config.message_intercept
|
295
|
+
if intercept_config.mode == "point":
|
296
|
+
_interceptor_blocks = [
|
297
|
+
PointMessageBlock(
|
298
|
+
max_violation_time=intercept_config.max_violation_time
|
299
|
+
)
|
300
|
+
]
|
301
|
+
elif intercept_config.mode == "edge":
|
302
|
+
_interceptor_blocks = [
|
303
|
+
EdgeMessageBlock(
|
304
|
+
max_violation_time=intercept_config.max_violation_time
|
305
|
+
)
|
306
|
+
]
|
350
307
|
else:
|
351
|
-
raise ValueError(
|
308
|
+
raise ValueError(
|
309
|
+
f"Unsupported interception mode `{intercept_config.mode}!`"
|
310
|
+
)
|
352
311
|
_message_intercept_kwargs = {
|
353
312
|
"message_interceptor_blocks": _interceptor_blocks,
|
354
313
|
"message_listener": MessageBlockListener(),
|
355
314
|
}
|
356
315
|
else:
|
357
316
|
_message_intercept_kwargs = {}
|
317
|
+
embedding_model = agent_config.embedding_model
|
318
|
+
if embedding_model is None:
|
319
|
+
embedding_model = SimpleEmbedding()
|
358
320
|
await simulation.init_agents(
|
359
321
|
agent_count=agent_count,
|
360
|
-
group_size=
|
361
|
-
embedding_model=
|
362
|
-
|
363
|
-
|
364
|
-
memory_config_func=config["agent_config"].get("memory_config_func", None),
|
365
|
-
memory_config_init_func=config["agent_config"].get(
|
366
|
-
"memory_config_init_func", None
|
367
|
-
),
|
322
|
+
group_size=agent_config.group_size,
|
323
|
+
embedding_model=embedding_model,
|
324
|
+
memory_config_func=agent_config.memory_config_func,
|
325
|
+
memory_config_init_func=agent_config.memory_config_init_func,
|
368
326
|
**_message_intercept_kwargs,
|
369
327
|
environment=environment,
|
370
|
-
llm_semaphore=config.
|
328
|
+
llm_semaphore=config.llm_semaphore,
|
371
329
|
)
|
372
330
|
logger.info("Running Init Functions...")
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
331
|
+
init_funcs = agent_config.init_func
|
332
|
+
if init_funcs is None:
|
333
|
+
init_funcs = [bind_agent_info, initialize_social_network]
|
334
|
+
for init_func in init_funcs:
|
335
|
+
if inspect.iscoroutinefunction(init_func):
|
336
|
+
await init_func(simulation)
|
337
|
+
else:
|
338
|
+
init_func = cast(Callable, init_func)
|
339
|
+
init_func(simulation)
|
377
340
|
logger.info("Starting Simulation...")
|
378
341
|
llm_log_lists = []
|
379
342
|
mqtt_log_lists = []
|
380
343
|
simulator_log_lists = []
|
381
344
|
agent_time_log_lists = []
|
382
|
-
for step in config
|
345
|
+
for step in config.prop_workflow:
|
383
346
|
logger.info(
|
384
|
-
f"Running step: type: {step
|
347
|
+
f"Running step: type: {step.type} - description: {step.description}"
|
385
348
|
)
|
386
|
-
if step
|
387
|
-
"
|
388
|
-
|
389
|
-
|
390
|
-
"survey",
|
391
|
-
"intervene",
|
392
|
-
"pause",
|
393
|
-
"resume",
|
394
|
-
"function",
|
395
|
-
]:
|
396
|
-
raise ValueError(f"Invalid step type: {step['type']}")
|
397
|
-
if step["type"] == "run":
|
349
|
+
if step.type not in {t.value for t in WorkflowType}:
|
350
|
+
raise ValueError(f"Invalid step type: {step.type}")
|
351
|
+
if step.type == WorkflowType.RUN:
|
352
|
+
_days = cast(int, step.days)
|
398
353
|
llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = (
|
399
|
-
await simulation.run(
|
354
|
+
await simulation.run(_days)
|
400
355
|
)
|
401
356
|
llm_log_lists.extend(llm_log_list)
|
402
357
|
mqtt_log_lists.extend(mqtt_log_list)
|
403
358
|
simulator_log_lists.extend(simulator_log_list)
|
404
359
|
agent_time_log_lists.extend(agent_time_log_list)
|
405
|
-
elif step
|
406
|
-
times = step.
|
360
|
+
elif step.type == WorkflowType.STEP:
|
361
|
+
times = cast(int, step.times)
|
407
362
|
for _ in range(times):
|
408
363
|
(
|
409
364
|
llm_log_list,
|
@@ -415,12 +370,13 @@ class AgentSimulation:
|
|
415
370
|
mqtt_log_lists.extend(mqtt_log_list)
|
416
371
|
simulator_log_lists.extend(simulator_log_list)
|
417
372
|
agent_time_log_lists.extend(agent_time_log_list)
|
418
|
-
elif step
|
373
|
+
elif step.type == WorkflowType.PAUSE:
|
419
374
|
await simulation.pause_simulator()
|
420
|
-
elif step
|
375
|
+
elif step.type == WorkflowType.RESUME:
|
421
376
|
await simulation.resume_simulator()
|
422
|
-
|
423
|
-
|
377
|
+
elif step.type == WorkflowType.FUNCTION:
|
378
|
+
_func = cast(Callable, step.func)
|
379
|
+
await _func(simulation)
|
424
380
|
logger.info("Simulation finished")
|
425
381
|
return llm_log_lists, mqtt_log_lists, simulator_log_lists, agent_time_log_lists
|
426
382
|
|
@@ -467,13 +423,13 @@ class AgentSimulation:
|
|
467
423
|
return self._message_interceptors[0] # type:ignore
|
468
424
|
|
469
425
|
async def _save_exp_info(self) -> None:
|
470
|
-
"""
|
426
|
+
"""Async save experiment info to YAML file"""
|
471
427
|
try:
|
472
428
|
if self.enable_avro:
|
473
429
|
with open(self._exp_info_file, "w") as f:
|
474
430
|
yaml.dump(self._exp_info, f)
|
475
431
|
except Exception as e:
|
476
|
-
logger.error(f"Avro
|
432
|
+
logger.error(f"Avro save experiment info failed: {str(e)}")
|
477
433
|
try:
|
478
434
|
if self.enable_pgsql:
|
479
435
|
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
@@ -486,51 +442,51 @@ class AgentSimulation:
|
|
486
442
|
pg_exp_info
|
487
443
|
)
|
488
444
|
except Exception as e:
|
489
|
-
logger.error(f"PostgreSQL
|
445
|
+
logger.error(f"PostgreSQL save experiment info failed: {str(e)}")
|
490
446
|
|
491
447
|
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
492
448
|
self._exp_updated_time = datetime.now(timezone.utc)
|
493
|
-
"""
|
449
|
+
"""Update experiment status and save"""
|
494
450
|
self._exp_info["status"] = status
|
495
451
|
self._exp_info["error"] = error
|
496
452
|
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
497
453
|
await self._save_exp_info()
|
498
454
|
|
499
455
|
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
500
|
-
"""
|
456
|
+
"""Monitor experiment status and update
|
501
457
|
|
502
458
|
- **Args**:
|
503
|
-
stop_event:
|
459
|
+
stop_event: event for notifying monitor task to stop
|
504
460
|
"""
|
505
461
|
try:
|
506
462
|
while not stop_event.is_set():
|
507
|
-
#
|
508
|
-
#
|
463
|
+
# update experiment status
|
464
|
+
# assume all groups' cur_day and cur_t are synchronized, take the first one
|
509
465
|
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
510
466
|
self._exp_info["cur_t"] = (
|
511
467
|
await self._simulator.get_simulator_second_from_start_of_day()
|
512
468
|
)
|
513
469
|
await self._save_exp_info()
|
514
470
|
|
515
|
-
await asyncio.sleep(1) #
|
471
|
+
await asyncio.sleep(1) # avoid too frequent updates
|
516
472
|
except asyncio.CancelledError:
|
517
|
-
#
|
473
|
+
# normal cancellation, no special handling needed
|
518
474
|
pass
|
519
475
|
except Exception as e:
|
520
|
-
logger.error(f"
|
476
|
+
logger.error(f"Error monitoring experiment status: {str(e)}")
|
521
477
|
raise
|
522
478
|
|
523
479
|
async def __aenter__(self):
|
524
|
-
"""
|
480
|
+
"""Async context manager entry"""
|
525
481
|
return self
|
526
482
|
|
527
483
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
528
|
-
"""
|
484
|
+
"""Async context manager exit"""
|
529
485
|
if exc_type is not None:
|
530
|
-
#
|
486
|
+
# if exception occurs, update status to error
|
531
487
|
await self._update_exp_status(3, str(exc_val))
|
532
488
|
elif self._exp_info["status"] != 3:
|
533
|
-
#
|
489
|
+
# if no exception and status is not error, update to finished
|
534
490
|
await self._update_exp_status(2)
|
535
491
|
|
536
492
|
async def pause_simulator(self):
|
@@ -582,7 +538,6 @@ class AgentSimulation:
|
|
582
538
|
- `None`
|
583
539
|
"""
|
584
540
|
self.agent_count = agent_count
|
585
|
-
|
586
541
|
if len(self.agent_class) != len(agent_count):
|
587
542
|
raise ValueError("The length of agent_class and agent_count does not match")
|
588
543
|
|
@@ -591,14 +546,14 @@ class AgentSimulation:
|
|
591
546
|
if memory_config_func is None:
|
592
547
|
memory_config_func = self.default_memory_config_func # type:ignore
|
593
548
|
|
594
|
-
#
|
549
|
+
# use thread pool to create AgentGroup
|
595
550
|
group_creation_params = []
|
596
551
|
|
597
|
-
#
|
552
|
+
# process institution agent and citizen agent
|
598
553
|
institution_params = []
|
599
554
|
citizen_params = []
|
600
555
|
|
601
|
-
#
|
556
|
+
# collect all parameters
|
602
557
|
for i in range(len(self.agent_class)):
|
603
558
|
agent_class = self.agent_class[i]
|
604
559
|
agent_count_i = agent_count[agent_class]
|
@@ -607,8 +562,8 @@ class AgentSimulation:
|
|
607
562
|
agent_class, self.default_memory_config_func[agent_class] # type:ignore
|
608
563
|
)
|
609
564
|
|
610
|
-
if self.
|
611
|
-
config_file = self.
|
565
|
+
if self.agent_class_configs is not None:
|
566
|
+
config_file = self.agent_class_configs.get(agent_class, None)
|
612
567
|
else:
|
613
568
|
config_file = None
|
614
569
|
|
@@ -621,7 +576,7 @@ class AgentSimulation:
|
|
621
576
|
(agent_class, agent_count_i, memory_config_func_i, config_file)
|
622
577
|
)
|
623
578
|
|
624
|
-
#
|
579
|
+
# process institution group
|
625
580
|
if institution_params:
|
626
581
|
total_institution_count = sum(p[1] for p in institution_params)
|
627
582
|
num_institution_groups = (
|
@@ -638,7 +593,7 @@ class AgentSimulation:
|
|
638
593
|
memory_config_funcs = {}
|
639
594
|
config_files = {}
|
640
595
|
|
641
|
-
#
|
596
|
+
# assign each type of institution agent to current group
|
642
597
|
curr_start = start_idx
|
643
598
|
for agent_class, count, mem_func, conf_file in institution_params:
|
644
599
|
if curr_start < count:
|
@@ -658,7 +613,7 @@ class AgentSimulation:
|
|
658
613
|
)
|
659
614
|
)
|
660
615
|
|
661
|
-
#
|
616
|
+
# process citizen group
|
662
617
|
if citizen_params:
|
663
618
|
total_citizen_count = sum(p[1] for p in citizen_params)
|
664
619
|
num_citizen_groups = (total_citizen_count + group_size - 1) // group_size
|
@@ -673,7 +628,7 @@ class AgentSimulation:
|
|
673
628
|
memory_config_funcs = {}
|
674
629
|
config_files = {}
|
675
630
|
|
676
|
-
#
|
631
|
+
# assign each type of citizen agent to current group
|
677
632
|
curr_start = start_idx
|
678
633
|
for agent_class, count, mem_func, conf_file in citizen_params:
|
679
634
|
if curr_start < count:
|
@@ -692,18 +647,18 @@ class AgentSimulation:
|
|
692
647
|
config_files,
|
693
648
|
)
|
694
649
|
)
|
695
|
-
#
|
696
|
-
|
697
|
-
if
|
650
|
+
# initialize mlflow connection
|
651
|
+
metric_config = self.config.prop_metric_request
|
652
|
+
if metric_config is not None and metric_config.mlflow is not None:
|
698
653
|
mlflow_run_id, _ = init_mlflow_connection(
|
699
654
|
experiment_uuid=self.exp_id,
|
700
|
-
config=
|
655
|
+
config=metric_config.mlflow,
|
701
656
|
mlflow_run_name=f"{self.exp_name}_{1000*int(time.time())}",
|
702
657
|
experiment_name=self.exp_name,
|
703
658
|
)
|
704
659
|
else:
|
705
660
|
mlflow_run_id = None
|
706
|
-
#
|
661
|
+
# create table
|
707
662
|
if self.enable_pgsql:
|
708
663
|
_num_workers = min(1, pg_sql_writers)
|
709
664
|
create_pg_tables(
|
@@ -726,7 +681,7 @@ class AgentSimulation:
|
|
726
681
|
self._message_abort_listening_queue = _queue = None
|
727
682
|
_interceptor_blocks = message_interceptor_blocks
|
728
683
|
_black_list = [] if social_black_list is None else social_black_list
|
729
|
-
_llm_config = self.config.
|
684
|
+
_llm_config = self.config.llm_request
|
730
685
|
if message_interceptor_blocks is not None:
|
731
686
|
_num_interceptors = min(1, message_interceptors)
|
732
687
|
self._message_interceptors = _interceptors = [
|
@@ -751,7 +706,7 @@ class AgentSimulation:
|
|
751
706
|
group_name,
|
752
707
|
config_file,
|
753
708
|
) in enumerate(group_creation_params):
|
754
|
-
#
|
709
|
+
# create async task directly
|
755
710
|
group = AgentGroup.remote(
|
756
711
|
agent_class,
|
757
712
|
number_of_agents,
|
@@ -774,7 +729,7 @@ class AgentSimulation:
|
|
774
729
|
)
|
775
730
|
creation_tasks.append((group_name, group))
|
776
731
|
|
777
|
-
#
|
732
|
+
# update data structure
|
778
733
|
for group_name, group in creation_tasks:
|
779
734
|
self._groups[group_name] = group
|
780
735
|
group_agent_uuids = ray.get(group.get_agent_uuids.remote())
|
@@ -793,7 +748,7 @@ class AgentSimulation:
|
|
793
748
|
self._type2group[agent_type] = []
|
794
749
|
self._type2group[agent_type].append(group)
|
795
750
|
|
796
|
-
#
|
751
|
+
# parallel initialize all groups' agents
|
797
752
|
await self.resume_simulator()
|
798
753
|
init_tasks = []
|
799
754
|
for group in self._groups.values():
|
@@ -1063,7 +1018,7 @@ class AgentSimulation:
|
|
1063
1018
|
except Exception as e:
|
1064
1019
|
import traceback
|
1065
1020
|
|
1066
|
-
logger.error(f"
|
1021
|
+
logger.error(f"Simulation error: {str(e)}\n{traceback.format_exc()}")
|
1067
1022
|
raise RuntimeError(str(e)) from e
|
1068
1023
|
|
1069
1024
|
async def run(
|
@@ -1093,27 +1048,19 @@ class AgentSimulation:
|
|
1093
1048
|
agent_time_log_lists = []
|
1094
1049
|
try:
|
1095
1050
|
self._exp_info["num_day"] += day
|
1096
|
-
await self._update_exp_status(1) #
|
1051
|
+
await self._update_exp_status(1) # Update status to running
|
1097
1052
|
|
1098
|
-
#
|
1053
|
+
# Create stop event
|
1099
1054
|
stop_event = asyncio.Event()
|
1100
|
-
#
|
1055
|
+
# Create monitor task
|
1101
1056
|
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
1102
1057
|
|
1103
1058
|
try:
|
1104
1059
|
end_day = self._simulator_day + day
|
1105
1060
|
while True:
|
1106
1061
|
current_day = await self._simulator.get_simulator_day()
|
1107
|
-
# check whether insert agents
|
1108
|
-
need_insert_agents = False
|
1109
1062
|
if current_day > self._simulator_day:
|
1110
1063
|
self._simulator_day = current_day
|
1111
|
-
# if need_insert_agents:
|
1112
|
-
# insert_tasks = []
|
1113
|
-
# for group in self._groups.values():
|
1114
|
-
# insert_tasks.append(group.insert_agent.remote())
|
1115
|
-
# await asyncio.gather(*insert_tasks)
|
1116
|
-
|
1117
1064
|
if current_day >= end_day: # type:ignore
|
1118
1065
|
break
|
1119
1066
|
(
|
@@ -1127,12 +1074,12 @@ class AgentSimulation:
|
|
1127
1074
|
simulator_log_lists.extend(simulator_log_list)
|
1128
1075
|
agent_time_log_lists.extend(agent_time_log_list)
|
1129
1076
|
finally:
|
1130
|
-
#
|
1077
|
+
# Set stop event
|
1131
1078
|
stop_event.set()
|
1132
|
-
#
|
1079
|
+
# Wait for monitor task to finish
|
1133
1080
|
await monitor_task
|
1134
1081
|
|
1135
|
-
#
|
1082
|
+
# Update experiment status after successful run
|
1136
1083
|
await self._update_exp_status(2)
|
1137
1084
|
return (
|
1138
1085
|
llm_log_lists,
|
@@ -1141,7 +1088,7 @@ class AgentSimulation:
|
|
1141
1088
|
agent_time_log_lists,
|
1142
1089
|
)
|
1143
1090
|
except Exception as e:
|
1144
|
-
error_msg = f"
|
1091
|
+
error_msg = f"Simulation error: {str(e)}"
|
1145
1092
|
logger.error(error_msg)
|
1146
1093
|
await self._update_exp_status(3, error_msg)
|
1147
1094
|
raise RuntimeError(error_msg) from e
|