pycityagent 2.0.0a13__py3-none-any.whl → 2.0.0a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +164 -63
- pycityagent/economy/econ_client.py +2 -0
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/const.py +1 -0
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +136 -14
- pycityagent/simulation/simulation.py +212 -42
- pycityagent/survey/manager.py +58 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +7 -0
- pycityagent/utils/avro_schema.py +110 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/tool.py +0 -3
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/METADATA +3 -1
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/RECORD +20 -21
- pycityagent/simulation/interview.py +0 -40
- pycityagent/simulation/survey/manager.py +0 -68
- pycityagent/simulation/survey/models.py +0 -52
- pycityagent/simulation/ui/__init__.py +0 -3
- pycityagent/simulation/ui/interface.py +0 -602
- /pycityagent/{simulation/survey → survey}/__init__.py +0 -0
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/WHEEL +0 -0
@@ -1,21 +1,25 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
4
6
|
import uuid
|
5
|
-
from datetime import datetime
|
7
|
+
from datetime import datetime, timezone
|
6
8
|
import random
|
7
9
|
from typing import Dict, List, Optional, Callable, Union,Any
|
8
10
|
from mosstool.map._map_util.const import AOI_START_ID
|
9
11
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
12
|
+
from pycityagent.environment.simulator import Simulator
|
10
13
|
from pycityagent.memory.memory import Memory
|
14
|
+
from pycityagent.message.messager import Messager
|
15
|
+
from pycityagent.survey import Survey
|
16
|
+
import yaml
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
11
18
|
|
12
19
|
from ..agent import Agent, InstitutionAgent
|
13
|
-
from .interview import InterviewManager
|
14
|
-
from .survey import QuestionType, SurveyManager
|
15
|
-
from .ui import InterviewUI
|
16
20
|
from .agentgroup import AgentGroup
|
17
21
|
|
18
|
-
logger = logging.getLogger(
|
22
|
+
logger = logging.getLogger("pycityagent")
|
19
23
|
|
20
24
|
|
21
25
|
class AgentSimulation:
|
@@ -26,28 +30,72 @@ class AgentSimulation:
|
|
26
30
|
agent_class: Union[type[Agent], list[type[Agent]]],
|
27
31
|
config: dict,
|
28
32
|
agent_prefix: str = "agent_",
|
33
|
+
exp_name: str = "default_experiment",
|
34
|
+
logging_level: int = logging.WARNING
|
29
35
|
):
|
30
36
|
"""
|
31
37
|
Args:
|
32
38
|
agent_class: 智能体类
|
33
39
|
config: 配置
|
34
40
|
agent_prefix: 智能体名称前缀
|
41
|
+
exp_name: 实验名称
|
35
42
|
"""
|
36
|
-
self.exp_id = uuid.uuid4()
|
43
|
+
self.exp_id = str(uuid.uuid4())
|
37
44
|
if isinstance(agent_class, list):
|
38
45
|
self.agent_class = agent_class
|
39
46
|
else:
|
40
47
|
self.agent_class = [agent_class]
|
48
|
+
self.logging_level = logging_level
|
41
49
|
self.config = config
|
50
|
+
self._simulator = Simulator(config["simulator_request"])
|
42
51
|
self.agent_prefix = agent_prefix
|
43
52
|
self._agents: Dict[uuid.UUID, Agent] = {}
|
44
53
|
self._groups: Dict[str, AgentGroup] = {}
|
45
54
|
self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
|
46
55
|
self._agent_uuids: List[uuid.UUID] = []
|
47
|
-
|
56
|
+
self._user_chat_topics: Dict[uuid.UUID, str] = {}
|
57
|
+
self._user_survey_topics: Dict[uuid.UUID, str] = {}
|
58
|
+
self._user_interview_topics: Dict[uuid.UUID, str] = {}
|
48
59
|
self._loop = asyncio.get_event_loop()
|
49
|
-
|
50
|
-
self.
|
60
|
+
|
61
|
+
self._messager = Messager(
|
62
|
+
hostname=config["simulator_request"]["mqtt"]["server"],
|
63
|
+
port=config["simulator_request"]["mqtt"]["port"],
|
64
|
+
username=config["simulator_request"]["mqtt"].get("username", None),
|
65
|
+
password=config["simulator_request"]["mqtt"].get("password", None),
|
66
|
+
)
|
67
|
+
asyncio.create_task(self._messager.connect())
|
68
|
+
|
69
|
+
self._enable_avro = config["storage"]["avro"]["enabled"]
|
70
|
+
if not self._enable_avro:
|
71
|
+
logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
|
72
|
+
self._avro_path = Path(config["storage"]["avro"]["path"]) / f"{self.exp_id}"
|
73
|
+
self._avro_path.mkdir(parents=True, exist_ok=True)
|
74
|
+
|
75
|
+
self._enable_pgsql = config["storage"]["pgsql"]["enabled"]
|
76
|
+
self._pgsql_host = config["storage"]["pgsql"]["host"]
|
77
|
+
self._pgsql_port = config["storage"]["pgsql"]["port"]
|
78
|
+
self._pgsql_database = config["storage"]["pgsql"]["database"]
|
79
|
+
self._pgsql_user = config["storage"]["pgsql"]["user"]
|
80
|
+
self._pgsql_password = config["storage"]["pgsql"]["password"]
|
81
|
+
|
82
|
+
# 添加实验信息相关的属性
|
83
|
+
self._exp_info = {
|
84
|
+
"id": self.exp_id,
|
85
|
+
"name": exp_name,
|
86
|
+
"num_day": 0, # 将在 run 方法中更新
|
87
|
+
"status": 0,
|
88
|
+
"cur_day": 0,
|
89
|
+
"cur_t": 0.0,
|
90
|
+
"config": json.dumps(config),
|
91
|
+
"error": "",
|
92
|
+
"created_at": datetime.now(timezone.utc).isoformat()
|
93
|
+
}
|
94
|
+
|
95
|
+
# 创建异步任务保存实验信息
|
96
|
+
self._exp_info_file = self._avro_path / "experiment_info.yaml"
|
97
|
+
with open(self._exp_info_file, 'w') as f:
|
98
|
+
yaml.dump(self._exp_info, f)
|
51
99
|
|
52
100
|
@property
|
53
101
|
def agents(self):
|
@@ -64,6 +112,11 @@ class AgentSimulation:
|
|
64
112
|
@property
|
65
113
|
def agent_uuid2group(self):
|
66
114
|
return self._agent_uuid2group
|
115
|
+
|
116
|
+
def create_remote_group(self, group_name: str, agents: list[Agent], config: dict, exp_id: str, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
|
117
|
+
"""创建远程组"""
|
118
|
+
group = AgentGroup.remote(agents, config, exp_id, enable_avro, avro_path, logging_level)
|
119
|
+
return group_name, group, agents
|
67
120
|
|
68
121
|
async def init_agents(
|
69
122
|
self,
|
@@ -85,7 +138,7 @@ class AgentSimulation:
|
|
85
138
|
raise ValueError("agent_class和agent_count的长度不一致")
|
86
139
|
|
87
140
|
if memory_config_func is None:
|
88
|
-
|
141
|
+
logger.warning(
|
89
142
|
"memory_config_func is None, using default memory config function"
|
90
143
|
)
|
91
144
|
memory_config_func = []
|
@@ -98,17 +151,21 @@ class AgentSimulation:
|
|
98
151
|
memory_config_func = [memory_config_func]
|
99
152
|
|
100
153
|
if len(memory_config_func) != len(agent_count):
|
101
|
-
|
154
|
+
logger.warning(
|
102
155
|
"memory_config_func和agent_count的长度不一致,使用默认的memory_config"
|
103
156
|
)
|
104
157
|
memory_config_func = []
|
105
158
|
for agent_class in self.agent_class:
|
106
|
-
if agent_class
|
159
|
+
if issubclass(agent_class, InstitutionAgent):
|
107
160
|
memory_config_func.append(self.default_memory_config_institution)
|
108
161
|
else:
|
109
162
|
memory_config_func.append(self.default_memory_config_citizen)
|
110
163
|
|
164
|
+
# 使用线程池并行创建 AgentGroup
|
165
|
+
group_creation_params = []
|
111
166
|
class_init_index = 0
|
167
|
+
|
168
|
+
# 首先收集所有需要创建的组的参数
|
112
169
|
for i in range(len(self.agent_class)):
|
113
170
|
agent_class = self.agent_class[i]
|
114
171
|
agent_count_i = agent_count[i]
|
@@ -133,30 +190,51 @@ class AgentSimulation:
|
|
133
190
|
|
134
191
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
135
192
|
num_group = (agent_count_i + group_size - 1) // group_size
|
136
|
-
|
193
|
+
|
137
194
|
for k in range(num_group):
|
138
|
-
# 计算当前组的起始和结束索引
|
139
195
|
start_idx = class_init_index + k * group_size
|
140
196
|
end_idx = min(
|
141
|
-
class_init_index +
|
142
|
-
class_init_index + agent_count_i
|
197
|
+
class_init_index + (k + 1) * group_size, # 修正了索引计算
|
198
|
+
class_init_index + agent_count_i
|
143
199
|
)
|
144
|
-
|
145
|
-
# 获取当前组的agents
|
200
|
+
|
146
201
|
agents = list(self._agents.values())[start_idx:end_idx]
|
147
202
|
group_name = f"AgentType_{i}_Group_{k}"
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
203
|
+
|
204
|
+
# 收集创建参数
|
205
|
+
group_creation_params.append((
|
206
|
+
group_name,
|
207
|
+
agents
|
208
|
+
))
|
209
|
+
|
210
|
+
class_init_index += agent_count_i
|
211
|
+
|
212
|
+
# 收集所有创建组的参数
|
213
|
+
creation_tasks = []
|
214
|
+
for group_name, agents in group_creation_params:
|
215
|
+
# 直接创建异步任务
|
216
|
+
group = AgentGroup.remote(agents, self.config, self.exp_id,
|
217
|
+
self._enable_avro, self._avro_path,
|
218
|
+
self.logging_level)
|
219
|
+
creation_tasks.append((group_name, group, agents))
|
220
|
+
|
221
|
+
# 更新数据结构
|
222
|
+
for group_name, group, agents in creation_tasks:
|
223
|
+
self._groups[group_name] = group
|
224
|
+
for agent in agents:
|
225
|
+
self._agent_uuid2group[agent._uuid] = group
|
226
|
+
|
227
|
+
# 并行初始化所有组的agents
|
155
228
|
init_tasks = []
|
156
229
|
for group in self._groups.values():
|
157
230
|
init_tasks.append(group.init_agents.remote())
|
158
231
|
await asyncio.gather(*init_tasks)
|
159
232
|
|
233
|
+
# 设置用户主题
|
234
|
+
for uuid, agent in self._agents.items():
|
235
|
+
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
236
|
+
self._user_survey_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
237
|
+
|
160
238
|
async def gather(self, content: str):
|
161
239
|
"""收集智能体的特定信息"""
|
162
240
|
gather_tasks = []
|
@@ -164,10 +242,10 @@ class AgentSimulation:
|
|
164
242
|
gather_tasks.append(group.gather.remote(content))
|
165
243
|
return await asyncio.gather(*gather_tasks)
|
166
244
|
|
167
|
-
async def update(self,
|
245
|
+
async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
|
168
246
|
"""更新指定智能体的记忆"""
|
169
|
-
group = self._agent_uuid2group[
|
170
|
-
await group.update.remote(
|
247
|
+
group = self._agent_uuid2group[target_agent_uuid]
|
248
|
+
await group.update.remote(target_agent_uuid, target_key, content)
|
171
249
|
|
172
250
|
def default_memory_config_institution(self):
|
173
251
|
"""默认的Memory配置函数"""
|
@@ -219,6 +297,7 @@ class AgentSimulation:
|
|
219
297
|
}
|
220
298
|
|
221
299
|
PROFILE = {
|
300
|
+
"name": "unknown",
|
222
301
|
"gender": random.choice(["male", "female"]),
|
223
302
|
"education": random.choice(
|
224
303
|
["Doctor", "Master", "Bachelor", "College", "High School"]
|
@@ -251,7 +330,7 @@ class AgentSimulation:
|
|
251
330
|
"personality": random.choice(
|
252
331
|
["outgoint", "introvert", "ambivert", "extrovert"]
|
253
332
|
),
|
254
|
-
"income": random.randint(1000, 10000),
|
333
|
+
"income": str(random.randint(1000, 10000)),
|
255
334
|
"currency": random.randint(10000, 100000),
|
256
335
|
"residence": random.choice(["city", "suburb", "rural"]),
|
257
336
|
"race": random.choice(
|
@@ -285,6 +364,32 @@ class AgentSimulation:
|
|
285
364
|
}
|
286
365
|
|
287
366
|
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
367
|
+
|
368
|
+
async def send_survey(self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None):
|
369
|
+
"""发送问卷"""
|
370
|
+
survey = survey.to_dict()
|
371
|
+
if agent_uuids is None:
|
372
|
+
agent_uuids = self._agent_uuids
|
373
|
+
payload = {
|
374
|
+
"from": "none",
|
375
|
+
"survey_id": survey["id"],
|
376
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
377
|
+
"data": survey,
|
378
|
+
}
|
379
|
+
for uuid in agent_uuids:
|
380
|
+
topic = self._user_survey_topics[uuid]
|
381
|
+
await self._messager.send_message(topic, payload)
|
382
|
+
|
383
|
+
async def send_interview_message(self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]):
|
384
|
+
"""发送面试消息"""
|
385
|
+
payload = {
|
386
|
+
"from": "none",
|
387
|
+
"content": content,
|
388
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
389
|
+
}
|
390
|
+
for uuid in agent_uuids:
|
391
|
+
topic = self._user_chat_topics[uuid]
|
392
|
+
await self._messager.send_message(topic, payload)
|
288
393
|
|
289
394
|
async def step(self):
|
290
395
|
"""运行一步, 即每个智能体执行一次forward"""
|
@@ -297,23 +402,88 @@ class AgentSimulation:
|
|
297
402
|
logger.error(f"运行错误: {str(e)}")
|
298
403
|
raise
|
299
404
|
|
300
|
-
async def
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
405
|
+
async def _save_exp_info(self) -> None:
|
406
|
+
"""异步保存实验信息到YAML文件"""
|
407
|
+
try:
|
408
|
+
with open(self._exp_info_file, 'w') as f:
|
409
|
+
yaml.dump(self._exp_info, f)
|
410
|
+
except Exception as e:
|
411
|
+
logger.error(f"保存实验信息失败: {str(e)}")
|
412
|
+
|
413
|
+
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
414
|
+
"""更新实验状态并保存"""
|
415
|
+
self._exp_info["status"] = status
|
416
|
+
self._exp_info["error"] = error
|
417
|
+
await self._save_exp_info()
|
305
418
|
|
419
|
+
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
420
|
+
"""监控实验状态并更新
|
421
|
+
|
306
422
|
Args:
|
307
|
-
|
423
|
+
stop_event: 用于通知监控任务停止的事件
|
308
424
|
"""
|
309
425
|
try:
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
426
|
+
while not stop_event.is_set():
|
427
|
+
# 更新实验状态
|
428
|
+
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
429
|
+
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
430
|
+
self._exp_info["cur_t"] = await self._simulator.get_simulator_second_from_start_of_day()
|
431
|
+
await self._save_exp_info()
|
432
|
+
|
433
|
+
await asyncio.sleep(1) # 避免过于频繁的更新
|
434
|
+
except asyncio.CancelError:
|
435
|
+
# 正常取消,不需要特殊处理
|
436
|
+
pass
|
437
|
+
except Exception as e:
|
438
|
+
logger.error(f"监控实验状态时发生错误: {str(e)}")
|
439
|
+
raise
|
314
440
|
|
315
|
-
|
441
|
+
async def run(
|
442
|
+
self,
|
443
|
+
day: int = 1,
|
444
|
+
):
|
445
|
+
"""运行模拟器"""
|
446
|
+
try:
|
447
|
+
self._exp_info["num_day"] += day
|
448
|
+
await self._update_exp_status(1) # 更新状态为运行中
|
449
|
+
|
450
|
+
# 创建停止事件
|
451
|
+
stop_event = asyncio.Event()
|
452
|
+
# 创建监控任务
|
453
|
+
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
454
|
+
|
455
|
+
try:
|
456
|
+
tasks = []
|
457
|
+
for group in self._groups.values():
|
458
|
+
tasks.append(group.run.remote())
|
459
|
+
|
460
|
+
# 等待所有group运行完成
|
461
|
+
await asyncio.gather(*tasks)
|
462
|
+
|
463
|
+
finally:
|
464
|
+
# 设置停止事件
|
465
|
+
stop_event.set()
|
466
|
+
# 等待监控任务结束
|
467
|
+
await monitor_task
|
468
|
+
|
469
|
+
# 运行成功后更新状态
|
470
|
+
await self._update_exp_status(2)
|
316
471
|
|
317
472
|
except Exception as e:
|
318
|
-
|
319
|
-
|
473
|
+
error_msg = f"模拟器运行错误: {str(e)}"
|
474
|
+
logger.error(error_msg)
|
475
|
+
await self._update_exp_status(3, error_msg)
|
476
|
+
raise e
|
477
|
+
|
478
|
+
async def __aenter__(self):
|
479
|
+
"""异步上下文管理器入口"""
|
480
|
+
return self
|
481
|
+
|
482
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
483
|
+
"""异步上下文管理器出口"""
|
484
|
+
if exc_type is not None:
|
485
|
+
# 如果发生异常,更新状态为错误
|
486
|
+
await self._update_exp_status(3, str(exc_val))
|
487
|
+
elif self._exp_info["status"] != 3:
|
488
|
+
# 如果没有发生异常且状态不是错误,则更新为完成
|
489
|
+
await self._update_exp_status(2)
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from typing import List, Dict, Optional
|
2
|
+
from datetime import datetime
|
3
|
+
import uuid
|
4
|
+
import json
|
5
|
+
from .models import Survey, Question, QuestionType, Page
|
6
|
+
|
7
|
+
|
8
|
+
class SurveyManager:
|
9
|
+
def __init__(self):
|
10
|
+
self._surveys: Dict[str, Survey] = {}
|
11
|
+
|
12
|
+
def create_survey(
|
13
|
+
self, title: str, description: str, pages: List[dict]
|
14
|
+
) -> Survey:
|
15
|
+
"""创建新问卷"""
|
16
|
+
survey_id = uuid.uuid4()
|
17
|
+
|
18
|
+
# 转换页面和问题数据
|
19
|
+
survey_pages = []
|
20
|
+
for page_data in pages:
|
21
|
+
questions = []
|
22
|
+
for q in page_data["elements"]:
|
23
|
+
question = Question(
|
24
|
+
name=q["name"],
|
25
|
+
title=q["title"],
|
26
|
+
type=QuestionType(q["type"]),
|
27
|
+
required=q.get("required", True),
|
28
|
+
choices=q.get("choices", []),
|
29
|
+
columns=q.get("columns", []),
|
30
|
+
rows=q.get("rows", []),
|
31
|
+
min_rating=q.get("min_rating", 1),
|
32
|
+
max_rating=q.get("max_rating", 5),
|
33
|
+
)
|
34
|
+
questions.append(question)
|
35
|
+
|
36
|
+
page = Page(
|
37
|
+
name=page_data["name"],
|
38
|
+
elements=questions
|
39
|
+
)
|
40
|
+
survey_pages.append(page)
|
41
|
+
|
42
|
+
survey = Survey(
|
43
|
+
id=survey_id,
|
44
|
+
title=title,
|
45
|
+
description=description,
|
46
|
+
pages=survey_pages,
|
47
|
+
)
|
48
|
+
|
49
|
+
self._surveys[str(survey_id)] = survey
|
50
|
+
return survey
|
51
|
+
|
52
|
+
def get_survey(self, survey_id: str) -> Optional[Survey]:
|
53
|
+
"""获取指定问卷"""
|
54
|
+
return self._surveys.get(survey_id)
|
55
|
+
|
56
|
+
def get_all_surveys(self) -> List[Survey]:
|
57
|
+
"""获取所有问卷"""
|
58
|
+
return list(self._surveys.values())
|
@@ -0,0 +1,120 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List, Dict, Optional
|
3
|
+
from datetime import datetime
|
4
|
+
from enum import Enum
|
5
|
+
import uuid
|
6
|
+
import json
|
7
|
+
|
8
|
+
|
9
|
+
class QuestionType(Enum):
|
10
|
+
TEXT = "text"
|
11
|
+
RADIO = "radiogroup"
|
12
|
+
CHECKBOX = "checkbox"
|
13
|
+
BOOLEAN = "boolean"
|
14
|
+
RATING = "rating"
|
15
|
+
MATRIX = "matrix"
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class Question:
|
20
|
+
name: str
|
21
|
+
title: str
|
22
|
+
type: QuestionType
|
23
|
+
choices: List[str] = field(default_factory=list)
|
24
|
+
columns: List[str] = field(default_factory=list)
|
25
|
+
rows: List[str] = field(default_factory=list)
|
26
|
+
min_rating: int = 1
|
27
|
+
max_rating: int = 5
|
28
|
+
|
29
|
+
def to_dict(self) -> dict:
|
30
|
+
base_dict = {
|
31
|
+
"type": self.type.value,
|
32
|
+
"name": self.name,
|
33
|
+
"title": self.title,
|
34
|
+
}
|
35
|
+
|
36
|
+
if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
|
37
|
+
base_dict["choices"] = self.choices
|
38
|
+
elif self.type == QuestionType.MATRIX:
|
39
|
+
base_dict["columns"] = self.columns
|
40
|
+
base_dict["rows"] = self.rows
|
41
|
+
elif self.type == QuestionType.RATING:
|
42
|
+
base_dict["min_rating"] = self.min_rating
|
43
|
+
base_dict["max_rating"] = self.max_rating
|
44
|
+
|
45
|
+
return base_dict
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class Page:
|
50
|
+
name: str
|
51
|
+
elements: List[Question]
|
52
|
+
|
53
|
+
def to_dict(self) -> dict:
|
54
|
+
return {
|
55
|
+
"name": self.name,
|
56
|
+
"elements": [q.to_dict() for q in self.elements]
|
57
|
+
}
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass
|
61
|
+
class Survey:
|
62
|
+
id: uuid.UUID
|
63
|
+
title: str
|
64
|
+
description: str
|
65
|
+
pages: List[Page]
|
66
|
+
responses: Dict[str, dict] = field(default_factory=dict)
|
67
|
+
created_at: datetime = field(default_factory=datetime.now)
|
68
|
+
|
69
|
+
def to_dict(self) -> dict:
|
70
|
+
return {
|
71
|
+
"id": str(self.id),
|
72
|
+
"title": self.title,
|
73
|
+
"description": self.description,
|
74
|
+
"pages": [p.to_dict() for p in self.pages],
|
75
|
+
"response_count": len(self.responses),
|
76
|
+
}
|
77
|
+
|
78
|
+
def to_json(self) -> str:
|
79
|
+
"""Convert the survey to a JSON string for MQTT transmission"""
|
80
|
+
survey_dict = {
|
81
|
+
"id": str(self.id),
|
82
|
+
"title": self.title,
|
83
|
+
"description": self.description,
|
84
|
+
"pages": [p.to_dict() for p in self.pages],
|
85
|
+
"responses": self.responses,
|
86
|
+
"created_at": self.created_at.isoformat()
|
87
|
+
}
|
88
|
+
return json.dumps(survey_dict)
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def from_json(cls, json_str: str) -> 'Survey':
|
92
|
+
"""Create a Survey instance from a JSON string"""
|
93
|
+
data = json.loads(json_str)
|
94
|
+
pages = [
|
95
|
+
Page(
|
96
|
+
name=p["name"],
|
97
|
+
elements=[
|
98
|
+
Question(
|
99
|
+
name=q["name"],
|
100
|
+
title=q["title"],
|
101
|
+
type=QuestionType(q["type"]),
|
102
|
+
required=q.get("required", True),
|
103
|
+
choices=q.get("choices", []),
|
104
|
+
columns=q.get("columns", []),
|
105
|
+
rows=q.get("rows", []),
|
106
|
+
min_rating=q.get("min_rating", 1),
|
107
|
+
max_rating=q.get("max_rating", 5)
|
108
|
+
) for q in p["elements"]
|
109
|
+
]
|
110
|
+
) for p in data["pages"]
|
111
|
+
]
|
112
|
+
|
113
|
+
return cls(
|
114
|
+
id=uuid.UUID(data["id"]),
|
115
|
+
title=data["title"],
|
116
|
+
description=data["description"],
|
117
|
+
pages=pages,
|
118
|
+
responses=data.get("responses", {}),
|
119
|
+
created_at=datetime.fromisoformat(data["created_at"])
|
120
|
+
)
|
pycityagent/utils/__init__.py
CHANGED
@@ -0,0 +1,7 @@
|
|
1
|
+
from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
|
2
|
+
from .survey_util import process_survey_for_llm
|
3
|
+
|
4
|
+
__all__ = [
|
5
|
+
"PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
|
6
|
+
"process_survey_for_llm"
|
7
|
+
]
|
@@ -0,0 +1,110 @@
|
|
1
|
+
PROFILE_SCHEMA = {
|
2
|
+
"doc": "Agent属性",
|
3
|
+
"name": "AgentProfile",
|
4
|
+
"namespace": "com.socialcity",
|
5
|
+
"type": "record",
|
6
|
+
"fields": [
|
7
|
+
{"name": "id", "type": "string"}, # uuid as string
|
8
|
+
{"name": "name", "type": "string"},
|
9
|
+
{"name": "gender", "type": "string"},
|
10
|
+
{"name": "age", "type": "float"},
|
11
|
+
{"name": "education", "type": "string"},
|
12
|
+
{"name": "skill", "type": "string"},
|
13
|
+
{"name": "occupation", "type": "string"},
|
14
|
+
{"name": "family_consumption", "type": "string"},
|
15
|
+
{"name": "consumption", "type": "string"},
|
16
|
+
{"name": "personality", "type": "string"},
|
17
|
+
{"name": "income", "type": "string"},
|
18
|
+
{"name": "currency", "type": "float"},
|
19
|
+
{"name": "residence", "type": "string"},
|
20
|
+
{"name": "race", "type": "string"},
|
21
|
+
{"name": "religion", "type": "string"},
|
22
|
+
{"name": "marital_status", "type": "string"},
|
23
|
+
],
|
24
|
+
}
|
25
|
+
|
26
|
+
DIALOG_SCHEMA = {
|
27
|
+
"doc": "Agent对话",
|
28
|
+
"name": "AgentDialog",
|
29
|
+
"namespace": "com.socialcity",
|
30
|
+
"type": "record",
|
31
|
+
"fields": [
|
32
|
+
{"name": "id", "type": "string"}, # uuid as string
|
33
|
+
{"name": "day", "type": "int"},
|
34
|
+
{"name": "t", "type": "float"},
|
35
|
+
{"name": "type", "type": "int"},
|
36
|
+
{"name": "speaker", "type": "string"},
|
37
|
+
{"name": "content", "type": "string"},
|
38
|
+
{
|
39
|
+
"name": "created_at",
|
40
|
+
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
41
|
+
},
|
42
|
+
],
|
43
|
+
}
|
44
|
+
|
45
|
+
STATUS_SCHEMA = {
|
46
|
+
"doc": "Agent状态",
|
47
|
+
"name": "AgentStatus",
|
48
|
+
"namespace": "com.socialcity",
|
49
|
+
"type": "record",
|
50
|
+
"fields": [
|
51
|
+
{"name": "id", "type": "string"}, # uuid as string
|
52
|
+
{"name": "day", "type": "int"},
|
53
|
+
{"name": "t", "type": "float"},
|
54
|
+
{"name": "lng", "type": "double"},
|
55
|
+
{"name": "lat", "type": "double"},
|
56
|
+
{"name": "parent_id", "type": "int"},
|
57
|
+
{"name": "action", "type": "string"},
|
58
|
+
{"name": "hungry", "type": "float"},
|
59
|
+
{"name": "tired", "type": "float"},
|
60
|
+
{"name": "safe", "type": "float"},
|
61
|
+
{"name": "social", "type": "float"},
|
62
|
+
{
|
63
|
+
"name": "created_at",
|
64
|
+
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
65
|
+
},
|
66
|
+
],
|
67
|
+
}
|
68
|
+
|
69
|
+
INSTITUTION_STATUS_SCHEMA = {
|
70
|
+
"doc": "Institution状态",
|
71
|
+
"name": "InstitutionStatus",
|
72
|
+
"namespace": "com.socialcity",
|
73
|
+
"type": "record",
|
74
|
+
"fields": [
|
75
|
+
{"name": "id", "type": "string"}, # uuid as string
|
76
|
+
{"name": "day", "type": "int"},
|
77
|
+
{"name": "t", "type": "float"},
|
78
|
+
{"name": "type", "type": "int"},
|
79
|
+
{"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
80
|
+
{"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
81
|
+
{"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
82
|
+
{"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
83
|
+
{"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
84
|
+
{"name": "inventory", "type": ["int", "null"]},
|
85
|
+
{"name": "price", "type": ["float", "null"]},
|
86
|
+
{"name": "interest_rate", "type": ["float", "null"]},
|
87
|
+
{"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
88
|
+
{"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
89
|
+
{"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
90
|
+
{"name": "customers", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
91
|
+
],
|
92
|
+
}
|
93
|
+
|
94
|
+
SURVEY_SCHEMA = {
|
95
|
+
"doc": "Agent问卷",
|
96
|
+
"name": "AgentSurvey",
|
97
|
+
"namespace": "com.socialcity",
|
98
|
+
"type": "record",
|
99
|
+
"fields": [
|
100
|
+
{"name": "id", "type": "string"}, # uuid as string
|
101
|
+
{"name": "day", "type": "int"},
|
102
|
+
{"name": "t", "type": "float"},
|
103
|
+
{"name": "survey_id", "type": "string"},
|
104
|
+
{"name": "result", "type": "string"},
|
105
|
+
{
|
106
|
+
"name": "created_at",
|
107
|
+
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
108
|
+
},
|
109
|
+
],
|
110
|
+
}
|