pycityagent 2.0.0a47__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a48__cp312-cp312-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/__init__.py +3 -2
- pycityagent/agent.py +109 -4
- pycityagent/cityagent/__init__.py +20 -0
- pycityagent/cityagent/bankagent.py +54 -0
- pycityagent/cityagent/blocks/__init__.py +20 -0
- pycityagent/cityagent/blocks/cognition_block.py +304 -0
- pycityagent/cityagent/blocks/dispatcher.py +78 -0
- pycityagent/cityagent/blocks/economy_block.py +356 -0
- pycityagent/cityagent/blocks/mobility_block.py +258 -0
- pycityagent/cityagent/blocks/needs_block.py +305 -0
- pycityagent/cityagent/blocks/other_block.py +103 -0
- pycityagent/cityagent/blocks/plan_block.py +309 -0
- pycityagent/cityagent/blocks/social_block.py +345 -0
- pycityagent/cityagent/blocks/time_block.py +116 -0
- pycityagent/cityagent/blocks/utils.py +66 -0
- pycityagent/cityagent/firmagent.py +75 -0
- pycityagent/cityagent/governmentagent.py +60 -0
- pycityagent/cityagent/initial.py +98 -0
- pycityagent/cityagent/memory_config.py +202 -0
- pycityagent/cityagent/nbsagent.py +92 -0
- pycityagent/cityagent/societyagent.py +291 -0
- pycityagent/memory/memory.py +0 -18
- pycityagent/message/messager.py +6 -3
- pycityagent/simulation/agentgroup.py +118 -37
- pycityagent/simulation/simulation.py +311 -316
- pycityagent/workflow/block.py +66 -1
- pycityagent/workflow/tool.py +15 -11
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/RECORD +33 -14
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a47.dist-info → pycityagent-2.0.0a48.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,16 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
-
import os
|
5
|
-
import random
|
6
4
|
import time
|
7
5
|
import uuid
|
8
6
|
from collections.abc import Callable, Sequence
|
9
|
-
from concurrent.futures import ThreadPoolExecutor
|
10
7
|
from datetime import datetime, timezone
|
11
8
|
from pathlib import Path
|
12
|
-
from typing import Any, Optional, Union
|
9
|
+
from typing import Any, Optional, Type, Union
|
13
10
|
|
14
|
-
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
15
11
|
import ray
|
16
12
|
import yaml
|
17
13
|
from langchain_core.embeddings import Embeddings
|
18
|
-
from mosstool.map._map_util.const import AOI_START_ID
|
19
14
|
|
20
15
|
from ..agent import Agent, InstitutionAgent
|
21
16
|
from ..environment.simulator import Simulator
|
@@ -27,17 +22,20 @@ from ..survey import Survey
|
|
27
22
|
from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
28
23
|
from .agentgroup import AgentGroup
|
29
24
|
from .storage.pg import PgWriter, create_pg_tables
|
25
|
+
from ..cityagent import SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent, memory_config_societyagent, memory_config_government, memory_config_firm, memory_config_bank, memory_config_nbs
|
26
|
+
from ..cityagent.initial import bind_agent_info, initialize_social_network
|
30
27
|
|
31
28
|
logger = logging.getLogger("pycityagent")
|
32
29
|
|
33
|
-
|
34
30
|
class AgentSimulation:
|
35
31
|
"""城市智能体模拟器"""
|
36
32
|
|
37
33
|
def __init__(
|
38
34
|
self,
|
39
|
-
agent_class: Union[type[Agent], list[type[Agent]]],
|
40
35
|
config: dict,
|
36
|
+
agent_class: Union[None, type[Agent], list[type[Agent]]] = None,
|
37
|
+
agent_config_file: Optional[dict] = None,
|
38
|
+
enable_economy: bool = True,
|
41
39
|
agent_prefix: str = "agent_",
|
42
40
|
exp_name: str = "default_experiment",
|
43
41
|
logging_level: int = logging.WARNING,
|
@@ -52,20 +50,34 @@ class AgentSimulation:
|
|
52
50
|
self.exp_id = str(uuid.uuid4())
|
53
51
|
if isinstance(agent_class, list):
|
54
52
|
self.agent_class = agent_class
|
53
|
+
elif agent_class is None:
|
54
|
+
if enable_economy:
|
55
|
+
self.agent_class = [SocietyAgent, FirmAgent, BankAgent, NBSAgent, GovernmentAgent]
|
56
|
+
self.default_memory_config_func = [
|
57
|
+
memory_config_societyagent,
|
58
|
+
memory_config_firm,
|
59
|
+
memory_config_bank,
|
60
|
+
memory_config_nbs,
|
61
|
+
memory_config_government,
|
62
|
+
]
|
63
|
+
else:
|
64
|
+
self.agent_class = [SocietyAgent]
|
65
|
+
self.default_memory_config_func = [memory_config_societyagent]
|
55
66
|
else:
|
56
67
|
self.agent_class = [agent_class]
|
68
|
+
self.agent_config_file = agent_config_file
|
57
69
|
self.logging_level = logging_level
|
58
70
|
self.config = config
|
59
71
|
self.exp_name = exp_name
|
60
72
|
self._simulator = Simulator(config["simulator_request"])
|
61
73
|
self.agent_prefix = agent_prefix
|
62
|
-
self._agents: dict[uuid.UUID, Agent] = {}
|
63
74
|
self._groups: dict[str, AgentGroup] = {} # type:ignore
|
64
|
-
self._agent_uuid2group: dict[
|
65
|
-
self._agent_uuids: list[
|
66
|
-
self.
|
67
|
-
self.
|
68
|
-
self.
|
75
|
+
self._agent_uuid2group: dict[str, AgentGroup] = {} # type:ignore
|
76
|
+
self._agent_uuids: list[str] = []
|
77
|
+
self._type2group: dict[Type[Agent], AgentGroup] = {}
|
78
|
+
self._user_chat_topics: dict[str, str] = {}
|
79
|
+
self._user_survey_topics: dict[str, str] = {}
|
80
|
+
self._user_interview_topics: dict[str, str] = {}
|
69
81
|
self._loop = asyncio.get_event_loop()
|
70
82
|
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
71
83
|
|
@@ -126,6 +138,80 @@ class AgentSimulation:
|
|
126
138
|
with open(self._exp_info_file, "w") as f:
|
127
139
|
yaml.dump(self._exp_info, f)
|
128
140
|
|
141
|
+
@classmethod
|
142
|
+
async def run_from_config(cls, config: dict):
|
143
|
+
"""Directly run from config file
|
144
|
+
Basic config file should contain:
|
145
|
+
- simulation_config: file_path
|
146
|
+
- agent_config:
|
147
|
+
- agent_config_file: Optional[dict]
|
148
|
+
- memory_config_func: Optional[Union[Callable, list[Callable]]]
|
149
|
+
- init_func: Optional[list[Callable[AgentSimulation, None]]]
|
150
|
+
- group_size: Optional[int]
|
151
|
+
- embedding_model: Optional[EmbeddingModel]
|
152
|
+
- number_of_citizen: required, int
|
153
|
+
- number_of_firm: required, int
|
154
|
+
- number_of_government: required, int
|
155
|
+
- number_of_bank: required, int
|
156
|
+
- number_of_nbs: required, int
|
157
|
+
- workflow:
|
158
|
+
- list[Step]
|
159
|
+
- Step:
|
160
|
+
- type: str, "step", "run", "interview", "survey", "intervene"
|
161
|
+
- day: int if type is "run", else None
|
162
|
+
- time: int if type is "step", else None
|
163
|
+
- description: Optional[str], description of the step
|
164
|
+
- step_func: Optional[Callable[AgentSimulation, None]], only used when type is "interview", "survey" and "intervene"
|
165
|
+
- logging_level: Optional[int]
|
166
|
+
- exp_name: Optional[str]
|
167
|
+
"""
|
168
|
+
# required key check
|
169
|
+
if "simulation_config" not in config:
|
170
|
+
raise ValueError("simulation_config is required")
|
171
|
+
if "agent_config" not in config:
|
172
|
+
raise ValueError("agent_config is required")
|
173
|
+
if "workflow" not in config:
|
174
|
+
raise ValueError("workflow is required")
|
175
|
+
import yaml
|
176
|
+
logger.info("Loading config file...")
|
177
|
+
with open(config["simulation_config"], "r") as f:
|
178
|
+
simulation_config = yaml.safe_load(f)
|
179
|
+
logger.info("Creating AgentSimulation Task...")
|
180
|
+
simulation = cls(
|
181
|
+
config=simulation_config,
|
182
|
+
agent_config_file=config["agent_config"].get("agent_config_file", None),
|
183
|
+
exp_name=config.get("exp_name", "default_experiment"),
|
184
|
+
logging_level=config.get("logging_level", logging.WARNING),
|
185
|
+
)
|
186
|
+
logger.info("Initializing Agents...")
|
187
|
+
agent_count = []
|
188
|
+
agent_count.append(config["agent_config"]["number_of_citizen"])
|
189
|
+
agent_count.append(config["agent_config"]["number_of_firm"])
|
190
|
+
agent_count.append(config["agent_config"]["number_of_government"])
|
191
|
+
agent_count.append(config["agent_config"]["number_of_bank"])
|
192
|
+
agent_count.append(config["agent_config"]["number_of_nbs"])
|
193
|
+
await simulation.init_agents(
|
194
|
+
agent_count=agent_count,
|
195
|
+
group_size=config["agent_config"].get("group_size", 10000),
|
196
|
+
embedding_model=config["agent_config"].get("embedding_model", SimpleEmbedding()),
|
197
|
+
memory_config_func=config["agent_config"].get("memory_config_func", None),
|
198
|
+
)
|
199
|
+
logger.info("Running Init Functions...")
|
200
|
+
for init_func in config["agent_config"].get("init_func", [bind_agent_info, initialize_social_network]):
|
201
|
+
await init_func(simulation)
|
202
|
+
logger.info("Starting Simulation...")
|
203
|
+
for step in config["workflow"]:
|
204
|
+
logger.info(f"Running step: type: {step['type']} - description: {step.get('description', 'no description')}")
|
205
|
+
if step["type"] not in ["run", "step", "interview", "survey", "intervene"]:
|
206
|
+
raise ValueError(f"Invalid step type: {step['type']}")
|
207
|
+
if step["type"] == "run":
|
208
|
+
await simulation.run(step.get("day", 1))
|
209
|
+
elif step["type"] == "step":
|
210
|
+
await simulation.step(step.get("time", 1))
|
211
|
+
else:
|
212
|
+
await step["step_func"](simulation)
|
213
|
+
logger.info("Simulation finished")
|
214
|
+
|
129
215
|
@property
|
130
216
|
def enable_avro(
|
131
217
|
self,
|
@@ -138,10 +224,6 @@ class AgentSimulation:
|
|
138
224
|
) -> bool:
|
139
225
|
return self._enable_pgsql
|
140
226
|
|
141
|
-
@property
|
142
|
-
def agents(self) -> dict[uuid.UUID, Agent]:
|
143
|
-
return self._agents
|
144
|
-
|
145
227
|
@property
|
146
228
|
def avro_path(
|
147
229
|
self,
|
@@ -159,42 +241,82 @@ class AgentSimulation:
|
|
159
241
|
@property
|
160
242
|
def agent_uuid2group(self):
|
161
243
|
return self._agent_uuid2group
|
244
|
+
|
245
|
+
@property
|
246
|
+
def messager(self):
|
247
|
+
return self._messager
|
248
|
+
|
249
|
+
async def _save_exp_info(self) -> None:
|
250
|
+
"""异步保存实验信息到YAML文件"""
|
251
|
+
try:
|
252
|
+
if self.enable_avro:
|
253
|
+
with open(self._exp_info_file, "w") as f:
|
254
|
+
yaml.dump(self._exp_info, f)
|
255
|
+
except Exception as e:
|
256
|
+
logger.error(f"Avro保存实验信息失败: {str(e)}")
|
257
|
+
try:
|
258
|
+
if self.enable_pgsql:
|
259
|
+
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
260
|
+
pg_exp_info = {
|
261
|
+
k: self._exp_info[k] for (k, _) in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
262
|
+
}
|
263
|
+
pg_exp_info["created_at"] = self._exp_created_time
|
264
|
+
pg_exp_info["updated_at"] = self._exp_updated_time
|
265
|
+
await worker.async_update_exp_info.remote( # type:ignore
|
266
|
+
pg_exp_info
|
267
|
+
)
|
268
|
+
except Exception as e:
|
269
|
+
logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
|
162
270
|
|
163
|
-
def
|
164
|
-
self
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
271
|
+
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
272
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
273
|
+
"""更新实验状态并保存"""
|
274
|
+
self._exp_info["status"] = status
|
275
|
+
self._exp_info["error"] = error
|
276
|
+
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
277
|
+
await self._save_exp_info()
|
278
|
+
|
279
|
+
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
280
|
+
"""监控实验状态并更新
|
281
|
+
|
282
|
+
Args:
|
283
|
+
stop_event: 用于通知监控任务停止的事件
|
284
|
+
"""
|
285
|
+
try:
|
286
|
+
while not stop_event.is_set():
|
287
|
+
# 更新实验状态
|
288
|
+
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
289
|
+
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
290
|
+
self._exp_info["cur_t"] = (
|
291
|
+
await self._simulator.get_simulator_second_from_start_of_day()
|
292
|
+
)
|
293
|
+
await self._save_exp_info()
|
294
|
+
|
295
|
+
await asyncio.sleep(1) # 避免过于频繁的更新
|
296
|
+
except asyncio.CancelledError:
|
297
|
+
# 正常取消,不需要特殊处理
|
298
|
+
pass
|
299
|
+
except Exception as e:
|
300
|
+
logger.error(f"监控实验状态时发生错误: {str(e)}")
|
301
|
+
raise
|
302
|
+
|
303
|
+
async def __aenter__(self):
|
304
|
+
"""异步上下文管理器入口"""
|
305
|
+
return self
|
306
|
+
|
307
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
308
|
+
"""异步上下文管理器出口"""
|
309
|
+
if exc_type is not None:
|
310
|
+
# 如果发生异常,更新状态为错误
|
311
|
+
await self._update_exp_status(3, str(exc_val))
|
312
|
+
elif self._exp_info["status"] != 3:
|
313
|
+
# 如果没有发生异常且状态不是错误,则更新为完成
|
314
|
+
await self._update_exp_status(2)
|
193
315
|
|
194
316
|
async def init_agents(
|
195
317
|
self,
|
196
318
|
agent_count: Union[int, list[int]],
|
197
|
-
group_size: int =
|
319
|
+
group_size: int = 10000,
|
198
320
|
pg_sql_writers: int = 32,
|
199
321
|
embedding_model: Embeddings = SimpleEmbedding(),
|
200
322
|
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
@@ -206,7 +328,6 @@ class AgentSimulation:
|
|
206
328
|
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
207
329
|
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
|
208
330
|
"""
|
209
|
-
await self._messager.connect.remote()
|
210
331
|
if not isinstance(agent_count, list):
|
211
332
|
agent_count = [agent_count]
|
212
333
|
|
@@ -217,67 +338,104 @@ class AgentSimulation:
|
|
217
338
|
logger.warning(
|
218
339
|
"memory_config_func is None, using default memory config function"
|
219
340
|
)
|
220
|
-
memory_config_func =
|
221
|
-
|
222
|
-
if issubclass(agent_class, InstitutionAgent):
|
223
|
-
memory_config_func.append(self.default_memory_config_institution)
|
224
|
-
else:
|
225
|
-
memory_config_func.append(self.default_memory_config_citizen)
|
341
|
+
memory_config_func = self.default_memory_config_func
|
342
|
+
|
226
343
|
elif not isinstance(memory_config_func, list):
|
227
344
|
memory_config_func = [memory_config_func]
|
228
345
|
|
229
346
|
if len(memory_config_func) != len(agent_count):
|
230
347
|
logger.warning(
|
231
|
-
"memory_config_func
|
348
|
+
"The length of memory_config_func and agent_count does not match, using default memory_config"
|
232
349
|
)
|
233
|
-
memory_config_func =
|
234
|
-
for agent_class in self.agent_class:
|
235
|
-
if issubclass(agent_class, InstitutionAgent):
|
236
|
-
memory_config_func.append(self.default_memory_config_institution)
|
237
|
-
else:
|
238
|
-
memory_config_func.append(self.default_memory_config_citizen)
|
350
|
+
memory_config_func = self.default_memory_config_func
|
239
351
|
# 使用线程池并行创建 AgentGroup
|
240
352
|
group_creation_params = []
|
241
|
-
class_init_index = 0
|
242
353
|
|
243
|
-
#
|
354
|
+
# 分别处理机构智能体和普通智能体
|
355
|
+
institution_params = []
|
356
|
+
citizen_params = []
|
357
|
+
|
358
|
+
# 收集所有参数
|
244
359
|
for i in range(len(self.agent_class)):
|
245
360
|
agent_class = self.agent_class[i]
|
246
361
|
agent_count_i = agent_count[i]
|
247
362
|
memory_config_func_i = memory_config_func[i]
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
363
|
+
|
364
|
+
if self.agent_config_file is not None:
|
365
|
+
config_file = self.agent_config_file.get(agent_class, None)
|
366
|
+
else:
|
367
|
+
config_file = None
|
368
|
+
|
369
|
+
if issubclass(agent_class, InstitutionAgent):
|
370
|
+
institution_params.append((agent_class, agent_count_i, memory_config_func_i, config_file))
|
371
|
+
else:
|
372
|
+
citizen_params.append((agent_class, agent_count_i, memory_config_func_i, config_file))
|
373
|
+
|
374
|
+
# 处理机构智能体组
|
375
|
+
if institution_params:
|
376
|
+
total_institution_count = sum(p[1] for p in institution_params)
|
377
|
+
num_institution_groups = (total_institution_count + group_size - 1) // group_size
|
378
|
+
|
379
|
+
for k in range(num_institution_groups):
|
380
|
+
start_idx = k * group_size
|
381
|
+
remaining = total_institution_count - start_idx
|
382
|
+
number_of_agents = min(remaining, group_size)
|
383
|
+
|
384
|
+
agent_classes = []
|
385
|
+
agent_counts = []
|
386
|
+
memory_config_funcs = []
|
387
|
+
config_files = []
|
388
|
+
|
389
|
+
# 分配每种类型的机构智能体到当前组
|
390
|
+
curr_start = start_idx
|
391
|
+
for agent_class, count, mem_func, conf_file in institution_params:
|
392
|
+
if curr_start < count:
|
393
|
+
agent_classes.append(agent_class)
|
394
|
+
agent_counts.append(min(count - curr_start, number_of_agents))
|
395
|
+
memory_config_funcs.append(mem_func)
|
396
|
+
config_files.append(conf_file)
|
397
|
+
curr_start = max(0, curr_start - count)
|
398
|
+
|
399
|
+
group_creation_params.append((
|
400
|
+
agent_classes,
|
401
|
+
agent_counts,
|
402
|
+
memory_config_funcs,
|
403
|
+
f"InstitutionGroup_{k}",
|
404
|
+
config_files
|
405
|
+
))
|
406
|
+
|
407
|
+
# 处理普通智能体组
|
408
|
+
if citizen_params:
|
409
|
+
total_citizen_count = sum(p[1] for p in citizen_params)
|
410
|
+
num_citizen_groups = (total_citizen_count + group_size - 1) // group_size
|
411
|
+
|
412
|
+
for k in range(num_citizen_groups):
|
413
|
+
start_idx = k * group_size
|
414
|
+
remaining = total_citizen_count - start_idx
|
415
|
+
number_of_agents = min(remaining, group_size)
|
416
|
+
|
417
|
+
agent_classes = []
|
418
|
+
agent_counts = []
|
419
|
+
memory_config_funcs = []
|
420
|
+
config_files = []
|
421
|
+
|
422
|
+
# 分配每种类型的普通智能体到当前组
|
423
|
+
curr_start = start_idx
|
424
|
+
for agent_class, count, mem_func, conf_file in citizen_params:
|
425
|
+
if curr_start < count:
|
426
|
+
agent_classes.append(agent_class)
|
427
|
+
agent_counts.append(min(count - curr_start, number_of_agents))
|
428
|
+
memory_config_funcs.append(mem_func)
|
429
|
+
config_files.append(conf_file)
|
430
|
+
curr_start = max(0, curr_start - count)
|
431
|
+
|
432
|
+
group_creation_params.append((
|
433
|
+
agent_classes,
|
434
|
+
agent_counts,
|
435
|
+
memory_config_funcs,
|
436
|
+
f"CitizenGroup_{k}",
|
437
|
+
config_files
|
438
|
+
))
|
281
439
|
|
282
440
|
# 初始化mlflow连接
|
283
441
|
_mlflow_config = self.config.get("metric_request", {}).get("mlflow")
|
@@ -303,12 +461,14 @@ class AgentSimulation:
|
|
303
461
|
else:
|
304
462
|
_num_workers = 1
|
305
463
|
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
306
|
-
|
464
|
+
|
307
465
|
creation_tasks = []
|
308
|
-
for i, (group_name,
|
466
|
+
for i, (agent_class, number_of_agents, memory_config_function_group, group_name, config_file) in enumerate(group_creation_params):
|
309
467
|
# 直接创建异步任务
|
310
468
|
group = AgentGroup.remote(
|
311
|
-
|
469
|
+
agent_class,
|
470
|
+
number_of_agents,
|
471
|
+
memory_config_function_group,
|
312
472
|
self.config,
|
313
473
|
self.exp_id,
|
314
474
|
self.exp_name,
|
@@ -319,27 +479,31 @@ class AgentSimulation:
|
|
319
479
|
mlflow_run_id, # type:ignore
|
320
480
|
embedding_model,
|
321
481
|
self.logging_level,
|
482
|
+
config_file,
|
322
483
|
)
|
323
|
-
creation_tasks.append((group_name, group
|
484
|
+
creation_tasks.append((group_name, group))
|
324
485
|
|
325
486
|
# 更新数据结构
|
326
|
-
for group_name, group
|
487
|
+
for group_name, group in creation_tasks:
|
327
488
|
self._groups[group_name] = group
|
328
|
-
|
329
|
-
|
489
|
+
group_agent_uuids = ray.get(group.get_agent_uuids.remote())
|
490
|
+
for agent_uuid in group_agent_uuids:
|
491
|
+
self._agent_uuid2group[agent_uuid] = group
|
492
|
+
self._user_chat_topics[agent_uuid] = f"exps/{self.exp_id}/agents/{agent_uuid}/user-chat"
|
493
|
+
self._user_survey_topics[agent_uuid] = (
|
494
|
+
f"exps/{self.exp_id}/agents/{agent_uuid}/user-survey"
|
495
|
+
)
|
496
|
+
group_agent_type = ray.get(group.get_agent_type.remote())
|
497
|
+
for agent_type in group_agent_type:
|
498
|
+
if agent_type not in self._type2group:
|
499
|
+
self._type2group[agent_type] = []
|
500
|
+
self._type2group[agent_type].append(group)
|
330
501
|
|
331
502
|
# 并行初始化所有组的agents
|
332
503
|
init_tasks = []
|
333
504
|
for group in self._groups.values():
|
334
505
|
init_tasks.append(group.init_agents.remote())
|
335
|
-
|
336
|
-
|
337
|
-
# 设置用户主题
|
338
|
-
for uuid, agent in self._agents.items():
|
339
|
-
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
340
|
-
self._user_survey_topics[uuid] = (
|
341
|
-
f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
342
|
-
)
|
506
|
+
ray.get(init_tasks)
|
343
507
|
|
344
508
|
async def gather(self, content: str):
|
345
509
|
"""收集智能体的特定信息"""
|
@@ -347,145 +511,42 @@ class AgentSimulation:
|
|
347
511
|
for group in self._groups.values():
|
348
512
|
gather_tasks.append(group.gather.remote(content))
|
349
513
|
return await asyncio.gather(*gather_tasks)
|
514
|
+
|
515
|
+
async def filter(self,
|
516
|
+
types: Optional[list[Type[Agent]]] = None,
|
517
|
+
keys: Optional[list[str]] = None,
|
518
|
+
values: Optional[list[Any]] = None) -> list[str]:
|
519
|
+
"""过滤出指定类型的智能体"""
|
520
|
+
if not types and not keys and not values:
|
521
|
+
return self._agent_uuids
|
522
|
+
group_to_filter = []
|
523
|
+
for t in types:
|
524
|
+
if t in self._type2group:
|
525
|
+
group_to_filter.extend(self._type2group[t])
|
526
|
+
else:
|
527
|
+
raise ValueError(f"type {t} not found in simulation")
|
528
|
+
filtered_uuids = []
|
529
|
+
if keys:
|
530
|
+
if len(keys) != len(values):
|
531
|
+
raise ValueError("the length of key and value does not match")
|
532
|
+
for group in group_to_filter:
|
533
|
+
filtered_uuids.extend(await group.filter.remote(types, keys, values))
|
534
|
+
return filtered_uuids
|
535
|
+
else:
|
536
|
+
for group in group_to_filter:
|
537
|
+
filtered_uuids.extend(await group.filter.remote(types))
|
538
|
+
return filtered_uuids
|
350
539
|
|
351
|
-
async def update(self, target_agent_uuid:
|
540
|
+
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
352
541
|
"""更新指定智能体的记忆"""
|
353
542
|
group = self._agent_uuid2group[target_agent_uuid]
|
354
543
|
await group.update.remote(target_agent_uuid, target_key, content)
|
355
544
|
|
356
|
-
def default_memory_config_institution(self):
|
357
|
-
"""默认的Memory配置函数"""
|
358
|
-
EXTRA_ATTRIBUTES = {
|
359
|
-
"type": (
|
360
|
-
int,
|
361
|
-
random.choice(
|
362
|
-
[
|
363
|
-
economyv2.ORG_TYPE_BANK,
|
364
|
-
economyv2.ORG_TYPE_GOVERNMENT,
|
365
|
-
economyv2.ORG_TYPE_FIRM,
|
366
|
-
economyv2.ORG_TYPE_NBS,
|
367
|
-
economyv2.ORG_TYPE_UNSPECIFIED,
|
368
|
-
]
|
369
|
-
),
|
370
|
-
),
|
371
|
-
"nominal_gdp": (list, [], True),
|
372
|
-
"real_gdp": (list, [], True),
|
373
|
-
"unemployment": (list, [], True),
|
374
|
-
"wages": (list, [], True),
|
375
|
-
"prices": (list, [], True),
|
376
|
-
"inventory": (int, 0, True),
|
377
|
-
"price": (float, 0.0, True),
|
378
|
-
"interest_rate": (float, 0.0, True),
|
379
|
-
"bracket_cutoffs": (list, [], True),
|
380
|
-
"bracket_rates": (list, [], True),
|
381
|
-
"employees": (list, [], True),
|
382
|
-
"customers": (list, [], True),
|
383
|
-
}
|
384
|
-
return EXTRA_ATTRIBUTES, None, None
|
385
|
-
|
386
|
-
def default_memory_config_citizen(self):
|
387
|
-
"""默认的Memory配置函数"""
|
388
|
-
EXTRA_ATTRIBUTES = {
|
389
|
-
# 需求信息
|
390
|
-
"needs": (
|
391
|
-
dict,
|
392
|
-
{
|
393
|
-
"hungry": random.random(), # 饥饿感
|
394
|
-
"tired": random.random(), # 疲劳感
|
395
|
-
"safe": random.random(), # 安全需
|
396
|
-
"social": random.random(), # 社会需求
|
397
|
-
},
|
398
|
-
True,
|
399
|
-
),
|
400
|
-
"current_need": (str, "none", True),
|
401
|
-
"current_plan": (list, [], True),
|
402
|
-
"current_step": (dict, {"intention": "", "type": ""}, True),
|
403
|
-
"execution_context": (dict, {}, True),
|
404
|
-
"plan_history": (list, [], True),
|
405
|
-
# cognition
|
406
|
-
"fulfillment": (int, 5, True),
|
407
|
-
"emotion": (int, 5, True),
|
408
|
-
"attitude": (int, 5, True),
|
409
|
-
"thought": (str, "Currently nothing good or bad is happening", True),
|
410
|
-
"emotion_types": (str, "Relief", True),
|
411
|
-
"incident": (list, [], True),
|
412
|
-
# social
|
413
|
-
"friends": (list, [], True),
|
414
|
-
}
|
415
|
-
|
416
|
-
PROFILE = {
|
417
|
-
"name": "unknown",
|
418
|
-
"gender": random.choice(["male", "female"]),
|
419
|
-
"education": random.choice(
|
420
|
-
["Doctor", "Master", "Bachelor", "College", "High School"]
|
421
|
-
),
|
422
|
-
"consumption": random.choice(["sightly low", "low", "medium", "high"]),
|
423
|
-
"occupation": random.choice(
|
424
|
-
[
|
425
|
-
"Student",
|
426
|
-
"Teacher",
|
427
|
-
"Doctor",
|
428
|
-
"Engineer",
|
429
|
-
"Manager",
|
430
|
-
"Businessman",
|
431
|
-
"Artist",
|
432
|
-
"Athlete",
|
433
|
-
"Other",
|
434
|
-
]
|
435
|
-
),
|
436
|
-
"age": random.randint(18, 65),
|
437
|
-
"skill": random.choice(
|
438
|
-
[
|
439
|
-
"Good at problem-solving",
|
440
|
-
"Good at communication",
|
441
|
-
"Good at creativity",
|
442
|
-
"Good at teamwork",
|
443
|
-
"Other",
|
444
|
-
]
|
445
|
-
),
|
446
|
-
"family_consumption": random.choice(["low", "medium", "high"]),
|
447
|
-
"personality": random.choice(
|
448
|
-
["outgoint", "introvert", "ambivert", "extrovert"]
|
449
|
-
),
|
450
|
-
"income": str(random.randint(1000, 10000)),
|
451
|
-
"currency": random.randint(10000, 100000),
|
452
|
-
"residence": random.choice(["city", "suburb", "rural"]),
|
453
|
-
"race": random.choice(
|
454
|
-
[
|
455
|
-
"Chinese",
|
456
|
-
"American",
|
457
|
-
"British",
|
458
|
-
"French",
|
459
|
-
"German",
|
460
|
-
"Japanese",
|
461
|
-
"Korean",
|
462
|
-
"Russian",
|
463
|
-
"Other",
|
464
|
-
]
|
465
|
-
),
|
466
|
-
"religion": random.choice(
|
467
|
-
["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]
|
468
|
-
),
|
469
|
-
"marital_status": random.choice(
|
470
|
-
["not married", "married", "divorced", "widowed"]
|
471
|
-
),
|
472
|
-
}
|
473
|
-
|
474
|
-
BASE = {
|
475
|
-
"home": {
|
476
|
-
"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
|
477
|
-
},
|
478
|
-
"work": {
|
479
|
-
"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
|
480
|
-
},
|
481
|
-
}
|
482
|
-
|
483
|
-
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
484
|
-
|
485
545
|
async def send_survey(
|
486
|
-
self, survey: Survey, agent_uuids: Optional[list[
|
546
|
+
self, survey: Survey, agent_uuids: Optional[list[str]] = None
|
487
547
|
):
|
488
548
|
"""发送问卷"""
|
549
|
+
await self.messager.connect()
|
489
550
|
survey_dict = survey.to_dict()
|
490
551
|
if agent_uuids is None:
|
491
552
|
agent_uuids = self._agent_uuids
|
@@ -499,12 +560,13 @@ class AgentSimulation:
|
|
499
560
|
}
|
500
561
|
for uuid in agent_uuids:
|
501
562
|
topic = self._user_survey_topics[uuid]
|
502
|
-
await self.
|
563
|
+
await self.messager.send_message(topic, payload)
|
503
564
|
|
504
565
|
async def send_interview_message(
|
505
|
-
self, content: str, agent_uuids: Union[
|
566
|
+
self, content: str, agent_uuids: Union[str, list[str]]
|
506
567
|
):
|
507
|
-
"""
|
568
|
+
"""发送采访消息"""
|
569
|
+
await self.messager.connect()
|
508
570
|
_date_time = datetime.now(timezone.utc)
|
509
571
|
payload = {
|
510
572
|
"from": "none",
|
@@ -512,11 +574,11 @@ class AgentSimulation:
|
|
512
574
|
"timestamp": int(_date_time.timestamp() * 1000),
|
513
575
|
"_date_time": _date_time,
|
514
576
|
}
|
515
|
-
if not isinstance(agent_uuids,
|
577
|
+
if not isinstance(agent_uuids, list):
|
516
578
|
agent_uuids = [agent_uuids]
|
517
579
|
for uuid in agent_uuids:
|
518
580
|
topic = self._user_chat_topics[uuid]
|
519
|
-
await self.
|
581
|
+
await self.messager.send_message(topic, payload)
|
520
582
|
|
521
583
|
async def step(self):
|
522
584
|
"""运行一步, 即每个智能体执行一次forward"""
|
@@ -529,60 +591,6 @@ class AgentSimulation:
|
|
529
591
|
logger.error(f"运行错误: {str(e)}")
|
530
592
|
raise
|
531
593
|
|
532
|
-
async def _save_exp_info(self) -> None:
|
533
|
-
"""异步保存实验信息到YAML文件"""
|
534
|
-
try:
|
535
|
-
if self.enable_avro:
|
536
|
-
with open(self._exp_info_file, "w") as f:
|
537
|
-
yaml.dump(self._exp_info, f)
|
538
|
-
except Exception as e:
|
539
|
-
logger.error(f"Avro保存实验信息失败: {str(e)}")
|
540
|
-
try:
|
541
|
-
if self.enable_pgsql:
|
542
|
-
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
543
|
-
pg_exp_info = {
|
544
|
-
k: self._exp_info[k] for (k, _) in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
545
|
-
}
|
546
|
-
pg_exp_info["created_at"] = self._exp_created_time
|
547
|
-
pg_exp_info["updated_at"] = self._exp_updated_time
|
548
|
-
await worker.async_update_exp_info.remote( # type:ignore
|
549
|
-
pg_exp_info
|
550
|
-
)
|
551
|
-
except Exception as e:
|
552
|
-
logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
|
553
|
-
|
554
|
-
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
555
|
-
self._exp_updated_time = datetime.now(timezone.utc)
|
556
|
-
"""更新实验状态并保存"""
|
557
|
-
self._exp_info["status"] = status
|
558
|
-
self._exp_info["error"] = error
|
559
|
-
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
560
|
-
await self._save_exp_info()
|
561
|
-
|
562
|
-
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
563
|
-
"""监控实验状态并更新
|
564
|
-
|
565
|
-
Args:
|
566
|
-
stop_event: 用于通知监控任务停止的事件
|
567
|
-
"""
|
568
|
-
try:
|
569
|
-
while not stop_event.is_set():
|
570
|
-
# 更新实验状态
|
571
|
-
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
572
|
-
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
573
|
-
self._exp_info["cur_t"] = (
|
574
|
-
await self._simulator.get_simulator_second_from_start_of_day()
|
575
|
-
)
|
576
|
-
await self._save_exp_info()
|
577
|
-
|
578
|
-
await asyncio.sleep(1) # 避免过于频繁的更新
|
579
|
-
except asyncio.CancelledError:
|
580
|
-
# 正常取消,不需要特殊处理
|
581
|
-
pass
|
582
|
-
except Exception as e:
|
583
|
-
logger.error(f"监控实验状态时发生错误: {str(e)}")
|
584
|
-
raise
|
585
|
-
|
586
594
|
async def run(
|
587
595
|
self,
|
588
596
|
day: int = 1,
|
@@ -619,16 +627,3 @@ class AgentSimulation:
|
|
619
627
|
logger.error(error_msg)
|
620
628
|
await self._update_exp_status(3, error_msg)
|
621
629
|
raise RuntimeError(error_msg) from e
|
622
|
-
|
623
|
-
async def __aenter__(self):
|
624
|
-
"""异步上下文管理器入口"""
|
625
|
-
return self
|
626
|
-
|
627
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
628
|
-
"""异步上下文管理器出口"""
|
629
|
-
if exc_type is not None:
|
630
|
-
# 如果发生异常,更新状态为错误
|
631
|
-
await self._update_exp_status(3, str(exc_val))
|
632
|
-
elif self._exp_info["status"] != 3:
|
633
|
-
# 如果没有发生异常且状态不是错误,则更新为完成
|
634
|
-
await self._update_exp_status(2)
|