pycityagent 2.0.0a43__cp39-cp39-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. pycityagent/__init__.py +23 -0
  2. pycityagent/agent.py +833 -0
  3. pycityagent/cli/wrapper.py +44 -0
  4. pycityagent/economy/__init__.py +5 -0
  5. pycityagent/economy/econ_client.py +355 -0
  6. pycityagent/environment/__init__.py +7 -0
  7. pycityagent/environment/interact/__init__.py +0 -0
  8. pycityagent/environment/interact/interact.py +198 -0
  9. pycityagent/environment/message/__init__.py +0 -0
  10. pycityagent/environment/sence/__init__.py +0 -0
  11. pycityagent/environment/sence/static.py +416 -0
  12. pycityagent/environment/sidecar/__init__.py +8 -0
  13. pycityagent/environment/sidecar/sidecarv2.py +109 -0
  14. pycityagent/environment/sim/__init__.py +29 -0
  15. pycityagent/environment/sim/aoi_service.py +39 -0
  16. pycityagent/environment/sim/client.py +126 -0
  17. pycityagent/environment/sim/clock_service.py +44 -0
  18. pycityagent/environment/sim/economy_services.py +192 -0
  19. pycityagent/environment/sim/lane_service.py +111 -0
  20. pycityagent/environment/sim/light_service.py +122 -0
  21. pycityagent/environment/sim/person_service.py +295 -0
  22. pycityagent/environment/sim/road_service.py +39 -0
  23. pycityagent/environment/sim/sim_env.py +145 -0
  24. pycityagent/environment/sim/social_service.py +59 -0
  25. pycityagent/environment/simulator.py +331 -0
  26. pycityagent/environment/utils/__init__.py +14 -0
  27. pycityagent/environment/utils/base64.py +16 -0
  28. pycityagent/environment/utils/const.py +244 -0
  29. pycityagent/environment/utils/geojson.py +24 -0
  30. pycityagent/environment/utils/grpc.py +57 -0
  31. pycityagent/environment/utils/map_utils.py +157 -0
  32. pycityagent/environment/utils/port.py +11 -0
  33. pycityagent/environment/utils/protobuf.py +41 -0
  34. pycityagent/llm/__init__.py +11 -0
  35. pycityagent/llm/embeddings.py +231 -0
  36. pycityagent/llm/llm.py +377 -0
  37. pycityagent/llm/llmconfig.py +13 -0
  38. pycityagent/llm/utils.py +6 -0
  39. pycityagent/memory/__init__.py +13 -0
  40. pycityagent/memory/const.py +43 -0
  41. pycityagent/memory/faiss_query.py +302 -0
  42. pycityagent/memory/memory.py +448 -0
  43. pycityagent/memory/memory_base.py +170 -0
  44. pycityagent/memory/profile.py +165 -0
  45. pycityagent/memory/self_define.py +165 -0
  46. pycityagent/memory/state.py +173 -0
  47. pycityagent/memory/utils.py +28 -0
  48. pycityagent/message/__init__.py +3 -0
  49. pycityagent/message/messager.py +88 -0
  50. pycityagent/metrics/__init__.py +6 -0
  51. pycityagent/metrics/mlflow_client.py +147 -0
  52. pycityagent/metrics/utils/const.py +0 -0
  53. pycityagent/pycityagent-sim +0 -0
  54. pycityagent/pycityagent-ui +0 -0
  55. pycityagent/simulation/__init__.py +8 -0
  56. pycityagent/simulation/agentgroup.py +580 -0
  57. pycityagent/simulation/simulation.py +634 -0
  58. pycityagent/simulation/storage/pg.py +184 -0
  59. pycityagent/survey/__init__.py +4 -0
  60. pycityagent/survey/manager.py +54 -0
  61. pycityagent/survey/models.py +120 -0
  62. pycityagent/utils/__init__.py +11 -0
  63. pycityagent/utils/avro_schema.py +109 -0
  64. pycityagent/utils/decorators.py +99 -0
  65. pycityagent/utils/parsers/__init__.py +13 -0
  66. pycityagent/utils/parsers/code_block_parser.py +37 -0
  67. pycityagent/utils/parsers/json_parser.py +86 -0
  68. pycityagent/utils/parsers/parser_base.py +60 -0
  69. pycityagent/utils/pg_query.py +92 -0
  70. pycityagent/utils/survey_util.py +53 -0
  71. pycityagent/workflow/__init__.py +26 -0
  72. pycityagent/workflow/block.py +211 -0
  73. pycityagent/workflow/prompt.py +79 -0
  74. pycityagent/workflow/tool.py +240 -0
  75. pycityagent/workflow/trigger.py +163 -0
  76. pycityagent-2.0.0a43.dist-info/LICENSE +21 -0
  77. pycityagent-2.0.0a43.dist-info/METADATA +235 -0
  78. pycityagent-2.0.0a43.dist-info/RECORD +81 -0
  79. pycityagent-2.0.0a43.dist-info/WHEEL +5 -0
  80. pycityagent-2.0.0a43.dist-info/entry_points.txt +3 -0
  81. pycityagent-2.0.0a43.dist-info/top_level.txt +3 -0
pycityagent/agent.py ADDED
@@ -0,0 +1,833 @@
1
+ """智能体模板类及其定义"""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import random
7
+ import uuid
8
+ from abc import ABC, abstractmethod
9
+ from copy import deepcopy
10
+ from datetime import datetime, timezone
11
+ from enum import Enum
12
+ from typing import Any, Optional
13
+ from uuid import UUID
14
+
15
+ import fastavro
16
+ import ray
17
+ from mosstool.util.format_converter import dict2pb
18
+ from pycityproto.city.person.v2 import person_pb2 as person_pb2
19
+
20
+ from .economy import EconomyClient
21
+ from .environment import Simulator
22
+ from .environment.sim.person_service import PersonService
23
+ from .llm import LLM
24
+ from .memory import Memory
25
+ from .message.messager import Messager
26
+ from .metrics import MlflowClient
27
+ from .utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
28
+
29
+ logger = logging.getLogger("pycityagent")
30
+
31
+
32
+ class AgentType(Enum):
33
+ """
34
+ Agent类型
35
+
36
+ - Citizen, Citizen type agent
37
+ - Institution, Orgnization or institution type agent
38
+ """
39
+
40
+ Unspecified = "Unspecified"
41
+ Citizen = "Citizen"
42
+ Institution = "Institution"
43
+
44
+
45
+ class Agent(ABC):
46
+ """
47
+ Agent base class
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ name: str,
53
+ type: AgentType = AgentType.Unspecified,
54
+ llm_client: Optional[LLM] = None,
55
+ economy_client: Optional[EconomyClient] = None,
56
+ messager: Optional[Messager] = None,
57
+ simulator: Optional[Simulator] = None,
58
+ mlflow_client: Optional[MlflowClient] = None,
59
+ memory: Optional[Memory] = None,
60
+ avro_file: Optional[dict[str, str]] = None,
61
+ copy_writer: Optional[ray.ObjectRef] = None,
62
+ ) -> None:
63
+ """
64
+ Initialize the Agent.
65
+
66
+ Args:
67
+ name (str): The name of the agent.
68
+ type (AgentType): The type of the agent. Defaults to `AgentType.Unspecified`
69
+ llm_client (LLM): The language model client. Defaults to None.
70
+ economy_client (EconomyClient): The `EconomySim` client. Defaults to None.
71
+ messager (Messager, optional): The messager object. Defaults to None.
72
+ simulator (Simulator, optional): The simulator object. Defaults to None.
73
+ mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
74
+ memory (Memory, optional): The memory of the agent. Defaults to None.
75
+ avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
76
+ copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
77
+ """
78
+ self._name = name
79
+ self._type = type
80
+ self._uuid = str(uuid.uuid4())
81
+ self._llm_client = llm_client
82
+ self._economy_client = economy_client
83
+ self._messager = messager
84
+ self._simulator = simulator
85
+ self._mlflow_client = mlflow_client
86
+ self._memory = memory
87
+ self._exp_id = -1
88
+ self._agent_id = -1
89
+ self._has_bound_to_simulator = False
90
+ self._has_bound_to_economy = False
91
+ self._blocked = False
92
+ self._interview_history: list[dict] = [] # 存储采访历史
93
+ self._person_template = PersonService.default_dict_person()
94
+ self._avro_file = avro_file
95
+ self._pgsql_writer = copy_writer
96
+ self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
97
+
98
+ def __getstate__(self):
99
+ state = self.__dict__.copy()
100
+ # 排除锁对象
101
+ del state["_llm_client"]
102
+ return state
103
+
104
+ def set_messager(self, messager: Messager):
105
+ """
106
+ Set the messager of the agent.
107
+ """
108
+ self._messager = messager
109
+
110
+ def set_llm_client(self, llm_client: LLM):
111
+ """
112
+ Set the llm_client of the agent.
113
+ """
114
+ self._llm_client = llm_client
115
+
116
+ def set_simulator(self, simulator: Simulator):
117
+ """
118
+ Set the simulator of the agent.
119
+ """
120
+ self._simulator = simulator
121
+
122
+ def set_mlflow_client(self, mlflow_client: MlflowClient):
123
+ """
124
+ Set the mlflow_client of the agent.
125
+ """
126
+ self._mlflow_client = mlflow_client
127
+
128
+ def set_economy_client(self, economy_client: EconomyClient):
129
+ """
130
+ Set the economy_client of the agent.
131
+ """
132
+ self._economy_client = economy_client
133
+
134
+ def set_memory(self, memory: Memory):
135
+ """
136
+ Set the memory of the agent.
137
+ """
138
+ self._memory = memory
139
+
140
+ def set_exp_id(self, exp_id: str):
141
+ """
142
+ Set the exp_id of the agent.
143
+ """
144
+ self._exp_id = exp_id
145
+
146
+ def set_avro_file(self, avro_file: dict[str, str]):
147
+ """
148
+ Set the avro file of the agent.
149
+ """
150
+ self._avro_file = avro_file
151
+
152
+ def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
153
+ """
154
+ Set the PostgreSQL copy writer of the agent.
155
+ """
156
+ self._pgsql_writer = pgsql_writer
157
+
158
+ @property
159
+ def uuid(self):
160
+ """The Agent's UUID"""
161
+ return self._uuid
162
+
163
+ @property
164
+ def sim_id(self):
165
+ """The Agent's Simulator ID"""
166
+ return self._agent_id
167
+
168
+ @property
169
+ def llm(self):
170
+ """The Agent's LLM"""
171
+ if self._llm_client is None:
172
+ raise RuntimeError(
173
+ f"LLM access before assignment, please `set_llm_client` first!"
174
+ )
175
+ return self._llm_client
176
+
177
+ @property
178
+ def economy_client(self):
179
+ """The Agent's EconomyClient"""
180
+ if self._economy_client is None:
181
+ raise RuntimeError(
182
+ f"EconomyClient access before assignment, please `set_economy_client` first!"
183
+ )
184
+ return self._economy_client
185
+
186
+ @property
187
+ def mlflow_client(self):
188
+ """The Agent's MlflowClient"""
189
+ if self._mlflow_client is None:
190
+ raise RuntimeError(
191
+ f"MlflowClient access before assignment, please `set_mlflow_client` first!"
192
+ )
193
+ return self._mlflow_client
194
+
195
+ @property
196
+ def memory(self):
197
+ """The Agent's Memory"""
198
+ if self._memory is None:
199
+ raise RuntimeError(
200
+ f"Memory access before assignment, please `set_memory` first!"
201
+ )
202
+ return self._memory
203
+
204
+ @property
205
+ def simulator(self):
206
+ """The Simulator"""
207
+ if self._simulator is None:
208
+ raise RuntimeError(
209
+ f"Simulator access before assignment, please `set_simulator` first!"
210
+ )
211
+ return self._simulator
212
+
213
+ @property
214
+ def copy_writer(self):
215
+ """Pg Copy Writer"""
216
+ if self._pgsql_writer is None:
217
+ raise RuntimeError(
218
+ f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
219
+ )
220
+ return self._pgsql_writer
221
+
222
+ async def generate_user_survey_response(self, survey: dict) -> str:
223
+ """生成回答 —— 可重写
224
+ 基于智能体的记忆和当前状态,生成对问卷调查的回答。
225
+ Args:
226
+ survey: 需要回答的问卷 dict
227
+ Returns:
228
+ str: 智能体的回答
229
+ """
230
+ survey_prompt = process_survey_for_llm(survey)
231
+ dialog = []
232
+
233
+ # 添加系统提示
234
+ system_prompt = "Please answer the survey question in first person. Follow the format requirements strictly and provide clear and specific answers."
235
+ dialog.append({"role": "system", "content": system_prompt})
236
+
237
+ # 添加记忆上下文
238
+ if self._memory:
239
+ relevant_memories = await self.memory.search(survey_prompt)
240
+
241
+ formatted_results = []
242
+ # for result in top_results:
243
+ # formatted_results.append(
244
+ # f"- [{result['type']}] {result['content']} "
245
+ # f"(相关度: {result['similarity']:.2f})"
246
+ # )
247
+
248
+ if relevant_memories:
249
+ dialog.append(
250
+ {
251
+ "role": "system",
252
+ "content": f"Answer based on these memories:\n{relevant_memories}",
253
+ }
254
+ )
255
+
256
+ # 添加问卷问题
257
+ dialog.append({"role": "user", "content": survey_prompt})
258
+
259
+ # 使用LLM生成回答
260
+ if not self._llm_client:
261
+ return "Sorry, I cannot answer survey questions right now."
262
+
263
+ response = await self._llm_client.atext_request(dialog) # type:ignore
264
+
265
+ return response # type:ignore
266
+
267
+ async def _process_survey(self, survey: dict):
268
+ survey_response = await self.generate_user_survey_response(survey)
269
+ _date_time = datetime.now(timezone.utc)
270
+ # Avro
271
+ response_to_avro = [
272
+ {
273
+ "id": self._uuid,
274
+ "day": await self.simulator.get_simulator_day(),
275
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
276
+ "survey_id": survey["id"],
277
+ "result": survey_response,
278
+ "created_at": int(_date_time.timestamp() * 1000),
279
+ }
280
+ ]
281
+ if self._avro_file is not None:
282
+ with open(self._avro_file["survey"], "a+b") as f:
283
+ fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
284
+ # Pg
285
+ if self._pgsql_writer is not None:
286
+ if self._last_asyncio_pg_task is not None:
287
+ await self._last_asyncio_pg_task
288
+ _keys = [
289
+ "id",
290
+ "day",
291
+ "t",
292
+ "survey_id",
293
+ "result",
294
+ ]
295
+ _data_tuples: list[tuple] = []
296
+ # str to json
297
+ for _dict in response_to_avro:
298
+ res = _dict["result"]
299
+ _dict["result"] = json.dumps(
300
+ {
301
+ "result": res,
302
+ }
303
+ )
304
+ _data_list = [_dict[k] for k in _keys]
305
+ # created_at
306
+ _data_list.append(_date_time)
307
+ _data_tuples.append(tuple(_data_list))
308
+ self._last_asyncio_pg_task = (
309
+ self._pgsql_writer.async_write_survey.remote( # type:ignore
310
+ _data_tuples
311
+ )
312
+ )
313
+
314
+ async def generate_user_chat_response(self, question: str) -> str:
315
+ """生成回答 —— 可重写
316
+ 基于智能体的记忆和当前状态,生成对问题的回答。
317
+ Args:
318
+ question: 需要回答的问题
319
+
320
+ Returns:
321
+ str: 智能体的回答
322
+ """
323
+ dialog = []
324
+
325
+ # 添加系统提示
326
+ system_prompt = "Please answer the question in first person and keep the response concise and clear."
327
+ dialog.append({"role": "system", "content": system_prompt})
328
+
329
+ # 添加记忆上下文
330
+ if self._memory:
331
+ relevant_memories = await self._memory.search(question)
332
+ if relevant_memories:
333
+ dialog.append(
334
+ {
335
+ "role": "system",
336
+ "content": f"Answer based on these memories:\n{relevant_memories}",
337
+ }
338
+ )
339
+
340
+ # 添加用户问题
341
+ dialog.append({"role": "user", "content": question})
342
+
343
+ # 使用LLM生成回答
344
+ if not self._llm_client:
345
+ return "Sorry, I cannot answer questions right now."
346
+
347
+ response = await self._llm_client.atext_request(dialog) # type:ignore
348
+
349
+ return response # type:ignore
350
+
351
+ async def _process_interview(self, payload: dict):
352
+ pg_list: list[tuple[dict, datetime]] = []
353
+ auros: list[dict] = []
354
+ _date_time = datetime.now(timezone.utc)
355
+ _interview_dict = {
356
+ "id": self._uuid,
357
+ "day": await self.simulator.get_simulator_day(),
358
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
359
+ "type": 2,
360
+ "speaker": "user",
361
+ "content": payload["content"],
362
+ "created_at": int(_date_time.timestamp() * 1000),
363
+ }
364
+ auros.append(_interview_dict)
365
+ pg_list.append((_interview_dict, _date_time))
366
+ question = payload["content"]
367
+ response = await self.generate_user_chat_response(question)
368
+ _date_time = datetime.now(timezone.utc)
369
+ _interview_dict = {
370
+ "id": self._uuid,
371
+ "day": await self.simulator.get_simulator_day(),
372
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
373
+ "type": 2,
374
+ "speaker": "",
375
+ "content": response,
376
+ "created_at": int(_date_time.timestamp() * 1000),
377
+ }
378
+ auros.append(_interview_dict)
379
+ pg_list.append((_interview_dict, _date_time))
380
+ # Avro
381
+ if self._avro_file is not None:
382
+ with open(self._avro_file["dialog"], "a+b") as f:
383
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
384
+ # Pg
385
+ if self._pgsql_writer is not None:
386
+ if self._last_asyncio_pg_task is not None:
387
+ await self._last_asyncio_pg_task
388
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
389
+ _data = [
390
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
391
+ for _dict, _date_time in pg_list
392
+ ]
393
+ self._last_asyncio_pg_task = (
394
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
395
+ _data
396
+ )
397
+ )
398
+
399
+ async def process_agent_chat_response(self, payload: dict) -> str:
400
+ resp = f"Agent {self._uuid} received agent chat response: {payload}"
401
+ logger.info(resp)
402
+ return resp
403
+
404
+ async def _process_agent_chat(self, payload: dict):
405
+ pg_list: list[tuple[dict, datetime]] = []
406
+ auros: list[dict] = []
407
+ _date_time = datetime.now(timezone.utc)
408
+ _chat_dict = {
409
+ "id": self._uuid,
410
+ "day": payload["day"],
411
+ "t": payload["t"],
412
+ "type": 1,
413
+ "speaker": payload["from"],
414
+ "content": payload["content"],
415
+ "created_at": int(_date_time.timestamp() * 1000),
416
+ }
417
+ auros.append(_chat_dict)
418
+ pg_list.append((_chat_dict, _date_time))
419
+ asyncio.create_task(self.process_agent_chat_response(payload))
420
+ # Avro
421
+ if self._avro_file is not None:
422
+ with open(self._avro_file["dialog"], "a+b") as f:
423
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
424
+ # Pg
425
+ if self._pgsql_writer is not None:
426
+ if self._last_asyncio_pg_task is not None:
427
+ await self._last_asyncio_pg_task
428
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
429
+ _data = [
430
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
431
+ for _dict, _date_time in pg_list
432
+ ]
433
+ self._last_asyncio_pg_task = (
434
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
435
+ _data
436
+ )
437
+ )
438
+
439
+ # Callback functions for MQTT message
440
+ async def handle_agent_chat_message(self, payload: dict):
441
+ """处理收到的消息,识别发送者"""
442
+ # 从消息中解析发送者 ID 和消息内容
443
+ logger.info(f"Agent {self._uuid} received agent chat message: {payload}")
444
+ asyncio.create_task(self._process_agent_chat(payload))
445
+
446
+ async def handle_user_chat_message(self, payload: dict):
447
+ """处理收到的消息,识别发送者"""
448
+ # 从消息中解析发送者 ID 和消息内容
449
+ logger.info(f"Agent {self._uuid} received user chat message: {payload}")
450
+ asyncio.create_task(self._process_interview(payload))
451
+
452
+ async def handle_user_survey_message(self, payload: dict):
453
+ """处理收到的消息,识别发送者"""
454
+ # 从消息中解析发送者 ID 和消息内容
455
+ logger.info(f"Agent {self._uuid} received user survey message: {payload}")
456
+ asyncio.create_task(self._process_survey(payload["data"]))
457
+
458
+ async def handle_gather_message(self, payload: Any):
459
+ raise NotImplementedError
460
+
461
+ # MQTT send message
462
+ async def _send_message(self, to_agent_uuid: str, payload: dict, sub_topic: str):
463
+ """通过 Messager 发送消息"""
464
+ if self._messager is None:
465
+ raise RuntimeError("Messager is not set")
466
+ topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
467
+ await self._messager.send_message.remote(topic, payload)
468
+
469
+ async def send_message_to_agent(
470
+ self, to_agent_uuid: str, content: str, type: str = "social"
471
+ ):
472
+ """通过 Messager 发送消息"""
473
+ if self._messager is None:
474
+ raise RuntimeError("Messager is not set")
475
+ if type not in ["social", "economy"]:
476
+ logger.warning(f"Invalid message type: {type}, sent from {self._uuid}")
477
+ payload = {
478
+ "from": self._uuid,
479
+ "content": content,
480
+ "type": type,
481
+ "timestamp": int(datetime.now().timestamp() * 1000),
482
+ "day": await self.simulator.get_simulator_day(),
483
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
484
+ }
485
+ await self._send_message(to_agent_uuid, payload, "agent-chat")
486
+ pg_list: list[tuple[dict, datetime]] = []
487
+ auros: list[dict] = []
488
+ _date_time = datetime.now(timezone.utc)
489
+ _message_dict = {
490
+ "id": self._uuid,
491
+ "day": await self.simulator.get_simulator_day(),
492
+ "t": await self.simulator.get_simulator_second_from_start_of_day(),
493
+ "type": 1,
494
+ "speaker": self._uuid,
495
+ "content": content,
496
+ "created_at": int(datetime.now().timestamp() * 1000),
497
+ }
498
+ auros.append(_message_dict)
499
+ pg_list.append((_message_dict, _date_time))
500
+ # Avro
501
+ if self._avro_file is not None and type == "social":
502
+ with open(self._avro_file["dialog"], "a+b") as f:
503
+ fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
504
+ # Pg
505
+ if self._pgsql_writer is not None and type == "social":
506
+ if self._last_asyncio_pg_task is not None:
507
+ await self._last_asyncio_pg_task
508
+ _keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
509
+ _data = [
510
+ tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
511
+ for _dict, _date_time in pg_list
512
+ ]
513
+ self._last_asyncio_pg_task = (
514
+ self._pgsql_writer.async_write_dialog.remote( # type:ignore
515
+ _data
516
+ )
517
+ )
518
+
519
+ # Agent logic
520
+ @abstractmethod
521
+ async def forward(self) -> None:
522
+ """智能体行为逻辑"""
523
+ raise NotImplementedError
524
+
525
+ async def run(self) -> None:
526
+ """
527
+ 统一的Agent执行入口
528
+ 当_blocked为True时,不执行forward方法
529
+ """
530
+ if not self._blocked:
531
+ await self.forward()
532
+
533
+
534
+ class CitizenAgent(Agent):
535
+ """
536
+ CitizenAgent: 城市居民智能体类及其定义
537
+ """
538
+
539
+ def __init__(
540
+ self,
541
+ name: str,
542
+ llm_client: Optional[LLM] = None,
543
+ simulator: Optional[Simulator] = None,
544
+ mlflow_client: Optional[MlflowClient] = None,
545
+ memory: Optional[Memory] = None,
546
+ economy_client: Optional[EconomyClient] = None,
547
+ messager: Optional[Messager] = None,
548
+ avro_file: Optional[dict] = None,
549
+ ) -> None:
550
+ super().__init__(
551
+ name=name,
552
+ type=AgentType.Citizen,
553
+ llm_client=llm_client,
554
+ economy_client=economy_client,
555
+ messager=messager,
556
+ simulator=simulator,
557
+ mlflow_client=mlflow_client,
558
+ memory=memory,
559
+ avro_file=avro_file,
560
+ )
561
+
562
+ async def bind_to_simulator(self):
563
+ await self._bind_to_simulator()
564
+ await self._bind_to_economy()
565
+
566
+ async def _bind_to_simulator(self):
567
+ """
568
+ Bind Agent to Simulator
569
+
570
+ Args:
571
+ person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
572
+ """
573
+ if self._simulator is None:
574
+ logger.warning("Simulator is not set")
575
+ return
576
+ if not self._has_bound_to_simulator:
577
+ FROM_MEMORY_KEYS = {
578
+ "attribute",
579
+ "home",
580
+ "work",
581
+ "vehicle_attribute",
582
+ "bus_attribute",
583
+ "pedestrian_attribute",
584
+ "bike_attribute",
585
+ }
586
+ simulator = self.simulator
587
+ memory = self.memory
588
+ person_id = await memory.get("id")
589
+ # ATTENTION:模拟器分配的id从0开始
590
+ if person_id >= 0:
591
+ await simulator.get_person(person_id)
592
+ logger.debug(f"Binding to Person `{person_id}` already in Simulator")
593
+ else:
594
+ dict_person = deepcopy(self._person_template)
595
+ for _key in FROM_MEMORY_KEYS:
596
+ try:
597
+ _value = await memory.get(_key)
598
+ if _value:
599
+ dict_person[_key] = _value
600
+ except KeyError as e:
601
+ continue
602
+ resp = await simulator.add_person(
603
+ dict2pb(dict_person, person_pb2.Person())
604
+ )
605
+ person_id = resp["person_id"]
606
+ await memory.update("id", person_id, protect_llm_read_only_fields=False)
607
+ logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
608
+ # 防止模拟器还没有到prepare阶段导致get_person出错
609
+ self._has_bound_to_simulator = True
610
+ self._agent_id = person_id
611
+ self.memory.set_agent_id(person_id)
612
+
613
+ async def _bind_to_economy(self):
614
+ if self._economy_client is None:
615
+ logger.warning("Economy client is not set")
616
+ return
617
+ if not self._has_bound_to_economy:
618
+ if self._has_bound_to_simulator:
619
+ try:
620
+ await self._economy_client.remove_agents([self._agent_id])
621
+ except:
622
+ pass
623
+ person_id = await self.memory.get("id")
624
+ await self._economy_client.add_agents(
625
+ {
626
+ "id": person_id,
627
+ "currency": await self.memory.get("currency"),
628
+ }
629
+ )
630
+ self._has_bound_to_economy = True
631
+ else:
632
+ logger.debug(
633
+ f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
634
+ )
635
+
636
+ async def handle_gather_message(self, payload: dict):
637
+ """处理收到的消息,识别发送者"""
638
+ # 从消息中解析发送者 ID 和消息内容
639
+ target = payload["target"]
640
+ sender_id = payload["from"]
641
+ content = await self.memory.get(f"{target}")
642
+ payload = {
643
+ "from": self._uuid,
644
+ "content": content,
645
+ }
646
+ await self._send_message(sender_id, payload, "gather")
647
+
648
+
649
+ class InstitutionAgent(Agent):
650
+ """
651
+ InstitutionAgent: 机构智能体类及其定义
652
+ """
653
+
654
+ def __init__(
655
+ self,
656
+ name: str,
657
+ llm_client: Optional[LLM] = None,
658
+ simulator: Optional[Simulator] = None,
659
+ mlflow_client: Optional[MlflowClient] = None,
660
+ memory: Optional[Memory] = None,
661
+ economy_client: Optional[EconomyClient] = None,
662
+ messager: Optional[Messager] = None,
663
+ avro_file: Optional[dict] = None,
664
+ ) -> None:
665
+ super().__init__(
666
+ name=name,
667
+ type=AgentType.Institution,
668
+ llm_client=llm_client,
669
+ economy_client=economy_client,
670
+ mlflow_client=mlflow_client,
671
+ messager=messager,
672
+ simulator=simulator,
673
+ memory=memory,
674
+ avro_file=avro_file,
675
+ )
676
+ # 添加响应收集器
677
+ self._gather_responses: dict[str, asyncio.Future] = {}
678
+
679
+ async def bind_to_simulator(self):
680
+ await self._bind_to_economy()
681
+
682
+ async def _bind_to_economy(self):
683
+ if self._economy_client is None:
684
+ logger.debug("Economy client is not set")
685
+ return
686
+ if not self._has_bound_to_economy:
687
+ # TODO: More general id generation
688
+ _id = random.randint(100000, 999999)
689
+ self._agent_id = _id
690
+ self.memory.set_agent_id(_id)
691
+ map_header = self.simulator.map.header
692
+ # TODO: remove random position assignment
693
+ await self.memory.update(
694
+ "position",
695
+ {
696
+ "xy_position": {
697
+ "x": float(
698
+ random.randrange(
699
+ start=int(map_header["west"]),
700
+ stop=int(map_header["east"]),
701
+ )
702
+ ),
703
+ "y": float(
704
+ random.randrange(
705
+ start=int(map_header["south"]),
706
+ stop=int(map_header["north"]),
707
+ )
708
+ ),
709
+ }
710
+ },
711
+ protect_llm_read_only_fields=False,
712
+ )
713
+ await self.memory.update("id", _id, protect_llm_read_only_fields=False)
714
+ try:
715
+ await self._economy_client.remove_orgs([self._agent_id])
716
+ except:
717
+ pass
718
+ try:
719
+ _memory = self.memory
720
+ _id = await _memory.get("id")
721
+ _type = await _memory.get("type")
722
+ try:
723
+ nominal_gdp = await _memory.get("nominal_gdp")
724
+ except:
725
+ nominal_gdp = []
726
+ try:
727
+ real_gdp = await _memory.get("real_gdp")
728
+ except:
729
+ real_gdp = []
730
+ try:
731
+ unemployment = await _memory.get("unemployment")
732
+ except:
733
+ unemployment = []
734
+ try:
735
+ wages = await _memory.get("wages")
736
+ except:
737
+ wages = []
738
+ try:
739
+ prices = await _memory.get("prices")
740
+ except:
741
+ prices = []
742
+ try:
743
+ inventory = await _memory.get("inventory")
744
+ except:
745
+ inventory = 0
746
+ try:
747
+ price = await _memory.get("price")
748
+ except:
749
+ price = 0
750
+ try:
751
+ currency = await _memory.get("currency")
752
+ except:
753
+ currency = 0.0
754
+ try:
755
+ interest_rate = await _memory.get("interest_rate")
756
+ except:
757
+ interest_rate = 0.0
758
+ try:
759
+ bracket_cutoffs = await _memory.get("bracket_cutoffs")
760
+ except:
761
+ bracket_cutoffs = []
762
+ try:
763
+ bracket_rates = await _memory.get("bracket_rates")
764
+ except:
765
+ bracket_rates = []
766
+ await self._economy_client.add_orgs(
767
+ {
768
+ "id": _id,
769
+ "type": _type,
770
+ "nominal_gdp": nominal_gdp,
771
+ "real_gdp": real_gdp,
772
+ "unemployment": unemployment,
773
+ "wages": wages,
774
+ "prices": prices,
775
+ "inventory": inventory,
776
+ "price": price,
777
+ "currency": currency,
778
+ "interest_rate": interest_rate,
779
+ "bracket_cutoffs": bracket_cutoffs,
780
+ "bracket_rates": bracket_rates,
781
+ }
782
+ )
783
+ except Exception as e:
784
+ logger.error(f"Failed to bind to Economy: {e}")
785
+ self._has_bound_to_economy = True
786
+
787
+ async def handle_gather_message(self, payload: dict):
788
+ """处理收到的消息,识别发送者"""
789
+ content = payload["content"]
790
+ sender_id = payload["from"]
791
+
792
+ # 将响应存储到对应的Future中
793
+ response_key = str(sender_id)
794
+ if response_key in self._gather_responses:
795
+ self._gather_responses[response_key].set_result(
796
+ {
797
+ "from": sender_id,
798
+ "content": content,
799
+ }
800
+ )
801
+
802
+ async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
803
+ """从多个智能体收集消息
804
+
805
+ Args:
806
+ agent_uuids: 目标智能体UUID列表
807
+ target: 要收集的信息类型
808
+
809
+ Returns:
810
+ list[dict]: 收集到的所有响应
811
+ """
812
+ # 为每个agent创建Future
813
+ futures = {}
814
+ for agent_uuid in agent_uuids:
815
+ futures[agent_uuid] = asyncio.Future()
816
+ self._gather_responses[agent_uuid] = futures[agent_uuid]
817
+
818
+ # 发送gather请求
819
+ payload = {
820
+ "from": self._uuid,
821
+ "target": target,
822
+ }
823
+ for agent_uuid in agent_uuids:
824
+ await self._send_message(agent_uuid, payload, "gather")
825
+
826
+ try:
827
+ # 等待所有响应
828
+ responses = await asyncio.gather(*futures.values())
829
+ return responses
830
+ finally:
831
+ # 清理Future
832
+ for key in futures:
833
+ self._gather_responses.pop(key, None)