pycityagent 2.0.0a47__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a48__cp312-cp312-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (33) hide show
  1. pycityagent/__init__.py +3 -2
  2. pycityagent/agent.py +109 -4
  3. pycityagent/cityagent/__init__.py +20 -0
  4. pycityagent/cityagent/bankagent.py +54 -0
  5. pycityagent/cityagent/blocks/__init__.py +20 -0
  6. pycityagent/cityagent/blocks/cognition_block.py +304 -0
  7. pycityagent/cityagent/blocks/dispatcher.py +78 -0
  8. pycityagent/cityagent/blocks/economy_block.py +356 -0
  9. pycityagent/cityagent/blocks/mobility_block.py +258 -0
  10. pycityagent/cityagent/blocks/needs_block.py +305 -0
  11. pycityagent/cityagent/blocks/other_block.py +103 -0
  12. pycityagent/cityagent/blocks/plan_block.py +309 -0
  13. pycityagent/cityagent/blocks/social_block.py +345 -0
  14. pycityagent/cityagent/blocks/time_block.py +116 -0
  15. pycityagent/cityagent/blocks/utils.py +66 -0
  16. pycityagent/cityagent/firmagent.py +75 -0
  17. pycityagent/cityagent/governmentagent.py +60 -0
  18. pycityagent/cityagent/initial.py +98 -0
  19. pycityagent/cityagent/memory_config.py +202 -0
  20. pycityagent/cityagent/nbsagent.py +92 -0
  21. pycityagent/cityagent/societyagent.py +291 -0
  22. pycityagent/memory/memory.py +0 -18
  23. pycityagent/message/messager.py +6 -3
  24. pycityagent/simulation/agentgroup.py +118 -37
  25. pycityagent/simulation/simulation.py +311 -316
  26. pycityagent/workflow/block.py +66 -1
  27. pycityagent/workflow/tool.py +15 -11
  28. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/METADATA +2 -2
  29. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/RECORD +33 -14
  30. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/LICENSE +0 -0
  31. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/WHEEL +0 -0
  32. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/entry_points.txt +0 -0
  33. {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,16 @@
1
1
  import asyncio
2
2
  import json
3
3
  import logging
4
- import os
5
- import random
6
4
  import time
7
5
  import uuid
8
6
  from collections.abc import Callable, Sequence
9
- from concurrent.futures import ThreadPoolExecutor
10
7
  from datetime import datetime, timezone
11
8
  from pathlib import Path
12
- from typing import Any, Optional, Union
9
+ from typing import Any, Optional, Type, Union
13
10
 
14
- import pycityproto.city.economy.v2.economy_pb2 as economyv2
15
11
  import ray
16
12
  import yaml
17
13
  from langchain_core.embeddings import Embeddings
18
- from mosstool.map._map_util.const import AOI_START_ID
19
14
 
20
15
  from ..agent import Agent, InstitutionAgent
21
16
  from ..environment.simulator import Simulator
@@ -27,17 +22,20 @@ from ..survey import Survey
27
22
  from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
28
23
  from .agentgroup import AgentGroup
29
24
  from .storage.pg import PgWriter, create_pg_tables
25
+ from ..cityagent import SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent, memory_config_societyagent, memory_config_government, memory_config_firm, memory_config_bank, memory_config_nbs
26
+ from ..cityagent.initial import bind_agent_info, initialize_social_network
30
27
 
31
28
  logger = logging.getLogger("pycityagent")
32
29
 
33
-
34
30
  class AgentSimulation:
35
31
  """城市智能体模拟器"""
36
32
 
37
33
  def __init__(
38
34
  self,
39
- agent_class: Union[type[Agent], list[type[Agent]]],
40
35
  config: dict,
36
+ agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
37
+ agent_config_file: Optional[dict] = None,
38
+ enable_economy: bool = True,
41
39
  agent_prefix: str = "agent_",
42
40
  exp_name: str = "default_experiment",
43
41
  logging_level: int = logging.WARNING,
@@ -52,20 +50,34 @@ class AgentSimulation:
52
50
  self.exp_id = str(uuid.uuid4())
53
51
  if isinstance(agent_class, list):
54
52
  self.agent_class = agent_class
53
+ elif agent_class is None:
54
+ if enable_economy:
55
+ self.agent_class = [SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent]
56
+ self.default_memory_config_func = [
57
+ memory_config_societyagent,
58
+ memory_config_firm,
59
+ memory_config_bank,
60
+ memory_config_nbs,
61
+ memory_config_government,
62
+ ]
63
+ else:
64
+ self.agent_class = [SocietyAgent]
65
+ self.default_memory_config_func = [memory_config_societyagent]
55
66
  else:
56
67
  self.agent_class = [agent_class]
68
+ self.agent_config_file = agent_config_file
57
69
  self.logging_level = logging_level
58
70
  self.config = config
59
71
  self.exp_name = exp_name
60
72
  self._simulator = Simulator(config["simulator_request"])
61
73
  self.agent_prefix = agent_prefix
62
- self._agents: dict[uuid.UUID, Agent] = {}
63
74
  self._groups: dict[str, AgentGroup] = {} # type:ignore
64
- self._agent_uuid2group: dict[uuid.UUID, AgentGroup] = {} # type:ignore
65
- self._agent_uuids: list[uuid.UUID] = []
66
- self._user_chat_topics: dict[uuid.UUID, str] = {}
67
- self._user_survey_topics: dict[uuid.UUID, str] = {}
68
- self._user_interview_topics: dict[uuid.UUID, str] = {}
75
+ self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
76
+ self._agent_uuids: list[str] = []
77
+ self._type2group: dict[Type[Agent], AgentGroup] = {}
78
+ self._user_chat_topics: dict[str, str] = {}
79
+ self._user_survey_topics: dict[str, str] = {}
80
+ self._user_interview_topics: dict[str, str] = {}
69
81
  self._loop = asyncio.get_event_loop()
70
82
  # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
71
83
 
@@ -126,6 +138,80 @@ class AgentSimulation:
126
138
  with open(self._exp_info_file, "w") as f:
127
139
  yaml.dump(self._exp_info, f)
128
140
 
141
+ @classmethod
142
+ async def run_from_config(cls, config: dict):
143
+ """Directly run from config file
144
+ Basic config file should contain:
145
+ - simulation_config: file_path
146
+ - agent_config:
147
+ - agent_config_file: Optional[dict]
148
+ - memory_config_func: Optional[Union[Callable, list[Callable]]]
149
+ - init_func: Optional[list[Callable[AgentSimulation, None]]]
150
+ - group_size: Optional[int]
151
+ - embedding_model: Optional[EmbeddingModel]
152
+ - number_of_citizen: required, int
153
+ - number_of_firm: required, int
154
+ - number_of_government: required, int
155
+ - number_of_bank: required, int
156
+ - number_of_nbs: required, int
157
+ - workflow:
158
+ - list[Step]
159
+ - Step:
160
+ - type: str, "step", "run", "interview", "survey", "intervene"
161
+ - day: int if type is "run", else None
162
+ - time: int if type is "step", else None
163
+ - description: Optional[str], description of the step
164
+ - step_func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
165
+ - logging_level: Optional[int]
166
+ - exp_name: Optional[str]
167
+ """
168
+ # required key check
169
+ if "simulation_config" not in config:
170
+ raise ValueError("simulation_config is required")
171
+ if "agent_config" not in config:
172
+ raise ValueError("agent_config is required")
173
+ if "workflow" not in config:
174
+ raise ValueError("workflow is required")
175
+ import yaml
176
+ logger.info("Loading config file...")
177
+ with open(config["simulation_config"], "r") as f:
178
+ simulation_config = yaml.safe_load(f)
179
+ logger.info("Creating AgentSimulation Task...")
180
+ simulation = cls(
181
+ config=simulation_config,
182
+ agent_config_file=config["agent_config"].get("agent_config_file", None),
183
+ exp_name=config.get("exp_name", "default_experiment"),
184
+ logging_level=config.get("logging_level", logging.WARNING),
185
+ )
186
+ logger.info("Initializing Agents...")
187
+ agent_count = []
188
+ agent_count.append(config["agent_config"]["number_of_citizen"])
189
+ agent_count.append(config["agent_config"]["number_of_firm"])
190
+ agent_count.append(config["agent_config"]["number_of_government"])
191
+ agent_count.append(config["agent_config"]["number_of_bank"])
192
+ agent_count.append(config["agent_config"]["number_of_nbs"])
193
+ await simulation.init_agents(
194
+ agent_count=agent_count,
195
+ group_size=config["agent_config"].get("group_size", 10000),
196
+ embedding_model=config["agent_config"].get("embedding_model", SimpleEmbedding()),
197
+ memory_config_func=config["agent_config"].get("memory_config_func", None),
198
+ )
199
+ logger.info("Running Init Functions...")
200
+ for init_func in config["agent_config"].get("init_func", [bind_agent_info, initialize_social_network]):
201
+ await init_func(simulation)
202
+ logger.info("Starting Simulation...")
203
+ for step in config["workflow"]:
204
+ logger.info(f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}")
205
+ if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
206
+ raise ValueError(f"Invalid step type: {step['type']}")
207
+ if step["type"] == "run":
208
+ await simulation.run(step.get("day", 1))
209
+ elif step["type"] == "step":
210
+ await simulation.step(step.get("time", 1))
211
+ else:
212
+ await step["step_func"](simulation)
213
+ logger.info("Simulation finished")
214
+
129
215
  @property
130
216
  def enable_avro(
131
217
  self,
@@ -138,10 +224,6 @@ class AgentSimulation:
138
224
  ) -> bool:
139
225
  return self._enable_pgsql
140
226
 
141
- @property
142
- def agents(self) -> dict[uuid.UUID, Agent]:
143
- return self._agents
144
-
145
227
  @property
146
228
  def avro_path(
147
229
  self,
@@ -159,42 +241,82 @@ class AgentSimulation:
159
241
  @property
160
242
  def agent_uuid2group(self):
161
243
  return self._agent_uuid2group
244
+
245
+ @property
246
+ def messager(self):
247
+ return self._messager
248
+
249
+ async def _save_exp_info(self) -> None:
250
+ """异步保存实验信息到YAML文件"""
251
+ try:
252
+ if self.enable_avro:
253
+ with open(self._exp_info_file, "w") as f:
254
+ yaml.dump(self._exp_info, f)
255
+ except Exception as e:
256
+ logger.error(f"Avro保存实验信息失败: {str(e)}")
257
+ try:
258
+ if self.enable_pgsql:
259
+ worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
260
+ pg_exp_info = {
261
+ k: self._exp_info[k] for (k, _) in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
262
+ }
263
+ pg_exp_info["created_at"] = self._exp_created_time
264
+ pg_exp_info["updated_at"] = self._exp_updated_time
265
+ await worker.async_update_exp_info.remote( # type:ignore
266
+ pg_exp_info
267
+ )
268
+ except Exception as e:
269
+ logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
162
270
 
163
- def create_remote_group(
164
- self,
165
- group_name: str,
166
- agents: list[Agent],
167
- config: dict,
168
- exp_id: str,
169
- exp_name: str,
170
- enable_avro: bool,
171
- avro_path: Path,
172
- enable_pgsql: bool,
173
- pgsql_writer: ray.ObjectRef,
174
- mlflow_run_id: str = None, # type: ignore
175
- embedding_model: Embeddings = None, # type: ignore
176
- logging_level: int = logging.WARNING,
177
- ):
178
- """创建远程组"""
179
- group = AgentGroup.remote(
180
- agents,
181
- config,
182
- exp_id,
183
- exp_name,
184
- enable_avro,
185
- avro_path,
186
- enable_pgsql,
187
- pgsql_writer,
188
- mlflow_run_id,
189
- embedding_model,
190
- logging_level,
191
- )
192
- return group_name, group, agents
271
+ async def _update_exp_status(self, status: int, error: str = "") -> None:
272
+ self._exp_updated_time = datetime.now(timezone.utc)
273
+ """更新实验状态并保存"""
274
+ self._exp_info["status"] = status
275
+ self._exp_info["error"] = error
276
+ self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
277
+ await self._save_exp_info()
278
+
279
+ async def _monitor_exp_status(self, stop_event: asyncio.Event):
280
+ """监控实验状态并更新
281
+
282
+ Args:
283
+ stop_event: 用于通知监控任务停止的事件
284
+ """
285
+ try:
286
+ while not stop_event.is_set():
287
+ # 更新实验状态
288
+ # 假设所有group的cur_day和cur_t是同步的,取第一个即可
289
+ self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
290
+ self._exp_info["cur_t"] = (
291
+ await self._simulator.get_simulator_second_from_start_of_day()
292
+ )
293
+ await self._save_exp_info()
294
+
295
+ await asyncio.sleep(1) # 避免过于频繁的更新
296
+ except asyncio.CancelledError:
297
+ # 正常取消,不需要特殊处理
298
+ pass
299
+ except Exception as e:
300
+ logger.error(f"监控实验状态时发生错误: {str(e)}")
301
+ raise
302
+
303
+ async def __aenter__(self):
304
+ """异步上下文管理器入口"""
305
+ return self
306
+
307
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
308
+ """异步上下文管理器出口"""
309
+ if exc_type is not None:
310
+ # 如果发生异常,更新状态为错误
311
+ await self._update_exp_status(3, str(exc_val))
312
+ elif self._exp_info["status"] != 3:
313
+ # 如果没有发生异常且状态不是错误,则更新为完成
314
+ await self._update_exp_status(2)
193
315
 
194
316
  async def init_agents(
195
317
  self,
196
318
  agent_count: Union[int, list[int]],
197
- group_size: int = 1000,
319
+ group_size: int = 10000,
198
320
  pg_sql_writers: int = 32,
199
321
  embedding_model: Embeddings = SimpleEmbedding(),
200
322
  memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
@@ -206,7 +328,6 @@ class AgentSimulation:
206
328
  group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
207
329
  memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
208
330
  """
209
- await self._messager.connect.remote()
210
331
  if not isinstance(agent_count, list):
211
332
  agent_count = [agent_count]
212
333
 
@@ -217,67 +338,104 @@ class AgentSimulation:
217
338
  logger.warning(
218
339
  "memory_config_func is None, using default memory config function"
219
340
  )
220
- memory_config_func = []
221
- for agent_class in self.agent_class:
222
- if issubclass(agent_class, InstitutionAgent):
223
- memory_config_func.append(self.default_memory_config_institution)
224
- else:
225
- memory_config_func.append(self.default_memory_config_citizen)
341
+ memory_config_func = self.default_memory_config_func
342
+
226
343
  elif not isinstance(memory_config_func, list):
227
344
  memory_config_func = [memory_config_func]
228
345
 
229
346
  if len(memory_config_func) != len(agent_count):
230
347
  logger.warning(
231
- "memory_config_funcagent_count的长度不一致,使用默认的memory_config"
348
+ "The length of memory_config_func and agent_count does not match, using default memory_config"
232
349
  )
233
- memory_config_func = []
234
- for agent_class in self.agent_class:
235
- if issubclass(agent_class, InstitutionAgent):
236
- memory_config_func.append(self.default_memory_config_institution)
237
- else:
238
- memory_config_func.append(self.default_memory_config_citizen)
350
+ memory_config_func = self.default_memory_config_func
239
351
  # 使用线程池并行创建 AgentGroup
240
352
  group_creation_params = []
241
- class_init_index = 0
242
353
 
243
- # 首先收集所有需要创建的组的参数
354
+ # 分别处理机构智能体和普通智能体
355
+ institution_params = []
356
+ citizen_params = []
357
+
358
+ # 收集所有参数
244
359
  for i in range(len(self.agent_class)):
245
360
  agent_class = self.agent_class[i]
246
361
  agent_count_i = agent_count[i]
247
362
  memory_config_func_i = memory_config_func[i]
248
- for j in range(agent_count_i):
249
- agent_name = f"{self.agent_prefix}_{i}_{j}"
250
-
251
- # 获取Memory配置
252
- extra_attributes, profile, base = memory_config_func_i()
253
- memory = Memory(config=extra_attributes, profile=profile, base=base)
254
-
255
- # 创建智能体时传入Memory配置
256
- agent = agent_class(
257
- name=agent_name,
258
- memory=memory,
259
- )
260
-
261
- self._agents[agent._uuid] = agent # type:ignore
262
- self._agent_uuids.append(agent._uuid) # type:ignore
263
-
264
- # 计算需要的组数,向上取整以处理不足一组的情况
265
- num_group = (agent_count_i + group_size - 1) // group_size
266
-
267
- for k in range(num_group):
268
- start_idx = class_init_index + k * group_size
269
- end_idx = min(
270
- class_init_index + (k + 1) * group_size, # 修正了索引计算
271
- class_init_index + agent_count_i,
272
- )
273
-
274
- agents = list(self._agents.values())[start_idx:end_idx]
275
- group_name = f"AgentType_{i}_Group_{k}"
276
-
277
- # 收集创建参数
278
- group_creation_params.append((group_name, agents))
279
-
280
- class_init_index += agent_count_i
363
+
364
+ if self.agent_config_file is not None:
365
+ config_file = self.agent_config_file.get(agent_class, None)
366
+ else:
367
+ config_file = None
368
+
369
+ if issubclass(agent_class, InstitutionAgent):
370
+ institution_params.append((agent_class, agent_count_i, memory_config_func_i, config_file))
371
+ else:
372
+ citizen_params.append((agent_class, agent_count_i, memory_config_func_i, config_file))
373
+
374
+ # 处理机构智能体组
375
+ if institution_params:
376
+ total_institution_count = sum(p[1] for p in institution_params)
377
+ num_institution_groups = (total_institution_count + group_size - 1) // group_size
378
+
379
+ for k in range(num_institution_groups):
380
+ start_idx = k * group_size
381
+ remaining = total_institution_count - start_idx
382
+ number_of_agents = min(remaining, group_size)
383
+
384
+ agent_classes = []
385
+ agent_counts = []
386
+ memory_config_funcs = []
387
+ config_files = []
388
+
389
+ # 分配每种类型的机构智能体到当前组
390
+ curr_start = start_idx
391
+ for agent_class, count, mem_func, conf_file in institution_params:
392
+ if curr_start < count:
393
+ agent_classes.append(agent_class)
394
+ agent_counts.append(min(count - curr_start, number_of_agents))
395
+ memory_config_funcs.append(mem_func)
396
+ config_files.append(conf_file)
397
+ curr_start = max(0, curr_start - count)
398
+
399
+ group_creation_params.append((
400
+ agent_classes,
401
+ agent_counts,
402
+ memory_config_funcs,
403
+ f"InstitutionGroup_{k}",
404
+ config_files
405
+ ))
406
+
407
+ # 处理普通智能体组
408
+ if citizen_params:
409
+ total_citizen_count = sum(p[1] for p in citizen_params)
410
+ num_citizen_groups = (total_citizen_count + group_size - 1) // group_size
411
+
412
+ for k in range(num_citizen_groups):
413
+ start_idx = k * group_size
414
+ remaining = total_citizen_count - start_idx
415
+ number_of_agents = min(remaining, group_size)
416
+
417
+ agent_classes = []
418
+ agent_counts = []
419
+ memory_config_funcs = []
420
+ config_files = []
421
+
422
+ # 分配每种类型的普通智能体到当前组
423
+ curr_start = start_idx
424
+ for agent_class, count, mem_func, conf_file in citizen_params:
425
+ if curr_start < count:
426
+ agent_classes.append(agent_class)
427
+ agent_counts.append(min(count - curr_start, number_of_agents))
428
+ memory_config_funcs.append(mem_func)
429
+ config_files.append(conf_file)
430
+ curr_start = max(0, curr_start - count)
431
+
432
+ group_creation_params.append((
433
+ agent_classes,
434
+ agent_counts,
435
+ memory_config_funcs,
436
+ f"CitizenGroup_{k}",
437
+ config_files
438
+ ))
281
439
 
282
440
  # 初始化mlflow连接
283
441
  _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
@@ -303,12 +461,14 @@ class AgentSimulation:
303
461
  else:
304
462
  _num_workers = 1
305
463
  self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
306
- # 收集所有创建组的参数
464
+
307
465
  creation_tasks = []
308
- for i, (group_name, agents) in enumerate(group_creation_params):
466
+ for i, (agent_class, number_of_agents, memory_config_function_group, group_name, config_file) in enumerate(group_creation_params):
309
467
  # 直接创建异步任务
310
468
  group = AgentGroup.remote(
311
- agents,
469
+ agent_class,
470
+ number_of_agents,
471
+ memory_config_function_group,
312
472
  self.config,
313
473
  self.exp_id,
314
474
  self.exp_name,
@@ -319,27 +479,31 @@ class AgentSimulation:
319
479
  mlflow_run_id, # type:ignore
320
480
  embedding_model,
321
481
  self.logging_level,
482
+ config_file,
322
483
  )
323
- creation_tasks.append((group_name, group, agents))
484
+ creation_tasks.append((group_name, group))
324
485
 
325
486
  # 更新数据结构
326
- for group_name, group, agents in creation_tasks:
487
+ for group_name, group in creation_tasks:
327
488
  self._groups[group_name] = group
328
- for agent in agents:
329
- self._agent_uuid2group[agent._uuid] = group
489
+ group_agent_uuids = ray.get(group.get_agent_uuids.remote())
490
+ for agent_uuid in group_agent_uuids:
491
+ self._agent_uuid2group[agent_uuid] = group
492
+ self._user_chat_topics[agent_uuid] = f"exps/{self.exp_id}/agents/{agent_uuid}/user-chat"
493
+ self._user_survey_topics[agent_uuid] = (
494
+ f"exps/{self.exp_id}/agents/{agent_uuid}/user-survey"
495
+ )
496
+ group_agent_type = ray.get(group.get_agent_type.remote())
497
+ for agent_type in group_agent_type:
498
+ if agent_type not in self._type2group:
499
+ self._type2group[agent_type] = []
500
+ self._type2group[agent_type].append(group)
330
501
 
331
502
  # 并行初始化所有组的agents
332
503
  init_tasks = []
333
504
  for group in self._groups.values():
334
505
  init_tasks.append(group.init_agents.remote())
335
- await asyncio.gather(*init_tasks)
336
-
337
- # 设置用户主题
338
- for uuid, agent in self._agents.items():
339
- self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
340
- self._user_survey_topics[uuid] = (
341
- f"exps/{self.exp_id}/agents/{uuid}/user-survey"
342
- )
506
+ ray.get(init_tasks)
343
507
 
344
508
  async def gather(self, content: str):
345
509
  """收集智能体的特定信息"""
@@ -347,145 +511,42 @@ class AgentSimulation:
347
511
  for group in self._groups.values():
348
512
  gather_tasks.append(group.gather.remote(content))
349
513
  return await asyncio.gather(*gather_tasks)
514
+
515
+ async def filter(self,
516
+ types: Optional[list[Type[Agent]]] = None,
517
+ keys: Optional[list[str]] = None,
518
+ values: Optional[list[Any]] = None) -> list[str]:
519
+ """过滤出指定类型的智能体"""
520
+ if not types and not keys and not values:
521
+ return self._agent_uuids
522
+ group_to_filter = []
523
+ for t in types:
524
+ if t in self._type2group:
525
+ group_to_filter.extend(self._type2group[t])
526
+ else:
527
+ raise ValueError(f"type {t} not found in simulation")
528
+ filtered_uuids = []
529
+ if keys:
530
+ if len(keys) != len(values):
531
+ raise ValueError("the length of key and value does not match")
532
+ for group in group_to_filter:
533
+ filtered_uuids.extend(await group.filter.remote(types, keys, values))
534
+ return filtered_uuids
535
+ else:
536
+ for group in group_to_filter:
537
+ filtered_uuids.extend(await group.filter.remote(types))
538
+ return filtered_uuids
350
539
 
351
- async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
540
+ async def update(self, target_agent_uuid: str, target_key: str, content: Any):
352
541
  """更新指定智能体的记忆"""
353
542
  group = self._agent_uuid2group[target_agent_uuid]
354
543
  await group.update.remote(target_agent_uuid, target_key, content)
355
544
 
356
- def default_memory_config_institution(self):
357
- """默认的Memory配置函数"""
358
- EXTRA_ATTRIBUTES = {
359
- "type": (
360
- int,
361
- random.choice(
362
- [
363
- economyv2.ORG_TYPE_BANK,
364
- economyv2.ORG_TYPE_GOVERNMENT,
365
- economyv2.ORG_TYPE_FIRM,
366
- economyv2.ORG_TYPE_NBS,
367
- economyv2.ORG_TYPE_UNSPECIFIED,
368
- ]
369
- ),
370
- ),
371
- "nominal_gdp": (list, [], True),
372
- "real_gdp": (list, [], True),
373
- "unemployment": (list, [], True),
374
- "wages": (list, [], True),
375
- "prices": (list, [], True),
376
- "inventory": (int, 0, True),
377
- "price": (float, 0.0, True),
378
- "interest_rate": (float, 0.0, True),
379
- "bracket_cutoffs": (list, [], True),
380
- "bracket_rates": (list, [], True),
381
- "employees": (list, [], True),
382
- "customers": (list, [], True),
383
- }
384
- return EXTRA_ATTRIBUTES, None, None
385
-
386
- def default_memory_config_citizen(self):
387
- """默认的Memory配置函数"""
388
- EXTRA_ATTRIBUTES = {
389
- # 需求信息
390
- "needs": (
391
- dict,
392
- {
393
- "hungry": random.random(), # 饥饿感
394
- "tired": random.random(), # 疲劳感
395
- "safe": random.random(), # 安全需
396
- "social": random.random(), # 社会需求
397
- },
398
- True,
399
- ),
400
- "current_need": (str, "none", True),
401
- "current_plan": (list, [], True),
402
- "current_step": (dict, {"intention": "", "type": ""}, True),
403
- "execution_context": (dict, {}, True),
404
- "plan_history": (list, [], True),
405
- # cognition
406
- "fulfillment": (int, 5, True),
407
- "emotion": (int, 5, True),
408
- "attitude": (int, 5, True),
409
- "thought": (str, "Currently nothing good or bad is happening", True),
410
- "emotion_types": (str, "Relief", True),
411
- "incident": (list, [], True),
412
- # social
413
- "friends": (list, [], True),
414
- }
415
-
416
- PROFILE = {
417
- "name": "unknown",
418
- "gender": random.choice(["male", "female"]),
419
- "education": random.choice(
420
- ["Doctor", "Master", "Bachelor", "College", "High School"]
421
- ),
422
- "consumption": random.choice(["sightly low", "low", "medium", "high"]),
423
- "occupation": random.choice(
424
- [
425
- "Student",
426
- "Teacher",
427
- "Doctor",
428
- "Engineer",
429
- "Manager",
430
- "Businessman",
431
- "Artist",
432
- "Athlete",
433
- "Other",
434
- ]
435
- ),
436
- "age": random.randint(18, 65),
437
- "skill": random.choice(
438
- [
439
- "Good at problem-solving",
440
- "Good at communication",
441
- "Good at creativity",
442
- "Good at teamwork",
443
- "Other",
444
- ]
445
- ),
446
- "family_consumption": random.choice(["low", "medium", "high"]),
447
- "personality": random.choice(
448
- ["outgoint", "introvert", "ambivert", "extrovert"]
449
- ),
450
- "income": str(random.randint(1000, 10000)),
451
- "currency": random.randint(10000, 100000),
452
- "residence": random.choice(["city", "suburb", "rural"]),
453
- "race": random.choice(
454
- [
455
- "Chinese",
456
- "American",
457
- "British",
458
- "French",
459
- "German",
460
- "Japanese",
461
- "Korean",
462
- "Russian",
463
- "Other",
464
- ]
465
- ),
466
- "religion": random.choice(
467
- ["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]
468
- ),
469
- "marital_status": random.choice(
470
- ["not married", "married", "divorced", "widowed"]
471
- ),
472
- }
473
-
474
- BASE = {
475
- "home": {
476
- "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
477
- },
478
- "work": {
479
- "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
480
- },
481
- }
482
-
483
- return EXTRA_ATTRIBUTES, PROFILE, BASE
484
-
485
545
  async def send_survey(
486
- self, survey: Survey, agent_uuids: Optional[list[uuid.UUID]] = None
546
+ self, survey: Survey, agent_uuids: Optional[list[str]] = None
487
547
  ):
488
548
  """发送问卷"""
549
+ await self.messager.connect()
489
550
  survey_dict = survey.to_dict()
490
551
  if agent_uuids is None:
491
552
  agent_uuids = self._agent_uuids
@@ -499,12 +560,13 @@ class AgentSimulation:
499
560
  }
500
561
  for uuid in agent_uuids:
501
562
  topic = self._user_survey_topics[uuid]
502
- await self._messager.send_message.remote(topic, payload)
563
+ await self.messager.send_message(topic, payload)
503
564
 
504
565
  async def send_interview_message(
505
- self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
566
+ self, content: str, agent_uuids: Union[str, list[str]]
506
567
  ):
507
- """发送面试消息"""
568
+ """发送采访消息"""
569
+ await self.messager.connect()
508
570
  _date_time = datetime.now(timezone.utc)
509
571
  payload = {
510
572
  "from": "none",
@@ -512,11 +574,11 @@ class AgentSimulation:
512
574
  "timestamp": int(_date_time.timestamp() * 1000),
513
575
  "_date_time": _date_time,
514
576
  }
515
- if not isinstance(agent_uuids, Sequence):
577
+ if not isinstance(agent_uuids, list):
516
578
  agent_uuids = [agent_uuids]
517
579
  for uuid in agent_uuids:
518
580
  topic = self._user_chat_topics[uuid]
519
- await self._messager.send_message.remote(topic, payload)
581
+ await self.messager.send_message(topic, payload)
520
582
 
521
583
  async def step(self):
522
584
  """运行一步, 即每个智能体执行一次forward"""
@@ -529,60 +591,6 @@ class AgentSimulation:
529
591
  logger.error(f"运行错误: {str(e)}")
530
592
  raise
531
593
 
532
- async def _save_exp_info(self) -> None:
533
- """异步保存实验信息到YAML文件"""
534
- try:
535
- if self.enable_avro:
536
- with open(self._exp_info_file, "w") as f:
537
- yaml.dump(self._exp_info, f)
538
- except Exception as e:
539
- logger.error(f"Avro保存实验信息失败: {str(e)}")
540
- try:
541
- if self.enable_pgsql:
542
- worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
543
- pg_exp_info = {
544
- k: self._exp_info[k] for (k, _) in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
545
- }
546
- pg_exp_info["created_at"] = self._exp_created_time
547
- pg_exp_info["updated_at"] = self._exp_updated_time
548
- await worker.async_update_exp_info.remote( # type:ignore
549
- pg_exp_info
550
- )
551
- except Exception as e:
552
- logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
553
-
554
- async def _update_exp_status(self, status: int, error: str = "") -> None:
555
- self._exp_updated_time = datetime.now(timezone.utc)
556
- """更新实验状态并保存"""
557
- self._exp_info["status"] = status
558
- self._exp_info["error"] = error
559
- self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
560
- await self._save_exp_info()
561
-
562
- async def _monitor_exp_status(self, stop_event: asyncio.Event):
563
- """监控实验状态并更新
564
-
565
- Args:
566
- stop_event: 用于通知监控任务停止的事件
567
- """
568
- try:
569
- while not stop_event.is_set():
570
- # 更新实验状态
571
- # 假设所有group的cur_day和cur_t是同步的,取第一个即可
572
- self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
573
- self._exp_info["cur_t"] = (
574
- await self._simulator.get_simulator_second_from_start_of_day()
575
- )
576
- await self._save_exp_info()
577
-
578
- await asyncio.sleep(1) # 避免过于频繁的更新
579
- except asyncio.CancelledError:
580
- # 正常取消,不需要特殊处理
581
- pass
582
- except Exception as e:
583
- logger.error(f"监控实验状态时发生错误: {str(e)}")
584
- raise
585
-
586
594
  async def run(
587
595
  self,
588
596
  day: int = 1,
@@ -619,16 +627,3 @@ class AgentSimulation:
619
627
  logger.error(error_msg)
620
628
  await self._update_exp_status(3, error_msg)
621
629
  raise RuntimeError(error_msg) from e
622
-
623
- async def __aenter__(self):
624
- """异步上下文管理器入口"""
625
- return self
626
-
627
- async def __aexit__(self, exc_type, exc_val, exc_tb):
628
- """异步上下文管理器出口"""
629
- if exc_type is not None:
630
- # 如果发生异常,更新状态为错误
631
- await self._update_exp_status(3, str(exc_val))
632
- elif self._exp_info["status"] != 3:
633
- # 如果没有发生异常且状态不是错误,则更新为完成
634
- await self._update_exp_status(2)