pycityagent 2.0.0a93__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a95__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. pycityagent/agent/agent.py +5 -5
  2. pycityagent/agent/agent_base.py +1 -2
  3. pycityagent/cityagent/__init__.py +6 -5
  4. pycityagent/cityagent/bankagent.py +2 -2
  5. pycityagent/cityagent/blocks/__init__.py +4 -4
  6. pycityagent/cityagent/blocks/cognition_block.py +7 -4
  7. pycityagent/cityagent/blocks/economy_block.py +227 -135
  8. pycityagent/cityagent/blocks/mobility_block.py +70 -27
  9. pycityagent/cityagent/blocks/needs_block.py +11 -12
  10. pycityagent/cityagent/blocks/other_block.py +2 -2
  11. pycityagent/cityagent/blocks/plan_block.py +22 -24
  12. pycityagent/cityagent/blocks/social_block.py +15 -17
  13. pycityagent/cityagent/blocks/utils.py +3 -2
  14. pycityagent/cityagent/firmagent.py +1 -1
  15. pycityagent/cityagent/governmentagent.py +1 -1
  16. pycityagent/cityagent/initial.py +1 -1
  17. pycityagent/cityagent/memory_config.py +0 -1
  18. pycityagent/cityagent/message_intercept.py +7 -8
  19. pycityagent/cityagent/nbsagent.py +1 -1
  20. pycityagent/cityagent/societyagent.py +1 -2
  21. pycityagent/configs/__init__.py +18 -0
  22. pycityagent/configs/exp_config.py +202 -0
  23. pycityagent/configs/sim_config.py +251 -0
  24. pycityagent/configs/utils.py +17 -0
  25. pycityagent/environment/__init__.py +2 -0
  26. pycityagent/{economy → environment/economy}/econ_client.py +14 -32
  27. pycityagent/environment/sim/sim_env.py +17 -24
  28. pycityagent/environment/simulator.py +36 -113
  29. pycityagent/llm/__init__.py +1 -2
  30. pycityagent/llm/llm.py +60 -166
  31. pycityagent/memory/memory.py +13 -12
  32. pycityagent/message/message_interceptor.py +5 -4
  33. pycityagent/message/messager.py +3 -5
  34. pycityagent/metrics/__init__.py +1 -1
  35. pycityagent/metrics/mlflow_client.py +20 -17
  36. pycityagent/pycityagent-sim +0 -0
  37. pycityagent/simulation/agentgroup.py +17 -19
  38. pycityagent/simulation/simulation.py +157 -210
  39. pycityagent/survey/manager.py +0 -2
  40. pycityagent/utils/__init__.py +3 -0
  41. pycityagent/utils/config_const.py +20 -0
  42. pycityagent/workflow/__init__.py +1 -2
  43. pycityagent/workflow/block.py +0 -3
  44. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/METADATA +7 -24
  45. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/RECORD +50 -46
  46. pycityagent/llm/llmconfig.py +0 -18
  47. /pycityagent/{economy → environment/economy}/__init__.py +0 -0
  48. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/LICENSE +0 -0
  49. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/WHEEL +0 -0
  50. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/entry_points.txt +0 -0
  51. {pycityagent-2.0.0a93.dist-info → pycityagent-2.0.0a95.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import inspect
2
3
  import json
3
4
  import logging
4
5
  import time
@@ -6,7 +7,7 @@ import uuid
6
7
  from collections.abc import Callable
7
8
  from datetime import datetime, timezone
8
9
  from pathlib import Path
9
- from typing import Any, Literal, Optional, Type, Union
10
+ from typing import Any, Literal, Optional, Type, Union, cast
10
11
 
11
12
  import ray
12
13
  import yaml
@@ -23,19 +24,21 @@ from ..cityagent.memory_config import (memory_config_bank, memory_config_firm,
23
24
  from ..cityagent.message_intercept import (EdgeMessageBlock,
24
25
  MessageBlockListener,
25
26
  PointMessageBlock)
26
- from ..economy.econ_client import EconomyClient
27
- from ..environment import Simulator
27
+ from ..configs import ExpConfig, SimConfig
28
+ from ..environment import EconomyClient, Simulator
28
29
  from ..llm import SimpleEmbedding
29
30
  from ..message import (MessageBlockBase, MessageBlockListenerBase,
30
31
  MessageInterceptor, Messager)
31
32
  from ..metrics import init_mlflow_connection
32
33
  from ..metrics.mlflow_client import MlflowClient
33
34
  from ..survey import Survey
34
- from ..utils import SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
35
+ from ..utils import (SURVEY_SENDER_UUID, TO_UPDATE_EXP_INFO_KEYS_AND_TYPES,
36
+ WorkflowType)
35
37
  from .agentgroup import AgentGroup
36
38
  from .storage.pg import PgWriter, create_pg_tables
37
39
 
38
40
  logger = logging.getLogger("pycityagent")
41
+ ExpConfig.model_rebuild() # rebuild the schema due to circular import
39
42
 
40
43
 
41
44
  __all__ = ["AgentSimulation"]
@@ -58,9 +61,9 @@ class AgentSimulation:
58
61
 
59
62
  def __init__(
60
63
  self,
61
- config: dict,
64
+ config: SimConfig,
62
65
  agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
63
- agent_config_file: Optional[dict] = None,
66
+ agent_class_configs: Optional[dict] = None,
64
67
  metric_extractors: Optional[list[tuple[int, Callable]]] = None,
65
68
  enable_institution: bool = True,
66
69
  agent_prefix: str = "agent_",
@@ -75,10 +78,10 @@ class AgentSimulation:
75
78
  it can include a predefined set of institutional agents. If specific agent classes are provided, those will be used instead.
76
79
 
77
80
  - **Args**:
78
- - `config` (dict): The main configuration dictionary for the simulation.
81
+ - `config` (SimConfig): The main configuration for the simulation.
79
82
  - `agent_class` (Union[None, Type[Agent], List[Type[Agent]]], optional):
80
83
  Either a single agent class or a list of agent classes to instantiate. Defaults to None, which implies a default set of agents.
81
- - `agent_config_file` (Optional[dict], optional): An optional configuration file or dictionary used to initialize agents. Defaults to None.
84
+ - `agent_class_configs` (Optional[dict], optional): An optional configuration dict used to initialize agents. Defaults to None.
82
85
  - `metric_extractors` (Optional[List[Tuple[int, Callable]]], optional):
83
86
  A list of tuples containing intervals and callables for extracting metrics from the simulation. Defaults to None.
84
87
  - `enable_institution` (bool, optional): Flag indicating whether institutional agents should be included in the simulation. Defaults to True.
@@ -115,41 +118,18 @@ class AgentSimulation:
115
118
  }
116
119
  else:
117
120
  self.agent_class = [agent_class]
118
- self.agent_config_file = agent_config_file
121
+ self.agent_class_configs = agent_class_configs
119
122
  self.logging_level = logging_level
120
123
  self.config = config
121
124
  self.exp_name = exp_name
122
- _simulator_config = config["simulator_request"].get("simulator", {})
123
- if "server" in _simulator_config:
124
- raise ValueError(f"Passing Traffic Simulation address is not supported!")
125
- simulator = Simulator(config["simulator_request"], create_map=True)
125
+ simulator = Simulator(config, create_map=True)
126
126
  self._simulator = simulator
127
127
  self._map_ref = self._simulator.map
128
128
  server_addr = self._simulator.get_server_addr()
129
- config["simulator_request"]["simulator"]["server"] = server_addr
130
- self._economy_client = EconomyClient(
131
- config["simulator_request"]["simulator"]["server"]
132
- )
129
+ config.SetServerAddress(server_addr)
130
+ self._economy_client = EconomyClient(server_addr)
133
131
  if enable_institution:
134
132
  self._economy_addr = economy_addr = server_addr
135
- if economy_addr is None:
136
- raise ValueError(
137
- f"`simulator` not provided in `simulator_request`, thus unable to activate economy!"
138
- )
139
- _req_dict: dict = self.config["simulator_request"]
140
- if "economy" in _req_dict:
141
- if _req_dict["economy"] is None:
142
- _req_dict["economy"] = {}
143
- if "server" in _req_dict["economy"]:
144
- raise ValueError(
145
- f"Passing Economy Simulation address is not supported!"
146
- )
147
- else:
148
- _req_dict["economy"]["server"] = economy_addr
149
- else:
150
- _req_dict["economy"] = {
151
- "server": economy_addr,
152
- }
153
133
  self.agent_prefix = agent_prefix
154
134
  self._groups: dict[str, AgentGroup] = {} # type:ignore
155
135
  self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
@@ -161,42 +141,45 @@ class AgentSimulation:
161
141
  self._loop = asyncio.get_event_loop()
162
142
  self._total_steps = 0
163
143
  self._simulator_day = 0
164
- # self._last_asyncio_pg_task = None # SQL写入的IO隐藏到计算任务后
144
+ # self._last_asyncio_pg_task = None # hide SQL write IO to calculation task
165
145
 
146
+ mqtt_config = config.prop_mqtt
166
147
  self._messager = Messager.remote(
167
- hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
168
- port=config["simulator_request"]["mqtt"]["port"],
169
- username=config["simulator_request"]["mqtt"].get("username", None),
170
- password=config["simulator_request"]["mqtt"].get("password", None),
148
+ hostname=mqtt_config.server, # type:ignore
149
+ port=mqtt_config.port,
150
+ username=mqtt_config.username,
151
+ password=mqtt_config.password,
171
152
  )
172
153
 
173
154
  # storage
174
- _storage_config: dict[str, Any] = config.get("storage", {})
175
- if _storage_config is None:
176
- _storage_config = {}
177
155
 
178
156
  # avro
179
- _avro_config: dict[str, Any] = _storage_config.get("avro", {})
180
- self._enable_avro = _avro_config.get("enabled", False)
181
- if not self._enable_avro:
157
+ avro_config = config.prop_avro_config
158
+ if avro_config is not None:
159
+ self._enable_avro: bool = avro_config.enabled # type:ignore
160
+ if not self._enable_avro:
161
+ self._avro_path = None
162
+ logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
163
+ else:
164
+ self._avro_path = Path(avro_config.path) / f"{self.exp_id}"
165
+ self._avro_path.mkdir(parents=True, exist_ok=True)
166
+ else:
182
167
  self._avro_path = None
183
168
  logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
184
- else:
185
- self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
186
- self._avro_path.mkdir(parents=True, exist_ok=True)
169
+ self._enable_avro = False
187
170
 
188
171
  # mlflow
189
- _mlflow_config: dict[str, Any] = config.get("metric_request", {}).get("mlflow")
190
- if _mlflow_config:
172
+ metric_config = config.prop_metric_request
173
+ if metric_config is not None and metric_config.mlflow is not None:
191
174
  logger.info(f"-----Creating Mlflow client...")
192
175
  mlflow_run_id, _ = init_mlflow_connection(
193
- config=_mlflow_config,
176
+ config=metric_config.mlflow,
194
177
  experiment_uuid=self.exp_id,
195
178
  mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
196
179
  experiment_name=self.exp_name,
197
180
  )
198
181
  self.mlflow_client = MlflowClient(
199
- config=_mlflow_config,
182
+ config=metric_config.mlflow,
200
183
  experiment_uuid=self.exp_id,
201
184
  mlflow_run_name=f"EXP_{exp_name}_{1000*int(time.time())}",
202
185
  experiment_name=exp_name,
@@ -209,35 +192,36 @@ class AgentSimulation:
209
192
  self.metric_extractors = None
210
193
 
211
194
  # pg
212
- _pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
213
- self._enable_pgsql = _pgsql_config.get("enabled", False)
214
- if not self._enable_pgsql:
215
- logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
216
- self._pgsql_dsn = ""
195
+ pgsql_config = config.prop_postgre_sql_config
196
+ if pgsql_config is not None:
197
+ self._enable_pgsql: bool = pgsql_config.enabled # type:ignore
198
+ if not self._enable_pgsql:
199
+ logger.warning(
200
+ "PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE"
201
+ )
202
+ self._pgsql_dsn = ""
203
+ else:
204
+ self._pgsql_dsn = pgsql_config.dsn
217
205
  else:
218
- self._pgsql_dsn = (
219
- _pgsql_config["data_source_name"]
220
- if "data_source_name" in _pgsql_config
221
- else _pgsql_config["dsn"]
222
- )
206
+ self._enable_pgsql = False
223
207
 
224
- # 添加实验信息相关的属性
208
+ # add experiment info related properties
225
209
  self._exp_created_time = datetime.now(timezone.utc)
226
210
  self._exp_updated_time = datetime.now(timezone.utc)
227
211
  self._exp_info = {
228
212
  "id": self.exp_id,
229
213
  "name": exp_name,
230
- "num_day": 0, # 将在 run 方法中更新
214
+ "num_day": 0, # will be updated in run method
231
215
  "status": 0,
232
216
  "cur_day": 0,
233
217
  "cur_t": 0.0,
234
- "config": json.dumps(config),
218
+ "config": str(config.model_dump()),
235
219
  "error": "",
236
220
  "created_at": self._exp_created_time.isoformat(),
237
221
  "updated_at": self._exp_updated_time.isoformat(),
238
222
  }
239
223
 
240
- # 创建异步任务保存实验信息
224
+ # create async task to save experiment info
241
225
  if self._enable_avro:
242
226
  assert self._avro_path is not None
243
227
  self._exp_info_file = self._avro_path / "experiment_info.yaml"
@@ -245,7 +229,7 @@ class AgentSimulation:
245
229
  yaml.dump(self._exp_info, f)
246
230
 
247
231
  @classmethod
248
- async def run_from_config(cls, config: dict):
232
+ async def run_from_config(cls, config: ExpConfig, sim_config: SimConfig):
249
233
  """Directly run from config file
250
234
  Basic config file should contain:
251
235
  - simulation_config: file_path
@@ -280,130 +264,101 @@ class AgentSimulation:
280
264
  - logging_level: Optional[int]
281
265
  - exp_name: Optional[str]
282
266
  """
283
- # required key check
284
- if "simulation_config" not in config:
285
- raise ValueError("simulation_config is required")
286
- if "agent_config" not in config:
287
- raise ValueError("agent_config is required")
288
- if "workflow" not in config:
289
- raise ValueError("workflow is required")
290
- import yaml
291
-
292
- logger.info("Loading config file...")
293
- with open(config["simulation_config"], "r") as f:
294
- simulation_config = yaml.safe_load(f)
267
+
268
+ agent_config = config.prop_agent_config
269
+
295
270
  logger.info("Creating AgentSimulation Task...")
296
271
  simulation = cls(
297
- config=simulation_config,
298
- agent_config_file=config["agent_config"].get("agent_config_file", None),
299
- metric_extractors=config["agent_config"].get("metric_extractors", None),
300
- enable_institution=config.get("enable_institution", True),
301
- exp_name=config.get("exp_name", "default_experiment"),
302
- logging_level=config.get("logging_level", logging.WARNING),
303
- )
304
- environment = config.get(
305
- "environment",
306
- {
307
- "weather": "The weather is normal",
308
- "crime": "The crime rate is low",
309
- "pollution": "The pollution level is low",
310
- "temperature": "The temperature is normal",
311
- "day": "Workday",
312
- },
272
+ config=sim_config,
273
+ agent_class_configs=agent_config.agent_class_configs,
274
+ metric_extractors=config.prop_metric_extractors,
275
+ enable_institution=agent_config.enable_institution,
276
+ exp_name=config.exp_name,
277
+ logging_level=config.logging_level,
313
278
  )
279
+ environment = config.prop_environment.model_dump()
314
280
  simulation._simulator.set_environment(environment)
315
281
  logger.info("Initializing Agents...")
316
- agent_count = {
317
- SocietyAgent: config["agent_config"].get("number_of_citizen", 0),
318
- FirmAgent: config["agent_config"].get("number_of_firm", 0),
319
- GovernmentAgent: config["agent_config"].get("number_of_government", 0),
320
- BankAgent: config["agent_config"].get("number_of_bank", 0),
321
- NBSAgent: config["agent_config"].get("number_of_nbs", 0),
282
+ agent_count: dict[type[Agent], int] = {
283
+ SocietyAgent: agent_config.number_of_citizen,
284
+ FirmAgent: agent_config.number_of_firm,
285
+ GovernmentAgent: agent_config.number_of_government,
286
+ BankAgent: agent_config.number_of_bank,
287
+ NBSAgent: agent_config.number_of_nbs,
322
288
  }
323
289
  if agent_count.get(SocietyAgent, 0) == 0:
324
290
  raise ValueError("number_of_citizen is required")
325
291
 
326
292
  # support MessageInterceptor
327
- if "message_intercept" in config:
328
- _intercept_config = config["message_intercept"]
329
- _mode = _intercept_config.get("mode", "point")
330
- if _mode == "point":
331
- _kwargs = {
332
- k: v
333
- for k, v in _intercept_config.items()
334
- if k
335
- in {
336
- "max_violation_time",
337
- }
338
- }
339
- _interceptor_blocks = [PointMessageBlock(**_kwargs)]
340
- elif _mode == "edge":
341
- _kwargs = {
342
- k: v
343
- for k, v in _intercept_config.items()
344
- if k
345
- in {
346
- "max_violation_time",
347
- }
348
- }
349
- _interceptor_blocks = [EdgeMessageBlock(**_kwargs)]
293
+ if config.message_intercept is not None:
294
+ intercept_config = config.message_intercept
295
+ if intercept_config.mode == "point":
296
+ _interceptor_blocks = [
297
+ PointMessageBlock(
298
+ max_violation_time=intercept_config.max_violation_time
299
+ )
300
+ ]
301
+ elif intercept_config.mode == "edge":
302
+ _interceptor_blocks = [
303
+ EdgeMessageBlock(
304
+ max_violation_time=intercept_config.max_violation_time
305
+ )
306
+ ]
350
307
  else:
351
- raise ValueError(f"Unsupported interception mode `{_mode}!`")
308
+ raise ValueError(
309
+ f"Unsupported interception mode `{intercept_config.mode}!`"
310
+ )
352
311
  _message_intercept_kwargs = {
353
312
  "message_interceptor_blocks": _interceptor_blocks,
354
313
  "message_listener": MessageBlockListener(),
355
314
  }
356
315
  else:
357
316
  _message_intercept_kwargs = {}
317
+ embedding_model = agent_config.embedding_model
318
+ if embedding_model is None:
319
+ embedding_model = SimpleEmbedding()
358
320
  await simulation.init_agents(
359
321
  agent_count=agent_count,
360
- group_size=config["agent_config"].get("group_size", 10000),
361
- embedding_model=config["agent_config"].get(
362
- "embedding_model", SimpleEmbedding()
363
- ),
364
- memory_config_func=config["agent_config"].get("memory_config_func", None),
365
- memory_config_init_func=config["agent_config"].get(
366
- "memory_config_init_func", None
367
- ),
322
+ group_size=agent_config.group_size,
323
+ embedding_model=embedding_model,
324
+ memory_config_func=agent_config.memory_config_func,
325
+ memory_config_init_func=agent_config.memory_config_init_func,
368
326
  **_message_intercept_kwargs,
369
327
  environment=environment,
370
- llm_semaphore=config.get("llm_semaphore", 200),
328
+ llm_semaphore=config.llm_semaphore,
371
329
  )
372
330
  logger.info("Running Init Functions...")
373
- for init_func in config["agent_config"].get(
374
- "init_func", [bind_agent_info, initialize_social_network]
375
- ):
376
- await init_func(simulation)
331
+ init_funcs = agent_config.init_func
332
+ if init_funcs is None:
333
+ init_funcs = [bind_agent_info, initialize_social_network]
334
+ for init_func in init_funcs:
335
+ if inspect.iscoroutinefunction(init_func):
336
+ await init_func(simulation)
337
+ else:
338
+ init_func = cast(Callable, init_func)
339
+ init_func(simulation)
377
340
  logger.info("Starting Simulation...")
378
341
  llm_log_lists = []
379
342
  mqtt_log_lists = []
380
343
  simulator_log_lists = []
381
344
  agent_time_log_lists = []
382
- for step in config["workflow"]:
345
+ for step in config.prop_workflow:
383
346
  logger.info(
384
- f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}"
347
+ f"Running step: type: {step.type} - description: {step.description}"
385
348
  )
386
- if step["type"] not in [
387
- "run",
388
- "step",
389
- "interview",
390
- "survey",
391
- "intervene",
392
- "pause",
393
- "resume",
394
- "function",
395
- ]:
396
- raise ValueError(f"Invalid step type: {step['type']}")
397
- if step["type"] == "run":
349
+ if step.type not in {t.value for t in WorkflowType}:
350
+ raise ValueError(f"Invalid step type: {step.type}")
351
+ if step.type == WorkflowType.RUN:
352
+ _days = cast(int, step.days)
398
353
  llm_log_list, mqtt_log_list, simulator_log_list, agent_time_log_list = (
399
- await simulation.run(step.get("days", 1))
354
+ await simulation.run(_days)
400
355
  )
401
356
  llm_log_lists.extend(llm_log_list)
402
357
  mqtt_log_lists.extend(mqtt_log_list)
403
358
  simulator_log_lists.extend(simulator_log_list)
404
359
  agent_time_log_lists.extend(agent_time_log_list)
405
- elif step["type"] == "step":
406
- times = step.get("times", 1)
360
+ elif step.type == WorkflowType.STEP:
361
+ times = cast(int, step.times)
407
362
  for _ in range(times):
408
363
  (
409
364
  llm_log_list,
@@ -415,12 +370,13 @@ class AgentSimulation:
415
370
  mqtt_log_lists.extend(mqtt_log_list)
416
371
  simulator_log_lists.extend(simulator_log_list)
417
372
  agent_time_log_lists.extend(agent_time_log_list)
418
- elif step["type"] == "pause":
373
+ elif step.type == WorkflowType.PAUSE:
419
374
  await simulation.pause_simulator()
420
- elif step["type"] == "resume":
375
+ elif step.type == WorkflowType.RESUME:
421
376
  await simulation.resume_simulator()
422
- else:
423
- await step["func"](simulation)
377
+ elif step.type == WorkflowType.FUNCTION:
378
+ _func = cast(Callable, step.func)
379
+ await _func(simulation)
424
380
  logger.info("Simulation finished")
425
381
  return llm_log_lists, mqtt_log_lists, simulator_log_lists, agent_time_log_lists
426
382
 
@@ -467,13 +423,13 @@ class AgentSimulation:
467
423
  return self._message_interceptors[0] # type:ignore
468
424
 
469
425
  async def _save_exp_info(self) -> None:
470
- """异步保存实验信息到YAML文件"""
426
+ """Async save experiment info to YAML file"""
471
427
  try:
472
428
  if self.enable_avro:
473
429
  with open(self._exp_info_file, "w") as f:
474
430
  yaml.dump(self._exp_info, f)
475
431
  except Exception as e:
476
- logger.error(f"Avro保存实验信息失败: {str(e)}")
432
+ logger.error(f"Avro save experiment info failed: {str(e)}")
477
433
  try:
478
434
  if self.enable_pgsql:
479
435
  worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
@@ -486,51 +442,51 @@ class AgentSimulation:
486
442
  pg_exp_info
487
443
  )
488
444
  except Exception as e:
489
- logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
445
+ logger.error(f"PostgreSQL save experiment info failed: {str(e)}")
490
446
 
491
447
  async def _update_exp_status(self, status: int, error: str = "") -> None:
492
448
  self._exp_updated_time = datetime.now(timezone.utc)
493
- """更新实验状态并保存"""
449
+ """Update experiment status and save"""
494
450
  self._exp_info["status"] = status
495
451
  self._exp_info["error"] = error
496
452
  self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
497
453
  await self._save_exp_info()
498
454
 
499
455
  async def _monitor_exp_status(self, stop_event: asyncio.Event):
500
- """监控实验状态并更新
456
+ """Monitor experiment status and update
501
457
 
502
458
  - **Args**:
503
- stop_event: 用于通知监控任务停止的事件
459
+ stop_event: event for notifying monitor task to stop
504
460
  """
505
461
  try:
506
462
  while not stop_event.is_set():
507
- # 更新实验状态
508
- # 假设所有group的cur_daycur_t是同步的,取第一个即可
463
+ # update experiment status
464
+ # assume all groups' cur_day and cur_t are synchronized, take the first one
509
465
  self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
510
466
  self._exp_info["cur_t"] = (
511
467
  await self._simulator.get_simulator_second_from_start_of_day()
512
468
  )
513
469
  await self._save_exp_info()
514
470
 
515
- await asyncio.sleep(1) # 避免过于频繁的更新
471
+ await asyncio.sleep(1) # avoid too frequent updates
516
472
  except asyncio.CancelledError:
517
- # 正常取消,不需要特殊处理
473
+ # normal cancellation, no special handling needed
518
474
  pass
519
475
  except Exception as e:
520
- logger.error(f"监控实验状态时发生错误: {str(e)}")
476
+ logger.error(f"Error monitoring experiment status: {str(e)}")
521
477
  raise
522
478
 
523
479
  async def __aenter__(self):
524
- """异步上下文管理器入口"""
480
+ """Async context manager entry"""
525
481
  return self
526
482
 
527
483
  async def __aexit__(self, exc_type, exc_val, exc_tb):
528
- """异步上下文管理器出口"""
484
+ """Async context manager exit"""
529
485
  if exc_type is not None:
530
- # 如果发生异常,更新状态为错误
486
+ # if exception occurs, update status to error
531
487
  await self._update_exp_status(3, str(exc_val))
532
488
  elif self._exp_info["status"] != 3:
533
- # 如果没有发生异常且状态不是错误,则更新为完成
489
+ # if no exception and status is not error, update to finished
534
490
  await self._update_exp_status(2)
535
491
 
536
492
  async def pause_simulator(self):
@@ -582,7 +538,6 @@ class AgentSimulation:
582
538
  - `None`
583
539
  """
584
540
  self.agent_count = agent_count
585
-
586
541
  if len(self.agent_class) != len(agent_count):
587
542
  raise ValueError("The length of agent_class and agent_count does not match")
588
543
 
@@ -591,14 +546,14 @@ class AgentSimulation:
591
546
  if memory_config_func is None:
592
547
  memory_config_func = self.default_memory_config_func # type:ignore
593
548
 
594
- # 使用线程池并行创建 AgentGroup
549
+ # use thread pool to create AgentGroup
595
550
  group_creation_params = []
596
551
 
597
- # 分别处理机构智能体和普通智能体
552
+ # process institution agent and citizen agent
598
553
  institution_params = []
599
554
  citizen_params = []
600
555
 
601
- # 收集所有参数
556
+ # collect all parameters
602
557
  for i in range(len(self.agent_class)):
603
558
  agent_class = self.agent_class[i]
604
559
  agent_count_i = agent_count[agent_class]
@@ -607,8 +562,8 @@ class AgentSimulation:
607
562
  agent_class, self.default_memory_config_func[agent_class] # type:ignore
608
563
  )
609
564
 
610
- if self.agent_config_file is not None:
611
- config_file = self.agent_config_file.get(agent_class, None)
565
+ if self.agent_class_configs is not None:
566
+ config_file = self.agent_class_configs.get(agent_class, None)
612
567
  else:
613
568
  config_file = None
614
569
 
@@ -621,7 +576,7 @@ class AgentSimulation:
621
576
  (agent_class, agent_count_i, memory_config_func_i, config_file)
622
577
  )
623
578
 
624
- # 处理机构智能体组
579
+ # process institution group
625
580
  if institution_params:
626
581
  total_institution_count = sum(p[1] for p in institution_params)
627
582
  num_institution_groups = (
@@ -638,7 +593,7 @@ class AgentSimulation:
638
593
  memory_config_funcs = {}
639
594
  config_files = {}
640
595
 
641
- # 分配每种类型的机构智能体到当前组
596
+ # assign each type of institution agent to current group
642
597
  curr_start = start_idx
643
598
  for agent_class, count, mem_func, conf_file in institution_params:
644
599
  if curr_start < count:
@@ -658,7 +613,7 @@ class AgentSimulation:
658
613
  )
659
614
  )
660
615
 
661
- # 处理普通智能体组
616
+ # process citizen group
662
617
  if citizen_params:
663
618
  total_citizen_count = sum(p[1] for p in citizen_params)
664
619
  num_citizen_groups = (total_citizen_count + group_size - 1) // group_size
@@ -673,7 +628,7 @@ class AgentSimulation:
673
628
  memory_config_funcs = {}
674
629
  config_files = {}
675
630
 
676
- # 分配每种类型的普通智能体到当前组
631
+ # assign each type of citizen agent to current group
677
632
  curr_start = start_idx
678
633
  for agent_class, count, mem_func, conf_file in citizen_params:
679
634
  if curr_start < count:
@@ -692,18 +647,18 @@ class AgentSimulation:
692
647
  config_files,
693
648
  )
694
649
  )
695
- # 初始化mlflow连接
696
- _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
697
- if _mlflow_config:
650
+ # initialize mlflow connection
651
+ metric_config = self.config.prop_metric_request
652
+ if metric_config is not None and metric_config.mlflow is not None:
698
653
  mlflow_run_id, _ = init_mlflow_connection(
699
654
  experiment_uuid=self.exp_id,
700
- config=_mlflow_config,
655
+ config=metric_config.mlflow,
701
656
  mlflow_run_name=f"{self.exp_name}_{1000*int(time.time())}",
702
657
  experiment_name=self.exp_name,
703
658
  )
704
659
  else:
705
660
  mlflow_run_id = None
706
- # 建表
661
+ # create table
707
662
  if self.enable_pgsql:
708
663
  _num_workers = min(1, pg_sql_writers)
709
664
  create_pg_tables(
@@ -726,7 +681,7 @@ class AgentSimulation:
726
681
  self._message_abort_listening_queue = _queue = None
727
682
  _interceptor_blocks = message_interceptor_blocks
728
683
  _black_list = [] if social_black_list is None else social_black_list
729
- _llm_config = self.config.get("llm_request", {})
684
+ _llm_config = self.config.llm_request
730
685
  if message_interceptor_blocks is not None:
731
686
  _num_interceptors = min(1, message_interceptors)
732
687
  self._message_interceptors = _interceptors = [
@@ -751,7 +706,7 @@ class AgentSimulation:
751
706
  group_name,
752
707
  config_file,
753
708
  ) in enumerate(group_creation_params):
754
- # 直接创建异步任务
709
+ # create async task directly
755
710
  group = AgentGroup.remote(
756
711
  agent_class,
757
712
  number_of_agents,
@@ -774,7 +729,7 @@ class AgentSimulation:
774
729
  )
775
730
  creation_tasks.append((group_name, group))
776
731
 
777
- # 更新数据结构
732
+ # update data structure
778
733
  for group_name, group in creation_tasks:
779
734
  self._groups[group_name] = group
780
735
  group_agent_uuids = ray.get(group.get_agent_uuids.remote())
@@ -793,7 +748,7 @@ class AgentSimulation:
793
748
  self._type2group[agent_type] = []
794
749
  self._type2group[agent_type].append(group)
795
750
 
796
- # 并行初始化所有组的agents
751
+ # parallel initialize all groups' agents
797
752
  await self.resume_simulator()
798
753
  init_tasks = []
799
754
  for group in self._groups.values():
@@ -1063,7 +1018,7 @@ class AgentSimulation:
1063
1018
  except Exception as e:
1064
1019
  import traceback
1065
1020
 
1066
- logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
1021
+ logger.error(f"Simulation error: {str(e)}\n{traceback.format_exc()}")
1067
1022
  raise RuntimeError(str(e)) from e
1068
1023
 
1069
1024
  async def run(
@@ -1093,27 +1048,19 @@ class AgentSimulation:
1093
1048
  agent_time_log_lists = []
1094
1049
  try:
1095
1050
  self._exp_info["num_day"] += day
1096
- await self._update_exp_status(1) # 更新状态为运行中
1051
+ await self._update_exp_status(1) # Update status to running
1097
1052
 
1098
- # 创建停止事件
1053
+ # Create stop event
1099
1054
  stop_event = asyncio.Event()
1100
- # 创建监控任务
1055
+ # Create monitor task
1101
1056
  monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
1102
1057
 
1103
1058
  try:
1104
1059
  end_day = self._simulator_day + day
1105
1060
  while True:
1106
1061
  current_day = await self._simulator.get_simulator_day()
1107
- # check whether insert agents
1108
- need_insert_agents = False
1109
1062
  if current_day > self._simulator_day:
1110
1063
  self._simulator_day = current_day
1111
- # if need_insert_agents:
1112
- # insert_tasks = []
1113
- # for group in self._groups.values():
1114
- # insert_tasks.append(group.insert_agent.remote())
1115
- # await asyncio.gather(*insert_tasks)
1116
-
1117
1064
  if current_day >= end_day: # type:ignore
1118
1065
  break
1119
1066
  (
@@ -1127,12 +1074,12 @@ class AgentSimulation:
1127
1074
  simulator_log_lists.extend(simulator_log_list)
1128
1075
  agent_time_log_lists.extend(agent_time_log_list)
1129
1076
  finally:
1130
- # 设置停止事件
1077
+ # Set stop event
1131
1078
  stop_event.set()
1132
- # 等待监控任务结束
1079
+ # Wait for monitor task to finish
1133
1080
  await monitor_task
1134
1081
 
1135
- # 运行成功后更新状态
1082
+ # Update experiment status after successful run
1136
1083
  await self._update_exp_status(2)
1137
1084
  return (
1138
1085
  llm_log_lists,
@@ -1141,7 +1088,7 @@ class AgentSimulation:
1141
1088
  agent_time_log_lists,
1142
1089
  )
1143
1090
  except Exception as e:
1144
- error_msg = f"模拟器运行错误: {str(e)}"
1091
+ error_msg = f"Simulation error: {str(e)}"
1145
1092
  logger.error(error_msg)
1146
1093
  await self._update_exp_status(3, error_msg)
1147
1094
  raise RuntimeError(error_msg) from e