pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,23 +2,27 @@ import asyncio
2
2
  import json
3
3
  import logging
4
4
  import os
5
- from pathlib import Path
6
- import uuid
7
- from datetime import datetime
8
5
  import random
9
- from typing import Dict, List, Optional, Callable, Union,Any
10
- import fastavro
11
- from mosstool.map._map_util.const import AOI_START_ID
6
+ import uuid
7
+ from collections.abc import Sequence
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from datetime import datetime, timezone
10
+ from pathlib import Path
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
12
13
  import pycityproto.city.economy.v2.economy_pb2 as economyv2
14
+ import yaml
15
+ from mosstool.map._map_util.const import AOI_START_ID
16
+
17
+ from pycityagent.environment.simulator import Simulator
13
18
  from pycityagent.memory.memory import Memory
14
19
  from pycityagent.message.messager import Messager
15
20
  from pycityagent.survey import Survey
16
- from pycityagent.utils.avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
17
21
 
18
22
  from ..agent import Agent, InstitutionAgent
19
23
  from .agentgroup import AgentGroup
20
24
 
21
- logger = logging.getLogger(__name__)
25
+ logger = logging.getLogger("pycityagent")
22
26
 
23
27
 
24
28
  class AgentSimulation:
@@ -29,23 +33,28 @@ class AgentSimulation:
29
33
  agent_class: Union[type[Agent], list[type[Agent]]],
30
34
  config: dict,
31
35
  agent_prefix: str = "agent_",
36
+ exp_name: str = "default_experiment",
37
+ logging_level: int = logging.WARNING,
32
38
  ):
33
39
  """
34
40
  Args:
35
41
  agent_class: 智能体类
36
42
  config: 配置
37
43
  agent_prefix: 智能体名称前缀
44
+ exp_name: 实验名称
38
45
  """
39
- self.exp_id = uuid.uuid4()
46
+ self.exp_id = str(uuid.uuid4())
40
47
  if isinstance(agent_class, list):
41
48
  self.agent_class = agent_class
42
49
  else:
43
50
  self.agent_class = [agent_class]
51
+ self.logging_level = logging_level
44
52
  self.config = config
53
+ self._simulator = Simulator(config["simulator_request"])
45
54
  self.agent_prefix = agent_prefix
46
55
  self._agents: Dict[uuid.UUID, Agent] = {}
47
- self._groups: Dict[str, AgentGroup] = {}
48
- self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
56
+ self._groups: Dict[str, AgentGroup] = {} # type:ignore
57
+ self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {} # type:ignore
49
58
  self._agent_uuids: List[uuid.UUID] = []
50
59
  self._user_chat_topics: Dict[uuid.UUID, str] = {}
51
60
  self._user_survey_topics: Dict[uuid.UUID, str] = {}
@@ -61,13 +70,10 @@ class AgentSimulation:
61
70
  asyncio.create_task(self._messager.connect())
62
71
 
63
72
  self._enable_avro = config["storage"]["avro"]["enabled"]
64
- self._avro_path = Path(config["storage"]["avro"]["path"])
65
- self._avro_file = {
66
- "profile": self._avro_path / f"{self.exp_id}_profile.avro",
67
- "dialog": self._avro_path / f"{self.exp_id}_dialog.avro",
68
- "status": self._avro_path / f"{self.exp_id}_status.avro",
69
- "survey": self._avro_path / f"{self.exp_id}_survey.avro",
70
- }
73
+ if not self._enable_avro:
74
+ logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
75
+ self._avro_path = Path(config["storage"]["avro"]["path"]) / f"{self.exp_id}"
76
+ self._avro_path.mkdir(parents=True, exist_ok=True)
71
77
 
72
78
  self._enable_pgsql = config["storage"]["pgsql"]["enabled"]
73
79
  self._pgsql_host = config["storage"]["pgsql"]["host"]
@@ -76,22 +82,56 @@ class AgentSimulation:
76
82
  self._pgsql_user = config["storage"]["pgsql"]["user"]
77
83
  self._pgsql_password = config["storage"]["pgsql"]["password"]
78
84
 
85
+ # 添加实验信息相关的属性
86
+ self._exp_info = {
87
+ "id": self.exp_id,
88
+ "name": exp_name,
89
+ "num_day": 0, # 将在 run 方法中更新
90
+ "status": 0,
91
+ "cur_day": 0,
92
+ "cur_t": 0.0,
93
+ "config": json.dumps(config),
94
+ "error": "",
95
+ "created_at": datetime.now(timezone.utc).isoformat(),
96
+ }
97
+
98
+ # 创建异步任务保存实验信息
99
+ self._exp_info_file = self._avro_path / "experiment_info.yaml"
100
+ with open(self._exp_info_file, "w") as f:
101
+ yaml.dump(self._exp_info, f)
102
+
79
103
  @property
80
104
  def agents(self):
81
105
  return self._agents
82
-
106
+
83
107
  @property
84
108
  def groups(self):
85
109
  return self._groups
86
-
110
+
87
111
  @property
88
112
  def agent_uuids(self):
89
113
  return self._agent_uuids
90
-
114
+
91
115
  @property
92
116
  def agent_uuid2group(self):
93
117
  return self._agent_uuid2group
94
118
 
119
+ def create_remote_group(
120
+ self,
121
+ group_name: str,
122
+ agents: list[Agent],
123
+ config: dict,
124
+ exp_id: str,
125
+ enable_avro: bool,
126
+ avro_path: Path,
127
+ logging_level: int = logging.WARNING,
128
+ ):
129
+ """创建远程组"""
130
+ group = AgentGroup.remote(
131
+ agents, config, exp_id, enable_avro, avro_path, logging_level
132
+ )
133
+ return group_name, group, agents
134
+
95
135
  async def init_agents(
96
136
  self,
97
137
  agent_count: Union[int, list[int]],
@@ -112,7 +152,7 @@ class AgentSimulation:
112
152
  raise ValueError("agent_class和agent_count的长度不一致")
113
153
 
114
154
  if memory_config_func is None:
115
- logging.warning(
155
+ logger.warning(
116
156
  "memory_config_func is None, using default memory config function"
117
157
  )
118
158
  memory_config_func = []
@@ -125,17 +165,21 @@ class AgentSimulation:
125
165
  memory_config_func = [memory_config_func]
126
166
 
127
167
  if len(memory_config_func) != len(agent_count):
128
- logging.warning(
168
+ logger.warning(
129
169
  "memory_config_func和agent_count的长度不一致,使用默认的memory_config"
130
170
  )
131
171
  memory_config_func = []
132
172
  for agent_class in self.agent_class:
133
- if agent_class == InstitutionAgent:
173
+ if issubclass(agent_class, InstitutionAgent):
134
174
  memory_config_func.append(self.default_memory_config_institution)
135
175
  else:
136
176
  memory_config_func.append(self.default_memory_config_citizen)
137
177
 
178
+ # 使用线程池并行创建 AgentGroup
179
+ group_creation_params = []
138
180
  class_init_index = 0
181
+
182
+ # 首先收集所有需要创建的组的参数
139
183
  for i in range(len(self.agent_class)):
140
184
  agent_class = self.agent_class[i]
141
185
  agent_count_i = agent_count[i]
@@ -145,15 +189,12 @@ class AgentSimulation:
145
189
 
146
190
  # 获取Memory配置
147
191
  extra_attributes, profile, base = memory_config_func_i()
148
- memory = Memory(
149
- config=extra_attributes, profile=profile, base=base
150
- )
192
+ memory = Memory(config=extra_attributes, profile=profile, base=base)
151
193
 
152
194
  # 创建智能体时传入Memory配置
153
195
  agent = agent_class(
154
196
  name=agent_name,
155
197
  memory=memory,
156
- avro_file=self._avro_file,
157
198
  )
158
199
 
159
200
  self._agents[agent._uuid] = agent
@@ -163,63 +204,52 @@ class AgentSimulation:
163
204
  num_group = (agent_count_i + group_size - 1) // group_size
164
205
 
165
206
  for k in range(num_group):
166
- # 计算当前组的起始和结束索引
167
207
  start_idx = class_init_index + k * group_size
168
208
  end_idx = min(
169
- class_init_index + start_idx + group_size,
209
+ class_init_index + (k + 1) * group_size, # 修正了索引计算
170
210
  class_init_index + agent_count_i,
171
211
  )
172
212
 
173
- # 获取当前组的agents
174
213
  agents = list(self._agents.values())[start_idx:end_idx]
175
214
  group_name = f"AgentType_{i}_Group_{k}"
176
- group = AgentGroup.remote(agents, self.config, self.exp_id, self._avro_file)
177
- self._groups[group_name] = group
178
- for agent in agents:
179
- self._agent_uuid2group[agent._uuid] = group
180
215
 
181
- class_init_index += agent_count_i # 更新类初始索引
216
+ # 收集创建参数
217
+ group_creation_params.append((group_name, agents))
218
+
219
+ class_init_index += agent_count_i
220
+
221
+ # 收集所有创建组的参数
222
+ creation_tasks = []
223
+ for group_name, agents in group_creation_params:
224
+ # 直接创建异步任务
225
+ group = AgentGroup.remote(
226
+ agents,
227
+ self.config,
228
+ self.exp_id,
229
+ self._enable_avro,
230
+ self._avro_path,
231
+ self.logging_level,
232
+ )
233
+ creation_tasks.append((group_name, group, agents))
234
+
235
+ # 更新数据结构
236
+ for group_name, group, agents in creation_tasks:
237
+ self._groups[group_name] = group
238
+ for agent in agents:
239
+ self._agent_uuid2group[agent._uuid] = group
182
240
 
241
+ # 并行初始化所有组的agents
183
242
  init_tasks = []
184
243
  for group in self._groups.values():
185
244
  init_tasks.append(group.init_agents.remote())
186
245
  await asyncio.gather(*init_tasks)
246
+
247
+ # 设置用户主题
187
248
  for uuid, agent in self._agents.items():
188
249
  self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
189
- self._user_survey_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-survey"
190
-
191
- # save profile
192
- if self._enable_avro:
193
- self._avro_path.mkdir(parents=True, exist_ok=True)
194
- # profile
195
- filename = self._avro_file["profile"]
196
- with open(filename, "wb") as f:
197
- profiles = []
198
- for agent in self._agents.values():
199
- profile = await agent.memory._profile.export()
200
- profile = profile[0]
201
- profile['id'] = str(agent._uuid)
202
- profiles.append(profile)
203
- fastavro.writer(f, PROFILE_SCHEMA, profiles)
204
-
205
- # dialog
206
- filename = self._avro_file["dialog"]
207
- with open(filename, "wb") as f:
208
- dialogs = []
209
- fastavro.writer(f, DIALOG_SCHEMA, dialogs)
210
-
211
- # status
212
- filename = self._avro_file["status"]
213
- with open(filename, "wb") as f:
214
- statuses = []
215
- fastavro.writer(f, STATUS_SCHEMA, statuses)
216
-
217
- # survey
218
- filename = self._avro_file["survey"]
219
- with open(filename, "wb") as f:
220
- surveys = []
221
- fastavro.writer(f, SURVEY_SCHEMA, surveys)
222
-
250
+ self._user_survey_topics[uuid] = (
251
+ f"exps/{self.exp_id}/agents/{uuid}/user-survey"
252
+ )
223
253
 
224
254
  async def gather(self, content: str):
225
255
  """收集智能体的特定信息"""
@@ -228,15 +258,26 @@ class AgentSimulation:
228
258
  gather_tasks.append(group.gather.remote(content))
229
259
  return await asyncio.gather(*gather_tasks)
230
260
 
231
- async def update(self, target_agent_id: str, target_key: str, content: Any):
261
+ async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
232
262
  """更新指定智能体的记忆"""
233
- group = self._agent_uuid2group[target_agent_id]
234
- await group.update.remote(target_agent_id, target_key, content)
263
+ group = self._agent_uuid2group[target_agent_uuid]
264
+ await group.update.remote(target_agent_uuid, target_key, content)
235
265
 
236
266
  def default_memory_config_institution(self):
237
267
  """默认的Memory配置函数"""
238
268
  EXTRA_ATTRIBUTES = {
239
- "type": (int, random.choice([economyv2.ORG_TYPE_BANK, economyv2.ORG_TYPE_GOVERNMENT, economyv2.ORG_TYPE_FIRM, economyv2.ORG_TYPE_NBS, economyv2.ORG_TYPE_UNSPECIFIED])),
269
+ "type": (
270
+ int,
271
+ random.choice(
272
+ [
273
+ economyv2.ORG_TYPE_BANK,
274
+ economyv2.ORG_TYPE_GOVERNMENT,
275
+ economyv2.ORG_TYPE_FIRM,
276
+ economyv2.ORG_TYPE_NBS,
277
+ economyv2.ORG_TYPE_UNSPECIFIED,
278
+ ]
279
+ ),
280
+ ),
240
281
  "nominal_gdp": (list, [], True),
241
282
  "real_gdp": (list, [], True),
242
283
  "unemployment": (list, [], True),
@@ -350,29 +391,35 @@ class AgentSimulation:
350
391
  }
351
392
 
352
393
  return EXTRA_ATTRIBUTES, PROFILE, BASE
353
-
354
- async def send_survey(self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None):
394
+
395
+ async def send_survey(
396
+ self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None
397
+ ):
355
398
  """发送问卷"""
356
- survey = survey.to_dict()
399
+ survey_dict = survey.to_dict()
357
400
  if agent_uuids is None:
358
401
  agent_uuids = self._agent_uuids
359
402
  payload = {
360
403
  "from": "none",
361
- "survey_id": survey["id"],
404
+ "survey_id": survey_dict["id"],
362
405
  "timestamp": int(datetime.now().timestamp() * 1000),
363
- "data": survey,
406
+ "data": survey_dict,
364
407
  }
365
408
  for uuid in agent_uuids:
366
409
  topic = self._user_survey_topics[uuid]
367
410
  await self._messager.send_message(topic, payload)
368
411
 
369
- async def send_interview_message(self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]):
412
+ async def send_interview_message(
413
+ self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]
414
+ ):
370
415
  """发送面试消息"""
371
416
  payload = {
372
417
  "from": "none",
373
418
  "content": content,
374
419
  "timestamp": int(datetime.now().timestamp() * 1000),
375
420
  }
421
+ if not isinstance(agent_uuids, Sequence):
422
+ agent_uuids = [agent_uuids]
376
423
  for uuid in agent_uuids:
377
424
  topic = self._user_chat_topics[uuid]
378
425
  await self._messager.send_message(topic, payload)
@@ -388,23 +435,90 @@ class AgentSimulation:
388
435
  logger.error(f"运行错误: {str(e)}")
389
436
  raise
390
437
 
438
+ async def _save_exp_info(self) -> None:
439
+ """异步保存实验信息到YAML文件"""
440
+ try:
441
+ with open(self._exp_info_file, "w") as f:
442
+ yaml.dump(self._exp_info, f)
443
+ except Exception as e:
444
+ logger.error(f"保存实验信息失败: {str(e)}")
445
+
446
+ async def _update_exp_status(self, status: int, error: str = "") -> None:
447
+ """更新实验状态并保存"""
448
+ self._exp_info["status"] = status
449
+ self._exp_info["error"] = error
450
+ await self._save_exp_info()
451
+
452
+ async def _monitor_exp_status(self, stop_event: asyncio.Event):
453
+ """监控实验状态并更新
454
+
455
+ Args:
456
+ stop_event: 用于通知监控任务停止的事件
457
+ """
458
+ try:
459
+ while not stop_event.is_set():
460
+ # 更新实验状态
461
+ # 假设所有group的cur_day和cur_t是同步的,取第一个即可
462
+ self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
463
+ self._exp_info["cur_t"] = (
464
+ await self._simulator.get_simulator_second_from_start_of_day()
465
+ )
466
+ await self._save_exp_info()
467
+
468
+ await asyncio.sleep(1) # 避免过于频繁的更新
469
+ except asyncio.CancelledError:
470
+ # 正常取消,不需要特殊处理
471
+ pass
472
+ except Exception as e:
473
+ logger.error(f"监控实验状态时发生错误: {str(e)}")
474
+ raise
475
+
391
476
  async def run(
392
477
  self,
393
478
  day: int = 1,
394
479
  ):
395
- """运行模拟器
396
-
397
- Args:
398
- day: 运行天数,默认为1天
399
- """
480
+ """运行模拟器"""
400
481
  try:
401
- # 获取开始时间
402
- tasks = []
403
- for group in self._groups.values():
404
- tasks.append(group.run.remote(day))
482
+ self._exp_info["num_day"] += day
483
+ await self._update_exp_status(1) # 更新状态为运行中
405
484
 
406
- await asyncio.gather(*tasks)
485
+ # 创建停止事件
486
+ stop_event = asyncio.Event()
487
+ # 创建监控任务
488
+ monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
489
+
490
+ try:
491
+ tasks = []
492
+ for group in self._groups.values():
493
+ tasks.append(group.run.remote())
494
+
495
+ # 等待所有group运行完成
496
+ await asyncio.gather(*tasks)
497
+
498
+ finally:
499
+ # 设置停止事件
500
+ stop_event.set()
501
+ # 等待监控任务结束
502
+ await monitor_task
503
+
504
+ # 运行成功后更新状态
505
+ await self._update_exp_status(2)
407
506
 
408
507
  except Exception as e:
409
- logger.error(f"模拟器运行错误: {str(e)}")
410
- raise
508
+ error_msg = f"模拟器运行错误: {str(e)}"
509
+ logger.error(error_msg)
510
+ await self._update_exp_status(3, error_msg)
511
+ raise e
512
+
513
+ async def __aenter__(self):
514
+ """异步上下文管理器入口"""
515
+ return self
516
+
517
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
518
+ """异步上下文管理器出口"""
519
+ if exc_type is not None:
520
+ # 如果发生异常,更新状态为错误
521
+ await self._update_exp_status(3, str(exc_val))
522
+ elif self._exp_info["status"] != 3:
523
+ # 如果没有发生异常且状态不是错误,则更新为完成
524
+ await self._update_exp_status(2)
@@ -1,17 +1,16 @@
1
- from typing import List, Dict, Optional
2
- from datetime import datetime
3
- import uuid
4
1
  import json
5
- from .models import Survey, Question, QuestionType, Page
2
+ import uuid
3
+ from datetime import datetime
4
+ from typing import Optional
5
+
6
+ from .models import Page, Question, QuestionType, Survey
6
7
 
7
8
 
8
9
  class SurveyManager:
9
10
  def __init__(self):
10
- self._surveys: Dict[str, Survey] = {}
11
+ self._surveys: dict[str, Survey] = {}
11
12
 
12
- def create_survey(
13
- self, title: str, description: str, pages: List[dict]
14
- ) -> Survey:
13
+ def create_survey(self, title: str, description: str, pages: list[dict]) -> Survey:
15
14
  """创建新问卷"""
16
15
  survey_id = uuid.uuid4()
17
16
 
@@ -32,11 +31,8 @@ class SurveyManager:
32
31
  max_rating=q.get("max_rating", 5),
33
32
  )
34
33
  questions.append(question)
35
-
36
- page = Page(
37
- name=page_data["name"],
38
- elements=questions
39
- )
34
+
35
+ page = Page(name=page_data["name"], elements=questions)
40
36
  survey_pages.append(page)
41
37
 
42
38
  survey = Survey(
@@ -53,6 +49,6 @@ class SurveyManager:
53
49
  """获取指定问卷"""
54
50
  return self._surveys.get(survey_id)
55
51
 
56
- def get_all_surveys(self) -> List[Survey]:
52
+ def get_all_surveys(self) -> list[Survey]:
57
53
  """获取所有问卷"""
58
54
  return list(self._surveys.values())
@@ -1,9 +1,9 @@
1
+ import json
2
+ import uuid
1
3
  from dataclasses import dataclass, field
2
- from typing import List, Dict, Optional
3
4
  from datetime import datetime
4
5
  from enum import Enum
5
- import uuid
6
- import json
6
+ from typing import Any
7
7
 
8
8
 
9
9
  class QuestionType(Enum):
@@ -20,19 +20,20 @@ class Question:
20
20
  name: str
21
21
  title: str
22
22
  type: QuestionType
23
- choices: List[str] = field(default_factory=list)
24
- columns: List[str] = field(default_factory=list)
25
- rows: List[str] = field(default_factory=list)
23
+ choices: list[str] = field(default_factory=list)
24
+ columns: list[str] = field(default_factory=list)
25
+ rows: list[str] = field(default_factory=list)
26
+ required: bool = True
26
27
  min_rating: int = 1
27
28
  max_rating: int = 5
28
29
 
29
30
  def to_dict(self) -> dict:
30
- base_dict = {
31
+ base_dict: dict[str, Any] = {
31
32
  "type": self.type.value,
32
33
  "name": self.name,
33
34
  "title": self.title,
34
35
  }
35
-
36
+
36
37
  if self.type in [QuestionType.RADIO, QuestionType.CHECKBOX]:
37
38
  base_dict["choices"] = self.choices
38
39
  elif self.type == QuestionType.MATRIX:
@@ -41,20 +42,17 @@ class Question:
41
42
  elif self.type == QuestionType.RATING:
42
43
  base_dict["min_rating"] = self.min_rating
43
44
  base_dict["max_rating"] = self.max_rating
44
-
45
+
45
46
  return base_dict
46
47
 
47
48
 
48
49
  @dataclass
49
50
  class Page:
50
51
  name: str
51
- elements: List[Question]
52
+ elements: list[Question]
52
53
 
53
54
  def to_dict(self) -> dict:
54
- return {
55
- "name": self.name,
56
- "elements": [q.to_dict() for q in self.elements]
57
- }
55
+ return {"name": self.name, "elements": [q.to_dict() for q in self.elements]}
58
56
 
59
57
 
60
58
  @dataclass
@@ -62,8 +60,8 @@ class Survey:
62
60
  id: uuid.UUID
63
61
  title: str
64
62
  description: str
65
- pages: List[Page]
66
- responses: Dict[str, dict] = field(default_factory=dict)
63
+ pages: list[Page]
64
+ responses: dict[str, dict] = field(default_factory=dict)
67
65
  created_at: datetime = field(default_factory=datetime.now)
68
66
 
69
67
  def to_dict(self) -> dict:
@@ -83,12 +81,12 @@ class Survey:
83
81
  "description": self.description,
84
82
  "pages": [p.to_dict() for p in self.pages],
85
83
  "responses": self.responses,
86
- "created_at": self.created_at.isoformat()
84
+ "created_at": self.created_at.isoformat(),
87
85
  }
88
86
  return json.dumps(survey_dict)
89
87
 
90
88
  @classmethod
91
- def from_json(cls, json_str: str) -> 'Survey':
89
+ def from_json(cls, json_str: str) -> "Survey":
92
90
  """Create a Survey instance from a JSON string"""
93
91
  data = json.loads(json_str)
94
92
  pages = [
@@ -104,17 +102,19 @@ class Survey:
104
102
  columns=q.get("columns", []),
105
103
  rows=q.get("rows", []),
106
104
  min_rating=q.get("min_rating", 1),
107
- max_rating=q.get("max_rating", 5)
108
- ) for q in p["elements"]
109
- ]
110
- ) for p in data["pages"]
105
+ max_rating=q.get("max_rating", 5),
106
+ )
107
+ for q in p["elements"]
108
+ ],
109
+ )
110
+ for p in data["pages"]
111
111
  ]
112
-
112
+
113
113
  return cls(
114
114
  id=uuid.UUID(data["id"]),
115
115
  title=data["title"],
116
116
  description=data["description"],
117
117
  pages=pages,
118
118
  responses=data.get("responses", {}),
119
- created_at=datetime.fromisoformat(data["created_at"])
119
+ created_at=datetime.fromisoformat(data["created_at"]),
120
120
  )
@@ -1,7 +1,7 @@
1
- from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
1
+ from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
2
2
  from .survey_util import process_survey_for_llm
3
3
 
4
4
  __all__ = [
5
- "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA",
5
+ "PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
6
6
  "process_survey_for_llm"
7
7
  ]
@@ -66,6 +66,31 @@ STATUS_SCHEMA = {
66
66
  ],
67
67
  }
68
68
 
69
+ INSTITUTION_STATUS_SCHEMA = {
70
+ "doc": "Institution状态",
71
+ "name": "InstitutionStatus",
72
+ "namespace": "com.socialcity",
73
+ "type": "record",
74
+ "fields": [
75
+ {"name": "id", "type": "string"}, # uuid as string
76
+ {"name": "day", "type": "int"},
77
+ {"name": "t", "type": "float"},
78
+ {"name": "type", "type": "int"},
79
+ {"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
80
+ {"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
81
+ {"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
82
+ {"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
83
+ {"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
84
+ {"name": "inventory", "type": ["int", "null"]},
85
+ {"name": "price", "type": ["float", "null"]},
86
+ {"name": "interest_rate", "type": ["float", "null"]},
87
+ {"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
88
+ {"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
89
+ {"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
90
+ {"name": "customers", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
91
+ ],
92
+ }
93
+
69
94
  SURVEY_SCHEMA = {
70
95
  "doc": "Agent问卷",
71
96
  "name": "AgentSurvey",
@@ -82,4 +107,4 @@ SURVEY_SCHEMA = {
82
107
  "type": {"type": "long", "logicalType": "timestamp-millis"},
83
108
  },
84
109
  ],
85
- }
110
+ }