pycityagent 2.0.0a6__py3-none-any.whl → 2.0.0a8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. pycityagent/agent.py +29 -5
  2. pycityagent/environment/interact/interact.py +86 -29
  3. pycityagent/environment/sence/static.py +3 -2
  4. pycityagent/environment/sim/aoi_service.py +1 -1
  5. pycityagent/environment/sim/economy_services.py +1 -1
  6. pycityagent/environment/sim/road_service.py +1 -1
  7. pycityagent/environment/sim/social_service.py +1 -1
  8. pycityagent/environment/simulator.py +6 -4
  9. pycityagent/environment/utils/__init__.py +5 -1
  10. pycityagent/llm/__init__.py +1 -1
  11. pycityagent/llm/embedding.py +36 -35
  12. pycityagent/llm/llm.py +197 -161
  13. pycityagent/llm/llmconfig.py +7 -9
  14. pycityagent/llm/utils.py +2 -2
  15. pycityagent/memory/memory.py +1 -2
  16. pycityagent/memory/memory_base.py +1 -2
  17. pycityagent/memory/profile.py +1 -2
  18. pycityagent/memory/self_define.py +1 -2
  19. pycityagent/memory/state.py +1 -2
  20. pycityagent/message/__init__.py +1 -1
  21. pycityagent/message/messager.py +11 -4
  22. pycityagent/simulation/__init__.py +1 -1
  23. pycityagent/simulation/agentgroup.py +39 -11
  24. pycityagent/simulation/interview.py +9 -5
  25. pycityagent/simulation/simulation.py +181 -61
  26. pycityagent/simulation/survey/__init__.py +1 -6
  27. pycityagent/simulation/survey/manager.py +22 -21
  28. pycityagent/simulation/survey/models.py +8 -5
  29. pycityagent/utils/decorators.py +14 -4
  30. pycityagent/utils/parsers/__init__.py +2 -1
  31. pycityagent/workflow/block.py +4 -3
  32. pycityagent/workflow/prompt.py +16 -9
  33. pycityagent/workflow/tool.py +1 -2
  34. pycityagent/workflow/trigger.py +36 -23
  35. {pycityagent-2.0.0a6.dist-info → pycityagent-2.0.0a8.dist-info}/METADATA +1 -1
  36. pycityagent-2.0.0a8.dist-info/RECORD +70 -0
  37. pycityagent-2.0.0a6.dist-info/RECORD +0 -70
  38. {pycityagent-2.0.0a6.dist-info → pycityagent-2.0.0a8.dist-info}/WHEEL +0 -0
@@ -3,8 +3,7 @@ Agent Profile
3
3
  """
4
4
 
5
5
  from copy import deepcopy
6
- from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple,
7
- Union, cast)
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
8
7
 
9
8
  from ..utils.decorators import lock_decorator
10
9
  from .const import *
@@ -3,8 +3,7 @@ Self Define Data
3
3
  """
4
4
 
5
5
  from copy import deepcopy
6
- from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple,
7
- Union, cast)
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
8
7
 
9
8
  from ..utils.decorators import lock_decorator
10
9
  from .const import *
@@ -3,8 +3,7 @@ Agent State
3
3
  """
4
4
 
5
5
  from copy import deepcopy
6
- from typing import (Any, Callable, Dict, List, Optional, Sequence, Tuple,
7
- Union, cast)
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
8
7
 
9
8
  from ..utils.decorators import lock_decorator
10
9
  from .const import *
@@ -1,3 +1,3 @@
1
1
  from .messager import Messager
2
2
 
3
- __all__ = ["Messager"]
3
+ __all__ = ["Messager"]
@@ -4,9 +4,14 @@ import logging
4
4
  import math
5
5
  from aiomqtt import Client
6
6
 
7
+
7
8
  class Messager:
8
- def __init__(self, broker, port=1883, timeout=math.inf):
9
- self.client = Client(broker, port=port, timeout=timeout)
9
+ def __init__(
10
+ self, hostname, port=1883, username=None, password=None, timeout=math.inf
11
+ ):
12
+ self.client = Client(
13
+ hostname, port=port, username=username, password=password, timeout=timeout
14
+ )
10
15
  self.connected = False # 是否已连接标志
11
16
  self.message_queue = asyncio.Queue() # 用于存储接收到的消息
12
17
  self.subscribers = {} # 订阅者信息,topic -> Agent 映射
@@ -31,7 +36,9 @@ class Messager:
31
36
 
32
37
  async def subscribe(self, topic, agent):
33
38
  if not self.is_connected():
34
- logging.error(f"Cannot subscribe to {topic} because not connected to the Broker.")
39
+ logging.error(
40
+ f"Cannot subscribe to {topic} because not connected to the Broker."
41
+ )
35
42
  return
36
43
  await self.client.subscribe(topic)
37
44
  self.subscribers[topic] = agent
@@ -48,7 +55,7 @@ class Messager:
48
55
  while not self.message_queue.empty():
49
56
  messages.append(await self.message_queue.get())
50
57
  return messages
51
-
58
+
52
59
  async def send_message(self, topic: str, payload: str, sender_id: int):
53
60
  """通过 Messager 发送消息,包含发送者 ID"""
54
61
  # 构造消息,payload 中加入 sender_id 以便接收者识别
@@ -4,4 +4,4 @@
4
4
 
5
5
  from .simulation import AgentSimulation
6
6
 
7
- __all__ = ["AgentSimulation"]
7
+ __all__ = ["AgentSimulation"]
@@ -8,15 +8,21 @@ from pycityagent.llm.llm import LLM
8
8
  from pycityagent.llm.llmconfig import LLMConfig
9
9
  from pycityagent.message import Messager
10
10
 
11
+
11
12
  @ray.remote
12
13
  class AgentGroup:
13
14
  def __init__(self, agents: list[Agent], config: dict, exp_id: str):
14
15
  self.agents = agents
15
16
  self.config = config
16
17
  self.exp_id = exp_id
17
- self.messager = Messager(config["simulator_request"]["mqtt"]["server"], config["simulator_request"]["mqtt"]["port"])
18
+ self.messager = Messager(
19
+ hostname=config["simulator_request"]["mqtt"]["server"],
20
+ port=config["simulator_request"]["mqtt"]["port"],
21
+ username=config["simulator_request"]["mqtt"].get("username", None),
22
+ password=config["simulator_request"]["mqtt"].get("password", None),
23
+ )
18
24
  self.initialized = False
19
-
25
+ self.id2agent = {}
20
26
  # Step:1 prepare LLM client
21
27
  llmConfig = LLMConfig(config["llm_request"])
22
28
  logging.info("-----Creating LLM client in remote...")
@@ -27,8 +33,13 @@ class AgentGroup:
27
33
  self.simulator = Simulator(config["simulator_request"])
28
34
 
29
35
  # Step:3 prepare Economy client
30
- logging.info("-----Creating Economy client in remote...")
31
- self.economy_client = EconomyClient(config["simulator_request"]["economy"]['server'])
36
+ if "economy" in config["simulator_request"]:
37
+ logging.info("-----Creating Economy client in remote...")
38
+ self.economy_client = EconomyClient(
39
+ config["simulator_request"]["economy"]["server"]
40
+ )
41
+ else:
42
+ self.economy_client = None
32
43
 
33
44
  for agent in self.agents:
34
45
  agent.set_exp_id(self.exp_id)
@@ -48,8 +59,20 @@ class AgentGroup:
48
59
  agent.set_messager(self.messager)
49
60
  topic = f"/exps/{self.exp_id}/agents/{agent._agent_id}/chat"
50
61
  await self.messager.subscribe(topic, agent)
62
+ topic = f"/exps/{self.exp_id}/agents/{agent._agent_id}/gather"
63
+ await self.messager.subscribe(topic, agent)
51
64
  self.initialized = True
52
65
 
66
+ async def gather(self, content: str):
67
+ results = {}
68
+ for agent in self.agents:
69
+ results[agent._agent_id] = await agent.memory.get(content)
70
+ return results
71
+
72
+ async def update(self, target_agent_id: str, target_key: str, content: any):
73
+ agent = self.id2agent[target_agent_id]
74
+ await agent.memory.update(target_key, content)
75
+
53
76
  async def step(self):
54
77
  if not self.initialized:
55
78
  await self.init_agents()
@@ -65,6 +88,8 @@ class AgentGroup:
65
88
  # Step 2: 从 Messager 获取消息
66
89
  messages = await self.messager.fetch_messages()
67
90
 
91
+ print(f"Received {len(messages)} messages")
92
+
68
93
  # Step 3: 分发消息到对应的 Agent
69
94
  for message in messages:
70
95
  topic = message.topic.value
@@ -72,14 +97,18 @@ class AgentGroup:
72
97
 
73
98
  # 添加解码步骤,将bytes转换为str
74
99
  if isinstance(payload, bytes):
75
- payload = payload.decode('utf-8')
76
- # 提取 agent_id(主题格式为 "/exps/{exp_id}/agents/{agent_id}/chat")
77
- _, _, _, agent_id, _ = topic.strip("/").split("/")
100
+ payload = payload.decode("utf-8")
101
+
102
+ # 提取 agent_id(主题格式为 "/exps/{exp_id}/agents/{agent_id}/chat""/exps/{exp_id}/agents/{agent_id}/gather"
103
+ _, _, _, agent_id, topic_type = topic.strip("/").split("/")
78
104
  agent_id = int(agent_id)
79
105
 
80
106
  if agent_id in self.id2agent:
81
107
  agent = self.id2agent[agent_id]
82
- await agent.handle_message(payload)
108
+ if topic_type == "chat":
109
+ await agent.handle_message(payload)
110
+ elif topic_type == "gather":
111
+ await agent.handle_gather_message(payload)
83
112
 
84
113
  # Step 4: 调用每个 Agent 的运行逻辑
85
114
  tasks = [agent.run() for agent in self.agents]
@@ -96,15 +125,14 @@ class AgentGroup:
96
125
  start_time = await self.simulator.get_time()
97
126
  # 计算结束时间(秒)
98
127
  end_time = start_time + day * 24 * 3600 # 将天数转换为秒
99
-
128
+
100
129
  while True:
101
130
  current_time = await self.simulator.get_time()
102
131
  if current_time >= end_time:
103
132
  break
104
-
133
+
105
134
  await self.step()
106
135
 
107
136
  except Exception as e:
108
137
  logging.error(f"模拟器运行错误: {str(e)}")
109
138
  raise
110
-
@@ -2,20 +2,24 @@ from dataclasses import dataclass
2
2
  from datetime import datetime
3
3
  from typing import List, Optional
4
4
 
5
+
5
6
  @dataclass
6
7
  class InterviewRecord:
7
8
  """采访记录"""
9
+
8
10
  timestamp: datetime
9
11
  agent_name: str
10
12
  question: str
11
13
  response: str
12
14
  blocking: bool
13
15
 
16
+
14
17
  class InterviewManager:
15
18
  """采访管理器"""
19
+
16
20
  def __init__(self):
17
21
  self._history: List[InterviewRecord] = []
18
-
22
+
19
23
  def add_record(self, agent_name: str, question: str, response: str, blocking: bool):
20
24
  """添加采访记录"""
21
25
  record = InterviewRecord(
@@ -23,14 +27,14 @@ class InterviewManager:
23
27
  agent_name=agent_name,
24
28
  question=question,
25
29
  response=response,
26
- blocking=blocking
30
+ blocking=blocking,
27
31
  )
28
32
  self._history.append(record)
29
-
33
+
30
34
  def get_agent_history(self, agent_name: str) -> List[InterviewRecord]:
31
35
  """获取指定智能体的采访历史"""
32
36
  return [r for r in self._history if r.agent_name == agent_name]
33
-
37
+
34
38
  def get_recent_history(self, limit: int = 10) -> List[InterviewRecord]:
35
39
  """获取最近的采访记录"""
36
- return sorted(self._history, key=lambda x: x.timestamp, reverse=True)[:limit]
40
+ return sorted(self._history, key=lambda x: x.timestamp, reverse=True)[:limit]
@@ -4,7 +4,7 @@ import logging
4
4
  import uuid
5
5
  from datetime import datetime
6
6
  import random
7
- from typing import Dict, List, Optional, Callable
7
+ from typing import Dict, List, Optional, Callable, Union
8
8
  from mosstool.map._map_util.const import AOI_START_ID
9
9
 
10
10
  from pycityagent.memory.memory import Memory
@@ -14,12 +14,19 @@ from .interview import InterviewManager
14
14
  from .survey import QuestionType, SurveyManager
15
15
  from .ui import InterviewUI
16
16
  from .agentgroup import AgentGroup
17
+
17
18
  logger = logging.getLogger(__name__)
18
19
 
19
20
 
20
21
  class AgentSimulation:
21
22
  """城市智能体模拟器"""
22
- def __init__(self, agent_class: type[Agent], config: dict, agent_prefix: str = "agent_"):
23
+
24
+ def __init__(
25
+ self,
26
+ agent_class: Union[type[Agent], list[type[Agent]]],
27
+ config: dict,
28
+ agent_prefix: str = "agent_",
29
+ ):
23
30
  """
24
31
  Args:
25
32
  agent_class: 智能体类
@@ -27,7 +34,10 @@ class AgentSimulation:
27
34
  agent_prefix: 智能体名称前缀
28
35
  """
29
36
  self.exp_id = uuid.uuid4()
30
- self.agent_class = agent_class
37
+ if isinstance(agent_class, list):
38
+ self.agent_class = agent_class
39
+ else:
40
+ self.agent_class = [agent_class]
31
41
  self.config = config
32
42
  self.agent_prefix = agent_prefix
33
43
  self._agents: Dict[str, Agent] = {}
@@ -40,88 +50,200 @@ class AgentSimulation:
40
50
  self._loop = asyncio.get_event_loop()
41
51
  self._blocked_agents: List[str] = [] # 新增:持续阻塞的智能体列表
42
52
  self._survey_manager = SurveyManager()
53
+ self._agentid2group: Dict[str, AgentGroup] = {}
54
+ self._agent_ids: List[str] = []
43
55
 
44
- async def init_agents(self, agent_count: int, group_size: int = 1000, memory_config_func: Callable = None) -> None:
56
+ async def init_agents(
57
+ self,
58
+ agent_count: Union[int, list[int]],
59
+ group_size: int = 1000,
60
+ memory_config_func: Union[Callable, list[Callable]] = None,
61
+ ) -> None:
45
62
  """初始化智能体
46
-
63
+
47
64
  Args:
48
- agent_count: 要创建的总智能体数量
65
+ agent_count: 要创建的总智能体数量, 如果为列表,则每个元素表示一个智能体类创建的智能体数量
49
66
  group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
50
- memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组
67
+ memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组, 如果为列表,则每个元素表示一个智能体类创建的Memory配置函数
51
68
  """
52
- if memory_config_func is None:
53
- memory_config_func = self.default_memory_config_func
69
+ if not isinstance(agent_count, list):
70
+ agent_count = [agent_count]
54
71
 
55
- for i in range(agent_count):
56
- agent_name = f"{self.agent_prefix}{i}"
72
+ if len(self.agent_class) != len(agent_count):
73
+ raise ValueError("agent_class和agent_count的长度不一致")
57
74
 
58
- # 获取Memory配置
59
- extra_attributes, profile, base = memory_config_func()
60
- memory = Memory(
61
- config=extra_attributes,
62
- profile=profile.copy(),
63
- base=base.copy()
64
- )
65
-
66
- # 创建智能体时传入Memory配置
67
- agent = self.agent_class(
68
- name=agent_name,
69
- memory=memory,
75
+ if memory_config_func is None:
76
+ logging.warning(
77
+ "memory_config_func is None, using default memory config function"
70
78
  )
79
+ memory_config_func = [self.default_memory_config_func]
80
+ elif not isinstance(memory_config_func, list):
81
+ memory_config_func = [memory_config_func]
71
82
 
72
- self._agents[agent_name] = agent
73
-
74
- # 计算需要的组数,向上取整以处理不足一组的情况
75
- num_group = (agent_count + group_size - 1) // group_size
76
-
77
- for i in range(num_group):
78
- # 计算当前组的起始和结束索引
79
- start_idx = i * group_size
80
- end_idx = min((i + 1) * group_size, agent_count)
81
-
82
- # 获取当前组的agents
83
- agents = list(self._agents.values())[start_idx:end_idx]
84
- group_name = f"{self.agent_prefix}_group_{i}"
85
- group = AgentGroup.remote(agents, self.config, self.exp_id)
86
- self._groups[group_name] = group
83
+ if len(memory_config_func) != len(agent_count):
84
+ logging.warning(
85
+ "memory_config_func和agent_count的长度不一致,使用默认的memory_config_func"
86
+ )
87
+ memory_config_func = [self.default_memory_config_func] * len(agent_count)
88
+
89
+ class_init_index = 0
90
+ for i in range(len(self.agent_class)):
91
+ agent_class = self.agent_class[i]
92
+ agent_count_i = agent_count[i]
93
+ memory_config_func_i = memory_config_func[i]
94
+ for j in range(agent_count_i):
95
+ agent_name = f"{self.agent_prefix}_{i}_{j}"
96
+
97
+ # 获取Memory配置
98
+ extra_attributes, profile, base = memory_config_func_i()
99
+ memory = Memory(
100
+ config=extra_attributes, profile=profile.copy(), base=base.copy()
101
+ )
102
+
103
+ # 创建智能体时传入Memory配置
104
+ agent = agent_class(
105
+ name=agent_name,
106
+ memory=memory,
107
+ )
108
+
109
+ self._agents[agent_name] = agent
110
+
111
+ # 计算需要的组数,向上取整以处理不足一组的情况
112
+ num_group = (agent_count_i + group_size - 1) // group_size
113
+
114
+ for k in range(num_group):
115
+ # 计算当前组的起始和结束索引
116
+ start_idx = class_init_index + k * group_size
117
+ end_idx = min(
118
+ class_init_index + start_idx + group_size,
119
+ class_init_index + agent_count_i,
120
+ )
121
+
122
+ # 获取当前组的agents
123
+ agents = list(self._agents.values())[start_idx:end_idx]
124
+ group_name = f"{self.agent_prefix}_{i}_group_{k}"
125
+ group = AgentGroup.remote(agents, self.config, self.exp_id)
126
+ self._groups[group_name] = group
127
+
128
+ class_init_index += agent_count_i # 更新类初始索引
129
+
130
+ init_tasks = []
131
+ for group in self._groups.values():
132
+ init_tasks.append(group.init_agents.remote())
133
+ await asyncio.gather(*init_tasks)
134
+
135
+ for group in self._groups.values():
136
+ agent_ids = await group.gather.remote("id")
137
+ for agent_id in agent_ids:
138
+ self._agent_ids.append(agent_id)
139
+ self._agentid2group[agent_id] = group
140
+
141
+ async def gather(self, content: str):
142
+ """收集所有智能体的ID"""
143
+ gather_tasks = []
144
+ for group in self._groups.values():
145
+ gather_tasks.append(group.gather.remote(content))
146
+ return await asyncio.gather(*gather_tasks)
147
+
148
+ async def update(self, target_agent_id: str, target_key: str, content: any):
149
+ """更新指定智能体的记忆"""
150
+ group = self._agentid2group[target_agent_id]
151
+ await group.update.remote(target_agent_id, target_key, content)
87
152
 
88
153
  def default_memory_config_func(self):
89
154
  """默认的Memory配置函数"""
90
155
  EXTRA_ATTRIBUTES = {
91
156
  # 需求信息
92
- "needs": (dict, {
93
- 'hungry': random.random(), # 饥饿感
94
- 'tired': random.random(), # 疲劳感
95
- 'safe': random.random(), # 安全需
96
- 'social': random.random(), # 社会需求
97
- }, True),
157
+ "needs": (
158
+ dict,
159
+ {
160
+ "hungry": random.random(), # 饥饿感
161
+ "tired": random.random(), # 疲劳感
162
+ "safe": random.random(), # 安全需
163
+ "social": random.random(), # 社会需求
164
+ },
165
+ True,
166
+ ),
98
167
  "current_need": (str, "none", True),
99
168
  "current_plan": (list, [], True),
100
169
  "current_step": (dict, {"intention": "", "type": ""}, True),
101
- "execution_context" : (dict, {}, True),
170
+ "execution_context": (dict, {}, True),
102
171
  "plan_history": (list, [], True),
172
+ # cognition
173
+ "fulfillment": (int, 5, True),
174
+ "emotion": (int, 5, True),
175
+ "attitude": (int, 5, True),
176
+ "thought": (str, "Currently nothing good or bad is happening", True),
177
+ "emotion_types": (str, "Relief", True),
178
+ "incident": (list, [], True),
179
+ # social
180
+ "friends": (list, [], True),
103
181
  }
104
182
 
105
183
  PROFILE = {
106
184
  "gender": random.choice(["male", "female"]),
107
- "education": random.choice(["Doctor", "Master", "Bachelor", "College", "High School"]),
185
+ "education": random.choice(
186
+ ["Doctor", "Master", "Bachelor", "College", "High School"]
187
+ ),
108
188
  "consumption": random.choice(["sightly low", "low", "medium", "high"]),
109
- "occupation": random.choice(["Student", "Teacher", "Doctor", "Engineer", "Manager", "Businessman", "Artist", "Athlete", "Other"]),
189
+ "occupation": random.choice(
190
+ [
191
+ "Student",
192
+ "Teacher",
193
+ "Doctor",
194
+ "Engineer",
195
+ "Manager",
196
+ "Businessman",
197
+ "Artist",
198
+ "Athlete",
199
+ "Other",
200
+ ]
201
+ ),
110
202
  "age": random.randint(18, 65),
111
- "skill": random.choice(["Good at problem-solving", "Good at communication", "Good at creativity", "Good at teamwork", "Other"]),
203
+ "skill": random.choice(
204
+ [
205
+ "Good at problem-solving",
206
+ "Good at communication",
207
+ "Good at creativity",
208
+ "Good at teamwork",
209
+ "Other",
210
+ ]
211
+ ),
112
212
  "family_consumption": random.choice(["low", "medium", "high"]),
113
- "personality": random.choice(["outgoint", "introvert", "ambivert", "extrovert"]),
213
+ "personality": random.choice(
214
+ ["outgoint", "introvert", "ambivert", "extrovert"]
215
+ ),
114
216
  "income": random.randint(1000, 10000),
115
217
  "currency": random.randint(10000, 100000),
116
218
  "residence": random.choice(["city", "suburb", "rural"]),
117
- "race": random.choice(["Chinese", "American", "British", "French", "German", "Japanese", "Korean", "Russian", "Other"]),
118
- "religion": random.choice(["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]),
119
- "marital_status": random.choice(["not married", "married", "divorced", "widowed"]),
120
- }
219
+ "race": random.choice(
220
+ [
221
+ "Chinese",
222
+ "American",
223
+ "British",
224
+ "French",
225
+ "German",
226
+ "Japanese",
227
+ "Korean",
228
+ "Russian",
229
+ "Other",
230
+ ]
231
+ ),
232
+ "religion": random.choice(
233
+ ["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]
234
+ ),
235
+ "marital_status": random.choice(
236
+ ["not married", "married", "divorced", "widowed"]
237
+ ),
238
+ }
121
239
 
122
240
  BASE = {
123
- "home": {"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}},
124
- "work": {"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}},
241
+ "home": {
242
+ "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
243
+ },
244
+ "work": {
245
+ "aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}
246
+ },
125
247
  }
126
248
 
127
249
  return EXTRA_ATTRIBUTES, PROFILE, BASE
@@ -218,7 +340,7 @@ class AgentSimulation:
218
340
  except Exception as e:
219
341
  logger.error(f"采访过程出错: {str(e)}")
220
342
  return f"采访过程出现错误: {str(e)}"
221
-
343
+
222
344
  async def submit_survey(self, agent_name: str, survey_id: str) -> str:
223
345
  """向智能体提交问卷
224
346
 
@@ -319,9 +441,7 @@ class AgentSimulation:
319
441
  prevent_thread_lock=True,
320
442
  quiet=True,
321
443
  )
322
- logger.info(
323
- f"Gradio Frontend is running on http://{server_name}:{server_port}"
324
- )
444
+ logger.info(f"Gradio Frontend is running on http://{server_name}:{server_port}")
325
445
 
326
446
  async def step(self):
327
447
  """运行一步, 即每个智能体执行一次forward"""
@@ -333,7 +453,7 @@ class AgentSimulation:
333
453
  except Exception as e:
334
454
  logger.error(f"运行错误: {str(e)}")
335
455
  raise
336
-
456
+
337
457
  async def run(
338
458
  self,
339
459
  day: int = 1,
@@ -348,7 +468,7 @@ class AgentSimulation:
348
468
  tasks = []
349
469
  for group in self._groups.values():
350
470
  tasks.append(group.run.remote(day))
351
-
471
+
352
472
  await asyncio.gather(*tasks)
353
473
 
354
474
  except Exception as e:
@@ -1,9 +1,4 @@
1
1
  from .models import QuestionType, Question, Survey
2
2
  from .manager import SurveyManager
3
3
 
4
- __all__ = [
5
- 'QuestionType',
6
- 'Question',
7
- 'Survey',
8
- 'SurveyManager'
9
- ]
4
+ __all__ = ["QuestionType", "Question", "Survey", "SurveyManager"]
@@ -4,14 +4,17 @@ import uuid
4
4
  import json
5
5
  from .models import Survey, Question, QuestionType
6
6
 
7
+
7
8
  class SurveyManager:
8
9
  def __init__(self):
9
10
  self._surveys: Dict[str, Survey] = {}
10
-
11
- def create_survey(self, title: str, description: str, questions: List[dict]) -> Survey:
11
+
12
+ def create_survey(
13
+ self, title: str, description: str, questions: List[dict]
14
+ ) -> Survey:
12
15
  """创建新问卷"""
13
16
  survey_id = str(uuid.uuid4())
14
-
17
+
15
18
  # 转换问题数据
16
19
  survey_questions = []
17
20
  for q in questions:
@@ -21,47 +24,45 @@ class SurveyManager:
21
24
  required=q.get("required", True),
22
25
  options=q.get("options", []),
23
26
  min_rating=q.get("min_rating", 1),
24
- max_rating=q.get("max_rating", 5)
27
+ max_rating=q.get("max_rating", 5),
25
28
  )
26
29
  survey_questions.append(question)
27
-
30
+
28
31
  survey = Survey(
29
32
  id=survey_id,
30
33
  title=title,
31
34
  description=description,
32
- questions=survey_questions
35
+ questions=survey_questions,
33
36
  )
34
-
37
+
35
38
  self._surveys[survey_id] = survey
36
39
  return survey
37
-
40
+
38
41
  def get_survey(self, survey_id: str) -> Optional[Survey]:
39
42
  """获取指定问卷"""
40
43
  return self._surveys.get(survey_id)
41
-
44
+
42
45
  def get_all_surveys(self) -> List[Survey]:
43
46
  """获取所有问卷"""
44
47
  return list(self._surveys.values())
45
-
48
+
46
49
  def add_response(self, survey_id: str, agent_name: str, response: dict) -> bool:
47
50
  """添加问卷回答"""
48
51
  survey = self.get_survey(survey_id)
49
52
  if not survey:
50
53
  return False
51
-
52
- survey.responses[agent_name] = {
53
- "timestamp": datetime.now(),
54
- **response
55
- }
54
+
55
+ survey.responses[agent_name] = {"timestamp": datetime.now(), **response}
56
56
  return True
57
-
57
+
58
58
  def export_results(self, survey_id: str) -> str:
59
59
  """导出问卷结果"""
60
60
  survey = self.get_survey(survey_id)
61
61
  if not survey:
62
62
  return json.dumps({"error": "问卷不存在"})
63
-
64
- return json.dumps({
65
- "survey": survey.to_dict(),
66
- "responses": survey.responses
67
- }, ensure_ascii=False, indent=2)
63
+
64
+ return json.dumps(
65
+ {"survey": survey.to_dict(), "responses": survey.responses},
66
+ ensure_ascii=False,
67
+ indent=2,
68
+ )