pycityagent 2.0.0a43__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a43.dist-info/METADATA +235 -0
- pycityagent-2.0.0a43.dist-info/RECORD +81 -0
- pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
@@ -0,0 +1,634 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import random
|
6
|
+
import time
|
7
|
+
import uuid
|
8
|
+
from collections.abc import Callable, Sequence
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
10
|
+
from datetime import datetime, timezone
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import Any, Optional, Union
|
13
|
+
|
14
|
+
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
15
|
+
import ray
|
16
|
+
import yaml
|
17
|
+
from langchain_core.embeddings import Embeddings
|
18
|
+
from mosstool.map._map_util.const import AOI_START_ID
|
19
|
+
|
20
|
+
from ..agent import Agent, InstitutionAgent
|
21
|
+
from ..environment.simulator import Simulator
|
22
|
+
from ..llm import SimpleEmbedding
|
23
|
+
from ..memory import Memory
|
24
|
+
from ..message.messager import Messager
|
25
|
+
from ..metrics import init_mlflow_connection
|
26
|
+
from ..survey import Survey
|
27
|
+
from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
28
|
+
from .agentgroup import AgentGroup
|
29
|
+
from .storage.pg import PgWriter, create_pg_tables
|
30
|
+
|
31
|
+
logger = logging.getLogger("pycityagent")
|
32
|
+
|
33
|
+
|
34
|
+
class AgentSimulation:
|
35
|
+
"""城市智能体模拟器"""
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
agent_class: Union[type[Agent], list[type[Agent]]],
|
40
|
+
config: dict,
|
41
|
+
agent_prefix: str = "agent_",
|
42
|
+
exp_name: str = "default_experiment",
|
43
|
+
logging_level: int = logging.WARNING,
|
44
|
+
):
|
45
|
+
"""
|
46
|
+
Args:
|
47
|
+
agent_class: 智能体类
|
48
|
+
config: 配置
|
49
|
+
agent_prefix: 智能体名称前缀
|
50
|
+
exp_name: 实验名称
|
51
|
+
"""
|
52
|
+
self.exp_id = str(uuid.uuid4())
|
53
|
+
if isinstance(agent_class, list):
|
54
|
+
self.agent_class = agent_class
|
55
|
+
else:
|
56
|
+
self.agent_class = [agent_class]
|
57
|
+
self.logging_level = logging_level
|
58
|
+
self.config = config
|
59
|
+
self.exp_name = exp_name
|
60
|
+
self._simulator = Simulator(config["simulator_request"])
|
61
|
+
self.agent_prefix = agent_prefix
|
62
|
+
self._agents: dict[uuid.UUID, Agent] = {}
|
63
|
+
self._groups: dict[str, AgentGroup] = {} # type:ignore
|
64
|
+
self._agent_uuid2group: dict[uuid.UUID, AgentGroup] = {} # type:ignore
|
65
|
+
self._agent_uuids: list[uuid.UUID] = []
|
66
|
+
self._user_chat_topics: dict[uuid.UUID, str] = {}
|
67
|
+
self._user_survey_topics: dict[uuid.UUID, str] = {}
|
68
|
+
self._user_interview_topics: dict[uuid.UUID, str] = {}
|
69
|
+
self._loop = asyncio.get_event_loop()
|
70
|
+
# self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
71
|
+
|
72
|
+
self._messager = Messager.remote(
|
73
|
+
hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
|
74
|
+
port=config["simulator_request"]["mqtt"]["port"],
|
75
|
+
username=config["simulator_request"]["mqtt"].get("username", None),
|
76
|
+
password=config["simulator_request"]["mqtt"].get("password", None),
|
77
|
+
)
|
78
|
+
|
79
|
+
# storage
|
80
|
+
_storage_config: dict[str, Any] = config.get("storage", {})
|
81
|
+
if _storage_config is None:
|
82
|
+
_storage_config = {}
|
83
|
+
# avro
|
84
|
+
_avro_config: dict[str, Any] = _storage_config.get("avro", {})
|
85
|
+
self._enable_avro = _avro_config.get("enabled", False)
|
86
|
+
if not self._enable_avro:
|
87
|
+
self._avro_path = None
|
88
|
+
logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
|
89
|
+
else:
|
90
|
+
self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
|
91
|
+
self._avro_path.mkdir(parents=True, exist_ok=True)
|
92
|
+
|
93
|
+
# pg
|
94
|
+
_pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
|
95
|
+
self._enable_pgsql = _pgsql_config.get("enabled", False)
|
96
|
+
if not self._enable_pgsql:
|
97
|
+
logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
|
98
|
+
self._pgsql_dsn = ""
|
99
|
+
else:
|
100
|
+
self._pgsql_dsn = (
|
101
|
+
_pgsql_config["data_source_name"]
|
102
|
+
if "data_source_name" in _pgsql_config
|
103
|
+
else _pgsql_config["dsn"]
|
104
|
+
)
|
105
|
+
|
106
|
+
# 添加实验信息相关的属性
|
107
|
+
self._exp_created_time = datetime.now(timezone.utc)
|
108
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
109
|
+
self._exp_info = {
|
110
|
+
"id": self.exp_id,
|
111
|
+
"name": exp_name,
|
112
|
+
"num_day": 0, # 将在 run 方法中更新
|
113
|
+
"status": 0,
|
114
|
+
"cur_day": 0,
|
115
|
+
"cur_t": 0.0,
|
116
|
+
"config": json.dumps(config),
|
117
|
+
"error": "",
|
118
|
+
"created_at": self._exp_created_time.isoformat(),
|
119
|
+
"updated_at": self._exp_updated_time.isoformat(),
|
120
|
+
}
|
121
|
+
|
122
|
+
# 创建异步任务保存实验信息
|
123
|
+
if self._enable_avro:
|
124
|
+
assert self._avro_path is not None
|
125
|
+
self._exp_info_file = self._avro_path / "experiment_info.yaml"
|
126
|
+
with open(self._exp_info_file, "w") as f:
|
127
|
+
yaml.dump(self._exp_info, f)
|
128
|
+
|
129
|
+
@property
|
130
|
+
def enable_avro(
|
131
|
+
self,
|
132
|
+
) -> bool:
|
133
|
+
return self._enable_avro
|
134
|
+
|
135
|
+
@property
|
136
|
+
def enable_pgsql(
|
137
|
+
self,
|
138
|
+
) -> bool:
|
139
|
+
return self._enable_pgsql
|
140
|
+
|
141
|
+
@property
|
142
|
+
def agents(self) -> dict[uuid.UUID, Agent]:
|
143
|
+
return self._agents
|
144
|
+
|
145
|
+
@property
|
146
|
+
def avro_path(
|
147
|
+
self,
|
148
|
+
) -> Path:
|
149
|
+
return self._avro_path # type:ignore
|
150
|
+
|
151
|
+
@property
|
152
|
+
def groups(self):
|
153
|
+
return self._groups
|
154
|
+
|
155
|
+
@property
|
156
|
+
def agent_uuids(self):
|
157
|
+
return self._agent_uuids
|
158
|
+
|
159
|
+
@property
|
160
|
+
def agent_uuid2group(self):
|
161
|
+
return self._agent_uuid2group
|
162
|
+
|
163
|
+
def create_remote_group(
|
164
|
+
self,
|
165
|
+
group_name: str,
|
166
|
+
agents: list[Agent],
|
167
|
+
config: dict,
|
168
|
+
exp_id: str,
|
169
|
+
exp_name: str,
|
170
|
+
enable_avro: bool,
|
171
|
+
avro_path: Path,
|
172
|
+
enable_pgsql: bool,
|
173
|
+
pgsql_writer: ray.ObjectRef,
|
174
|
+
mlflow_run_id: str = None, # type: ignore
|
175
|
+
embedding_model: Embeddings = None, # type: ignore
|
176
|
+
logging_level: int = logging.WARNING,
|
177
|
+
):
|
178
|
+
"""创建远程组"""
|
179
|
+
group = AgentGroup.remote(
|
180
|
+
agents,
|
181
|
+
config,
|
182
|
+
exp_id,
|
183
|
+
exp_name,
|
184
|
+
enable_avro,
|
185
|
+
avro_path,
|
186
|
+
enable_pgsql,
|
187
|
+
pgsql_writer,
|
188
|
+
mlflow_run_id,
|
189
|
+
embedding_model,
|
190
|
+
logging_level,
|
191
|
+
)
|
192
|
+
return group_name, group, agents
|
193
|
+
|
194
|
+
async def init_agents(
|
195
|
+
self,
|
196
|
+
agent_count: Union[int, list[int]],
|
197
|
+
group_size: int = 1000,
|
198
|
+
pg_sql_writers: int = 32,
|
199
|
+
embedding_model: Embeddings = SimpleEmbedding(),
|
200
|
+
memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
|
201
|
+
) -> None:
|
202
|
+
"""初始化智能体
|
203
|
+
|
204
|
+
Args:
|
205
|
+
agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
|
206
|
+
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
207
|
+
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
|
208
|
+
"""
|
209
|
+
await self._messager.connect.remote()
|
210
|
+
if not isinstance(agent_count, list):
|
211
|
+
agent_count = [agent_count]
|
212
|
+
|
213
|
+
if len(self.agent_class) != len(agent_count):
|
214
|
+
raise ValueError("agent_class和agent_count的长度不一致")
|
215
|
+
|
216
|
+
if memory_config_func is None:
|
217
|
+
logger.warning(
|
218
|
+
"memory_config_func is None, using default memory config function"
|
219
|
+
)
|
220
|
+
memory_config_func = []
|
221
|
+
for agent_class in self.agent_class:
|
222
|
+
if issubclass(agent_class, InstitutionAgent):
|
223
|
+
memory_config_func.append(self.default_memory_config_institution)
|
224
|
+
else:
|
225
|
+
memory_config_func.append(self.default_memory_config_citizen)
|
226
|
+
elif not isinstance(memory_config_func, list):
|
227
|
+
memory_config_func = [memory_config_func]
|
228
|
+
|
229
|
+
if len(memory_config_func) != len(agent_count):
|
230
|
+
logger.warning(
|
231
|
+
"memory_config_func和agent_count的长度不一致,使用默认的memory_config"
|
232
|
+
)
|
233
|
+
memory_config_func = []
|
234
|
+
for agent_class in self.agent_class:
|
235
|
+
if issubclass(agent_class, InstitutionAgent):
|
236
|
+
memory_config_func.append(self.default_memory_config_institution)
|
237
|
+
else:
|
238
|
+
memory_config_func.append(self.default_memory_config_citizen)
|
239
|
+
# 使用线程池并行创建 AgentGroup
|
240
|
+
group_creation_params = []
|
241
|
+
class_init_index = 0
|
242
|
+
|
243
|
+
# 首先收集所有需要创建的组的参数
|
244
|
+
for i in range(len(self.agent_class)):
|
245
|
+
agent_class = self.agent_class[i]
|
246
|
+
agent_count_i = agent_count[i]
|
247
|
+
memory_config_func_i = memory_config_func[i]
|
248
|
+
for j in range(agent_count_i):
|
249
|
+
agent_name = f"{self.agent_prefix}_{i}_{j}"
|
250
|
+
|
251
|
+
# 获取Memory配置
|
252
|
+
extra_attributes, profile, base = memory_config_func_i()
|
253
|
+
memory = Memory(config=extra_attributes, profile=profile, base=base)
|
254
|
+
|
255
|
+
# 创建智能体时传入Memory配置
|
256
|
+
agent = agent_class(
|
257
|
+
name=agent_name,
|
258
|
+
memory=memory,
|
259
|
+
)
|
260
|
+
|
261
|
+
self._agents[agent._uuid] = agent # type:ignore
|
262
|
+
self._agent_uuids.append(agent._uuid) # type:ignore
|
263
|
+
|
264
|
+
# 计算需要的组数,向上取整以处理不足一组的情况
|
265
|
+
num_group = (agent_count_i + group_size - 1) // group_size
|
266
|
+
|
267
|
+
for k in range(num_group):
|
268
|
+
start_idx = class_init_index + k * group_size
|
269
|
+
end_idx = min(
|
270
|
+
class_init_index + (k + 1) * group_size, # 修正了索引计算
|
271
|
+
class_init_index + agent_count_i,
|
272
|
+
)
|
273
|
+
|
274
|
+
agents = list(self._agents.values())[start_idx:end_idx]
|
275
|
+
group_name = f"AgentType_{i}_Group_{k}"
|
276
|
+
|
277
|
+
# 收集创建参数
|
278
|
+
group_creation_params.append((group_name, agents))
|
279
|
+
|
280
|
+
class_init_index += agent_count_i
|
281
|
+
|
282
|
+
# 初始化mlflow连接
|
283
|
+
_mlflow_config = self.config.get("metric_request", {}).get("mlflow")
|
284
|
+
if _mlflow_config:
|
285
|
+
mlflow_run_id, _ = init_mlflow_connection(
|
286
|
+
config=_mlflow_config,
|
287
|
+
mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
|
288
|
+
experiment_name=self.exp_name,
|
289
|
+
)
|
290
|
+
else:
|
291
|
+
mlflow_run_id = None
|
292
|
+
# 建表
|
293
|
+
if self.enable_pgsql:
|
294
|
+
_num_workers = min(1, pg_sql_writers)
|
295
|
+
create_pg_tables(
|
296
|
+
exp_id=self.exp_id,
|
297
|
+
dsn=self._pgsql_dsn,
|
298
|
+
)
|
299
|
+
self._pgsql_writers = _workers = [
|
300
|
+
PgWriter.remote(self.exp_id, self._pgsql_dsn)
|
301
|
+
for _ in range(_num_workers)
|
302
|
+
]
|
303
|
+
else:
|
304
|
+
_num_workers = 1
|
305
|
+
self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
|
306
|
+
# 收集所有创建组的参数
|
307
|
+
creation_tasks = []
|
308
|
+
for i, (group_name, agents) in enumerate(group_creation_params):
|
309
|
+
# 直接创建异步任务
|
310
|
+
group = AgentGroup.remote(
|
311
|
+
agents,
|
312
|
+
self.config,
|
313
|
+
self.exp_id,
|
314
|
+
self.exp_name,
|
315
|
+
self.enable_avro,
|
316
|
+
self.avro_path,
|
317
|
+
self.enable_pgsql,
|
318
|
+
_workers[i % _num_workers], # type:ignore
|
319
|
+
mlflow_run_id, # type:ignore
|
320
|
+
embedding_model,
|
321
|
+
self.logging_level,
|
322
|
+
)
|
323
|
+
creation_tasks.append((group_name, group, agents))
|
324
|
+
|
325
|
+
# 更新数据结构
|
326
|
+
for group_name, group, agents in creation_tasks:
|
327
|
+
self._groups[group_name] = group
|
328
|
+
for agent in agents:
|
329
|
+
self._agent_uuid2group[agent._uuid] = group
|
330
|
+
|
331
|
+
# 并行初始化所有组的agents
|
332
|
+
init_tasks = []
|
333
|
+
for group in self._groups.values():
|
334
|
+
init_tasks.append(group.init_agents.remote())
|
335
|
+
await asyncio.gather(*init_tasks)
|
336
|
+
|
337
|
+
# 设置用户主题
|
338
|
+
for uuid, agent in self._agents.items():
|
339
|
+
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
340
|
+
self._user_survey_topics[uuid] = (
|
341
|
+
f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
342
|
+
)
|
343
|
+
|
344
|
+
async def gather(self, content: str):
|
345
|
+
"""收集智能体的特定信息"""
|
346
|
+
gather_tasks = []
|
347
|
+
for group in self._groups.values():
|
348
|
+
gather_tasks.append(group.gather.remote(content))
|
349
|
+
return await asyncio.gather(*gather_tasks)
|
350
|
+
|
351
|
+
async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
|
352
|
+
"""更新指定智能体的记忆"""
|
353
|
+
group = self._agent_uuid2group[target_agent_uuid]
|
354
|
+
await group.update.remote(target_agent_uuid, target_key, content)
|
355
|
+
|
356
|
+
def default_memory_config_institution(self):
|
357
|
+
"""默认的Memory配置函数"""
|
358
|
+
EXTRA_ATTRIBUTES = {
|
359
|
+
"type": (
|
360
|
+
int,
|
361
|
+
random.choice(
|
362
|
+
[
|
363
|
+
economyv2.ORG_TYPE_BANK,
|
364
|
+
economyv2.ORG_TYPE_GOVERNMENT,
|
365
|
+
economyv2.ORG_TYPE_FIRM,
|
366
|
+
economyv2.ORG_TYPE_NBS,
|
367
|
+
economyv2.ORG_TYPE_UNSPECIFIED,
|
368
|
+
]
|
369
|
+
),
|
370
|
+
),
|
371
|
+
"nominal_gdp": (list, [], True),
|
372
|
+
"real_gdp": (list, [], True),
|
373
|
+
"unemployment": (list, [], True),
|
374
|
+
"wages": (list, [], True),
|
375
|
+
"prices": (list, [], True),
|
376
|
+
"inventory": (int, 0, True),
|
377
|
+
"price": (float, 0.0, True),
|
378
|
+
"interest_rate": (float, 0.0, True),
|
379
|
+
"bracket_cutoffs": (list, [], True),
|
380
|
+
"bracket_rates": (list, [], True),
|
381
|
+
"employees": (list, [], True),
|
382
|
+
"customers": (list, [], True),
|
383
|
+
}
|
384
|
+
return EXTRA_ATTRIBUTES, None, None
|
385
|
+
|
386
|
+
def default_memory_config_citizen(self):
|
387
|
+
"""默认的Memory配置函数"""
|
388
|
+
EXTRA_ATTRIBUTES = {
|
389
|
+
# 需求信息
|
390
|
+
"needs": (
|
391
|
+
dict,
|
392
|
+
{
|
393
|
+
"hungry": random.random(), # 饥饿感
|
394
|
+
"tired": random.random(), # 疲劳感
|
395
|
+
"safe": random.random(), # 安全需
|
396
|
+
"social": random.random(), # 社会需求
|
397
|
+
},
|
398
|
+
True,
|
399
|
+
),
|
400
|
+
"current_need": (str, "none", True),
|
401
|
+
"current_plan": (list, [], True),
|
402
|
+
"current_step": (dict, {"intention": "", "type": ""}, True),
|
403
|
+
"execution_context": (dict, {}, True),
|
404
|
+
"plan_history": (list, [], True),
|
405
|
+
# cognition
|
406
|
+
"fulfillment": (int, 5, True),
|
407
|
+
"emotion": (int, 5, True),
|
408
|
+
"attitude": (int, 5, True),
|
409
|
+
"thought": (str, "Currently nothing good or bad is happening", True),
|
410
|
+
"emotion_types": (str, "Relief", True),
|
411
|
+
"incident": (list, [], True),
|
412
|
+
# social
|
413
|
+
"friends": (list, [], True),
|
414
|
+
}
|
415
|
+
|
416
|
+
PROFILE = {
|
417
|
+
"name": "unknown",
|
418
|
+
"gender": random.choice(["male", "female"]),
|
419
|
+
"education": random.choice(
|
420
|
+
["Doctor", "Master", "Bachelor", "College", "High School"]
|
421
|
+
),
|
422
|
+
"consumption": random.choice(["sightly low", "low", "medium", "high"]),
|
423
|
+
"occupation": random.choice(
|
424
|
+
[
|
425
|
+
"Student",
|
426
|
+
"Teacher",
|
427
|
+
"Doctor",
|
428
|
+
"Engineer",
|
429
|
+
"Manager",
|
430
|
+
"Businessman",
|
431
|
+
"Artist",
|
432
|
+
"Athlete",
|
433
|
+
"Other",
|
434
|
+
]
|
435
|
+
),
|
436
|
+
"age": random.randint(18, 65),
|
437
|
+
"skill": random.choice(
|
438
|
+
[
|
439
|
+
"Good at problem-solving",
|
440
|
+
"Good at communication",
|
441
|
+
"Good at creativity",
|
442
|
+
"Good at teamwork",
|
443
|
+
"Other",
|
444
|
+
]
|
445
|
+
),
|
446
|
+
"family_consumption": random.choice(["low", "medium", "high"]),
|
447
|
+
"personality": random.choice(
|
448
|
+
["outgoint", "introvert", "ambivert", "extrovert"]
|
449
|
+
),
|
450
|
+
"income": str(random.randint(1000, 10000)),
|
451
|
+
"currency": random.randint(10000, 100000),
|
452
|
+
"residence": random.choice(["city", "suburb", "rural"]),
|
453
|
+
"race": random.choice(
|
454
|
+
[
|
455
|
+
"Chinese",
|
456
|
+
"American",
|
457
|
+
"British",
|
458
|
+
"French",
|
459
|
+
"German",
|
460
|
+
"Japanese",
|
461
|
+
"Korean",
|
462
|
+
"Russian",
|
463
|
+
"Other",
|
464
|
+
]
|
465
|
+
),
|
466
|
+
"religion": random.choice(
|
467
|
+
["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]
|
468
|
+
),
|
469
|
+
"marital_status": random.choice(
|
470
|
+
["not married", "married", "divorced", "widowed"]
|
471
|
+
),
|
472
|
+
}
|
473
|
+
|
474
|
+
BASE = {
|
475
|
+
"home": {
|
476
|
+
"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
|
477
|
+
},
|
478
|
+
"work": {
|
479
|
+
"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
|
480
|
+
},
|
481
|
+
}
|
482
|
+
|
483
|
+
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
484
|
+
|
485
|
+
async def send_survey(
|
486
|
+
self, survey: Survey, agent_uuids: Optional[list[uuid.UUID]] = None
|
487
|
+
):
|
488
|
+
"""发送问卷"""
|
489
|
+
survey_dict = survey.to_dict()
|
490
|
+
if agent_uuids is None:
|
491
|
+
agent_uuids = self._agent_uuids
|
492
|
+
_date_time = datetime.now(timezone.utc)
|
493
|
+
payload = {
|
494
|
+
"from": "none",
|
495
|
+
"survey_id": survey_dict["id"],
|
496
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
497
|
+
"data": survey_dict,
|
498
|
+
"_date_time": _date_time,
|
499
|
+
}
|
500
|
+
for uuid in agent_uuids:
|
501
|
+
topic = self._user_survey_topics[uuid]
|
502
|
+
await self._messager.send_message.remote(topic, payload)
|
503
|
+
|
504
|
+
async def send_interview_message(
|
505
|
+
self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
|
506
|
+
):
|
507
|
+
"""发送面试消息"""
|
508
|
+
_date_time = datetime.now(timezone.utc)
|
509
|
+
payload = {
|
510
|
+
"from": "none",
|
511
|
+
"content": content,
|
512
|
+
"timestamp": int(_date_time.timestamp() * 1000),
|
513
|
+
"_date_time": _date_time,
|
514
|
+
}
|
515
|
+
if not isinstance(agent_uuids, Sequence):
|
516
|
+
agent_uuids = [agent_uuids]
|
517
|
+
for uuid in agent_uuids:
|
518
|
+
topic = self._user_chat_topics[uuid]
|
519
|
+
await self._messager.send_message.remote(topic, payload)
|
520
|
+
|
521
|
+
async def step(self):
|
522
|
+
"""运行一步, 即每个智能体执行一次forward"""
|
523
|
+
try:
|
524
|
+
tasks = []
|
525
|
+
for group in self._groups.values():
|
526
|
+
tasks.append(group.step.remote())
|
527
|
+
await asyncio.gather(*tasks)
|
528
|
+
except Exception as e:
|
529
|
+
logger.error(f"运行错误: {str(e)}")
|
530
|
+
raise
|
531
|
+
|
532
|
+
async def _save_exp_info(self) -> None:
|
533
|
+
"""异步保存实验信息到YAML文件"""
|
534
|
+
try:
|
535
|
+
if self.enable_avro:
|
536
|
+
with open(self._exp_info_file, "w") as f:
|
537
|
+
yaml.dump(self._exp_info, f)
|
538
|
+
except Exception as e:
|
539
|
+
logger.error(f"Avro保存实验信息失败: {str(e)}")
|
540
|
+
try:
|
541
|
+
if self.enable_pgsql:
|
542
|
+
worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
|
543
|
+
pg_exp_info = {
|
544
|
+
k: self._exp_info[k] for (k, _) in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
|
545
|
+
}
|
546
|
+
pg_exp_info["created_at"] = self._exp_created_time
|
547
|
+
pg_exp_info["updated_at"] = self._exp_updated_time
|
548
|
+
await worker.async_update_exp_info.remote( # type:ignore
|
549
|
+
pg_exp_info
|
550
|
+
)
|
551
|
+
except Exception as e:
|
552
|
+
logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
|
553
|
+
|
554
|
+
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
555
|
+
self._exp_updated_time = datetime.now(timezone.utc)
|
556
|
+
"""更新实验状态并保存"""
|
557
|
+
self._exp_info["status"] = status
|
558
|
+
self._exp_info["error"] = error
|
559
|
+
self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
|
560
|
+
await self._save_exp_info()
|
561
|
+
|
562
|
+
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
563
|
+
"""监控实验状态并更新
|
564
|
+
|
565
|
+
Args:
|
566
|
+
stop_event: 用于通知监控任务停止的事件
|
567
|
+
"""
|
568
|
+
try:
|
569
|
+
while not stop_event.is_set():
|
570
|
+
# 更新实验状态
|
571
|
+
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
572
|
+
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
573
|
+
self._exp_info["cur_t"] = (
|
574
|
+
await self._simulator.get_simulator_second_from_start_of_day()
|
575
|
+
)
|
576
|
+
await self._save_exp_info()
|
577
|
+
|
578
|
+
await asyncio.sleep(1) # 避免过于频繁的更新
|
579
|
+
except asyncio.CancelledError:
|
580
|
+
# 正常取消,不需要特殊处理
|
581
|
+
pass
|
582
|
+
except Exception as e:
|
583
|
+
logger.error(f"监控实验状态时发生错误: {str(e)}")
|
584
|
+
raise
|
585
|
+
|
586
|
+
async def run(
|
587
|
+
self,
|
588
|
+
day: int = 1,
|
589
|
+
):
|
590
|
+
"""运行模拟器"""
|
591
|
+
try:
|
592
|
+
self._exp_info["num_day"] += day
|
593
|
+
await self._update_exp_status(1) # 更新状态为运行中
|
594
|
+
|
595
|
+
# 创建停止事件
|
596
|
+
stop_event = asyncio.Event()
|
597
|
+
# 创建监控任务
|
598
|
+
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
599
|
+
|
600
|
+
try:
|
601
|
+
for _ in range(day):
|
602
|
+
tasks = []
|
603
|
+
for group in self._groups.values():
|
604
|
+
tasks.append(group.run.remote())
|
605
|
+
# 等待所有group运行完成
|
606
|
+
await asyncio.gather(*tasks)
|
607
|
+
|
608
|
+
finally:
|
609
|
+
# 设置停止事件
|
610
|
+
stop_event.set()
|
611
|
+
# 等待监控任务结束
|
612
|
+
await monitor_task
|
613
|
+
|
614
|
+
# 运行成功后更新状态
|
615
|
+
await self._update_exp_status(2)
|
616
|
+
|
617
|
+
except Exception as e:
|
618
|
+
error_msg = f"模拟器运行错误: {str(e)}"
|
619
|
+
logger.error(error_msg)
|
620
|
+
await self._update_exp_status(3, error_msg)
|
621
|
+
raise RuntimeError(error_msg) from e
|
622
|
+
|
623
|
+
async def __aenter__(self):
|
624
|
+
"""异步上下文管理器入口"""
|
625
|
+
return self
|
626
|
+
|
627
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
628
|
+
"""异步上下文管理器出口"""
|
629
|
+
if exc_type is not None:
|
630
|
+
# 如果发生异常,更新状态为错误
|
631
|
+
await self._update_exp_status(3, str(exc_val))
|
632
|
+
elif self._exp_info["status"] != 3:
|
633
|
+
# 如果没有发生异常且状态不是错误,则更新为完成
|
634
|
+
await self._update_exp_status(2)
|