pycityagent 2.0.0a12__py3-none-any.whl → 2.0.0a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pycityagent/agent.py CHANGED
@@ -13,11 +13,15 @@ import random
13
13
  import uuid
14
14
  from typing import Dict, List, Optional
15
15
 
16
+ import fastavro
17
+
16
18
  from pycityagent.environment.sim.person_service import PersonService
17
19
  from mosstool.util.format_converter import dict2pb
18
20
  from pycityproto.city.person.v2 import person_pb2 as person_pb2
21
+ from pycityagent.utils import process_survey_for_llm
19
22
 
20
23
  from pycityagent.message.messager import Messager
24
+ from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
21
25
 
22
26
  from .economy import EconomyClient
23
27
  from .environment import Simulator
@@ -52,6 +56,7 @@ class Agent(ABC):
52
56
  messager: Optional[Messager] = None,
53
57
  simulator: Optional[Simulator] = None,
54
58
  memory: Optional[Memory] = None,
59
+ avro_file: Optional[Dict[str, str]] = None,
55
60
  ) -> None:
56
61
  """
57
62
  Initialize the Agent.
@@ -64,6 +69,7 @@ class Agent(ABC):
64
69
  messager (Messager, optional): The messager object. Defaults to None.
65
70
  simulator (Simulator, optional): The simulator object. Defaults to None.
66
71
  memory (Memory, optional): The memory of the agent. Defaults to None.
72
+ avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
67
73
  """
68
74
  self._name = name
69
75
  self._type = type
@@ -79,6 +85,7 @@ class Agent(ABC):
79
85
  self._blocked = False
80
86
  self._interview_history: List[Dict] = [] # 存储采访历史
81
87
  self._person_template = PersonService.default_dict_person()
88
+ self._avro_file = avro_file
82
89
 
83
90
  def __getstate__(self):
84
91
  state = self.__dict__.copy()
@@ -168,11 +175,59 @@ class Agent(ABC):
168
175
  )
169
176
  return self._simulator
170
177
 
171
- async def generate_response(self, question: str) -> str:
172
- """生成回答
178
+ async def generate_user_survey_response(self, survey: dict) -> str:
179
+ """生成回答 —— 可重写
180
+ 基于智能体的记忆和当前状态,生成对问卷调查的回答。
181
+ Args:
182
+ survey: 需要回答的问卷 dict
183
+ Returns:
184
+ str: 智能体的回答
185
+ """
186
+ survey_prompt = process_survey_for_llm(survey)
187
+ dialog = []
173
188
 
174
- 基于智能体的记忆和当前状态,生成对问题的回答。
189
+ # 添加系统提示
190
+ system_prompt = "Please answer the survey question in first person. Follow the format requirements strictly and provide clear and specific answers."
191
+ dialog.append({"role": "system", "content": system_prompt})
192
+
193
+ # 添加记忆上下文
194
+ if self._memory:
195
+ relevant_memories = await self._memory.search(survey_prompt)
196
+ if relevant_memories:
197
+ dialog.append(
198
+ {
199
+ "role": "system",
200
+ "content": f"Answer based on these memories:\n{relevant_memories}",
201
+ }
202
+ )
203
+
204
+ # 添加问卷问题
205
+ dialog.append({"role": "user", "content": survey_prompt})
206
+
207
+ # 使用LLM生成回答
208
+ if not self._llm_client:
209
+ return "Sorry, I cannot answer survey questions right now."
210
+
211
+ response = await self._llm_client.atext_request(dialog) # type:ignore
175
212
 
213
+ return response # type:ignore
214
+
215
+ async def _process_survey(self, survey: dict):
216
+ survey_response = await self.generate_user_survey_response(survey)
217
+ response_to_avro = [{
218
+ "id": str(self._uuid),
219
+ "day": await self._simulator.get_simulator_day(),
220
+ "t": await self._simulator.get_simulator_second_from_start_of_day(),
221
+ "survey_id": survey["id"],
222
+ "result": survey_response,
223
+ "created_at": int(datetime.now().timestamp() * 1000),
224
+ }]
225
+ with open(self._avro_file["survey"], "a+b") as f:
226
+ fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
227
+
228
+ async def generate_user_chat_response(self, question: str) -> str:
229
+ """生成回答 —— 可重写
230
+ 基于智能体的记忆和当前状态,生成对问题的回答。
176
231
  Args:
177
232
  question: 需要回答的问题
178
233
 
@@ -182,7 +237,7 @@ class Agent(ABC):
182
237
  dialog = []
183
238
 
184
239
  # 添加系统提示
185
- system_prompt = f"请以第一人称的方式回答问题,保持回答简洁明了。"
240
+ system_prompt = "Please answer the question in first person and keep the response concise and clear."
186
241
  dialog.append({"role": "system", "content": system_prompt})
187
242
 
188
243
  # 添加记忆上下文
@@ -192,7 +247,7 @@ class Agent(ABC):
192
247
  dialog.append(
193
248
  {
194
249
  "role": "system",
195
- "content": f"基于以下记忆回答问题:\n{relevant_memories}",
250
+ "content": f"Answer based on these memories:\n{relevant_memories}",
196
251
  }
197
252
  )
198
253
 
@@ -201,46 +256,76 @@ class Agent(ABC):
201
256
 
202
257
  # 使用LLM生成回答
203
258
  if not self._llm_client:
204
- return "抱歉,我现在无法回答问题。"
259
+ return "Sorry, I cannot answer questions right now."
205
260
 
206
261
  response = await self._llm_client.atext_request(dialog) # type:ignore
207
262
 
208
- # 记录采访历史
209
- self._interview_history.append(
210
- {
211
- "timestamp": datetime.now().isoformat(),
212
- "question": question,
213
- "response": response,
214
- }
215
- )
216
-
217
263
  return response # type:ignore
218
-
219
- def get_interview_history(self) -> List[Dict]:
220
- """获取采访历史记录"""
221
- return self._interview_history
222
-
264
+
265
+ async def _process_interview(self, payload: dict):
266
+ auros = [{
267
+ "id": str(self._uuid),
268
+ "day": await self._simulator.get_simulator_day(),
269
+ "t": await self._simulator.get_simulator_second_from_start_of_day(),
270
+ "type": 2,
271
+ "speaker": "user",
272
+ "content": payload["content"],
273
+ "created_at": int(datetime.now().timestamp() * 1000),
274
+ }]
275
+ question = payload["content"]
276
+ response = await self.generate_user_chat_response(question)
277
+ auros.append({
278
+ "id": str(self._uuid),
279
+ "day": await self._simulator.get_simulator_day(),
280
+ "t": await self._simulator.get_simulator_second_from_start_of_day(),
281
+ "type": 2,
282
+ "speaker": "",
283
+ "content": response,
284
+ "created_at": int(datetime.now().timestamp() * 1000),
285
+ })
286
+ with open(self._avro_file["dialog"], "a+b") as f:
287
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
288
+
289
+ async def process_agent_chat_response(self, payload: dict) -> str:
290
+ logging.info(f"Agent {self._uuid} received agent chat response: {payload}")
291
+
292
+ async def _process_agent_chat(self, payload: dict):
293
+ auros = [{
294
+ "id": str(self._uuid),
295
+ "day": payload["day"],
296
+ "t": payload["t"],
297
+ "type": 1,
298
+ "speaker": payload["from"],
299
+ "content": payload["content"],
300
+ "created_at": int(datetime.now().timestamp() * 1000),
301
+ }]
302
+ asyncio.create_task(self.process_agent_chat_response(payload))
303
+ with open(self._avro_file["dialog"], "a+b") as f:
304
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
305
+
306
+ # Callback functions for MQTT message
223
307
  async def handle_agent_chat_message(self, payload: dict):
224
308
  """处理收到的消息,识别发送者"""
225
309
  # 从消息中解析发送者 ID 和消息内容
226
- print(
227
- f"Agent {self._uuid} received agent chat message: '{payload['content']}' from Agent {payload['from']}"
228
- )
310
+ logging.info(f"Agent {self._uuid} received agent chat message: {payload}")
311
+ asyncio.create_task(self._process_agent_chat(payload))
229
312
 
230
313
  async def handle_user_chat_message(self, payload: dict):
231
314
  """处理收到的消息,识别发送者"""
232
315
  # 从消息中解析发送者 ID 和消息内容
233
- print(
234
- f"Agent {self._uuid} received user chat message: '{payload['content']}' from User"
235
- )
316
+ logging.info(f"Agent {self._uuid} received user chat message: {payload}")
317
+ asyncio.create_task(self._process_interview(payload))
236
318
 
237
319
  async def handle_user_survey_message(self, payload: dict):
238
320
  """处理收到的消息,识别发送者"""
239
321
  # 从消息中解析发送者 ID 和消息内容
240
- print(
241
- f"Agent {self._uuid} received user survey message: '{payload['content']}' from User"
242
- )
322
+ logging.info(f"Agent {self._uuid} received user survey message: {payload}")
323
+ asyncio.create_task(self._process_survey(payload["data"]))
324
+
325
+ async def handle_gather_message(self, payload: str):
326
+ raise NotImplementedError
243
327
 
328
+ # MQTT send message
244
329
  async def _send_message(
245
330
  self, to_agent_uuid: UUID, payload: dict, sub_topic: str
246
331
  ):
@@ -251,7 +336,7 @@ class Agent(ABC):
251
336
  await self._messager.send_message(topic, payload)
252
337
 
253
338
  async def send_message_to_agent(
254
- self, to_agent_uuid: UUID, content: dict
339
+ self, to_agent_uuid: UUID, content: str
255
340
  ):
256
341
  """通过 Messager 发送消息"""
257
342
  if self._messager is None:
@@ -259,29 +344,28 @@ class Agent(ABC):
259
344
  payload = {
260
345
  "from": self._uuid,
261
346
  "content": content,
262
- "timestamp": int(time.time()),
347
+ "timestamp": int(datetime.now().timestamp() * 1000),
263
348
  "day": await self._simulator.get_simulator_day(),
264
349
  "t": await self._simulator.get_simulator_second_from_start_of_day(),
265
350
  }
266
351
  await self._send_message(to_agent_uuid, payload, "agent-chat")
352
+ auros = [{
353
+ "id": str(self._uuid),
354
+ "day": await self._simulator.get_simulator_day(),
355
+ "t": await self._simulator.get_simulator_second_from_start_of_day(),
356
+ "type": 1,
357
+ "speaker": str(self._uuid),
358
+ "content": content,
359
+ "created_at": int(datetime.now().timestamp() * 1000),
360
+ }]
361
+ with open(self._avro_file["dialog"], "a+b") as f:
362
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
267
363
 
268
- async def send_message_to_user(
269
- self, content: dict
270
- ):
271
- pass
272
-
273
- async def send_message_to_survey(
274
- self, content: dict
275
- ):
276
- pass
277
-
364
+ # Agent logic
278
365
  @abstractmethod
279
366
  async def forward(self) -> None:
280
367
  """智能体行为逻辑"""
281
368
  raise NotImplementedError
282
-
283
- async def handle_gather_message(self, payload: str):
284
- raise NotImplementedError
285
369
 
286
370
  async def run(self) -> None:
287
371
  """
@@ -305,6 +389,7 @@ class CitizenAgent(Agent):
305
389
  memory: Optional[Memory] = None,
306
390
  economy_client: Optional[EconomyClient] = None,
307
391
  messager: Optional[Messager] = None,
392
+ avro_file: Optional[dict] = None,
308
393
  ) -> None:
309
394
  super().__init__(
310
395
  name,
@@ -314,6 +399,7 @@ class CitizenAgent(Agent):
314
399
  messager,
315
400
  simulator,
316
401
  memory,
402
+ avro_file,
317
403
  )
318
404
 
319
405
  async def bind_to_simulator(self):
@@ -417,6 +503,7 @@ class InstitutionAgent(Agent):
417
503
  memory: Optional[Memory] = None,
418
504
  economy_client: Optional[EconomyClient] = None,
419
505
  messager: Optional[Messager] = None,
506
+ avro_file: Optional[dict] = None,
420
507
  ) -> None:
421
508
  super().__init__(
422
509
  name,
@@ -426,8 +513,11 @@ class InstitutionAgent(Agent):
426
513
  messager,
427
514
  simulator,
428
515
  memory,
516
+ avro_file,
429
517
  )
430
-
518
+ # 添加响应收集器
519
+ self._gather_responses: Dict[str, asyncio.Future] = {}
520
+
431
521
  async def bind_to_simulator(self):
432
522
  await self._bind_to_economy()
433
523
 
@@ -514,18 +604,47 @@ class InstitutionAgent(Agent):
514
604
 
515
605
  async def handle_gather_message(self, payload: dict):
516
606
  """处理收到的消息,识别发送者"""
517
- # 从消息中解析发送者 ID 和消息内容
518
607
  content = payload["content"]
519
608
  sender_id = payload["from"]
520
- print(
521
- f"Agent {self._uuid} received gather message: '{content}' from Agent {sender_id}"
522
- )
523
-
524
- async def gather_messages(self, agent_ids: list[UUID], target: str):
525
- """从多个智能体收集消息"""
609
+
610
+ # 将响应存储到对应的Future中
611
+ response_key = str(sender_id)
612
+ if response_key in self._gather_responses:
613
+ self._gather_responses[response_key].set_result({
614
+ "from": sender_id,
615
+ "content": content,
616
+ })
617
+
618
+ async def gather_messages(self, agent_ids: list[UUID], target: str) -> List[dict]:
619
+ """从多个智能体收集消息
620
+
621
+ Args:
622
+ agent_ids: 目标智能体ID列表
623
+ target: 要收集的信息类型
624
+
625
+ Returns:
626
+ List[dict]: 收集到的所有响应
627
+ """
628
+ # 为每个agent创建Future
629
+ futures = {}
630
+ for agent_id in agent_ids:
631
+ response_key = str(agent_id)
632
+ futures[response_key] = asyncio.Future()
633
+ self._gather_responses[response_key] = futures[response_key]
634
+
635
+ # 发送gather请求
526
636
  payload = {
527
637
  "from": self._uuid,
528
638
  "target": target,
529
639
  }
530
640
  for agent_id in agent_ids:
531
641
  await self._send_message(agent_id, payload, "gather")
642
+
643
+ try:
644
+ # 等待所有响应
645
+ responses = await asyncio.gather(*futures.values())
646
+ return responses
647
+ finally:
648
+ # 清理Future
649
+ for key in futures:
650
+ self._gather_responses.pop(key, None)
@@ -8,6 +8,8 @@ import pycityproto.city.economy.v2.org_service_pb2 as org_service
8
8
  import pycityproto.city.economy.v2.org_service_pb2_grpc as org_grpc
9
9
  from google.protobuf import descriptor
10
10
 
11
+ economyv2.ORG_TYPE_BANK
12
+
11
13
  __all__ = [
12
14
  "EconomyClient",
13
15
  ]
@@ -1,6 +1,7 @@
1
1
  from pycityproto.city.person.v2.motion_pb2 import Status
2
2
 
3
3
  PROFILE_ATTRIBUTES = {
4
+ "name": str(),
4
5
  "gender": str(),
5
6
  "age": float(),
6
7
  "education": str(),
@@ -1,23 +1,27 @@
1
1
  import asyncio
2
+ from datetime import datetime
2
3
  import json
3
4
  import logging
4
5
  import uuid
6
+ import fastavro
5
7
  import ray
6
8
  from uuid import UUID
7
- from pycityagent.agent import Agent
9
+ from pycityagent.agent import Agent, CitizenAgent
8
10
  from pycityagent.economy.econ_client import EconomyClient
9
11
  from pycityagent.environment.simulator import Simulator
10
12
  from pycityagent.llm.llm import LLM
11
13
  from pycityagent.llm.llmconfig import LLMConfig
12
14
  from pycityagent.message import Messager
15
+ from pycityagent.utils import STATUS_SCHEMA
13
16
  from typing import Any
14
17
 
15
18
  @ray.remote
16
19
  class AgentGroup:
17
- def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID):
20
+ def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, avro_file: dict):
18
21
  self.agents = agents
19
22
  self.config = config
20
23
  self.exp_id = exp_id
24
+ self.avro_file = avro_file
21
25
  self._uuid = uuid.uuid4()
22
26
  self.messager = Messager(
23
27
  hostname=config["simulator_request"]["mqtt"]["server"],
@@ -71,7 +75,8 @@ class AgentGroup:
71
75
  topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
72
76
  await self.messager.subscribe(topic, agent)
73
77
  self.initialized = True
74
-
78
+ self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
79
+
75
80
  async def gather(self, content: str):
76
81
  results = {}
77
82
  for agent in self.agents:
@@ -82,51 +87,82 @@ class AgentGroup:
82
87
  agent = self.id2agent[target_agent_uuid]
83
88
  await agent.memory.update(target_key, content)
84
89
 
90
+ async def message_dispatch(self):
91
+ while True:
92
+ if not self.messager.is_connected():
93
+ logging.warning("Messager is not connected. Skipping message processing.")
94
+
95
+ # Step 1: 获取消息
96
+ messages = await self.messager.fetch_messages()
97
+ logging.info(f"Group {self._uuid} received {len(messages)} messages")
98
+
99
+ # Step 2: 分发消息到对应的 Agent
100
+ for message in messages:
101
+ topic = message.topic.value
102
+ payload = message.payload
103
+
104
+ # 添加解码步骤,将bytes转换为str
105
+ if isinstance(payload, bytes):
106
+ payload = payload.decode("utf-8")
107
+ payload = json.loads(payload)
108
+
109
+ # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
110
+ _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
111
+
112
+ if uuid.UUID(agent_uuid) in self.id2agent:
113
+ agent = self.id2agent[uuid.UUID(agent_uuid)]
114
+ # topic_type: agent-chat, user-chat, user-survey, gather
115
+ if topic_type == "agent-chat":
116
+ await agent.handle_agent_chat_message(payload)
117
+ elif topic_type == "user-chat":
118
+ await agent.handle_user_chat_message(payload)
119
+ elif topic_type == "user-survey":
120
+ await agent.handle_user_survey_message(payload)
121
+ elif topic_type == "gather":
122
+ await agent.handle_gather_message(payload)
123
+
124
+ await asyncio.sleep(0.5)
125
+
85
126
  async def step(self):
86
127
  if not self.initialized:
87
128
  await self.init_agents()
88
129
 
89
- # Step 1: 如果 Messager 无法连接,则跳过消息接收
90
- if not self.messager.is_connected():
91
- logging.warning("Messager is not connected. Skipping message processing.")
92
- # 跳过接收和分发消息
93
- tasks = [agent.run() for agent in self.agents]
94
- await asyncio.gather(*tasks)
95
- return
96
-
97
- # Step 2: 从 Messager 获取消息
98
- messages = await self.messager.fetch_messages()
99
-
100
- logging.info(f"Group {self._uuid} received {len(messages)} messages")
101
-
102
- # Step 3: 分发消息到对应的 Agent
103
- for message in messages:
104
- topic = message.topic.value
105
- payload = message.payload
106
-
107
- # 添加解码步骤,将bytes转换为str
108
- if isinstance(payload, bytes):
109
- payload = payload.decode("utf-8")
110
- payload = json.loads(payload)
111
-
112
- # 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
113
- _, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
114
-
115
- if uuid.UUID(agent_uuid) in self.id2agent:
116
- agent = self.id2agent[uuid.UUID(agent_uuid)]
117
- # topic_type: agent-chat, user-chat, user-survey, gather
118
- if topic_type == "agent-chat":
119
- await agent.handle_agent_chat_message(payload)
120
- elif topic_type == "user-chat":
121
- await agent.handle_user_chat_message(payload)
122
- elif topic_type == "user-survey":
123
- await agent.handle_user_survey_message(payload)
124
- elif topic_type == "gather":
125
- await agent.handle_gather_message(payload)
126
-
127
- # Step 4: 调用每个 Agent 的运行逻辑
128
130
  tasks = [agent.run() for agent in self.agents]
129
131
  await asyncio.gather(*tasks)
132
+ avros = []
133
+ for agent in self.agents:
134
+ if not issubclass(type(agent), CitizenAgent):
135
+ continue
136
+ position = await agent.memory.get("position")
137
+ lng = position["longlat_position"]["longitude"]
138
+ lat = position["longlat_position"]["latitude"]
139
+ if "aoi_position" in position:
140
+ parent_id = position["aoi_position"]["aoi_id"]
141
+ elif "lane_position" in position:
142
+ parent_id = position["lane_position"]["lane_id"]
143
+ else:
144
+ # BUG: 需要处理
145
+ parent_id = -1
146
+ needs = await agent.memory.get("needs")
147
+ action = await agent.memory.get("current_step")
148
+ action = action["intention"]
149
+ avro = {
150
+ "id": str(agent._uuid), # uuid as string
151
+ "day": await self.simulator.get_simulator_day(),
152
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
153
+ "lng": lng,
154
+ "lat": lat,
155
+ "parent_id": parent_id,
156
+ "action": action,
157
+ "hungry": needs["hungry"],
158
+ "tired": needs["tired"],
159
+ "safe": needs["safe"],
160
+ "social": needs["social"],
161
+ "created_at": int(datetime.now().timestamp() * 1000),
162
+ }
163
+ avros.append(avro)
164
+ with open(self.avro_file["status"], "a+b") as f:
165
+ fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
130
166
 
131
167
  async def run(self, day: int = 1):
132
168
  """运行模拟器