pycityagent 2.0.0a12__py3-none-any.whl → 2.0.0a14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +171 -52
- pycityagent/economy/econ_client.py +2 -0
- pycityagent/memory/const.py +1 -0
- pycityagent/simulation/agentgroup.py +78 -42
- pycityagent/simulation/simulation.py +99 -8
- pycityagent/survey/manager.py +58 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +7 -0
- pycityagent/utils/avro_schema.py +85 -0
- pycityagent/utils/survey_util.py +53 -0
- {pycityagent-2.0.0a12.dist-info → pycityagent-2.0.0a14.dist-info}/METADATA +3 -1
- {pycityagent-2.0.0a12.dist-info → pycityagent-2.0.0a14.dist-info}/RECORD +14 -15
- pycityagent/simulation/interview.py +0 -40
- pycityagent/simulation/survey/manager.py +0 -68
- pycityagent/simulation/survey/models.py +0 -52
- pycityagent/simulation/ui/__init__.py +0 -3
- pycityagent/simulation/ui/interface.py +0 -602
- /pycityagent/{simulation/survey → survey}/__init__.py +0 -0
- {pycityagent-2.0.0a12.dist-info → pycityagent-2.0.0a14.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -13,11 +13,15 @@ import random
|
|
13
13
|
import uuid
|
14
14
|
from typing import Dict, List, Optional
|
15
15
|
|
16
|
+
import fastavro
|
17
|
+
|
16
18
|
from pycityagent.environment.sim.person_service import PersonService
|
17
19
|
from mosstool.util.format_converter import dict2pb
|
18
20
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
21
|
+
from pycityagent.utils import process_survey_for_llm
|
19
22
|
|
20
23
|
from pycityagent.message.messager import Messager
|
24
|
+
from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
|
21
25
|
|
22
26
|
from .economy import EconomyClient
|
23
27
|
from .environment import Simulator
|
@@ -52,6 +56,7 @@ class Agent(ABC):
|
|
52
56
|
messager: Optional[Messager] = None,
|
53
57
|
simulator: Optional[Simulator] = None,
|
54
58
|
memory: Optional[Memory] = None,
|
59
|
+
avro_file: Optional[Dict[str, str]] = None,
|
55
60
|
) -> None:
|
56
61
|
"""
|
57
62
|
Initialize the Agent.
|
@@ -64,6 +69,7 @@ class Agent(ABC):
|
|
64
69
|
messager (Messager, optional): The messager object. Defaults to None.
|
65
70
|
simulator (Simulator, optional): The simulator object. Defaults to None.
|
66
71
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
|
+
avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
|
67
73
|
"""
|
68
74
|
self._name = name
|
69
75
|
self._type = type
|
@@ -79,6 +85,7 @@ class Agent(ABC):
|
|
79
85
|
self._blocked = False
|
80
86
|
self._interview_history: List[Dict] = [] # 存储采访历史
|
81
87
|
self._person_template = PersonService.default_dict_person()
|
88
|
+
self._avro_file = avro_file
|
82
89
|
|
83
90
|
def __getstate__(self):
|
84
91
|
state = self.__dict__.copy()
|
@@ -168,11 +175,59 @@ class Agent(ABC):
|
|
168
175
|
)
|
169
176
|
return self._simulator
|
170
177
|
|
171
|
-
async def
|
172
|
-
"""生成回答
|
178
|
+
async def generate_user_survey_response(self, survey: dict) -> str:
|
179
|
+
"""生成回答 —— 可重写
|
180
|
+
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
181
|
+
Args:
|
182
|
+
survey: 需要回答的问卷 dict
|
183
|
+
Returns:
|
184
|
+
str: 智能体的回答
|
185
|
+
"""
|
186
|
+
survey_prompt = process_survey_for_llm(survey)
|
187
|
+
dialog = []
|
173
188
|
|
174
|
-
|
189
|
+
# 添加系统提示
|
190
|
+
system_prompt = "Please answer the survey question in first person. Follow the format requirements strictly and provide clear and specific answers."
|
191
|
+
dialog.append({"role": "system", "content": system_prompt})
|
192
|
+
|
193
|
+
# 添加记忆上下文
|
194
|
+
if self._memory:
|
195
|
+
relevant_memories = await self._memory.search(survey_prompt)
|
196
|
+
if relevant_memories:
|
197
|
+
dialog.append(
|
198
|
+
{
|
199
|
+
"role": "system",
|
200
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
201
|
+
}
|
202
|
+
)
|
203
|
+
|
204
|
+
# 添加问卷问题
|
205
|
+
dialog.append({"role": "user", "content": survey_prompt})
|
206
|
+
|
207
|
+
# 使用LLM生成回答
|
208
|
+
if not self._llm_client:
|
209
|
+
return "Sorry, I cannot answer survey questions right now."
|
210
|
+
|
211
|
+
response = await self._llm_client.atext_request(dialog) # type:ignore
|
175
212
|
|
213
|
+
return response # type:ignore
|
214
|
+
|
215
|
+
async def _process_survey(self, survey: dict):
|
216
|
+
survey_response = await self.generate_user_survey_response(survey)
|
217
|
+
response_to_avro = [{
|
218
|
+
"id": str(self._uuid),
|
219
|
+
"day": await self._simulator.get_simulator_day(),
|
220
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
221
|
+
"survey_id": survey["id"],
|
222
|
+
"result": survey_response,
|
223
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
224
|
+
}]
|
225
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
226
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
227
|
+
|
228
|
+
async def generate_user_chat_response(self, question: str) -> str:
|
229
|
+
"""生成回答 —— 可重写
|
230
|
+
基于智能体的记忆和当前状态,生成对问题的回答。
|
176
231
|
Args:
|
177
232
|
question: 需要回答的问题
|
178
233
|
|
@@ -182,7 +237,7 @@ class Agent(ABC):
|
|
182
237
|
dialog = []
|
183
238
|
|
184
239
|
# 添加系统提示
|
185
|
-
system_prompt =
|
240
|
+
system_prompt = "Please answer the question in first person and keep the response concise and clear."
|
186
241
|
dialog.append({"role": "system", "content": system_prompt})
|
187
242
|
|
188
243
|
# 添加记忆上下文
|
@@ -192,7 +247,7 @@ class Agent(ABC):
|
|
192
247
|
dialog.append(
|
193
248
|
{
|
194
249
|
"role": "system",
|
195
|
-
"content": f"
|
250
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
196
251
|
}
|
197
252
|
)
|
198
253
|
|
@@ -201,46 +256,76 @@ class Agent(ABC):
|
|
201
256
|
|
202
257
|
# 使用LLM生成回答
|
203
258
|
if not self._llm_client:
|
204
|
-
return "
|
259
|
+
return "Sorry, I cannot answer questions right now."
|
205
260
|
|
206
261
|
response = await self._llm_client.atext_request(dialog) # type:ignore
|
207
262
|
|
208
|
-
# 记录采访历史
|
209
|
-
self._interview_history.append(
|
210
|
-
{
|
211
|
-
"timestamp": datetime.now().isoformat(),
|
212
|
-
"question": question,
|
213
|
-
"response": response,
|
214
|
-
}
|
215
|
-
)
|
216
|
-
|
217
263
|
return response # type:ignore
|
218
|
-
|
219
|
-
def
|
220
|
-
|
221
|
-
|
222
|
-
|
264
|
+
|
265
|
+
async def _process_interview(self, payload: dict):
|
266
|
+
auros = [{
|
267
|
+
"id": str(self._uuid),
|
268
|
+
"day": await self._simulator.get_simulator_day(),
|
269
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
270
|
+
"type": 2,
|
271
|
+
"speaker": "user",
|
272
|
+
"content": payload["content"],
|
273
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
274
|
+
}]
|
275
|
+
question = payload["content"]
|
276
|
+
response = await self.generate_user_chat_response(question)
|
277
|
+
auros.append({
|
278
|
+
"id": str(self._uuid),
|
279
|
+
"day": await self._simulator.get_simulator_day(),
|
280
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
281
|
+
"type": 2,
|
282
|
+
"speaker": "",
|
283
|
+
"content": response,
|
284
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
285
|
+
})
|
286
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
287
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
288
|
+
|
289
|
+
async def process_agent_chat_response(self, payload: dict) -> str:
|
290
|
+
logging.info(f"Agent {self._uuid} received agent chat response: {payload}")
|
291
|
+
|
292
|
+
async def _process_agent_chat(self, payload: dict):
|
293
|
+
auros = [{
|
294
|
+
"id": str(self._uuid),
|
295
|
+
"day": payload["day"],
|
296
|
+
"t": payload["t"],
|
297
|
+
"type": 1,
|
298
|
+
"speaker": payload["from"],
|
299
|
+
"content": payload["content"],
|
300
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
301
|
+
}]
|
302
|
+
asyncio.create_task(self.process_agent_chat_response(payload))
|
303
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
304
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
305
|
+
|
306
|
+
# Callback functions for MQTT message
|
223
307
|
async def handle_agent_chat_message(self, payload: dict):
|
224
308
|
"""处理收到的消息,识别发送者"""
|
225
309
|
# 从消息中解析发送者 ID 和消息内容
|
226
|
-
|
227
|
-
|
228
|
-
)
|
310
|
+
logging.info(f"Agent {self._uuid} received agent chat message: {payload}")
|
311
|
+
asyncio.create_task(self._process_agent_chat(payload))
|
229
312
|
|
230
313
|
async def handle_user_chat_message(self, payload: dict):
|
231
314
|
"""处理收到的消息,识别发送者"""
|
232
315
|
# 从消息中解析发送者 ID 和消息内容
|
233
|
-
|
234
|
-
|
235
|
-
)
|
316
|
+
logging.info(f"Agent {self._uuid} received user chat message: {payload}")
|
317
|
+
asyncio.create_task(self._process_interview(payload))
|
236
318
|
|
237
319
|
async def handle_user_survey_message(self, payload: dict):
|
238
320
|
"""处理收到的消息,识别发送者"""
|
239
321
|
# 从消息中解析发送者 ID 和消息内容
|
240
|
-
|
241
|
-
|
242
|
-
|
322
|
+
logging.info(f"Agent {self._uuid} received user survey message: {payload}")
|
323
|
+
asyncio.create_task(self._process_survey(payload["data"]))
|
324
|
+
|
325
|
+
async def handle_gather_message(self, payload: str):
|
326
|
+
raise NotImplementedError
|
243
327
|
|
328
|
+
# MQTT send message
|
244
329
|
async def _send_message(
|
245
330
|
self, to_agent_uuid: UUID, payload: dict, sub_topic: str
|
246
331
|
):
|
@@ -251,7 +336,7 @@ class Agent(ABC):
|
|
251
336
|
await self._messager.send_message(topic, payload)
|
252
337
|
|
253
338
|
async def send_message_to_agent(
|
254
|
-
self, to_agent_uuid: UUID, content:
|
339
|
+
self, to_agent_uuid: UUID, content: str
|
255
340
|
):
|
256
341
|
"""通过 Messager 发送消息"""
|
257
342
|
if self._messager is None:
|
@@ -259,29 +344,28 @@ class Agent(ABC):
|
|
259
344
|
payload = {
|
260
345
|
"from": self._uuid,
|
261
346
|
"content": content,
|
262
|
-
"timestamp": int(
|
347
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
263
348
|
"day": await self._simulator.get_simulator_day(),
|
264
349
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
265
350
|
}
|
266
351
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
352
|
+
auros = [{
|
353
|
+
"id": str(self._uuid),
|
354
|
+
"day": await self._simulator.get_simulator_day(),
|
355
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
356
|
+
"type": 1,
|
357
|
+
"speaker": str(self._uuid),
|
358
|
+
"content": content,
|
359
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
360
|
+
}]
|
361
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
362
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
267
363
|
|
268
|
-
|
269
|
-
self, content: dict
|
270
|
-
):
|
271
|
-
pass
|
272
|
-
|
273
|
-
async def send_message_to_survey(
|
274
|
-
self, content: dict
|
275
|
-
):
|
276
|
-
pass
|
277
|
-
|
364
|
+
# Agent logic
|
278
365
|
@abstractmethod
|
279
366
|
async def forward(self) -> None:
|
280
367
|
"""智能体行为逻辑"""
|
281
368
|
raise NotImplementedError
|
282
|
-
|
283
|
-
async def handle_gather_message(self, payload: str):
|
284
|
-
raise NotImplementedError
|
285
369
|
|
286
370
|
async def run(self) -> None:
|
287
371
|
"""
|
@@ -305,6 +389,7 @@ class CitizenAgent(Agent):
|
|
305
389
|
memory: Optional[Memory] = None,
|
306
390
|
economy_client: Optional[EconomyClient] = None,
|
307
391
|
messager: Optional[Messager] = None,
|
392
|
+
avro_file: Optional[dict] = None,
|
308
393
|
) -> None:
|
309
394
|
super().__init__(
|
310
395
|
name,
|
@@ -314,6 +399,7 @@ class CitizenAgent(Agent):
|
|
314
399
|
messager,
|
315
400
|
simulator,
|
316
401
|
memory,
|
402
|
+
avro_file,
|
317
403
|
)
|
318
404
|
|
319
405
|
async def bind_to_simulator(self):
|
@@ -417,6 +503,7 @@ class InstitutionAgent(Agent):
|
|
417
503
|
memory: Optional[Memory] = None,
|
418
504
|
economy_client: Optional[EconomyClient] = None,
|
419
505
|
messager: Optional[Messager] = None,
|
506
|
+
avro_file: Optional[dict] = None,
|
420
507
|
) -> None:
|
421
508
|
super().__init__(
|
422
509
|
name,
|
@@ -426,8 +513,11 @@ class InstitutionAgent(Agent):
|
|
426
513
|
messager,
|
427
514
|
simulator,
|
428
515
|
memory,
|
516
|
+
avro_file,
|
429
517
|
)
|
430
|
-
|
518
|
+
# 添加响应收集器
|
519
|
+
self._gather_responses: Dict[str, asyncio.Future] = {}
|
520
|
+
|
431
521
|
async def bind_to_simulator(self):
|
432
522
|
await self._bind_to_economy()
|
433
523
|
|
@@ -514,18 +604,47 @@ class InstitutionAgent(Agent):
|
|
514
604
|
|
515
605
|
async def handle_gather_message(self, payload: dict):
|
516
606
|
"""处理收到的消息,识别发送者"""
|
517
|
-
# 从消息中解析发送者 ID 和消息内容
|
518
607
|
content = payload["content"]
|
519
608
|
sender_id = payload["from"]
|
520
|
-
|
521
|
-
|
522
|
-
)
|
523
|
-
|
524
|
-
|
525
|
-
|
609
|
+
|
610
|
+
# 将响应存储到对应的Future中
|
611
|
+
response_key = str(sender_id)
|
612
|
+
if response_key in self._gather_responses:
|
613
|
+
self._gather_responses[response_key].set_result({
|
614
|
+
"from": sender_id,
|
615
|
+
"content": content,
|
616
|
+
})
|
617
|
+
|
618
|
+
async def gather_messages(self, agent_ids: list[UUID], target: str) -> List[dict]:
|
619
|
+
"""从多个智能体收集消息
|
620
|
+
|
621
|
+
Args:
|
622
|
+
agent_ids: 目标智能体ID列表
|
623
|
+
target: 要收集的信息类型
|
624
|
+
|
625
|
+
Returns:
|
626
|
+
List[dict]: 收集到的所有响应
|
627
|
+
"""
|
628
|
+
# 为每个agent创建Future
|
629
|
+
futures = {}
|
630
|
+
for agent_id in agent_ids:
|
631
|
+
response_key = str(agent_id)
|
632
|
+
futures[response_key] = asyncio.Future()
|
633
|
+
self._gather_responses[response_key] = futures[response_key]
|
634
|
+
|
635
|
+
# 发送gather请求
|
526
636
|
payload = {
|
527
637
|
"from": self._uuid,
|
528
638
|
"target": target,
|
529
639
|
}
|
530
640
|
for agent_id in agent_ids:
|
531
641
|
await self._send_message(agent_id, payload, "gather")
|
642
|
+
|
643
|
+
try:
|
644
|
+
# 等待所有响应
|
645
|
+
responses = await asyncio.gather(*futures.values())
|
646
|
+
return responses
|
647
|
+
finally:
|
648
|
+
# 清理Future
|
649
|
+
for key in futures:
|
650
|
+
self._gather_responses.pop(key, None)
|
pycityagent/memory/const.py
CHANGED
@@ -1,23 +1,27 @@
|
|
1
1
|
import asyncio
|
2
|
+
from datetime import datetime
|
2
3
|
import json
|
3
4
|
import logging
|
4
5
|
import uuid
|
6
|
+
import fastavro
|
5
7
|
import ray
|
6
8
|
from uuid import UUID
|
7
|
-
from pycityagent.agent import Agent
|
9
|
+
from pycityagent.agent import Agent, CitizenAgent
|
8
10
|
from pycityagent.economy.econ_client import EconomyClient
|
9
11
|
from pycityagent.environment.simulator import Simulator
|
10
12
|
from pycityagent.llm.llm import LLM
|
11
13
|
from pycityagent.llm.llmconfig import LLMConfig
|
12
14
|
from pycityagent.message import Messager
|
15
|
+
from pycityagent.utils import STATUS_SCHEMA
|
13
16
|
from typing import Any
|
14
17
|
|
15
18
|
@ray.remote
|
16
19
|
class AgentGroup:
|
17
|
-
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID):
|
20
|
+
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, avro_file: dict):
|
18
21
|
self.agents = agents
|
19
22
|
self.config = config
|
20
23
|
self.exp_id = exp_id
|
24
|
+
self.avro_file = avro_file
|
21
25
|
self._uuid = uuid.uuid4()
|
22
26
|
self.messager = Messager(
|
23
27
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
@@ -71,7 +75,8 @@ class AgentGroup:
|
|
71
75
|
topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
|
72
76
|
await self.messager.subscribe(topic, agent)
|
73
77
|
self.initialized = True
|
74
|
-
|
78
|
+
self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
|
79
|
+
|
75
80
|
async def gather(self, content: str):
|
76
81
|
results = {}
|
77
82
|
for agent in self.agents:
|
@@ -82,51 +87,82 @@ class AgentGroup:
|
|
82
87
|
agent = self.id2agent[target_agent_uuid]
|
83
88
|
await agent.memory.update(target_key, content)
|
84
89
|
|
90
|
+
async def message_dispatch(self):
|
91
|
+
while True:
|
92
|
+
if not self.messager.is_connected():
|
93
|
+
logging.warning("Messager is not connected. Skipping message processing.")
|
94
|
+
|
95
|
+
# Step 1: 获取消息
|
96
|
+
messages = await self.messager.fetch_messages()
|
97
|
+
logging.info(f"Group {self._uuid} received {len(messages)} messages")
|
98
|
+
|
99
|
+
# Step 2: 分发消息到对应的 Agent
|
100
|
+
for message in messages:
|
101
|
+
topic = message.topic.value
|
102
|
+
payload = message.payload
|
103
|
+
|
104
|
+
# 添加解码步骤,将bytes转换为str
|
105
|
+
if isinstance(payload, bytes):
|
106
|
+
payload = payload.decode("utf-8")
|
107
|
+
payload = json.loads(payload)
|
108
|
+
|
109
|
+
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
110
|
+
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
111
|
+
|
112
|
+
if uuid.UUID(agent_uuid) in self.id2agent:
|
113
|
+
agent = self.id2agent[uuid.UUID(agent_uuid)]
|
114
|
+
# topic_type: agent-chat, user-chat, user-survey, gather
|
115
|
+
if topic_type == "agent-chat":
|
116
|
+
await agent.handle_agent_chat_message(payload)
|
117
|
+
elif topic_type == "user-chat":
|
118
|
+
await agent.handle_user_chat_message(payload)
|
119
|
+
elif topic_type == "user-survey":
|
120
|
+
await agent.handle_user_survey_message(payload)
|
121
|
+
elif topic_type == "gather":
|
122
|
+
await agent.handle_gather_message(payload)
|
123
|
+
|
124
|
+
await asyncio.sleep(0.5)
|
125
|
+
|
85
126
|
async def step(self):
|
86
127
|
if not self.initialized:
|
87
128
|
await self.init_agents()
|
88
129
|
|
89
|
-
# Step 1: 如果 Messager 无法连接,则跳过消息接收
|
90
|
-
if not self.messager.is_connected():
|
91
|
-
logging.warning("Messager is not connected. Skipping message processing.")
|
92
|
-
# 跳过接收和分发消息
|
93
|
-
tasks = [agent.run() for agent in self.agents]
|
94
|
-
await asyncio.gather(*tasks)
|
95
|
-
return
|
96
|
-
|
97
|
-
# Step 2: 从 Messager 获取消息
|
98
|
-
messages = await self.messager.fetch_messages()
|
99
|
-
|
100
|
-
logging.info(f"Group {self._uuid} received {len(messages)} messages")
|
101
|
-
|
102
|
-
# Step 3: 分发消息到对应的 Agent
|
103
|
-
for message in messages:
|
104
|
-
topic = message.topic.value
|
105
|
-
payload = message.payload
|
106
|
-
|
107
|
-
# 添加解码步骤,将bytes转换为str
|
108
|
-
if isinstance(payload, bytes):
|
109
|
-
payload = payload.decode("utf-8")
|
110
|
-
payload = json.loads(payload)
|
111
|
-
|
112
|
-
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
113
|
-
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
114
|
-
|
115
|
-
if uuid.UUID(agent_uuid) in self.id2agent:
|
116
|
-
agent = self.id2agent[uuid.UUID(agent_uuid)]
|
117
|
-
# topic_type: agent-chat, user-chat, user-survey, gather
|
118
|
-
if topic_type == "agent-chat":
|
119
|
-
await agent.handle_agent_chat_message(payload)
|
120
|
-
elif topic_type == "user-chat":
|
121
|
-
await agent.handle_user_chat_message(payload)
|
122
|
-
elif topic_type == "user-survey":
|
123
|
-
await agent.handle_user_survey_message(payload)
|
124
|
-
elif topic_type == "gather":
|
125
|
-
await agent.handle_gather_message(payload)
|
126
|
-
|
127
|
-
# Step 4: 调用每个 Agent 的运行逻辑
|
128
130
|
tasks = [agent.run() for agent in self.agents]
|
129
131
|
await asyncio.gather(*tasks)
|
132
|
+
avros = []
|
133
|
+
for agent in self.agents:
|
134
|
+
if not issubclass(type(agent), CitizenAgent):
|
135
|
+
continue
|
136
|
+
position = await agent.memory.get("position")
|
137
|
+
lng = position["longlat_position"]["longitude"]
|
138
|
+
lat = position["longlat_position"]["latitude"]
|
139
|
+
if "aoi_position" in position:
|
140
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
141
|
+
elif "lane_position" in position:
|
142
|
+
parent_id = position["lane_position"]["lane_id"]
|
143
|
+
else:
|
144
|
+
# BUG: 需要处理
|
145
|
+
parent_id = -1
|
146
|
+
needs = await agent.memory.get("needs")
|
147
|
+
action = await agent.memory.get("current_step")
|
148
|
+
action = action["intention"]
|
149
|
+
avro = {
|
150
|
+
"id": str(agent._uuid), # uuid as string
|
151
|
+
"day": await self.simulator.get_simulator_day(),
|
152
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
153
|
+
"lng": lng,
|
154
|
+
"lat": lat,
|
155
|
+
"parent_id": parent_id,
|
156
|
+
"action": action,
|
157
|
+
"hungry": needs["hungry"],
|
158
|
+
"tired": needs["tired"],
|
159
|
+
"safe": needs["safe"],
|
160
|
+
"social": needs["social"],
|
161
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
162
|
+
}
|
163
|
+
avros.append(avro)
|
164
|
+
with open(self.avro_file["status"], "a+b") as f:
|
165
|
+
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
130
166
|
|
131
167
|
async def run(self, day: int = 1):
|
132
168
|
"""运行模拟器
|