pycityagent 2.0.0a43__cp312-cp312-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. pycityagent/__init__.py +23 -0
  2. pycityagent/agent.py +833 -0
  3. pycityagent/cli/wrapper.py +44 -0
  4. pycityagent/economy/__init__.py +5 -0
  5. pycityagent/economy/econ_client.py +355 -0
  6. pycityagent/environment/__init__.py +7 -0
  7. pycityagent/environment/interact/__init__.py +0 -0
  8. pycityagent/environment/interact/interact.py +198 -0
  9. pycityagent/environment/message/__init__.py +0 -0
  10. pycityagent/environment/sence/__init__.py +0 -0
  11. pycityagent/environment/sence/static.py +416 -0
  12. pycityagent/environment/sidecar/__init__.py +8 -0
  13. pycityagent/environment/sidecar/sidecarv2.py +109 -0
  14. pycityagent/environment/sim/__init__.py +29 -0
  15. pycityagent/environment/sim/aoi_service.py +39 -0
  16. pycityagent/environment/sim/client.py +126 -0
  17. pycityagent/environment/sim/clock_service.py +44 -0
  18. pycityagent/environment/sim/economy_services.py +192 -0
  19. pycityagent/environment/sim/lane_service.py +111 -0
  20. pycityagent/environment/sim/light_service.py +122 -0
  21. pycityagent/environment/sim/person_service.py +295 -0
  22. pycityagent/environment/sim/road_service.py +39 -0
  23. pycityagent/environment/sim/sim_env.py +145 -0
  24. pycityagent/environment/sim/social_service.py +59 -0
  25. pycityagent/environment/simulator.py +331 -0
  26. pycityagent/environment/utils/__init__.py +14 -0
  27. pycityagent/environment/utils/base64.py +16 -0
  28. pycityagent/environment/utils/const.py +244 -0
  29. pycityagent/environment/utils/geojson.py +24 -0
  30. pycityagent/environment/utils/grpc.py +57 -0
  31. pycityagent/environment/utils/map_utils.py +157 -0
  32. pycityagent/environment/utils/port.py +11 -0
  33. pycityagent/environment/utils/protobuf.py +41 -0
  34. pycityagent/llm/__init__.py +11 -0
  35. pycityagent/llm/embeddings.py +231 -0
  36. pycityagent/llm/llm.py +377 -0
  37. pycityagent/llm/llmconfig.py +13 -0
  38. pycityagent/llm/utils.py +6 -0
  39. pycityagent/memory/__init__.py +13 -0
  40. pycityagent/memory/const.py +43 -0
  41. pycityagent/memory/faiss_query.py +302 -0
  42. pycityagent/memory/memory.py +448 -0
  43. pycityagent/memory/memory_base.py +170 -0
  44. pycityagent/memory/profile.py +165 -0
  45. pycityagent/memory/self_define.py +165 -0
  46. pycityagent/memory/state.py +173 -0
  47. pycityagent/memory/utils.py +28 -0
  48. pycityagent/message/__init__.py +3 -0
  49. pycityagent/message/messager.py +88 -0
  50. pycityagent/metrics/__init__.py +6 -0
  51. pycityagent/metrics/mlflow_client.py +147 -0
  52. pycityagent/metrics/utils/const.py +0 -0
  53. pycityagent/pycityagent-sim +0 -0
  54. pycityagent/pycityagent-ui +0 -0
  55. pycityagent/simulation/__init__.py +8 -0
  56. pycityagent/simulation/agentgroup.py +580 -0
  57. pycityagent/simulation/simulation.py +634 -0
  58. pycityagent/simulation/storage/pg.py +184 -0
  59. pycityagent/survey/__init__.py +4 -0
  60. pycityagent/survey/manager.py +54 -0
  61. pycityagent/survey/models.py +120 -0
  62. pycityagent/utils/__init__.py +11 -0
  63. pycityagent/utils/avro_schema.py +109 -0
  64. pycityagent/utils/decorators.py +99 -0
  65. pycityagent/utils/parsers/__init__.py +13 -0
  66. pycityagent/utils/parsers/code_block_parser.py +37 -0
  67. pycityagent/utils/parsers/json_parser.py +86 -0
  68. pycityagent/utils/parsers/parser_base.py +60 -0
  69. pycityagent/utils/pg_query.py +92 -0
  70. pycityagent/utils/survey_util.py +53 -0
  71. pycityagent/workflow/__init__.py +26 -0
  72. pycityagent/workflow/block.py +211 -0
  73. pycityagent/workflow/prompt.py +79 -0
  74. pycityagent/workflow/tool.py +240 -0
  75. pycityagent/workflow/trigger.py +163 -0
  76. pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
  77. pycityagent-2.0.0a43.dist-info/METADATA +235 -0
  78. pycityagent-2.0.0a43.dist-info/RECORD +81 -0
  79. pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
  80. pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
  81. pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
@@ -0,0 +1,634 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import time
7
+ import uuid
8
+ from collections.abc import Callable, Sequence
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ from datetime import datetime, timezone
11
+ from pathlib import Path
12
+ from typing import Any, Optional, Union
13
+
14
+ import pycityproto.city.economy.v2.economy_pb2 as economyv2
15
+ import ray
16
+ import yaml
17
+ from langchain_core.embeddings import Embeddings
18
+ from mosstool.map._map_util.const import AOI_START_ID
19
+
20
+ from ..agent import Agent, InstitutionAgent
21
+ from ..environment.simulator import Simulator
22
+ from ..llm import SimpleEmbedding
23
+ from ..memory import Memory
24
+ from ..message.messager import Messager
25
+ from ..metrics import init_mlflow_connection
26
+ from ..survey import Survey
27
+ from ..utils import TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
28
+ from .agentgroup import AgentGroup
29
+ from .storage.pg import PgWriter, create_pg_tables
30
+
31
+ logger = logging.getLogger("pycityagent")
32
+
33
+
34
+ class AgentSimulation:
35
+ """城市智能体模拟器"""
36
+
37
+ def __init__(
38
+ self,
39
+ agent_class: Union[type[Agent], list[type[Agent]]],
40
+ config: dict,
41
+ agent_prefix: str = "agent_",
42
+ exp_name: str = "default_experiment",
43
+ logging_level: int = logging.WARNING,
44
+ ):
45
+ """
46
+ Args:
47
+ agent_class: 智能体类
48
+ config: 配置
49
+ agent_prefix: 智能体名称前缀
50
+ exp_name: 实验名称
51
+ """
52
+ self.exp_id = str(uuid.uuid4())
53
+ if isinstance(agent_class, list):
54
+ self.agent_class = agent_class
55
+ else:
56
+ self.agent_class = [agent_class]
57
+ self.logging_level = logging_level
58
+ self.config = config
59
+ self.exp_name = exp_name
60
+ self._simulator = Simulator(config["simulator_request"])
61
+ self.agent_prefix = agent_prefix
62
+ self._agents: dict[uuid.UUID, Agent] = {}
63
+ self._groups: dict[str, AgentGroup] = {} # type:ignore
64
+ self._agent_uuid2group: dict[uuid.UUID, AgentGroup] = {} # type:ignore
65
+ self._agent_uuids: list[uuid.UUID] = []
66
+ self._user_chat_topics: dict[uuid.UUID, str] = {}
67
+ self._user_survey_topics: dict[uuid.UUID, str] = {}
68
+ self._user_interview_topics: dict[uuid.UUID, str] = {}
69
+ self._loop = asyncio.get_event_loop()
70
+ # self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
71
+
72
+ self._messager = Messager.remote(
73
+ hostname=config["simulator_request"]["mqtt"]["server"], # type:ignore
74
+ port=config["simulator_request"]["mqtt"]["port"],
75
+ username=config["simulator_request"]["mqtt"].get("username", None),
76
+ password=config["simulator_request"]["mqtt"].get("password", None),
77
+ )
78
+
79
+ # storage
80
+ _storage_config: dict[str, Any] = config.get("storage", {})
81
+ if _storage_config is None:
82
+ _storage_config = {}
83
+ # avro
84
+ _avro_config: dict[str, Any] = _storage_config.get("avro", {})
85
+ self._enable_avro = _avro_config.get("enabled", False)
86
+ if not self._enable_avro:
87
+ self._avro_path = None
88
+ logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
89
+ else:
90
+ self._avro_path = Path(_avro_config["path"]) / f"{self.exp_id}"
91
+ self._avro_path.mkdir(parents=True, exist_ok=True)
92
+
93
+ # pg
94
+ _pgsql_config: dict[str, Any] = _storage_config.get("pgsql", {})
95
+ self._enable_pgsql = _pgsql_config.get("enabled", False)
96
+ if not self._enable_pgsql:
97
+ logger.warning("PostgreSQL is not enabled, NO POSTGRESQL DATABASE STORAGE")
98
+ self._pgsql_dsn = ""
99
+ else:
100
+ self._pgsql_dsn = (
101
+ _pgsql_config["data_source_name"]
102
+ if "data_source_name" in _pgsql_config
103
+ else _pgsql_config["dsn"]
104
+ )
105
+
106
+ # 添加实验信息相关的属性
107
+ self._exp_created_time = datetime.now(timezone.utc)
108
+ self._exp_updated_time = datetime.now(timezone.utc)
109
+ self._exp_info = {
110
+ "id": self.exp_id,
111
+ "name": exp_name,
112
+ "num_day": 0, # 将在 run 方法中更新
113
+ "status": 0,
114
+ "cur_day": 0,
115
+ "cur_t": 0.0,
116
+ "config": json.dumps(config),
117
+ "error": "",
118
+ "created_at": self._exp_created_time.isoformat(),
119
+ "updated_at": self._exp_updated_time.isoformat(),
120
+ }
121
+
122
+ # 创建异步任务保存实验信息
123
+ if self._enable_avro:
124
+ assert self._avro_path is not None
125
+ self._exp_info_file = self._avro_path / "experiment_info.yaml"
126
+ with open(self._exp_info_file, "w") as f:
127
+ yaml.dump(self._exp_info, f)
128
+
129
+ @property
130
+ def enable_avro(
131
+ self,
132
+ ) -> bool:
133
+ return self._enable_avro
134
+
135
+ @property
136
+ def enable_pgsql(
137
+ self,
138
+ ) -> bool:
139
+ return self._enable_pgsql
140
+
141
+ @property
142
+ def agents(self) -> dict[uuid.UUID, Agent]:
143
+ return self._agents
144
+
145
+ @property
146
+ def avro_path(
147
+ self,
148
+ ) -> Path:
149
+ return self._avro_path # type:ignore
150
+
151
+ @property
152
+ def groups(self):
153
+ return self._groups
154
+
155
+ @property
156
+ def agent_uuids(self):
157
+ return self._agent_uuids
158
+
159
+ @property
160
+ def agent_uuid2group(self):
161
+ return self._agent_uuid2group
162
+
163
+ def create_remote_group(
164
+ self,
165
+ group_name: str,
166
+ agents: list[Agent],
167
+ config: dict,
168
+ exp_id: str,
169
+ exp_name: str,
170
+ enable_avro: bool,
171
+ avro_path: Path,
172
+ enable_pgsql: bool,
173
+ pgsql_writer: ray.ObjectRef,
174
+ mlflow_run_id: str = None, # type: ignore
175
+ embedding_model: Embeddings = None, # type: ignore
176
+ logging_level: int = logging.WARNING,
177
+ ):
178
+ """创建远程组"""
179
+ group = AgentGroup.remote(
180
+ agents,
181
+ config,
182
+ exp_id,
183
+ exp_name,
184
+ enable_avro,
185
+ avro_path,
186
+ enable_pgsql,
187
+ pgsql_writer,
188
+ mlflow_run_id,
189
+ embedding_model,
190
+ logging_level,
191
+ )
192
+ return group_name, group, agents
193
+
194
+ async def init_agents(
195
+ self,
196
+ agent_count: Union[int, list[int]],
197
+ group_size: int = 1000,
198
+ pg_sql_writers: int = 32,
199
+ embedding_model: Embeddings = SimpleEmbedding(),
200
+ memory_config_func: Optional[Union[Callable, list[Callable]]] = None,
201
+ ) -> None:
202
+ """初始化智能体
203
+
204
+ Args:
205
+ agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
206
+ group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
207
+ memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
208
+ """
209
+ await self._messager.connect.remote()
210
+ if not isinstance(agent_count, list):
211
+ agent_count = [agent_count]
212
+
213
+ if len(self.agent_class) != len(agent_count):
214
+ raise ValueError("agent_class和agent_count的长度不一致")
215
+
216
+ if memory_config_func is None:
217
+ logger.warning(
218
+ "memory_config_func is None, using default memory config function"
219
+ )
220
+ memory_config_func = []
221
+ for agent_class in self.agent_class:
222
+ if issubclass(agent_class, InstitutionAgent):
223
+ memory_config_func.append(self.default_memory_config_institution)
224
+ else:
225
+ memory_config_func.append(self.default_memory_config_citizen)
226
+ elif not isinstance(memory_config_func, list):
227
+ memory_config_func = [memory_config_func]
228
+
229
+ if len(memory_config_func) != len(agent_count):
230
+ logger.warning(
231
+ "memory_config_func和agent_count的长度不一致,使用默认的memory_config"
232
+ )
233
+ memory_config_func = []
234
+ for agent_class in self.agent_class:
235
+ if issubclass(agent_class, InstitutionAgent):
236
+ memory_config_func.append(self.default_memory_config_institution)
237
+ else:
238
+ memory_config_func.append(self.default_memory_config_citizen)
239
+ # 使用线程池并行创建 AgentGroup
240
+ group_creation_params = []
241
+ class_init_index = 0
242
+
243
+ # 首先收集所有需要创建的组的参数
244
+ for i in range(len(self.agent_class)):
245
+ agent_class = self.agent_class[i]
246
+ agent_count_i = agent_count[i]
247
+ memory_config_func_i = memory_config_func[i]
248
+ for j in range(agent_count_i):
249
+ agent_name = f"{self.agent_prefix}_{i}_{j}"
250
+
251
+ # 获取Memory配置
252
+ extra_attributes, profile, base = memory_config_func_i()
253
+ memory = Memory(config=extra_attributes, profile=profile, base=base)
254
+
255
+ # 创建智能体时传入Memory配置
256
+ agent = agent_class(
257
+ name=agent_name,
258
+ memory=memory,
259
+ )
260
+
261
+ self._agents[agent._uuid] = agent # type:ignore
262
+ self._agent_uuids.append(agent._uuid) # type:ignore
263
+
264
+ # 计算需要的组数,向上取整以处理不足一组的情况
265
+ num_group = (agent_count_i + group_size - 1) // group_size
266
+
267
+ for k in range(num_group):
268
+ start_idx = class_init_index + k * group_size
269
+ end_idx = min(
270
+ class_init_index + (k + 1) * group_size, # 修正了索引计算
271
+ class_init_index + agent_count_i,
272
+ )
273
+
274
+ agents = list(self._agents.values())[start_idx:end_idx]
275
+ group_name = f"AgentType_{i}_Group_{k}"
276
+
277
+ # 收集创建参数
278
+ group_creation_params.append((group_name, agents))
279
+
280
+ class_init_index += agent_count_i
281
+
282
+ # 初始化mlflow连接
283
+ _mlflow_config = self.config.get("metric_request", {}).get("mlflow")
284
+ if _mlflow_config:
285
+ mlflow_run_id, _ = init_mlflow_connection(
286
+ config=_mlflow_config,
287
+ mlflow_run_name=f"EXP_{self.exp_name}_{1000*int(time.time())}",
288
+ experiment_name=self.exp_name,
289
+ )
290
+ else:
291
+ mlflow_run_id = None
292
+ # 建表
293
+ if self.enable_pgsql:
294
+ _num_workers = min(1, pg_sql_writers)
295
+ create_pg_tables(
296
+ exp_id=self.exp_id,
297
+ dsn=self._pgsql_dsn,
298
+ )
299
+ self._pgsql_writers = _workers = [
300
+ PgWriter.remote(self.exp_id, self._pgsql_dsn)
301
+ for _ in range(_num_workers)
302
+ ]
303
+ else:
304
+ _num_workers = 1
305
+ self._pgsql_writers = _workers = [None for _ in range(_num_workers)]
306
+ # 收集所有创建组的参数
307
+ creation_tasks = []
308
+ for i, (group_name, agents) in enumerate(group_creation_params):
309
+ # 直接创建异步任务
310
+ group = AgentGroup.remote(
311
+ agents,
312
+ self.config,
313
+ self.exp_id,
314
+ self.exp_name,
315
+ self.enable_avro,
316
+ self.avro_path,
317
+ self.enable_pgsql,
318
+ _workers[i % _num_workers], # type:ignore
319
+ mlflow_run_id, # type:ignore
320
+ embedding_model,
321
+ self.logging_level,
322
+ )
323
+ creation_tasks.append((group_name, group, agents))
324
+
325
+ # 更新数据结构
326
+ for group_name, group, agents in creation_tasks:
327
+ self._groups[group_name] = group
328
+ for agent in agents:
329
+ self._agent_uuid2group[agent._uuid] = group
330
+
331
+ # 并行初始化所有组的agents
332
+ init_tasks = []
333
+ for group in self._groups.values():
334
+ init_tasks.append(group.init_agents.remote())
335
+ await asyncio.gather(*init_tasks)
336
+
337
+ # 设置用户主题
338
+ for uuid, agent in self._agents.items():
339
+ self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
340
+ self._user_survey_topics[uuid] = (
341
+ f"exps/{self.exp_id}/agents/{uuid}/user-survey"
342
+ )
343
+
344
+ async def gather(self, content: str):
345
+ """收集智能体的特定信息"""
346
+ gather_tasks = []
347
+ for group in self._groups.values():
348
+ gather_tasks.append(group.gather.remote(content))
349
+ return await asyncio.gather(*gather_tasks)
350
+
351
+ async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
352
+ """更新指定智能体的记忆"""
353
+ group = self._agent_uuid2group[target_agent_uuid]
354
+ await group.update.remote(target_agent_uuid, target_key, content)
355
+
356
+ def default_memory_config_institution(self):
357
+ """默认的Memory配置函数"""
358
+ EXTRA_ATTRIBUTES = {
359
+ "type": (
360
+ int,
361
+ random.choice(
362
+ [
363
+ economyv2.ORG_TYPE_BANK,
364
+ economyv2.ORG_TYPE_GOVERNMENT,
365
+ economyv2.ORG_TYPE_FIRM,
366
+ economyv2.ORG_TYPE_NBS,
367
+ economyv2.ORG_TYPE_UNSPECIFIED,
368
+ ]
369
+ ),
370
+ ),
371
+ "nominal_gdp": (list, [], True),
372
+ "real_gdp": (list, [], True),
373
+ "unemployment": (list, [], True),
374
+ "wages": (list, [], True),
375
+ "prices": (list, [], True),
376
+ "inventory": (int, 0, True),
377
+ "price": (float, 0.0, True),
378
+ "interest_rate": (float, 0.0, True),
379
+ "bracket_cutoffs": (list, [], True),
380
+ "bracket_rates": (list, [], True),
381
+ "employees": (list, [], True),
382
+ "customers": (list, [], True),
383
+ }
384
+ return EXTRA_ATTRIBUTES, None, None
385
+
386
+ def default_memory_config_citizen(self):
387
+ """默认的Memory配置函数"""
388
+ EXTRA_ATTRIBUTES = {
389
+ # 需求信息
390
+ "needs": (
391
+ dict,
392
+ {
393
+ "hungry": random.random(), # 饥饿感
394
+ "tired": random.random(), # 疲劳感
395
+ "safe": random.random(), # 安全需
396
+ "social": random.random(), # 社会需求
397
+ },
398
+ True,
399
+ ),
400
+ "current_need": (str, "none", True),
401
+ "current_plan": (list, [], True),
402
+ "current_step": (dict, {"intention": "", "type": ""}, True),
403
+ "execution_context": (dict, {}, True),
404
+ "plan_history": (list, [], True),
405
+ # cognition
406
+ "fulfillment": (int, 5, True),
407
+ "emotion": (int, 5, True),
408
+ "attitude": (int, 5, True),
409
+ "thought": (str, "Currently nothing good or bad is happening", True),
410
+ "emotion_types": (str, "Relief", True),
411
+ "incident": (list, [], True),
412
+ # social
413
+ "friends": (list, [], True),
414
+ }
415
+
416
+ PROFILE = {
417
+ "name": "unknown",
418
+ "gender": random.choice(["male", "female"]),
419
+ "education": random.choice(
420
+ ["Doctor", "Master", "Bachelor", "College", "High School"]
421
+ ),
422
+ "consumption": random.choice(["sightly low", "low", "medium", "high"]),
423
+ "occupation": random.choice(
424
+ [
425
+ "Student",
426
+ "Teacher",
427
+ "Doctor",
428
+ "Engineer",
429
+ "Manager",
430
+ "Businessman",
431
+ "Artist",
432
+ "Athlete",
433
+ "Other",
434
+ ]
435
+ ),
436
+ "age": random.randint(18, 65),
437
+ "skill": random.choice(
438
+ [
439
+ "Good at problem-solving",
440
+ "Good at communication",
441
+ "Good at creativity",
442
+ "Good at teamwork",
443
+ "Other",
444
+ ]
445
+ ),
446
+ "family_consumption": random.choice(["low", "medium", "high"]),
447
+ "personality": random.choice(
448
+ ["outgoint", "introvert", "ambivert", "extrovert"]
449
+ ),
450
+ "income": str(random.randint(1000, 10000)),
451
+ "currency": random.randint(10000, 100000),
452
+ "residence": random.choice(["city", "suburb", "rural"]),
453
+ "race": random.choice(
454
+ [
455
+ "Chinese",
456
+ "American",
457
+ "British",
458
+ "French",
459
+ "German",
460
+ "Japanese",
461
+ "Korean",
462
+ "Russian",
463
+ "Other",
464
+ ]
465
+ ),
466
+ "religion": random.choice(
467
+ ["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]
468
+ ),
469
+ "marital_status": random.choice(
470
+ ["not married", "married", "divorced", "widowed"]
471
+ ),
472
+ }
473
+
474
+ BASE = {
475
+ "home": {
476
+ "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
477
+ },
478
+ "work": {
479
+ "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
480
+ },
481
+ }
482
+
483
+ return EXTRA_ATTRIBUTES, PROFILE, BASE
484
+
485
+ async def send_survey(
486
+ self, survey: Survey, agent_uuids: Optional[list[uuid.UUID]] = None
487
+ ):
488
+ """发送问卷"""
489
+ survey_dict = survey.to_dict()
490
+ if agent_uuids is None:
491
+ agent_uuids = self._agent_uuids
492
+ _date_time = datetime.now(timezone.utc)
493
+ payload = {
494
+ "from": "none",
495
+ "survey_id": survey_dict["id"],
496
+ "timestamp": int(_date_time.timestamp() * 1000),
497
+ "data": survey_dict,
498
+ "_date_time": _date_time,
499
+ }
500
+ for uuid in agent_uuids:
501
+ topic = self._user_survey_topics[uuid]
502
+ await self._messager.send_message.remote(topic, payload)
503
+
504
+ async def send_interview_message(
505
+ self, content: str, agent_uuids: Union[uuid.UUID, list[uuid.UUID]]
506
+ ):
507
+ """发送面试消息"""
508
+ _date_time = datetime.now(timezone.utc)
509
+ payload = {
510
+ "from": "none",
511
+ "content": content,
512
+ "timestamp": int(_date_time.timestamp() * 1000),
513
+ "_date_time": _date_time,
514
+ }
515
+ if not isinstance(agent_uuids, Sequence):
516
+ agent_uuids = [agent_uuids]
517
+ for uuid in agent_uuids:
518
+ topic = self._user_chat_topics[uuid]
519
+ await self._messager.send_message.remote(topic, payload)
520
+
521
+ async def step(self):
522
+ """运行一步, 即每个智能体执行一次forward"""
523
+ try:
524
+ tasks = []
525
+ for group in self._groups.values():
526
+ tasks.append(group.step.remote())
527
+ await asyncio.gather(*tasks)
528
+ except Exception as e:
529
+ logger.error(f"运行错误: {str(e)}")
530
+ raise
531
+
532
+ async def _save_exp_info(self) -> None:
533
+ """异步保存实验信息到YAML文件"""
534
+ try:
535
+ if self.enable_avro:
536
+ with open(self._exp_info_file, "w") as f:
537
+ yaml.dump(self._exp_info, f)
538
+ except Exception as e:
539
+ logger.error(f"Avro保存实验信息失败: {str(e)}")
540
+ try:
541
+ if self.enable_pgsql:
542
+ worker: ray.ObjectRef = self._pgsql_writers[0] # type:ignore
543
+ pg_exp_info = {
544
+ k: self._exp_info[k] for (k, _) in TO_UPDATE_EXP_INFO_KEYS_AND_TYPES
545
+ }
546
+ pg_exp_info["created_at"] = self._exp_created_time
547
+ pg_exp_info["updated_at"] = self._exp_updated_time
548
+ await worker.async_update_exp_info.remote( # type:ignore
549
+ pg_exp_info
550
+ )
551
+ except Exception as e:
552
+ logger.error(f"PostgreSQL保存实验信息失败: {str(e)}")
553
+
554
+ async def _update_exp_status(self, status: int, error: str = "") -> None:
555
+ self._exp_updated_time = datetime.now(timezone.utc)
556
+ """更新实验状态并保存"""
557
+ self._exp_info["status"] = status
558
+ self._exp_info["error"] = error
559
+ self._exp_info["updated_at"] = self._exp_updated_time.isoformat()
560
+ await self._save_exp_info()
561
+
562
+ async def _monitor_exp_status(self, stop_event: asyncio.Event):
563
+ """监控实验状态并更新
564
+
565
+ Args:
566
+ stop_event: 用于通知监控任务停止的事件
567
+ """
568
+ try:
569
+ while not stop_event.is_set():
570
+ # 更新实验状态
571
+ # 假设所有group的cur_day和cur_t是同步的,取第一个即可
572
+ self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
573
+ self._exp_info["cur_t"] = (
574
+ await self._simulator.get_simulator_second_from_start_of_day()
575
+ )
576
+ await self._save_exp_info()
577
+
578
+ await asyncio.sleep(1) # 避免过于频繁的更新
579
+ except asyncio.CancelledError:
580
+ # 正常取消,不需要特殊处理
581
+ pass
582
+ except Exception as e:
583
+ logger.error(f"监控实验状态时发生错误: {str(e)}")
584
+ raise
585
+
586
+ async def run(
587
+ self,
588
+ day: int = 1,
589
+ ):
590
+ """运行模拟器"""
591
+ try:
592
+ self._exp_info["num_day"] += day
593
+ await self._update_exp_status(1) # 更新状态为运行中
594
+
595
+ # 创建停止事件
596
+ stop_event = asyncio.Event()
597
+ # 创建监控任务
598
+ monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
599
+
600
+ try:
601
+ for _ in range(day):
602
+ tasks = []
603
+ for group in self._groups.values():
604
+ tasks.append(group.run.remote())
605
+ # 等待所有group运行完成
606
+ await asyncio.gather(*tasks)
607
+
608
+ finally:
609
+ # 设置停止事件
610
+ stop_event.set()
611
+ # 等待监控任务结束
612
+ await monitor_task
613
+
614
+ # 运行成功后更新状态
615
+ await self._update_exp_status(2)
616
+
617
+ except Exception as e:
618
+ error_msg = f"模拟器运行错误: {str(e)}"
619
+ logger.error(error_msg)
620
+ await self._update_exp_status(3, error_msg)
621
+ raise RuntimeError(error_msg) from e
622
+
623
+ async def __aenter__(self):
624
+ """异步上下文管理器入口"""
625
+ return self
626
+
627
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
628
+ """异步上下文管理器出口"""
629
+ if exc_type is not None:
630
+ # 如果发生异常,更新状态为错误
631
+ await self._update_exp_status(3, str(exc_val))
632
+ elif self._exp_info["status"] != 3:
633
+ # 如果没有发生异常且状态不是错误,则更新为完成
634
+ await self._update_exp_status(2)