pycityagent 2.0.0a13__py3-none-any.whl → 2.0.0a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +132 -44
- pycityagent/economy/econ_client.py +2 -0
- pycityagent/memory/const.py +1 -0
- pycityagent/simulation/agentgroup.py +41 -3
- pycityagent/simulation/simulation.py +99 -8
- pycityagent/survey/manager.py +58 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +7 -0
- pycityagent/utils/avro_schema.py +85 -0
- pycityagent/utils/survey_util.py +53 -0
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a14.dist-info}/METADATA +3 -1
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a14.dist-info}/RECORD +14 -15
- pycityagent/simulation/interview.py +0 -40
- pycityagent/simulation/survey/manager.py +0 -68
- pycityagent/simulation/survey/models.py +0 -52
- pycityagent/simulation/ui/__init__.py +0 -3
- pycityagent/simulation/ui/interface.py +0 -602
- /pycityagent/{simulation/survey → survey}/__init__.py +0 -0
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a14.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -13,11 +13,15 @@ import random
|
|
13
13
|
import uuid
|
14
14
|
from typing import Dict, List, Optional
|
15
15
|
|
16
|
+
import fastavro
|
17
|
+
|
16
18
|
from pycityagent.environment.sim.person_service import PersonService
|
17
19
|
from mosstool.util.format_converter import dict2pb
|
18
20
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
21
|
+
from pycityagent.utils import process_survey_for_llm
|
19
22
|
|
20
23
|
from pycityagent.message.messager import Messager
|
24
|
+
from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
|
21
25
|
|
22
26
|
from .economy import EconomyClient
|
23
27
|
from .environment import Simulator
|
@@ -52,6 +56,7 @@ class Agent(ABC):
|
|
52
56
|
messager: Optional[Messager] = None,
|
53
57
|
simulator: Optional[Simulator] = None,
|
54
58
|
memory: Optional[Memory] = None,
|
59
|
+
avro_file: Optional[Dict[str, str]] = None,
|
55
60
|
) -> None:
|
56
61
|
"""
|
57
62
|
Initialize the Agent.
|
@@ -64,6 +69,7 @@ class Agent(ABC):
|
|
64
69
|
messager (Messager, optional): The messager object. Defaults to None.
|
65
70
|
simulator (Simulator, optional): The simulator object. Defaults to None.
|
66
71
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
|
+
avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
|
67
73
|
"""
|
68
74
|
self._name = name
|
69
75
|
self._type = type
|
@@ -79,6 +85,7 @@ class Agent(ABC):
|
|
79
85
|
self._blocked = False
|
80
86
|
self._interview_history: List[Dict] = [] # 存储采访历史
|
81
87
|
self._person_template = PersonService.default_dict_person()
|
88
|
+
self._avro_file = avro_file
|
82
89
|
|
83
90
|
def __getstate__(self):
|
84
91
|
state = self.__dict__.copy()
|
@@ -168,11 +175,59 @@ class Agent(ABC):
|
|
168
175
|
)
|
169
176
|
return self._simulator
|
170
177
|
|
171
|
-
async def
|
172
|
-
"""生成回答
|
178
|
+
async def generate_user_survey_response(self, survey: dict) -> str:
|
179
|
+
"""生成回答 —— 可重写
|
180
|
+
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
181
|
+
Args:
|
182
|
+
survey: 需要回答的问卷 dict
|
183
|
+
Returns:
|
184
|
+
str: 智能体的回答
|
185
|
+
"""
|
186
|
+
survey_prompt = process_survey_for_llm(survey)
|
187
|
+
dialog = []
|
173
188
|
|
174
|
-
|
189
|
+
# 添加系统提示
|
190
|
+
system_prompt = "Please answer the survey question in first person. Follow the format requirements strictly and provide clear and specific answers."
|
191
|
+
dialog.append({"role": "system", "content": system_prompt})
|
192
|
+
|
193
|
+
# 添加记忆上下文
|
194
|
+
if self._memory:
|
195
|
+
relevant_memories = await self._memory.search(survey_prompt)
|
196
|
+
if relevant_memories:
|
197
|
+
dialog.append(
|
198
|
+
{
|
199
|
+
"role": "system",
|
200
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
201
|
+
}
|
202
|
+
)
|
203
|
+
|
204
|
+
# 添加问卷问题
|
205
|
+
dialog.append({"role": "user", "content": survey_prompt})
|
206
|
+
|
207
|
+
# 使用LLM生成回答
|
208
|
+
if not self._llm_client:
|
209
|
+
return "Sorry, I cannot answer survey questions right now."
|
210
|
+
|
211
|
+
response = await self._llm_client.atext_request(dialog) # type:ignore
|
175
212
|
|
213
|
+
return response # type:ignore
|
214
|
+
|
215
|
+
async def _process_survey(self, survey: dict):
|
216
|
+
survey_response = await self.generate_user_survey_response(survey)
|
217
|
+
response_to_avro = [{
|
218
|
+
"id": str(self._uuid),
|
219
|
+
"day": await self._simulator.get_simulator_day(),
|
220
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
221
|
+
"survey_id": survey["id"],
|
222
|
+
"result": survey_response,
|
223
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
224
|
+
}]
|
225
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
226
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
227
|
+
|
228
|
+
async def generate_user_chat_response(self, question: str) -> str:
|
229
|
+
"""生成回答 —— 可重写
|
230
|
+
基于智能体的记忆和当前状态,生成对问题的回答。
|
176
231
|
Args:
|
177
232
|
question: 需要回答的问题
|
178
233
|
|
@@ -182,7 +237,7 @@ class Agent(ABC):
|
|
182
237
|
dialog = []
|
183
238
|
|
184
239
|
# 添加系统提示
|
185
|
-
system_prompt =
|
240
|
+
system_prompt = "Please answer the question in first person and keep the response concise and clear."
|
186
241
|
dialog.append({"role": "system", "content": system_prompt})
|
187
242
|
|
188
243
|
# 添加记忆上下文
|
@@ -192,7 +247,7 @@ class Agent(ABC):
|
|
192
247
|
dialog.append(
|
193
248
|
{
|
194
249
|
"role": "system",
|
195
|
-
"content": f"
|
250
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
196
251
|
}
|
197
252
|
)
|
198
253
|
|
@@ -201,46 +256,76 @@ class Agent(ABC):
|
|
201
256
|
|
202
257
|
# 使用LLM生成回答
|
203
258
|
if not self._llm_client:
|
204
|
-
return "
|
259
|
+
return "Sorry, I cannot answer questions right now."
|
205
260
|
|
206
261
|
response = await self._llm_client.atext_request(dialog) # type:ignore
|
207
262
|
|
208
|
-
# 记录采访历史
|
209
|
-
self._interview_history.append(
|
210
|
-
{
|
211
|
-
"timestamp": datetime.now().isoformat(),
|
212
|
-
"question": question,
|
213
|
-
"response": response,
|
214
|
-
}
|
215
|
-
)
|
216
|
-
|
217
263
|
return response # type:ignore
|
218
|
-
|
219
|
-
def
|
220
|
-
|
221
|
-
|
222
|
-
|
264
|
+
|
265
|
+
async def _process_interview(self, payload: dict):
|
266
|
+
auros = [{
|
267
|
+
"id": str(self._uuid),
|
268
|
+
"day": await self._simulator.get_simulator_day(),
|
269
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
270
|
+
"type": 2,
|
271
|
+
"speaker": "user",
|
272
|
+
"content": payload["content"],
|
273
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
274
|
+
}]
|
275
|
+
question = payload["content"]
|
276
|
+
response = await self.generate_user_chat_response(question)
|
277
|
+
auros.append({
|
278
|
+
"id": str(self._uuid),
|
279
|
+
"day": await self._simulator.get_simulator_day(),
|
280
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
281
|
+
"type": 2,
|
282
|
+
"speaker": "",
|
283
|
+
"content": response,
|
284
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
285
|
+
})
|
286
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
287
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
288
|
+
|
289
|
+
async def process_agent_chat_response(self, payload: dict) -> str:
|
290
|
+
logging.info(f"Agent {self._uuid} received agent chat response: {payload}")
|
291
|
+
|
292
|
+
async def _process_agent_chat(self, payload: dict):
|
293
|
+
auros = [{
|
294
|
+
"id": str(self._uuid),
|
295
|
+
"day": payload["day"],
|
296
|
+
"t": payload["t"],
|
297
|
+
"type": 1,
|
298
|
+
"speaker": payload["from"],
|
299
|
+
"content": payload["content"],
|
300
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
301
|
+
}]
|
302
|
+
asyncio.create_task(self.process_agent_chat_response(payload))
|
303
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
304
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
305
|
+
|
306
|
+
# Callback functions for MQTT message
|
223
307
|
async def handle_agent_chat_message(self, payload: dict):
|
224
308
|
"""处理收到的消息,识别发送者"""
|
225
309
|
# 从消息中解析发送者 ID 和消息内容
|
226
|
-
|
227
|
-
|
228
|
-
)
|
310
|
+
logging.info(f"Agent {self._uuid} received agent chat message: {payload}")
|
311
|
+
asyncio.create_task(self._process_agent_chat(payload))
|
229
312
|
|
230
313
|
async def handle_user_chat_message(self, payload: dict):
|
231
314
|
"""处理收到的消息,识别发送者"""
|
232
315
|
# 从消息中解析发送者 ID 和消息内容
|
233
|
-
|
234
|
-
|
235
|
-
)
|
316
|
+
logging.info(f"Agent {self._uuid} received user chat message: {payload}")
|
317
|
+
asyncio.create_task(self._process_interview(payload))
|
236
318
|
|
237
319
|
async def handle_user_survey_message(self, payload: dict):
|
238
320
|
"""处理收到的消息,识别发送者"""
|
239
321
|
# 从消息中解析发送者 ID 和消息内容
|
240
|
-
|
241
|
-
|
242
|
-
)
|
322
|
+
logging.info(f"Agent {self._uuid} received user survey message: {payload}")
|
323
|
+
asyncio.create_task(self._process_survey(payload["data"]))
|
243
324
|
|
325
|
+
async def handle_gather_message(self, payload: str):
|
326
|
+
raise NotImplementedError
|
327
|
+
|
328
|
+
# MQTT send message
|
244
329
|
async def _send_message(
|
245
330
|
self, to_agent_uuid: UUID, payload: dict, sub_topic: str
|
246
331
|
):
|
@@ -251,7 +336,7 @@ class Agent(ABC):
|
|
251
336
|
await self._messager.send_message(topic, payload)
|
252
337
|
|
253
338
|
async def send_message_to_agent(
|
254
|
-
self, to_agent_uuid: UUID, content:
|
339
|
+
self, to_agent_uuid: UUID, content: str
|
255
340
|
):
|
256
341
|
"""通过 Messager 发送消息"""
|
257
342
|
if self._messager is None:
|
@@ -259,29 +344,28 @@ class Agent(ABC):
|
|
259
344
|
payload = {
|
260
345
|
"from": self._uuid,
|
261
346
|
"content": content,
|
262
|
-
"timestamp": int(
|
347
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
263
348
|
"day": await self._simulator.get_simulator_day(),
|
264
349
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
265
350
|
}
|
266
351
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
352
|
+
auros = [{
|
353
|
+
"id": str(self._uuid),
|
354
|
+
"day": await self._simulator.get_simulator_day(),
|
355
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
356
|
+
"type": 1,
|
357
|
+
"speaker": str(self._uuid),
|
358
|
+
"content": content,
|
359
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
360
|
+
}]
|
361
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
362
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
267
363
|
|
268
|
-
|
269
|
-
self, content: dict
|
270
|
-
):
|
271
|
-
pass
|
272
|
-
|
273
|
-
async def send_message_to_survey(
|
274
|
-
self, content: dict
|
275
|
-
):
|
276
|
-
pass
|
277
|
-
|
364
|
+
# Agent logic
|
278
365
|
@abstractmethod
|
279
366
|
async def forward(self) -> None:
|
280
367
|
"""智能体行为逻辑"""
|
281
368
|
raise NotImplementedError
|
282
|
-
|
283
|
-
async def handle_gather_message(self, payload: str):
|
284
|
-
raise NotImplementedError
|
285
369
|
|
286
370
|
async def run(self) -> None:
|
287
371
|
"""
|
@@ -305,6 +389,7 @@ class CitizenAgent(Agent):
|
|
305
389
|
memory: Optional[Memory] = None,
|
306
390
|
economy_client: Optional[EconomyClient] = None,
|
307
391
|
messager: Optional[Messager] = None,
|
392
|
+
avro_file: Optional[dict] = None,
|
308
393
|
) -> None:
|
309
394
|
super().__init__(
|
310
395
|
name,
|
@@ -314,6 +399,7 @@ class CitizenAgent(Agent):
|
|
314
399
|
messager,
|
315
400
|
simulator,
|
316
401
|
memory,
|
402
|
+
avro_file,
|
317
403
|
)
|
318
404
|
|
319
405
|
async def bind_to_simulator(self):
|
@@ -417,6 +503,7 @@ class InstitutionAgent(Agent):
|
|
417
503
|
memory: Optional[Memory] = None,
|
418
504
|
economy_client: Optional[EconomyClient] = None,
|
419
505
|
messager: Optional[Messager] = None,
|
506
|
+
avro_file: Optional[dict] = None,
|
420
507
|
) -> None:
|
421
508
|
super().__init__(
|
422
509
|
name,
|
@@ -426,6 +513,7 @@ class InstitutionAgent(Agent):
|
|
426
513
|
messager,
|
427
514
|
simulator,
|
428
515
|
memory,
|
516
|
+
avro_file,
|
429
517
|
)
|
430
518
|
# 添加响应收集器
|
431
519
|
self._gather_responses: Dict[str, asyncio.Future] = {}
|
pycityagent/memory/const.py
CHANGED
@@ -1,23 +1,27 @@
|
|
1
1
|
import asyncio
|
2
|
+
from datetime import datetime
|
2
3
|
import json
|
3
4
|
import logging
|
4
5
|
import uuid
|
6
|
+
import fastavro
|
5
7
|
import ray
|
6
8
|
from uuid import UUID
|
7
|
-
from pycityagent.agent import Agent
|
9
|
+
from pycityagent.agent import Agent, CitizenAgent
|
8
10
|
from pycityagent.economy.econ_client import EconomyClient
|
9
11
|
from pycityagent.environment.simulator import Simulator
|
10
12
|
from pycityagent.llm.llm import LLM
|
11
13
|
from pycityagent.llm.llmconfig import LLMConfig
|
12
14
|
from pycityagent.message import Messager
|
15
|
+
from pycityagent.utils import STATUS_SCHEMA
|
13
16
|
from typing import Any
|
14
17
|
|
15
18
|
@ray.remote
|
16
19
|
class AgentGroup:
|
17
|
-
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID):
|
20
|
+
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, avro_file: dict):
|
18
21
|
self.agents = agents
|
19
22
|
self.config = config
|
20
23
|
self.exp_id = exp_id
|
24
|
+
self.avro_file = avro_file
|
21
25
|
self._uuid = uuid.uuid4()
|
22
26
|
self.messager = Messager(
|
23
27
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -117,7 +121,7 @@ class AgentGroup:
|
|
117
121
|
elif topic_type == "gather":
|
118
122
|
await agent.handle_gather_message(payload)
|
119
123
|
|
120
|
-
await asyncio.sleep(
|
124
|
+
await asyncio.sleep(0.5)
|
121
125
|
|
122
126
|
async def step(self):
|
123
127
|
if not self.initialized:
|
@@ -125,6 +129,40 @@ class AgentGroup:
|
|
125
129
|
|
126
130
|
tasks = [agent.run() for agent in self.agents]
|
127
131
|
await asyncio.gather(*tasks)
|
132
|
+
avros = []
|
133
|
+
for agent in self.agents:
|
134
|
+
if not issubclass(type(agent), CitizenAgent):
|
135
|
+
continue
|
136
|
+
position = await agent.memory.get("position")
|
137
|
+
lng = position["longlat_position"]["longitude"]
|
138
|
+
lat = position["longlat_position"]["latitude"]
|
139
|
+
if "aoi_position" in position:
|
140
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
141
|
+
elif "lane_position" in position:
|
142
|
+
parent_id = position["lane_position"]["lane_id"]
|
143
|
+
else:
|
144
|
+
# BUG: 需要处理
|
145
|
+
parent_id = -1
|
146
|
+
needs = await agent.memory.get("needs")
|
147
|
+
action = await agent.memory.get("current_step")
|
148
|
+
action = action["intention"]
|
149
|
+
avro = {
|
150
|
+
"id": str(agent._uuid), # uuid as string
|
151
|
+
"day": await self.simulator.get_simulator_day(),
|
152
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
153
|
+
"lng": lng,
|
154
|
+
"lat": lat,
|
155
|
+
"parent_id": parent_id,
|
156
|
+
"action": action,
|
157
|
+
"hungry": needs["hungry"],
|
158
|
+
"tired": needs["tired"],
|
159
|
+
"safe": needs["safe"],
|
160
|
+
"social": needs["social"],
|
161
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
162
|
+
}
|
163
|
+
avros.append(avro)
|
164
|
+
with open(self.avro_file["status"], "a+b") as f:
|
165
|
+
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
128
166
|
|
129
167
|
async def run(self, day: int = 1):
|
130
168
|
"""运行模拟器
|
@@ -1,18 +1,21 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
4
6
|
import uuid
|
5
7
|
from datetime import datetime
|
6
8
|
import random
|
7
9
|
from typing import Dict, List, Optional, Callable, Union,Any
|
10
|
+
import fastavro
|
8
11
|
from mosstool.map._map_util.const import AOI_START_ID
|
9
12
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
10
13
|
from pycityagent.memory.memory import Memory
|
14
|
+
from pycityagent.message.messager import Messager
|
15
|
+
from pycityagent.survey import Survey
|
16
|
+
from pycityagent.utils.avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
|
11
17
|
|
12
18
|
from ..agent import Agent, InstitutionAgent
|
13
|
-
from .interview import InterviewManager
|
14
|
-
from .survey import QuestionType, SurveyManager
|
15
|
-
from .ui import InterviewUI
|
16
19
|
from .agentgroup import AgentGroup
|
17
20
|
|
18
21
|
logger = logging.getLogger(__name__)
|
@@ -44,10 +47,34 @@ class AgentSimulation:
|
|
44
47
|
self._groups: Dict[str, AgentGroup] = {}
|
45
48
|
self._agent_uuid2group: Dict[uuid.UUID, AgentGroup] = {}
|
46
49
|
self._agent_uuids: List[uuid.UUID] = []
|
47
|
-
|
50
|
+
self._user_chat_topics: Dict[uuid.UUID, str] = {}
|
51
|
+
self._user_survey_topics: Dict[uuid.UUID, str] = {}
|
52
|
+
self._user_interview_topics: Dict[uuid.UUID, str] = {}
|
48
53
|
self._loop = asyncio.get_event_loop()
|
49
|
-
|
50
|
-
self.
|
54
|
+
|
55
|
+
self._messager = Messager(
|
56
|
+
hostname=config["simulator_request"]["mqtt"]["server"],
|
57
|
+
port=config["simulator_request"]["mqtt"]["port"],
|
58
|
+
username=config["simulator_request"]["mqtt"].get("username", None),
|
59
|
+
password=config["simulator_request"]["mqtt"].get("password", None),
|
60
|
+
)
|
61
|
+
asyncio.create_task(self._messager.connect())
|
62
|
+
|
63
|
+
self._enable_avro = config["storage"]["avro"]["enabled"]
|
64
|
+
self._avro_path = Path(config["storage"]["avro"]["path"])
|
65
|
+
self._avro_file = {
|
66
|
+
"profile": self._avro_path / f"{self.exp_id}_profile.avro",
|
67
|
+
"dialog": self._avro_path / f"{self.exp_id}_dialog.avro",
|
68
|
+
"status": self._avro_path / f"{self.exp_id}_status.avro",
|
69
|
+
"survey": self._avro_path / f"{self.exp_id}_survey.avro",
|
70
|
+
}
|
71
|
+
|
72
|
+
self._enable_pgsql = config["storage"]["pgsql"]["enabled"]
|
73
|
+
self._pgsql_host = config["storage"]["pgsql"]["host"]
|
74
|
+
self._pgsql_port = config["storage"]["pgsql"]["port"]
|
75
|
+
self._pgsql_database = config["storage"]["pgsql"]["database"]
|
76
|
+
self._pgsql_user = config["storage"]["pgsql"]["user"]
|
77
|
+
self._pgsql_password = config["storage"]["pgsql"]["password"]
|
51
78
|
|
52
79
|
@property
|
53
80
|
def agents(self):
|
@@ -126,6 +153,7 @@ class AgentSimulation:
|
|
126
153
|
agent = agent_class(
|
127
154
|
name=agent_name,
|
128
155
|
memory=memory,
|
156
|
+
avro_file=self._avro_file,
|
129
157
|
)
|
130
158
|
|
131
159
|
self._agents[agent._uuid] = agent
|
@@ -145,7 +173,7 @@ class AgentSimulation:
|
|
145
173
|
# 获取当前组的agents
|
146
174
|
agents = list(self._agents.values())[start_idx:end_idx]
|
147
175
|
group_name = f"AgentType_{i}_Group_{k}"
|
148
|
-
group = AgentGroup.remote(agents, self.config, self.exp_id)
|
176
|
+
group = AgentGroup.remote(agents, self.config, self.exp_id, self._avro_file)
|
149
177
|
self._groups[group_name] = group
|
150
178
|
for agent in agents:
|
151
179
|
self._agent_uuid2group[agent._uuid] = group
|
@@ -156,6 +184,42 @@ class AgentSimulation:
|
|
156
184
|
for group in self._groups.values():
|
157
185
|
init_tasks.append(group.init_agents.remote())
|
158
186
|
await asyncio.gather(*init_tasks)
|
187
|
+
for uuid, agent in self._agents.items():
|
188
|
+
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
189
|
+
self._user_survey_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
190
|
+
|
191
|
+
# save profile
|
192
|
+
if self._enable_avro:
|
193
|
+
self._avro_path.mkdir(parents=True, exist_ok=True)
|
194
|
+
# profile
|
195
|
+
filename = self._avro_file["profile"]
|
196
|
+
with open(filename, "wb") as f:
|
197
|
+
profiles = []
|
198
|
+
for agent in self._agents.values():
|
199
|
+
profile = await agent.memory._profile.export()
|
200
|
+
profile = profile[0]
|
201
|
+
profile['id'] = str(agent._uuid)
|
202
|
+
profiles.append(profile)
|
203
|
+
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
204
|
+
|
205
|
+
# dialog
|
206
|
+
filename = self._avro_file["dialog"]
|
207
|
+
with open(filename, "wb") as f:
|
208
|
+
dialogs = []
|
209
|
+
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
210
|
+
|
211
|
+
# status
|
212
|
+
filename = self._avro_file["status"]
|
213
|
+
with open(filename, "wb") as f:
|
214
|
+
statuses = []
|
215
|
+
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
216
|
+
|
217
|
+
# survey
|
218
|
+
filename = self._avro_file["survey"]
|
219
|
+
with open(filename, "wb") as f:
|
220
|
+
surveys = []
|
221
|
+
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
222
|
+
|
159
223
|
|
160
224
|
async def gather(self, content: str):
|
161
225
|
"""收集智能体的特定信息"""
|
@@ -219,6 +283,7 @@ class AgentSimulation:
|
|
219
283
|
}
|
220
284
|
|
221
285
|
PROFILE = {
|
286
|
+
"name": "unknown",
|
222
287
|
"gender": random.choice(["male", "female"]),
|
223
288
|
"education": random.choice(
|
224
289
|
["Doctor", "Master", "Bachelor", "College", "High School"]
|
@@ -251,7 +316,7 @@ class AgentSimulation:
|
|
251
316
|
"personality": random.choice(
|
252
317
|
["outgoint", "introvert", "ambivert", "extrovert"]
|
253
318
|
),
|
254
|
-
"income": random.randint(1000, 10000),
|
319
|
+
"income": str(random.randint(1000, 10000)),
|
255
320
|
"currency": random.randint(10000, 100000),
|
256
321
|
"residence": random.choice(["city", "suburb", "rural"]),
|
257
322
|
"race": random.choice(
|
@@ -285,6 +350,32 @@ class AgentSimulation:
|
|
285
350
|
}
|
286
351
|
|
287
352
|
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
353
|
+
|
354
|
+
async def send_survey(self, survey: Survey, agent_uuids: Optional[List[uuid.UUID]] = None):
|
355
|
+
"""发送问卷"""
|
356
|
+
survey = survey.to_dict()
|
357
|
+
if agent_uuids is None:
|
358
|
+
agent_uuids = self._agent_uuids
|
359
|
+
payload = {
|
360
|
+
"from": "none",
|
361
|
+
"survey_id": survey["id"],
|
362
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
363
|
+
"data": survey,
|
364
|
+
}
|
365
|
+
for uuid in agent_uuids:
|
366
|
+
topic = self._user_survey_topics[uuid]
|
367
|
+
await self._messager.send_message(topic, payload)
|
368
|
+
|
369
|
+
async def send_interview_message(self, content: str, agent_uuids: Union[uuid.UUID, List[uuid.UUID]]):
|
370
|
+
"""发送面试消息"""
|
371
|
+
payload = {
|
372
|
+
"from": "none",
|
373
|
+
"content": content,
|
374
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
375
|
+
}
|
376
|
+
for uuid in agent_uuids:
|
377
|
+
topic = self._user_chat_topics[uuid]
|
378
|
+
await self._messager.send_message(topic, payload)
|
288
379
|
|
289
380
|
async def step(self):
|
290
381
|
"""运行一步, 即每个智能体执行一次forward"""
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from typing import List, Dict, Optional
|
2
|
+
from datetime import datetime
|
3
|
+
import uuid
|
4
|
+
import json
|
5
|
+
from .models import Survey, Question, QuestionType, Page
|
6
|
+
|
7
|
+
|
8
|
+
class SurveyManager:
|
9
|
+
def __init__(self):
|
10
|
+
self._surveys: Dict[str, Survey] = {}
|
11
|
+
|
12
|
+
def create_survey(
|
13
|
+
self, title: str, description: str, pages: List[dict]
|
14
|
+
) -> Survey:
|
15
|
+
"""创建新问卷"""
|
16
|
+
survey_id = uuid.uuid4()
|
17
|
+
|
18
|
+
# 转换页面和问题数据
|
19
|
+
survey_pages = []
|
20
|
+
for page_data in pages:
|
21
|
+
questions = []
|
22
|
+
for q in page_data["elements"]:
|
23
|
+
question = Question(
|
24
|
+
name=q["name"],
|
25
|
+
title=q["title"],
|
26
|
+
type=QuestionType(q["type"]),
|
27
|
+
required=q.get("required", True),
|
28
|
+
choices=q.get("choices", []),
|
29
|
+
columns=q.get("columns", []),
|
30
|
+
rows=q.get("rows", []),
|
31
|
+
min_rating=q.get("min_rating", 1),
|
32
|
+
max_rating=q.get("max_rating", 5),
|
33
|
+
)
|
34
|
+
questions.append(question)
|
35
|
+
|
36
|
+
page = Page(
|
37
|
+
name=page_data["name"],
|
38
|
+
elements=questions
|
39
|
+
)
|
40
|
+
survey_pages.append(page)
|
41
|
+
|
42
|
+
survey = Survey(
|
43
|
+
id=survey_id,
|
44
|
+
title=title,
|
45
|
+
description=description,
|
46
|
+
pages=survey_pages,
|
47
|
+
)
|
48
|
+
|
49
|
+
self._surveys[str(survey_id)] = survey
|
50
|
+
return survey
|
51
|
+
|
52
|
+
def get_survey(self, survey_id: str) -> Optional[Survey]:
|
53
|
+
"""获取指定问卷"""
|
54
|
+
return self._surveys.get(survey_id)
|
55
|
+
|
56
|
+
def get_all_surveys(self) -> List[Survey]:
|
57
|
+
"""获取所有问卷"""
|
58
|
+
return list(self._surveys.values())
|