pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +80 -63
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +135 -51
- pycityagent/simulation/simulation.py +206 -92
- pycityagent/survey/manager.py +10 -14
- pycityagent/survey/models.py +24 -24
- pycityagent/utils/__init__.py +2 -2
- pycityagent/utils/avro_schema.py +26 -1
- pycityagent/workflow/tool.py +1 -4
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/METADATA +3 -2
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/RECORD +16 -16
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/WHEEL +0 -0
@@ -2,23 +2,27 @@ import asyncio
|
|
2
2
|
import json
|
3
3
|
import logging
|
4
4
|
import os
|
5
|
-
from pathlib import Path
|
6
|
-
import uuid
|
7
|
-
from datetime import datetime
|
8
5
|
import random
|
9
|
-
|
10
|
-
import
|
11
|
-
from
|
6
|
+
import uuid
|
7
|
+
from collections.abc import Sequence
|
8
|
+
from concurrent.futures import ThreadPoolExecutor
|
9
|
+
from datetime import datetime, timezone
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
12
|
+
|
12
13
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
14
|
+
import yaml
|
15
|
+
from mosstool.map._map_util.const import AOI_START_ID
|
16
|
+
|
17
|
+
from pycityagent.environment.simulator import Simulator
|
13
18
|
from pycityagent.memory.memory import Memory
|
14
19
|
from pycityagent.message.messager import Messager
|
15
20
|
from pycityagent.survey import Survey
|
16
|
-
from pycityagent.utils.avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
|
17
21
|
|
18
22
|
from ..agent import Agent, InstitutionAgent
|
19
23
|
from .agentgroup import AgentGroup
|
20
24
|
|
21
|
-
logger = logging.getLogger(
|
25
|
+
logger = logging.getLogger("pycityagent")
|
22
26
|
|
23
27
|
|
24
28
|
class AgentSimulation:
|
@@ -29,23 +33,28 @@ class AgentSimulation:
|
|
29
33
|
agent_class: Union[type[Agent], list[type[Agent]]],
|
30
34
|
config: dict,
|
31
35
|
agent_prefix: str = "agent_",
|
36
|
+
exp_name: str = "default_experiment",
|
37
|
+
logging_level: int = logging.WARNING,
|
32
38
|
):
|
33
39
|
"""
|
34
40
|
Args:
|
35
41
|
agent_class: 智能体类
|
36
42
|
config: 配置
|
37
43
|
agent_prefix: 智能体名称前缀
|
44
|
+
exp_name: 实验名称
|
38
45
|
"""
|
39
|
-
self.exp_id = uuid.uuid4()
|
46
|
+
self.exp_id = str(uuid.uuid4())
|
40
47
|
if isinstance(agent_class, list):
|
41
48
|
self.agent_class = agent_class
|
42
49
|
else:
|
43
50
|
self.agent_class = [agent_class]
|
51
|
+
self.logging_level = logging_level
|
44
52
|
self.config = config
|
53
|
+
self._simulator = Simulator(config["simulator_request"])
|
45
54
|
self.agent_prefix = agent_prefix
|
46
55
|
self._agents: Dict[uuid.UUID, Agent] = {}
|
47
|
-
self._groups: Dict[str, AgentGroup] = {}
|
48
|
-
self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
|
56
|
+
self._groups: Dict[str, AgentGroup] = {} # type:ignore
|
57
|
+
self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {} # type:ignore
|
49
58
|
self._agent_uuids: List[uuid.UUID] = []
|
50
59
|
self._user_chat_topics: Dict[uuid.UUID, str] = {}
|
51
60
|
self._user_survey_topics: Dict[uuid.UUID, str] = {}
|
@@ -61,13 +70,10 @@ class AgentSimulation:
|
|
61
70
|
asyncio.create_task(self._messager.connect())
|
62
71
|
|
63
72
|
self._enable_avro = config["storage"]["avro"]["enabled"]
|
64
|
-
self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
"status": self._avro_path / f"{self.exp_id}_status.avro",
|
69
|
-
"survey": self._avro_path / f"{self.exp_id}_survey.avro",
|
70
|
-
}
|
73
|
+
if not self._enable_avro:
|
74
|
+
logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
|
75
|
+
self._avro_path = Path(config["storage"]["avro"]["path"]) / f"{self.exp_id}"
|
76
|
+
self._avro_path.mkdir(parents=True, exist_ok=True)
|
71
77
|
|
72
78
|
self._enable_pgsql = config["storage"]["pgsql"]["enabled"]
|
73
79
|
self._pgsql_host = config["storage"]["pgsql"]["host"]
|
@@ -76,22 +82,56 @@ class AgentSimulation:
|
|
76
82
|
self._pgsql_user = config["storage"]["pgsql"]["user"]
|
77
83
|
self._pgsql_password = config["storage"]["pgsql"]["password"]
|
78
84
|
|
85
|
+
# 添加实验信息相关的属性
|
86
|
+
self._exp_info = {
|
87
|
+
"id": self.exp_id,
|
88
|
+
"name": exp_name,
|
89
|
+
"num_day": 0, # 将在 run 方法中更新
|
90
|
+
"status": 0,
|
91
|
+
"cur_day": 0,
|
92
|
+
"cur_t": 0.0,
|
93
|
+
"config": json.dumps(config),
|
94
|
+
"error": "",
|
95
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
96
|
+
}
|
97
|
+
|
98
|
+
# 创建异步任务保存实验信息
|
99
|
+
self._exp_info_file = self._avro_path / "experiment_info.yaml"
|
100
|
+
with open(self._exp_info_file, "w") as f:
|
101
|
+
yaml.dump(self._exp_info, f)
|
102
|
+
|
79
103
|
@property
|
80
104
|
def agents(self):
|
81
105
|
return self._agents
|
82
|
-
|
106
|
+
|
83
107
|
@property
|
84
108
|
def groups(self):
|
85
109
|
return self._groups
|
86
|
-
|
110
|
+
|
87
111
|
@property
|
88
112
|
def agent_uuids(self):
|
89
113
|
return self._agent_uuids
|
90
|
-
|
114
|
+
|
91
115
|
@property
|
92
116
|
def agent_uuid2group(self):
|
93
117
|
return self._agent_uuid2group
|
94
118
|
|
119
|
+
def create_remote_group(
|
120
|
+
self,
|
121
|
+
group_name: str,
|
122
|
+
agents: list[Agent],
|
123
|
+
config: dict,
|
124
|
+
exp_id: str,
|
125
|
+
enable_avro: bool,
|
126
|
+
avro_path: Path,
|
127
|
+
logging_level: int = logging.WARNING,
|
128
|
+
):
|
129
|
+
"""创建远程组"""
|
130
|
+
group = AgentGroup.remote(
|
131
|
+
agents, config, exp_id, enable_avro, avro_path, logging_level
|
132
|
+
)
|
133
|
+
return group_name, group, agents
|
134
|
+
|
95
135
|
async def init_agents(
|
96
136
|
self,
|
97
137
|
agent_count: Union[int, list[int]],
|
@@ -112,7 +152,7 @@ class AgentSimulation:
|
|
112
152
|
raise ValueError("agent_class和agent_count的长度不一致")
|
113
153
|
|
114
154
|
if memory_config_func is None:
|
115
|
-
|
155
|
+
logger.warning(
|
116
156
|
"memory_config_func is None, using default memory config function"
|
117
157
|
)
|
118
158
|
memory_config_func = []
|
@@ -125,17 +165,21 @@ class AgentSimulation:
|
|
125
165
|
memory_config_func = [memory_config_func]
|
126
166
|
|
127
167
|
if len(memory_config_func) != len(agent_count):
|
128
|
-
|
168
|
+
logger.warning(
|
129
169
|
"memory_config_func和agent_count的长度不一致,使用默认的memory_config"
|
130
170
|
)
|
131
171
|
memory_config_func = []
|
132
172
|
for agent_class in self.agent_class:
|
133
|
-
if agent_class
|
173
|
+
if issubclass(agent_class, InstitutionAgent):
|
134
174
|
memory_config_func.append(self.default_memory_config_institution)
|
135
175
|
else:
|
136
176
|
memory_config_func.append(self.default_memory_config_citizen)
|
137
177
|
|
178
|
+
# 使用线程池并行创建 AgentGroup
|
179
|
+
group_creation_params = []
|
138
180
|
class_init_index = 0
|
181
|
+
|
182
|
+
# 首先收集所有需要创建的组的参数
|
139
183
|
for i in range(len(self.agent_class)):
|
140
184
|
agent_class = self.agent_class[i]
|
141
185
|
agent_count_i = agent_count[i]
|
@@ -145,15 +189,12 @@ class AgentSimulation:
|
|
145
189
|
|
146
190
|
# 获取Memory配置
|
147
191
|
extra_attributes, profile, base = memory_config_func_i()
|
148
|
-
memory = Memory(
|
149
|
-
config=extra_attributes, profile=profile, base=base
|
150
|
-
)
|
192
|
+
memory = Memory(config=extra_attributes, profile=profile, base=base)
|
151
193
|
|
152
194
|
# 创建智能体时传入Memory配置
|
153
195
|
agent = agent_class(
|
154
196
|
name=agent_name,
|
155
197
|
memory=memory,
|
156
|
-
avro_file=self._avro_file,
|
157
198
|
)
|
158
199
|
|
159
200
|
self._agents[agent._uuid] = agent
|
@@ -163,63 +204,52 @@ class AgentSimulation:
|
|
163
204
|
num_group = (agent_count_i + group_size - 1) // group_size
|
164
205
|
|
165
206
|
for k in range(num_group):
|
166
|
-
# 计算当前组的起始和结束索引
|
167
207
|
start_idx = class_init_index + k * group_size
|
168
208
|
end_idx = min(
|
169
|
-
class_init_index +
|
209
|
+
class_init_index + (k + 1) * group_size, # 修正了索引计算
|
170
210
|
class_init_index + agent_count_i,
|
171
211
|
)
|
172
212
|
|
173
|
-
# 获取当前组的agents
|
174
213
|
agents = list(self._agents.values())[start_idx:end_idx]
|
175
214
|
group_name = f"AgentType_{i}_Group_{k}"
|
176
|
-
group = AgentGroup.remote(agents, self.config, self.exp_id, self._avro_file)
|
177
|
-
self._groups[group_name] = group
|
178
|
-
for agent in agents:
|
179
|
-
self._agent_uuid2group[agent._uuid] = group
|
180
215
|
|
181
|
-
|
216
|
+
# 收集创建参数
|
217
|
+
group_creation_params.append((group_name, agents))
|
218
|
+
|
219
|
+
class_init_index += agent_count_i
|
220
|
+
|
221
|
+
# 收集所有创建组的参数
|
222
|
+
creation_tasks = []
|
223
|
+
for group_name, agents in group_creation_params:
|
224
|
+
# 直接创建异步任务
|
225
|
+
group = AgentGroup.remote(
|
226
|
+
agents,
|
227
|
+
self.config,
|
228
|
+
self.exp_id,
|
229
|
+
self._enable_avro,
|
230
|
+
self._avro_path,
|
231
|
+
self.logging_level,
|
232
|
+
)
|
233
|
+
creation_tasks.append((group_name, group, agents))
|
234
|
+
|
235
|
+
# 更新数据结构
|
236
|
+
for group_name, group, agents in creation_tasks:
|
237
|
+
self._groups[group_name] = group
|
238
|
+
for agent in agents:
|
239
|
+
self._agent_uuid2group[agent._uuid] = group
|
182
240
|
|
241
|
+
# 并行初始化所有组的agents
|
183
242
|
init_tasks = []
|
184
243
|
for group in self._groups.values():
|
185
244
|
init_tasks.append(group.init_agents.remote())
|
186
245
|
await asyncio.gather(*init_tasks)
|
246
|
+
|
247
|
+
# 设置用户主题
|
187
248
|
for uuid, agent in self._agents.items():
|
188
249
|
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
189
|
-
self._user_survey_topics[uuid] =
|
190
|
-
|
191
|
-
|
192
|
-
if self._enable_avro:
|
193
|
-
self._avro_path.mkdir(parents=True, exist_ok=True)
|
194
|
-
# profile
|
195
|
-
filename = self._avro_file["profile"]
|
196
|
-
with open(filename, "wb") as f:
|
197
|
-
profiles = []
|
198
|
-
for agent in self._agents.values():
|
199
|
-
profile = await agent.memory._profile.export()
|
200
|
-
profile = profile[0]
|
201
|
-
profile['id'] = str(agent._uuid)
|
202
|
-
profiles.append(profile)
|
203
|
-
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
204
|
-
|
205
|
-
# dialog
|
206
|
-
filename = self._avro_file["dialog"]
|
207
|
-
with open(filename, "wb") as f:
|
208
|
-
dialogs = []
|
209
|
-
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
210
|
-
|
211
|
-
# status
|
212
|
-
filename = self._avro_file["status"]
|
213
|
-
with open(filename, "wb") as f:
|
214
|
-
statuses = []
|
215
|
-
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
216
|
-
|
217
|
-
# survey
|
218
|
-
filename = self._avro_file["survey"]
|
219
|
-
with open(filename, "wb") as f:
|
220
|
-
surveys = []
|
221
|
-
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
222
|
-
|
250
|
+
self._user_survey_topics[uuid] = (
|
251
|
+
f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
252
|
+
)
|
223
253
|
|
224
254
|
async def gather(self, content: str):
|
225
255
|
"""收集智能体的特定信息"""
|
@@ -228,15 +258,26 @@ class AgentSimulation:
|
|
228
258
|
gather_tasks.append(group.gather.remote(content))
|
229
259
|
return await asyncio.gather(*gather_tasks)
|
230
260
|
|
231
|
-
async def update(self,
|
261
|
+
async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
|
232
262
|
"""更新指定智能体的记忆"""
|
233
|
-
group = self._agent_uuid2group[
|
234
|
-
await group.update.remote(
|
263
|
+
group = self._agent_uuid2group[target_agent_uuid]
|
264
|
+
await group.update.remote(target_agent_uuid, target_key, content)
|
235
265
|
|
236
266
|
def default_memory_config_institution(self):
|
237
267
|
"""默认的Memory配置函数"""
|
238
268
|
EXTRA_ATTRIBUTES = {
|
239
|
-
"type": (
|
269
|
+
"type": (
|
270
|
+
int,
|
271
|
+
random.choice(
|
272
|
+
[
|
273
|
+
economyv2.ORG_TYPE_BANK,
|
274
|
+
economyv2.ORG_TYPE_GOVERNMENT,
|
275
|
+
economyv2.ORG_TYPE_FIRM,
|
276
|
+
economyv2.ORG_TYPE_NBS,
|
277
|
+
economyv2.ORG_TYPE_UNSPECIFIED,
|
278
|
+
]
|
279
|
+
),
|
280
|
+
),
|
240
281
|
"nominal_gdp": (list, [], True),
|
241
282
|
"real_gdp": (list, [], True),
|
242
283
|
"unemployment": (list, [], True),
|
@@ -350,29 +391,35 @@ class AgentSimulation:
|
|
350
391
|
}
|
351
392
|
|
352
393
|
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
353
|
-
|
354
|
-
async def send_survey(
|
394
|
+
|
395
|
+
async def send_survey(
|
396
|
+
self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None
|
397
|
+
):
|
355
398
|
"""发送问卷"""
|
356
|
-
|
399
|
+
survey_dict = survey.to_dict()
|
357
400
|
if agent_uuids is None:
|
358
401
|
agent_uuids = self._agent_uuids
|
359
402
|
payload = {
|
360
403
|
"from": "none",
|
361
|
-
"survey_id":
|
404
|
+
"survey_id": survey_dict["id"],
|
362
405
|
"timestamp": int(datetime.now().timestamp() * 1000),
|
363
|
-
"data":
|
406
|
+
"data": survey_dict,
|
364
407
|
}
|
365
408
|
for uuid in agent_uuids:
|
366
409
|
topic = self._user_survey_topics[uuid]
|
367
410
|
await self._messager.send_message(topic, payload)
|
368
411
|
|
369
|
-
async def send_interview_message(
|
412
|
+
async def send_interview_message(
|
413
|
+
self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]
|
414
|
+
):
|
370
415
|
"""发送面试消息"""
|
371
416
|
payload = {
|
372
417
|
"from": "none",
|
373
418
|
"content": content,
|
374
419
|
"timestamp": int(datetime.now().timestamp() * 1000),
|
375
420
|
}
|
421
|
+
if not isinstance(agent_uuids, Sequence):
|
422
|
+
agent_uuids = [agent_uuids]
|
376
423
|
for uuid in agent_uuids:
|
377
424
|
topic = self._user_chat_topics[uuid]
|
378
425
|
await self._messager.send_message(topic, payload)
|
@@ -388,23 +435,90 @@ class AgentSimulation:
|
|
388
435
|
logger.error(f"运行错误: {str(e)}")
|
389
436
|
raise
|
390
437
|
|
438
|
+
async def _save_exp_info(self) -> None:
|
439
|
+
"""异步保存实验信息到YAML文件"""
|
440
|
+
try:
|
441
|
+
with open(self._exp_info_file, "w") as f:
|
442
|
+
yaml.dump(self._exp_info, f)
|
443
|
+
except Exception as e:
|
444
|
+
logger.error(f"保存实验信息失败: {str(e)}")
|
445
|
+
|
446
|
+
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
447
|
+
"""更新实验状态并保存"""
|
448
|
+
self._exp_info["status"] = status
|
449
|
+
self._exp_info["error"] = error
|
450
|
+
await self._save_exp_info()
|
451
|
+
|
452
|
+
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
453
|
+
"""监控实验状态并更新
|
454
|
+
|
455
|
+
Args:
|
456
|
+
stop_event: 用于通知监控任务停止的事件
|
457
|
+
"""
|
458
|
+
try:
|
459
|
+
while not stop_event.is_set():
|
460
|
+
# 更新实验状态
|
461
|
+
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
462
|
+
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
463
|
+
self._exp_info["cur_t"] = (
|
464
|
+
await self._simulator.get_simulator_second_from_start_of_day()
|
465
|
+
)
|
466
|
+
await self._save_exp_info()
|
467
|
+
|
468
|
+
await asyncio.sleep(1) # 避免过于频繁的更新
|
469
|
+
except asyncio.CancelledError:
|
470
|
+
# 正常取消,不需要特殊处理
|
471
|
+
pass
|
472
|
+
except Exception as e:
|
473
|
+
logger.error(f"监控实验状态时发生错误: {str(e)}")
|
474
|
+
raise
|
475
|
+
|
391
476
|
async def run(
|
392
477
|
self,
|
393
478
|
day: int = 1,
|
394
479
|
):
|
395
|
-
"""运行模拟器
|
396
|
-
|
397
|
-
Args:
|
398
|
-
day: 运行天数,默认为1天
|
399
|
-
"""
|
480
|
+
"""运行模拟器"""
|
400
481
|
try:
|
401
|
-
|
402
|
-
|
403
|
-
for group in self._groups.values():
|
404
|
-
tasks.append(group.run.remote(day))
|
482
|
+
self._exp_info["num_day"] += day
|
483
|
+
await self._update_exp_status(1) # 更新状态为运行中
|
405
484
|
|
406
|
-
|
485
|
+
# 创建停止事件
|
486
|
+
stop_event = asyncio.Event()
|
487
|
+
# 创建监控任务
|
488
|
+
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
489
|
+
|
490
|
+
try:
|
491
|
+
tasks = []
|
492
|
+
for group in self._groups.values():
|
493
|
+
tasks.append(group.run.remote())
|
494
|
+
|
495
|
+
# 等待所有group运行完成
|
496
|
+
await asyncio.gather(*tasks)
|
497
|
+
|
498
|
+
finally:
|
499
|
+
# 设置停止事件
|
500
|
+
stop_event.set()
|
501
|
+
# 等待监控任务结束
|
502
|
+
await monitor_task
|
503
|
+
|
504
|
+
# 运行成功后更新状态
|
505
|
+
await self._update_exp_status(2)
|
407
506
|
|
408
507
|
except Exception as e:
|
409
|
-
|
410
|
-
|
508
|
+
error_msg = f"模拟器运行错误: {str(e)}"
|
509
|
+
logger.error(error_msg)
|
510
|
+
await self._update_exp_status(3, error_msg)
|
511
|
+
raise e
|
512
|
+
|
513
|
+
async def __aenter__(self):
|
514
|
+
"""异步上下文管理器入口"""
|
515
|
+
return self
|
516
|
+
|
517
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
518
|
+
"""异步上下文管理器出口"""
|
519
|
+
if exc_type is not None:
|
520
|
+
# 如果发生异常,更新状态为错误
|
521
|
+
await self._update_exp_status(3, str(exc_val))
|
522
|
+
elif self._exp_info["status"] != 3:
|
523
|
+
# 如果没有发生异常且状态不是错误,则更新为完成
|
524
|
+
await self._update_exp_status(2)
|
pycityagent/survey/manager.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
from typing import List, Dict, Optional
|
2
|
-
from datetime import datetime
|
3
|
-
import uuid
|
4
1
|
import json
|
5
|
-
|
2
|
+
import uuid
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
from .models import Page, Question, QuestionType, Survey
|
6
7
|
|
7
8
|
|
8
9
|
class SurveyManager:
|
9
10
|
def __init__(self):
|
10
|
-
self._surveys:
|
11
|
+
self._surveys: dict[str, Survey] = {}
|
11
12
|
|
12
|
-
def create_survey(
|
13
|
-
self, title: str, description: str, pages: List[dict]
|
14
|
-
) -> Survey:
|
13
|
+
def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
|
15
14
|
"""创建新问卷"""
|
16
15
|
survey_id = uuid.uuid4()
|
17
16
|
|
@@ -32,11 +31,8 @@ class SurveyManager:
|
|
32
31
|
max_rating=q.get("max_rating", 5),
|
33
32
|
)
|
34
33
|
questions.append(question)
|
35
|
-
|
36
|
-
page = Page(
|
37
|
-
name=page_data["name"],
|
38
|
-
elements=questions
|
39
|
-
)
|
34
|
+
|
35
|
+
page = Page(name=page_data["name"], elements=questions)
|
40
36
|
survey_pages.append(page)
|
41
37
|
|
42
38
|
survey = Survey(
|
@@ -53,6 +49,6 @@ class SurveyManager:
|
|
53
49
|
"""获取指定问卷"""
|
54
50
|
return self._surveys.get(survey_id)
|
55
51
|
|
56
|
-
def get_all_surveys(self) ->
|
52
|
+
def get_all_surveys(self) -> list[Survey]:
|
57
53
|
"""获取所有问卷"""
|
58
54
|
return list(self._surveys.values())
|
pycityagent/survey/models.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
|
+
import json
|
2
|
+
import uuid
|
1
3
|
from dataclasses import dataclass, field
|
2
|
-
from typing import List, Dict, Optional
|
3
4
|
from datetime import datetime
|
4
5
|
from enum import Enum
|
5
|
-
import
|
6
|
-
import json
|
6
|
+
from typing import Any
|
7
7
|
|
8
8
|
|
9
9
|
class QuestionType(Enum):
|
@@ -20,19 +20,20 @@ class Question:
|
|
20
20
|
name: str
|
21
21
|
title: str
|
22
22
|
type: QuestionType
|
23
|
-
choices:
|
24
|
-
columns:
|
25
|
-
rows:
|
23
|
+
choices: list[str] = field(default_factory=list)
|
24
|
+
columns: list[str] = field(default_factory=list)
|
25
|
+
rows: list[str] = field(default_factory=list)
|
26
|
+
required: bool = True
|
26
27
|
min_rating: int = 1
|
27
28
|
max_rating: int = 5
|
28
29
|
|
29
30
|
def to_dict(self) -> dict:
|
30
|
-
base_dict = {
|
31
|
+
base_dict: dict[str, Any] = {
|
31
32
|
"type": self.type.value,
|
32
33
|
"name": self.name,
|
33
34
|
"title": self.title,
|
34
35
|
}
|
35
|
-
|
36
|
+
|
36
37
|
if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
|
37
38
|
base_dict["choices"] = self.choices
|
38
39
|
elif self.type == QuestionType.MATRIX:
|
@@ -41,20 +42,17 @@ class Question:
|
|
41
42
|
elif self.type == QuestionType.RATING:
|
42
43
|
base_dict["min_rating"] = self.min_rating
|
43
44
|
base_dict["max_rating"] = self.max_rating
|
44
|
-
|
45
|
+
|
45
46
|
return base_dict
|
46
47
|
|
47
48
|
|
48
49
|
@dataclass
|
49
50
|
class Page:
|
50
51
|
name: str
|
51
|
-
elements:
|
52
|
+
elements: list[Question]
|
52
53
|
|
53
54
|
def to_dict(self) -> dict:
|
54
|
-
return {
|
55
|
-
"name": self.name,
|
56
|
-
"elements": [q.to_dict() for q in self.elements]
|
57
|
-
}
|
55
|
+
return {"name": self.name, "elements": [q.to_dict() for q in self.elements]}
|
58
56
|
|
59
57
|
|
60
58
|
@dataclass
|
@@ -62,8 +60,8 @@ class Survey:
|
|
62
60
|
id: uuid.UUID
|
63
61
|
title: str
|
64
62
|
description: str
|
65
|
-
pages:
|
66
|
-
responses:
|
63
|
+
pages: list[Page]
|
64
|
+
responses: dict[str, dict] = field(default_factory=dict)
|
67
65
|
created_at: datetime = field(default_factory=datetime.now)
|
68
66
|
|
69
67
|
def to_dict(self) -> dict:
|
@@ -83,12 +81,12 @@ class Survey:
|
|
83
81
|
"description": self.description,
|
84
82
|
"pages": [p.to_dict() for p in self.pages],
|
85
83
|
"responses": self.responses,
|
86
|
-
"created_at": self.created_at.isoformat()
|
84
|
+
"created_at": self.created_at.isoformat(),
|
87
85
|
}
|
88
86
|
return json.dumps(survey_dict)
|
89
87
|
|
90
88
|
@classmethod
|
91
|
-
def from_json(cls, json_str: str) ->
|
89
|
+
def from_json(cls, json_str: str) -> "Survey":
|
92
90
|
"""Create a Survey instance from a JSON string"""
|
93
91
|
data = json.loads(json_str)
|
94
92
|
pages = [
|
@@ -104,17 +102,19 @@ class Survey:
|
|
104
102
|
columns=q.get("columns", []),
|
105
103
|
rows=q.get("rows", []),
|
106
104
|
min_rating=q.get("min_rating", 1),
|
107
|
-
max_rating=q.get("max_rating", 5)
|
108
|
-
)
|
109
|
-
|
110
|
-
|
105
|
+
max_rating=q.get("max_rating", 5),
|
106
|
+
)
|
107
|
+
for q in p["elements"]
|
108
|
+
],
|
109
|
+
)
|
110
|
+
for p in data["pages"]
|
111
111
|
]
|
112
|
-
|
112
|
+
|
113
113
|
return cls(
|
114
114
|
id=uuid.UUID(data["id"]),
|
115
115
|
title=data["title"],
|
116
116
|
description=data["description"],
|
117
117
|
pages=pages,
|
118
118
|
responses=data.get("responses", {}),
|
119
|
-
created_at=datetime.fromisoformat(data["created_at"])
|
119
|
+
created_at=datetime.fromisoformat(data["created_at"]),
|
120
120
|
)
|
pycityagent/utils/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
|
-
from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
|
1
|
+
from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
|
2
2
|
from .survey_util import process_survey_for_llm
|
3
3
|
|
4
4
|
__all__ = [
|
5
|
-
"PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA",
|
5
|
+
"PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
|
6
6
|
"process_survey_for_llm"
|
7
7
|
]
|
pycityagent/utils/avro_schema.py
CHANGED
@@ -66,6 +66,31 @@ STATUS_SCHEMA = {
|
|
66
66
|
],
|
67
67
|
}
|
68
68
|
|
69
|
+
INSTITUTION_STATUS_SCHEMA = {
|
70
|
+
"doc": "Institution状态",
|
71
|
+
"name": "InstitutionStatus",
|
72
|
+
"namespace": "com.socialcity",
|
73
|
+
"type": "record",
|
74
|
+
"fields": [
|
75
|
+
{"name": "id", "type": "string"}, # uuid as string
|
76
|
+
{"name": "day", "type": "int"},
|
77
|
+
{"name": "t", "type": "float"},
|
78
|
+
{"name": "type", "type": "int"},
|
79
|
+
{"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
80
|
+
{"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
81
|
+
{"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
82
|
+
{"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
83
|
+
{"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
84
|
+
{"name": "inventory", "type": ["int", "null"]},
|
85
|
+
{"name": "price", "type": ["float", "null"]},
|
86
|
+
{"name": "interest_rate", "type": ["float", "null"]},
|
87
|
+
{"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
88
|
+
{"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
89
|
+
{"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
90
|
+
{"name": "customers", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
91
|
+
],
|
92
|
+
}
|
93
|
+
|
69
94
|
SURVEY_SCHEMA = {
|
70
95
|
"doc": "Agent问卷",
|
71
96
|
"name": "AgentSurvey",
|
@@ -82,4 +107,4 @@ SURVEY_SCHEMA = {
|
|
82
107
|
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
83
108
|
},
|
84
109
|
],
|
85
|
-
}
|
110
|
+
}
|