pycityagent 2.0.0a42__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +23 -0
- pycityagent/agent.py +833 -0
- pycityagent/cli/wrapper.py +44 -0
- pycityagent/economy/__init__.py +5 -0
- pycityagent/economy/econ_client.py +355 -0
- pycityagent/environment/__init__.py +7 -0
- pycityagent/environment/interact/__init__.py +0 -0
- pycityagent/environment/interact/interact.py +198 -0
- pycityagent/environment/message/__init__.py +0 -0
- pycityagent/environment/sence/__init__.py +0 -0
- pycityagent/environment/sence/static.py +416 -0
- pycityagent/environment/sidecar/__init__.py +8 -0
- pycityagent/environment/sidecar/sidecarv2.py +109 -0
- pycityagent/environment/sim/__init__.py +29 -0
- pycityagent/environment/sim/aoi_service.py +39 -0
- pycityagent/environment/sim/client.py +126 -0
- pycityagent/environment/sim/clock_service.py +44 -0
- pycityagent/environment/sim/economy_services.py +192 -0
- pycityagent/environment/sim/lane_service.py +111 -0
- pycityagent/environment/sim/light_service.py +122 -0
- pycityagent/environment/sim/person_service.py +295 -0
- pycityagent/environment/sim/road_service.py +39 -0
- pycityagent/environment/sim/sim_env.py +145 -0
- pycityagent/environment/sim/social_service.py +59 -0
- pycityagent/environment/simulator.py +331 -0
- pycityagent/environment/utils/__init__.py +14 -0
- pycityagent/environment/utils/base64.py +16 -0
- pycityagent/environment/utils/const.py +244 -0
- pycityagent/environment/utils/geojson.py +24 -0
- pycityagent/environment/utils/grpc.py +57 -0
- pycityagent/environment/utils/map_utils.py +157 -0
- pycityagent/environment/utils/port.py +11 -0
- pycityagent/environment/utils/protobuf.py +41 -0
- pycityagent/llm/__init__.py +11 -0
- pycityagent/llm/embeddings.py +231 -0
- pycityagent/llm/llm.py +377 -0
- pycityagent/llm/llmconfig.py +13 -0
- pycityagent/llm/utils.py +6 -0
- pycityagent/memory/__init__.py +13 -0
- pycityagent/memory/const.py +43 -0
- pycityagent/memory/faiss_query.py +302 -0
- pycityagent/memory/memory.py +448 -0
- pycityagent/memory/memory_base.py +170 -0
- pycityagent/memory/profile.py +165 -0
- pycityagent/memory/self_define.py +165 -0
- pycityagent/memory/state.py +173 -0
- pycityagent/memory/utils.py +28 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +88 -0
- pycityagent/metrics/__init__.py +6 -0
- pycityagent/metrics/mlflow_client.py +147 -0
- pycityagent/metrics/utils/const.py +0 -0
- pycityagent/pycityagent-sim +0 -0
- pycityagent/pycityagent-ui +0 -0
- pycityagent/simulation/__init__.py +8 -0
- pycityagent/simulation/agentgroup.py +580 -0
- pycityagent/simulation/simulation.py +634 -0
- pycityagent/simulation/storage/pg.py +184 -0
- pycityagent/survey/__init__.py +4 -0
- pycityagent/survey/manager.py +54 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +11 -0
- pycityagent/utils/avro_schema.py +109 -0
- pycityagent/utils/decorators.py +99 -0
- pycityagent/utils/parsers/__init__.py +13 -0
- pycityagent/utils/parsers/code_block_parser.py +37 -0
- pycityagent/utils/parsers/json_parser.py +86 -0
- pycityagent/utils/parsers/parser_base.py +60 -0
- pycityagent/utils/pg_query.py +92 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/__init__.py +26 -0
- pycityagent/workflow/block.py +211 -0
- pycityagent/workflow/prompt.py +79 -0
- pycityagent/workflow/tool.py +240 -0
- pycityagent/workflow/trigger.py +163 -0
- pycityagent-2.0.0a42.dist-info/LICENSE +21 -0
- pycityagent-2.0.0a42.dist-info/METADATA +235 -0
- pycityagent-2.0.0a42.dist-info/RECORD +81 -0
- pycityagent-2.0.0a42.dist-info/WHEEL +5 -0
- pycityagent-2.0.0a42.dist-info/entry_points.txt +3 -0
- pycityagent-2.0.0a42.dist-info/top_level.txt +3 -0
pycityagent/agent.py
ADDED
@@ -0,0 +1,833 @@
|
|
1
|
+
"""智能体模板类及其定义"""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import random
|
7
|
+
import uuid
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
from copy import deepcopy
|
10
|
+
from datetime import datetime, timezone
|
11
|
+
from enum import Enum
|
12
|
+
from typing import Any, Optional
|
13
|
+
from uuid import UUID
|
14
|
+
|
15
|
+
import fastavro
|
16
|
+
import ray
|
17
|
+
from mosstool.util.format_converter import dict2pb
|
18
|
+
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
19
|
+
|
20
|
+
from .economy import EconomyClient
|
21
|
+
from .environment import Simulator
|
22
|
+
from .environment.sim.person_service import PersonService
|
23
|
+
from .llm import LLM
|
24
|
+
from .memory import Memory
|
25
|
+
from .message.messager import Messager
|
26
|
+
from .metrics import MlflowClient
|
27
|
+
from .utils import DIALOG_SCHEMA, SURVEY_SCHEMA, process_survey_for_llm
|
28
|
+
|
29
|
+
logger = logging.getLogger("pycityagent")
|
30
|
+
|
31
|
+
|
32
|
+
class AgentType(Enum):
|
33
|
+
"""
|
34
|
+
Agent类型
|
35
|
+
|
36
|
+
- Citizen, Citizen type agent
|
37
|
+
- Institution, Orgnization or institution type agent
|
38
|
+
"""
|
39
|
+
|
40
|
+
Unspecified = "Unspecified"
|
41
|
+
Citizen = "Citizen"
|
42
|
+
Institution = "Institution"
|
43
|
+
|
44
|
+
|
45
|
+
class Agent(ABC):
|
46
|
+
"""
|
47
|
+
Agent base class
|
48
|
+
"""
|
49
|
+
|
50
|
+
def __init__(
|
51
|
+
self,
|
52
|
+
name: str,
|
53
|
+
type: AgentType = AgentType.Unspecified,
|
54
|
+
llm_client: Optional[LLM] = None,
|
55
|
+
economy_client: Optional[EconomyClient] = None,
|
56
|
+
messager: Optional[Messager] = None,
|
57
|
+
simulator: Optional[Simulator] = None,
|
58
|
+
mlflow_client: Optional[MlflowClient] = None,
|
59
|
+
memory: Optional[Memory] = None,
|
60
|
+
avro_file: Optional[dict[str, str]] = None,
|
61
|
+
copy_writer: Optional[ray.ObjectRef] = None,
|
62
|
+
) -> None:
|
63
|
+
"""
|
64
|
+
Initialize the Agent.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
name (str): The name of the agent.
|
68
|
+
type (AgentType): The type of the agent. Defaults to `AgentType.Unspecified`
|
69
|
+
llm_client (LLM): The language model client. Defaults to None.
|
70
|
+
economy_client (EconomyClient): The `EconomySim` client. Defaults to None.
|
71
|
+
messager (Messager, optional): The messager object. Defaults to None.
|
72
|
+
simulator (Simulator, optional): The simulator object. Defaults to None.
|
73
|
+
mlflow_client (MlflowClient, optional): The Mlflow object. Defaults to None.
|
74
|
+
memory (Memory, optional): The memory of the agent. Defaults to None.
|
75
|
+
avro_file (dict[str, str], optional): The avro file of the agent. Defaults to None.
|
76
|
+
copy_writer (ray.ObjectRef): The copy_writer of the agent. Defaults to None.
|
77
|
+
"""
|
78
|
+
self._name = name
|
79
|
+
self._type = type
|
80
|
+
self._uuid = str(uuid.uuid4())
|
81
|
+
self._llm_client = llm_client
|
82
|
+
self._economy_client = economy_client
|
83
|
+
self._messager = messager
|
84
|
+
self._simulator = simulator
|
85
|
+
self._mlflow_client = mlflow_client
|
86
|
+
self._memory = memory
|
87
|
+
self._exp_id = -1
|
88
|
+
self._agent_id = -1
|
89
|
+
self._has_bound_to_simulator = False
|
90
|
+
self._has_bound_to_economy = False
|
91
|
+
self._blocked = False
|
92
|
+
self._interview_history: list[dict] = [] # 存储采访历史
|
93
|
+
self._person_template = PersonService.default_dict_person()
|
94
|
+
self._avro_file = avro_file
|
95
|
+
self._pgsql_writer = copy_writer
|
96
|
+
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
97
|
+
|
98
|
+
def __getstate__(self):
|
99
|
+
state = self.__dict__.copy()
|
100
|
+
# 排除锁对象
|
101
|
+
del state["_llm_client"]
|
102
|
+
return state
|
103
|
+
|
104
|
+
def set_messager(self, messager: Messager):
|
105
|
+
"""
|
106
|
+
Set the messager of the agent.
|
107
|
+
"""
|
108
|
+
self._messager = messager
|
109
|
+
|
110
|
+
def set_llm_client(self, llm_client: LLM):
|
111
|
+
"""
|
112
|
+
Set the llm_client of the agent.
|
113
|
+
"""
|
114
|
+
self._llm_client = llm_client
|
115
|
+
|
116
|
+
def set_simulator(self, simulator: Simulator):
|
117
|
+
"""
|
118
|
+
Set the simulator of the agent.
|
119
|
+
"""
|
120
|
+
self._simulator = simulator
|
121
|
+
|
122
|
+
def set_mlflow_client(self, mlflow_client: MlflowClient):
|
123
|
+
"""
|
124
|
+
Set the mlflow_client of the agent.
|
125
|
+
"""
|
126
|
+
self._mlflow_client = mlflow_client
|
127
|
+
|
128
|
+
def set_economy_client(self, economy_client: EconomyClient):
|
129
|
+
"""
|
130
|
+
Set the economy_client of the agent.
|
131
|
+
"""
|
132
|
+
self._economy_client = economy_client
|
133
|
+
|
134
|
+
def set_memory(self, memory: Memory):
|
135
|
+
"""
|
136
|
+
Set the memory of the agent.
|
137
|
+
"""
|
138
|
+
self._memory = memory
|
139
|
+
|
140
|
+
def set_exp_id(self, exp_id: str):
|
141
|
+
"""
|
142
|
+
Set the exp_id of the agent.
|
143
|
+
"""
|
144
|
+
self._exp_id = exp_id
|
145
|
+
|
146
|
+
def set_avro_file(self, avro_file: dict[str, str]):
|
147
|
+
"""
|
148
|
+
Set the avro file of the agent.
|
149
|
+
"""
|
150
|
+
self._avro_file = avro_file
|
151
|
+
|
152
|
+
def set_pgsql_writer(self, pgsql_writer: ray.ObjectRef):
|
153
|
+
"""
|
154
|
+
Set the PostgreSQL copy writer of the agent.
|
155
|
+
"""
|
156
|
+
self._pgsql_writer = pgsql_writer
|
157
|
+
|
158
|
+
@property
|
159
|
+
def uuid(self):
|
160
|
+
"""The Agent's UUID"""
|
161
|
+
return self._uuid
|
162
|
+
|
163
|
+
@property
|
164
|
+
def sim_id(self):
|
165
|
+
"""The Agent's Simulator ID"""
|
166
|
+
return self._agent_id
|
167
|
+
|
168
|
+
@property
|
169
|
+
def llm(self):
|
170
|
+
"""The Agent's LLM"""
|
171
|
+
if self._llm_client is None:
|
172
|
+
raise RuntimeError(
|
173
|
+
f"LLM access before assignment, please `set_llm_client` first!"
|
174
|
+
)
|
175
|
+
return self._llm_client
|
176
|
+
|
177
|
+
@property
|
178
|
+
def economy_client(self):
|
179
|
+
"""The Agent's EconomyClient"""
|
180
|
+
if self._economy_client is None:
|
181
|
+
raise RuntimeError(
|
182
|
+
f"EconomyClient access before assignment, please `set_economy_client` first!"
|
183
|
+
)
|
184
|
+
return self._economy_client
|
185
|
+
|
186
|
+
@property
|
187
|
+
def mlflow_client(self):
|
188
|
+
"""The Agent's MlflowClient"""
|
189
|
+
if self._mlflow_client is None:
|
190
|
+
raise RuntimeError(
|
191
|
+
f"MlflowClient access before assignment, please `set_mlflow_client` first!"
|
192
|
+
)
|
193
|
+
return self._mlflow_client
|
194
|
+
|
195
|
+
@property
|
196
|
+
def memory(self):
|
197
|
+
"""The Agent's Memory"""
|
198
|
+
if self._memory is None:
|
199
|
+
raise RuntimeError(
|
200
|
+
f"Memory access before assignment, please `set_memory` first!"
|
201
|
+
)
|
202
|
+
return self._memory
|
203
|
+
|
204
|
+
@property
|
205
|
+
def simulator(self):
|
206
|
+
"""The Simulator"""
|
207
|
+
if self._simulator is None:
|
208
|
+
raise RuntimeError(
|
209
|
+
f"Simulator access before assignment, please `set_simulator` first!"
|
210
|
+
)
|
211
|
+
return self._simulator
|
212
|
+
|
213
|
+
@property
|
214
|
+
def copy_writer(self):
|
215
|
+
"""Pg Copy Writer"""
|
216
|
+
if self._pgsql_writer is None:
|
217
|
+
raise RuntimeError(
|
218
|
+
f"Copy Writer access before assignment, please `set_pgsql_writer` first!"
|
219
|
+
)
|
220
|
+
return self._pgsql_writer
|
221
|
+
|
222
|
+
async def generate_user_survey_response(self, survey: dict) -> str:
|
223
|
+
"""生成回答 —— 可重写
|
224
|
+
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
225
|
+
Args:
|
226
|
+
survey: 需要回答的问卷 dict
|
227
|
+
Returns:
|
228
|
+
str: 智能体的回答
|
229
|
+
"""
|
230
|
+
survey_prompt = process_survey_for_llm(survey)
|
231
|
+
dialog = []
|
232
|
+
|
233
|
+
# 添加系统提示
|
234
|
+
system_prompt = "Please answer the survey question in first person. Follow the format requirements strictly and provide clear and specific answers."
|
235
|
+
dialog.append({"role": "system", "content": system_prompt})
|
236
|
+
|
237
|
+
# 添加记忆上下文
|
238
|
+
if self._memory:
|
239
|
+
relevant_memories = await self.memory.search(survey_prompt)
|
240
|
+
|
241
|
+
formatted_results = []
|
242
|
+
# for result in top_results:
|
243
|
+
# formatted_results.append(
|
244
|
+
# f"- [{result['type']}] {result['content']} "
|
245
|
+
# f"(相关度: {result['similarity']:.2f})"
|
246
|
+
# )
|
247
|
+
|
248
|
+
if relevant_memories:
|
249
|
+
dialog.append(
|
250
|
+
{
|
251
|
+
"role": "system",
|
252
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
253
|
+
}
|
254
|
+
)
|
255
|
+
|
256
|
+
# 添加问卷问题
|
257
|
+
dialog.append({"role": "user", "content": survey_prompt})
|
258
|
+
|
259
|
+
# 使用LLM生成回答
|
260
|
+
if not self._llm_client:
|
261
|
+
return "Sorry, I cannot answer survey questions right now."
|
262
|
+
|
263
|
+
response = await self._llm_client.atext_request(dialog) # type:ignore
|
264
|
+
|
265
|
+
return response # type:ignore
|
266
|
+
|
267
|
+
async def _process_survey(self, survey: dict):
|
268
|
+
survey_response = await self.generate_user_survey_response(survey)
|
269
|
+
_date_time = datetime.now(timezone.utc)
|
270
|
+
# Avro
|
271
|
+
response_to_avro = [
|
272
|
+
{
|
273
|
+
"id": self._uuid,
|
274
|
+
"day": await self.simulator.get_simulator_day(),
|
275
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
276
|
+
"survey_id": survey["id"],
|
277
|
+
"result": survey_response,
|
278
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
279
|
+
}
|
280
|
+
]
|
281
|
+
if self._avro_file is not None:
|
282
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
283
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
284
|
+
# Pg
|
285
|
+
if self._pgsql_writer is not None:
|
286
|
+
if self._last_asyncio_pg_task is not None:
|
287
|
+
await self._last_asyncio_pg_task
|
288
|
+
_keys = [
|
289
|
+
"id",
|
290
|
+
"day",
|
291
|
+
"t",
|
292
|
+
"survey_id",
|
293
|
+
"result",
|
294
|
+
]
|
295
|
+
_data_tuples: list[tuple] = []
|
296
|
+
# str to json
|
297
|
+
for _dict in response_to_avro:
|
298
|
+
res = _dict["result"]
|
299
|
+
_dict["result"] = json.dumps(
|
300
|
+
{
|
301
|
+
"result": res,
|
302
|
+
}
|
303
|
+
)
|
304
|
+
_data_list = [_dict[k] for k in _keys]
|
305
|
+
# created_at
|
306
|
+
_data_list.append(_date_time)
|
307
|
+
_data_tuples.append(tuple(_data_list))
|
308
|
+
self._last_asyncio_pg_task = (
|
309
|
+
self._pgsql_writer.async_write_survey.remote( # type:ignore
|
310
|
+
_data_tuples
|
311
|
+
)
|
312
|
+
)
|
313
|
+
|
314
|
+
async def generate_user_chat_response(self, question: str) -> str:
|
315
|
+
"""生成回答 —— 可重写
|
316
|
+
基于智能体的记忆和当前状态,生成对问题的回答。
|
317
|
+
Args:
|
318
|
+
question: 需要回答的问题
|
319
|
+
|
320
|
+
Returns:
|
321
|
+
str: 智能体的回答
|
322
|
+
"""
|
323
|
+
dialog = []
|
324
|
+
|
325
|
+
# 添加系统提示
|
326
|
+
system_prompt = "Please answer the question in first person and keep the response concise and clear."
|
327
|
+
dialog.append({"role": "system", "content": system_prompt})
|
328
|
+
|
329
|
+
# 添加记忆上下文
|
330
|
+
if self._memory:
|
331
|
+
relevant_memories = await self._memory.search(question)
|
332
|
+
if relevant_memories:
|
333
|
+
dialog.append(
|
334
|
+
{
|
335
|
+
"role": "system",
|
336
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
337
|
+
}
|
338
|
+
)
|
339
|
+
|
340
|
+
# 添加用户问题
|
341
|
+
dialog.append({"role": "user", "content": question})
|
342
|
+
|
343
|
+
# 使用LLM生成回答
|
344
|
+
if not self._llm_client:
|
345
|
+
return "Sorry, I cannot answer questions right now."
|
346
|
+
|
347
|
+
response = await self._llm_client.atext_request(dialog) # type:ignore
|
348
|
+
|
349
|
+
return response # type:ignore
|
350
|
+
|
351
|
+
async def _process_interview(self, payload: dict):
|
352
|
+
pg_list: list[tuple[dict, datetime]] = []
|
353
|
+
auros: list[dict] = []
|
354
|
+
_date_time = datetime.now(timezone.utc)
|
355
|
+
_interview_dict = {
|
356
|
+
"id": self._uuid,
|
357
|
+
"day": await self.simulator.get_simulator_day(),
|
358
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
359
|
+
"type": 2,
|
360
|
+
"speaker": "user",
|
361
|
+
"content": payload["content"],
|
362
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
363
|
+
}
|
364
|
+
auros.append(_interview_dict)
|
365
|
+
pg_list.append((_interview_dict, _date_time))
|
366
|
+
question = payload["content"]
|
367
|
+
response = await self.generate_user_chat_response(question)
|
368
|
+
_date_time = datetime.now(timezone.utc)
|
369
|
+
_interview_dict = {
|
370
|
+
"id": self._uuid,
|
371
|
+
"day": await self.simulator.get_simulator_day(),
|
372
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
373
|
+
"type": 2,
|
374
|
+
"speaker": "",
|
375
|
+
"content": response,
|
376
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
377
|
+
}
|
378
|
+
auros.append(_interview_dict)
|
379
|
+
pg_list.append((_interview_dict, _date_time))
|
380
|
+
# Avro
|
381
|
+
if self._avro_file is not None:
|
382
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
383
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
384
|
+
# Pg
|
385
|
+
if self._pgsql_writer is not None:
|
386
|
+
if self._last_asyncio_pg_task is not None:
|
387
|
+
await self._last_asyncio_pg_task
|
388
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
389
|
+
_data = [
|
390
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
391
|
+
for _dict, _date_time in pg_list
|
392
|
+
]
|
393
|
+
self._last_asyncio_pg_task = (
|
394
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
395
|
+
_data
|
396
|
+
)
|
397
|
+
)
|
398
|
+
|
399
|
+
async def process_agent_chat_response(self, payload: dict) -> str:
|
400
|
+
resp = f"Agent {self._uuid} received agent chat response: {payload}"
|
401
|
+
logger.info(resp)
|
402
|
+
return resp
|
403
|
+
|
404
|
+
async def _process_agent_chat(self, payload: dict):
|
405
|
+
pg_list: list[tuple[dict, datetime]] = []
|
406
|
+
auros: list[dict] = []
|
407
|
+
_date_time = datetime.now(timezone.utc)
|
408
|
+
_chat_dict = {
|
409
|
+
"id": self._uuid,
|
410
|
+
"day": payload["day"],
|
411
|
+
"t": payload["t"],
|
412
|
+
"type": 1,
|
413
|
+
"speaker": payload["from"],
|
414
|
+
"content": payload["content"],
|
415
|
+
"created_at": int(_date_time.timestamp() * 1000),
|
416
|
+
}
|
417
|
+
auros.append(_chat_dict)
|
418
|
+
pg_list.append((_chat_dict, _date_time))
|
419
|
+
asyncio.create_task(self.process_agent_chat_response(payload))
|
420
|
+
# Avro
|
421
|
+
if self._avro_file is not None:
|
422
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
423
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
424
|
+
# Pg
|
425
|
+
if self._pgsql_writer is not None:
|
426
|
+
if self._last_asyncio_pg_task is not None:
|
427
|
+
await self._last_asyncio_pg_task
|
428
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
429
|
+
_data = [
|
430
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
431
|
+
for _dict, _date_time in pg_list
|
432
|
+
]
|
433
|
+
self._last_asyncio_pg_task = (
|
434
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
435
|
+
_data
|
436
|
+
)
|
437
|
+
)
|
438
|
+
|
439
|
+
# Callback functions for MQTT message
|
440
|
+
async def handle_agent_chat_message(self, payload: dict):
|
441
|
+
"""处理收到的消息,识别发送者"""
|
442
|
+
# 从消息中解析发送者 ID 和消息内容
|
443
|
+
logger.info(f"Agent {self._uuid} received agent chat message: {payload}")
|
444
|
+
asyncio.create_task(self._process_agent_chat(payload))
|
445
|
+
|
446
|
+
async def handle_user_chat_message(self, payload: dict):
|
447
|
+
"""处理收到的消息,识别发送者"""
|
448
|
+
# 从消息中解析发送者 ID 和消息内容
|
449
|
+
logger.info(f"Agent {self._uuid} received user chat message: {payload}")
|
450
|
+
asyncio.create_task(self._process_interview(payload))
|
451
|
+
|
452
|
+
async def handle_user_survey_message(self, payload: dict):
|
453
|
+
"""处理收到的消息,识别发送者"""
|
454
|
+
# 从消息中解析发送者 ID 和消息内容
|
455
|
+
logger.info(f"Agent {self._uuid} received user survey message: {payload}")
|
456
|
+
asyncio.create_task(self._process_survey(payload["data"]))
|
457
|
+
|
458
|
+
async def handle_gather_message(self, payload: Any):
|
459
|
+
raise NotImplementedError
|
460
|
+
|
461
|
+
# MQTT send message
|
462
|
+
async def _send_message(self, to_agent_uuid: str, payload: dict, sub_topic: str):
|
463
|
+
"""通过 Messager 发送消息"""
|
464
|
+
if self._messager is None:
|
465
|
+
raise RuntimeError("Messager is not set")
|
466
|
+
topic = f"exps/{self._exp_id}/agents/{to_agent_uuid}/{sub_topic}"
|
467
|
+
await self._messager.send_message.remote(topic, payload)
|
468
|
+
|
469
|
+
async def send_message_to_agent(
|
470
|
+
self, to_agent_uuid: str, content: str, type: str = "social"
|
471
|
+
):
|
472
|
+
"""通过 Messager 发送消息"""
|
473
|
+
if self._messager is None:
|
474
|
+
raise RuntimeError("Messager is not set")
|
475
|
+
if type not in ["social", "economy"]:
|
476
|
+
logger.warning(f"Invalid message type: {type}, sent from {self._uuid}")
|
477
|
+
payload = {
|
478
|
+
"from": self._uuid,
|
479
|
+
"content": content,
|
480
|
+
"type": type,
|
481
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
482
|
+
"day": await self.simulator.get_simulator_day(),
|
483
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
484
|
+
}
|
485
|
+
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
486
|
+
pg_list: list[tuple[dict, datetime]] = []
|
487
|
+
auros: list[dict] = []
|
488
|
+
_date_time = datetime.now(timezone.utc)
|
489
|
+
_message_dict = {
|
490
|
+
"id": self._uuid,
|
491
|
+
"day": await self.simulator.get_simulator_day(),
|
492
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
493
|
+
"type": 1,
|
494
|
+
"speaker": self._uuid,
|
495
|
+
"content": content,
|
496
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
497
|
+
}
|
498
|
+
auros.append(_message_dict)
|
499
|
+
pg_list.append((_message_dict, _date_time))
|
500
|
+
# Avro
|
501
|
+
if self._avro_file is not None and type == "social":
|
502
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
503
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
504
|
+
# Pg
|
505
|
+
if self._pgsql_writer is not None and type == "social":
|
506
|
+
if self._last_asyncio_pg_task is not None:
|
507
|
+
await self._last_asyncio_pg_task
|
508
|
+
_keys = ["id", "day", "t", "type", "speaker", "content", "created_at"]
|
509
|
+
_data = [
|
510
|
+
tuple([_dict[k] if k != "created_at" else _date_time for k in _keys])
|
511
|
+
for _dict, _date_time in pg_list
|
512
|
+
]
|
513
|
+
self._last_asyncio_pg_task = (
|
514
|
+
self._pgsql_writer.async_write_dialog.remote( # type:ignore
|
515
|
+
_data
|
516
|
+
)
|
517
|
+
)
|
518
|
+
|
519
|
+
# Agent logic
|
520
|
+
@abstractmethod
|
521
|
+
async def forward(self) -> None:
|
522
|
+
"""智能体行为逻辑"""
|
523
|
+
raise NotImplementedError
|
524
|
+
|
525
|
+
async def run(self) -> None:
|
526
|
+
"""
|
527
|
+
统一的Agent执行入口
|
528
|
+
当_blocked为True时,不执行forward方法
|
529
|
+
"""
|
530
|
+
if not self._blocked:
|
531
|
+
await self.forward()
|
532
|
+
|
533
|
+
|
534
|
+
class CitizenAgent(Agent):
|
535
|
+
"""
|
536
|
+
CitizenAgent: 城市居民智能体类及其定义
|
537
|
+
"""
|
538
|
+
|
539
|
+
def __init__(
|
540
|
+
self,
|
541
|
+
name: str,
|
542
|
+
llm_client: Optional[LLM] = None,
|
543
|
+
simulator: Optional[Simulator] = None,
|
544
|
+
mlflow_client: Optional[MlflowClient] = None,
|
545
|
+
memory: Optional[Memory] = None,
|
546
|
+
economy_client: Optional[EconomyClient] = None,
|
547
|
+
messager: Optional[Messager] = None,
|
548
|
+
avro_file: Optional[dict] = None,
|
549
|
+
) -> None:
|
550
|
+
super().__init__(
|
551
|
+
name=name,
|
552
|
+
type=AgentType.Citizen,
|
553
|
+
llm_client=llm_client,
|
554
|
+
economy_client=economy_client,
|
555
|
+
messager=messager,
|
556
|
+
simulator=simulator,
|
557
|
+
mlflow_client=mlflow_client,
|
558
|
+
memory=memory,
|
559
|
+
avro_file=avro_file,
|
560
|
+
)
|
561
|
+
|
562
|
+
async def bind_to_simulator(self):
|
563
|
+
await self._bind_to_simulator()
|
564
|
+
await self._bind_to_economy()
|
565
|
+
|
566
|
+
async def _bind_to_simulator(self):
|
567
|
+
"""
|
568
|
+
Bind Agent to Simulator
|
569
|
+
|
570
|
+
Args:
|
571
|
+
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
572
|
+
"""
|
573
|
+
if self._simulator is None:
|
574
|
+
logger.warning("Simulator is not set")
|
575
|
+
return
|
576
|
+
if not self._has_bound_to_simulator:
|
577
|
+
FROM_MEMORY_KEYS = {
|
578
|
+
"attribute",
|
579
|
+
"home",
|
580
|
+
"work",
|
581
|
+
"vehicle_attribute",
|
582
|
+
"bus_attribute",
|
583
|
+
"pedestrian_attribute",
|
584
|
+
"bike_attribute",
|
585
|
+
}
|
586
|
+
simulator = self.simulator
|
587
|
+
memory = self.memory
|
588
|
+
person_id = await memory.get("id")
|
589
|
+
# ATTENTION:模拟器分配的id从0开始
|
590
|
+
if person_id >= 0:
|
591
|
+
await simulator.get_person(person_id)
|
592
|
+
logger.debug(f"Binding to Person `{person_id}` already in Simulator")
|
593
|
+
else:
|
594
|
+
dict_person = deepcopy(self._person_template)
|
595
|
+
for _key in FROM_MEMORY_KEYS:
|
596
|
+
try:
|
597
|
+
_value = await memory.get(_key)
|
598
|
+
if _value:
|
599
|
+
dict_person[_key] = _value
|
600
|
+
except KeyError as e:
|
601
|
+
continue
|
602
|
+
resp = await simulator.add_person(
|
603
|
+
dict2pb(dict_person, person_pb2.Person())
|
604
|
+
)
|
605
|
+
person_id = resp["person_id"]
|
606
|
+
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
607
|
+
logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
|
608
|
+
# 防止模拟器还没有到prepare阶段导致get_person出错
|
609
|
+
self._has_bound_to_simulator = True
|
610
|
+
self._agent_id = person_id
|
611
|
+
self.memory.set_agent_id(person_id)
|
612
|
+
|
613
|
+
async def _bind_to_economy(self):
|
614
|
+
if self._economy_client is None:
|
615
|
+
logger.warning("Economy client is not set")
|
616
|
+
return
|
617
|
+
if not self._has_bound_to_economy:
|
618
|
+
if self._has_bound_to_simulator:
|
619
|
+
try:
|
620
|
+
await self._economy_client.remove_agents([self._agent_id])
|
621
|
+
except:
|
622
|
+
pass
|
623
|
+
person_id = await self.memory.get("id")
|
624
|
+
await self._economy_client.add_agents(
|
625
|
+
{
|
626
|
+
"id": person_id,
|
627
|
+
"currency": await self.memory.get("currency"),
|
628
|
+
}
|
629
|
+
)
|
630
|
+
self._has_bound_to_economy = True
|
631
|
+
else:
|
632
|
+
logger.debug(
|
633
|
+
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
634
|
+
)
|
635
|
+
|
636
|
+
async def handle_gather_message(self, payload: dict):
|
637
|
+
"""处理收到的消息,识别发送者"""
|
638
|
+
# 从消息中解析发送者 ID 和消息内容
|
639
|
+
target = payload["target"]
|
640
|
+
sender_id = payload["from"]
|
641
|
+
content = await self.memory.get(f"{target}")
|
642
|
+
payload = {
|
643
|
+
"from": self._uuid,
|
644
|
+
"content": content,
|
645
|
+
}
|
646
|
+
await self._send_message(sender_id, payload, "gather")
|
647
|
+
|
648
|
+
|
649
|
+
class InstitutionAgent(Agent):
|
650
|
+
"""
|
651
|
+
InstitutionAgent: 机构智能体类及其定义
|
652
|
+
"""
|
653
|
+
|
654
|
+
def __init__(
|
655
|
+
self,
|
656
|
+
name: str,
|
657
|
+
llm_client: Optional[LLM] = None,
|
658
|
+
simulator: Optional[Simulator] = None,
|
659
|
+
mlflow_client: Optional[MlflowClient] = None,
|
660
|
+
memory: Optional[Memory] = None,
|
661
|
+
economy_client: Optional[EconomyClient] = None,
|
662
|
+
messager: Optional[Messager] = None,
|
663
|
+
avro_file: Optional[dict] = None,
|
664
|
+
) -> None:
|
665
|
+
super().__init__(
|
666
|
+
name=name,
|
667
|
+
type=AgentType.Institution,
|
668
|
+
llm_client=llm_client,
|
669
|
+
economy_client=economy_client,
|
670
|
+
mlflow_client=mlflow_client,
|
671
|
+
messager=messager,
|
672
|
+
simulator=simulator,
|
673
|
+
memory=memory,
|
674
|
+
avro_file=avro_file,
|
675
|
+
)
|
676
|
+
# 添加响应收集器
|
677
|
+
self._gather_responses: dict[str, asyncio.Future] = {}
|
678
|
+
|
679
|
+
async def bind_to_simulator(self):
|
680
|
+
await self._bind_to_economy()
|
681
|
+
|
682
|
+
async def _bind_to_economy(self):
|
683
|
+
if self._economy_client is None:
|
684
|
+
logger.debug("Economy client is not set")
|
685
|
+
return
|
686
|
+
if not self._has_bound_to_economy:
|
687
|
+
# TODO: More general id generation
|
688
|
+
_id = random.randint(100000, 999999)
|
689
|
+
self._agent_id = _id
|
690
|
+
self.memory.set_agent_id(_id)
|
691
|
+
map_header = self.simulator.map.header
|
692
|
+
# TODO: remove random position assignment
|
693
|
+
await self.memory.update(
|
694
|
+
"position",
|
695
|
+
{
|
696
|
+
"xy_position": {
|
697
|
+
"x": float(
|
698
|
+
random.randrange(
|
699
|
+
start=int(map_header["west"]),
|
700
|
+
stop=int(map_header["east"]),
|
701
|
+
)
|
702
|
+
),
|
703
|
+
"y": float(
|
704
|
+
random.randrange(
|
705
|
+
start=int(map_header["south"]),
|
706
|
+
stop=int(map_header["north"]),
|
707
|
+
)
|
708
|
+
),
|
709
|
+
}
|
710
|
+
},
|
711
|
+
protect_llm_read_only_fields=False,
|
712
|
+
)
|
713
|
+
await self.memory.update("id", _id, protect_llm_read_only_fields=False)
|
714
|
+
try:
|
715
|
+
await self._economy_client.remove_orgs([self._agent_id])
|
716
|
+
except:
|
717
|
+
pass
|
718
|
+
try:
|
719
|
+
_memory = self.memory
|
720
|
+
_id = await _memory.get("id")
|
721
|
+
_type = await _memory.get("type")
|
722
|
+
try:
|
723
|
+
nominal_gdp = await _memory.get("nominal_gdp")
|
724
|
+
except:
|
725
|
+
nominal_gdp = []
|
726
|
+
try:
|
727
|
+
real_gdp = await _memory.get("real_gdp")
|
728
|
+
except:
|
729
|
+
real_gdp = []
|
730
|
+
try:
|
731
|
+
unemployment = await _memory.get("unemployment")
|
732
|
+
except:
|
733
|
+
unemployment = []
|
734
|
+
try:
|
735
|
+
wages = await _memory.get("wages")
|
736
|
+
except:
|
737
|
+
wages = []
|
738
|
+
try:
|
739
|
+
prices = await _memory.get("prices")
|
740
|
+
except:
|
741
|
+
prices = []
|
742
|
+
try:
|
743
|
+
inventory = await _memory.get("inventory")
|
744
|
+
except:
|
745
|
+
inventory = 0
|
746
|
+
try:
|
747
|
+
price = await _memory.get("price")
|
748
|
+
except:
|
749
|
+
price = 0
|
750
|
+
try:
|
751
|
+
currency = await _memory.get("currency")
|
752
|
+
except:
|
753
|
+
currency = 0.0
|
754
|
+
try:
|
755
|
+
interest_rate = await _memory.get("interest_rate")
|
756
|
+
except:
|
757
|
+
interest_rate = 0.0
|
758
|
+
try:
|
759
|
+
bracket_cutoffs = await _memory.get("bracket_cutoffs")
|
760
|
+
except:
|
761
|
+
bracket_cutoffs = []
|
762
|
+
try:
|
763
|
+
bracket_rates = await _memory.get("bracket_rates")
|
764
|
+
except:
|
765
|
+
bracket_rates = []
|
766
|
+
await self._economy_client.add_orgs(
|
767
|
+
{
|
768
|
+
"id": _id,
|
769
|
+
"type": _type,
|
770
|
+
"nominal_gdp": nominal_gdp,
|
771
|
+
"real_gdp": real_gdp,
|
772
|
+
"unemployment": unemployment,
|
773
|
+
"wages": wages,
|
774
|
+
"prices": prices,
|
775
|
+
"inventory": inventory,
|
776
|
+
"price": price,
|
777
|
+
"currency": currency,
|
778
|
+
"interest_rate": interest_rate,
|
779
|
+
"bracket_cutoffs": bracket_cutoffs,
|
780
|
+
"bracket_rates": bracket_rates,
|
781
|
+
}
|
782
|
+
)
|
783
|
+
except Exception as e:
|
784
|
+
logger.error(f"Failed to bind to Economy: {e}")
|
785
|
+
self._has_bound_to_economy = True
|
786
|
+
|
787
|
+
async def handle_gather_message(self, payload: dict):
|
788
|
+
"""处理收到的消息,识别发送者"""
|
789
|
+
content = payload["content"]
|
790
|
+
sender_id = payload["from"]
|
791
|
+
|
792
|
+
# 将响应存储到对应的Future中
|
793
|
+
response_key = str(sender_id)
|
794
|
+
if response_key in self._gather_responses:
|
795
|
+
self._gather_responses[response_key].set_result(
|
796
|
+
{
|
797
|
+
"from": sender_id,
|
798
|
+
"content": content,
|
799
|
+
}
|
800
|
+
)
|
801
|
+
|
802
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
|
803
|
+
"""从多个智能体收集消息
|
804
|
+
|
805
|
+
Args:
|
806
|
+
agent_uuids: 目标智能体UUID列表
|
807
|
+
target: 要收集的信息类型
|
808
|
+
|
809
|
+
Returns:
|
810
|
+
list[dict]: 收集到的所有响应
|
811
|
+
"""
|
812
|
+
# 为每个agent创建Future
|
813
|
+
futures = {}
|
814
|
+
for agent_uuid in agent_uuids:
|
815
|
+
futures[agent_uuid] = asyncio.Future()
|
816
|
+
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
817
|
+
|
818
|
+
# 发送gather请求
|
819
|
+
payload = {
|
820
|
+
"from": self._uuid,
|
821
|
+
"target": target,
|
822
|
+
}
|
823
|
+
for agent_uuid in agent_uuids:
|
824
|
+
await self._send_message(agent_uuid, payload, "gather")
|
825
|
+
|
826
|
+
try:
|
827
|
+
# 等待所有响应
|
828
|
+
responses = await asyncio.gather(*futures.values())
|
829
|
+
return responses
|
830
|
+
finally:
|
831
|
+
# 清理Future
|
832
|
+
for key in futures:
|
833
|
+
self._gather_responses.pop(key, None)
|