pycityagent 2.0.0a49__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a51__cp39-cp39-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/__init__.py +12 -3
- pycityagent/agent/__init__.py +9 -0
- pycityagent/agent/agent.py +324 -0
- pycityagent/{agent.py → agent/agent_base.py} +41 -345
- pycityagent/cityagent/bankagent.py +28 -16
- pycityagent/cityagent/firmagent.py +63 -25
- pycityagent/cityagent/governmentagent.py +35 -19
- pycityagent/cityagent/initial.py +38 -28
- pycityagent/cityagent/memory_config.py +240 -128
- pycityagent/cityagent/nbsagent.py +82 -36
- pycityagent/cityagent/societyagent.py +155 -72
- pycityagent/simulation/agentgroup.py +2 -2
- pycityagent/simulation/simulation.py +94 -55
- pycityagent/tools/__init__.py +11 -0
- pycityagent/{workflow → tools}/tool.py +3 -1
- pycityagent/workflow/__init__.py +0 -5
- pycityagent/workflow/block.py +12 -10
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a51.dist-info}/METADATA +1 -2
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a51.dist-info}/RECORD +23 -20
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a51.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a51.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a51.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a49.dist-info → pycityagent-2.0.0a51.dist-info}/top_level.txt +0 -0
pycityagent/__init__.py
CHANGED
@@ -2,11 +2,12 @@
|
|
2
2
|
Pycityagent: 城市智能体构建框架
|
3
3
|
"""
|
4
4
|
|
5
|
-
|
5
|
+
import logging
|
6
|
+
|
7
|
+
from .agent import Agent, AgentType, CitizenAgent, InstitutionAgent
|
6
8
|
from .environment import Simulator
|
7
9
|
from .llm import SentenceEmbedding
|
8
10
|
from .simulation import AgentSimulation
|
9
|
-
import logging
|
10
11
|
|
11
12
|
# 创建一个 pycityagent 记录器
|
12
13
|
logger = logging.getLogger("pycityagent")
|
@@ -21,4 +22,12 @@ if not logger.hasHandlers():
|
|
21
22
|
handler.setFormatter(formatter)
|
22
23
|
logger.addHandler(handler)
|
23
24
|
|
24
|
-
__all__ = [
|
25
|
+
__all__ = [
|
26
|
+
"Agent",
|
27
|
+
"Simulator",
|
28
|
+
"CitizenAgent",
|
29
|
+
"InstitutionAgent",
|
30
|
+
"SentenceEmbedding",
|
31
|
+
"AgentSimulation",
|
32
|
+
"AgentType",
|
33
|
+
]
|
@@ -0,0 +1,324 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import random
|
6
|
+
from copy import deepcopy
|
7
|
+
from typing import Any, Optional
|
8
|
+
|
9
|
+
from mosstool.util.format_converter import dict2pb
|
10
|
+
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
11
|
+
|
12
|
+
from ..economy import EconomyClient
|
13
|
+
from ..environment import Simulator
|
14
|
+
from ..llm import LLM
|
15
|
+
from ..memory import Memory
|
16
|
+
from ..message.messager import Messager
|
17
|
+
from ..metrics import MlflowClient
|
18
|
+
from .agent_base import Agent, AgentType
|
19
|
+
|
20
|
+
logger = logging.getLogger("pycityagent")
|
21
|
+
|
22
|
+
|
23
|
+
class CitizenAgent(Agent):
|
24
|
+
"""
|
25
|
+
CitizenAgent: 城市居民智能体类及其定义
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
name: str,
|
31
|
+
llm_client: Optional[LLM] = None,
|
32
|
+
simulator: Optional[Simulator] = None,
|
33
|
+
mlflow_client: Optional[MlflowClient] = None,
|
34
|
+
memory: Optional[Memory] = None,
|
35
|
+
economy_client: Optional[EconomyClient] = None,
|
36
|
+
messager: Optional[Messager] = None, # type:ignore
|
37
|
+
avro_file: Optional[dict] = None,
|
38
|
+
) -> None:
|
39
|
+
super().__init__(
|
40
|
+
name=name,
|
41
|
+
type=AgentType.Citizen,
|
42
|
+
llm_client=llm_client,
|
43
|
+
economy_client=economy_client,
|
44
|
+
messager=messager,
|
45
|
+
simulator=simulator,
|
46
|
+
mlflow_client=mlflow_client,
|
47
|
+
memory=memory,
|
48
|
+
avro_file=avro_file,
|
49
|
+
)
|
50
|
+
|
51
|
+
async def bind_to_simulator(self):
|
52
|
+
await self._bind_to_simulator()
|
53
|
+
await self._bind_to_economy()
|
54
|
+
|
55
|
+
async def _bind_to_simulator(self):
|
56
|
+
"""
|
57
|
+
Bind Agent to Simulator
|
58
|
+
|
59
|
+
Args:
|
60
|
+
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
61
|
+
"""
|
62
|
+
if self._simulator is None:
|
63
|
+
logger.warning("Simulator is not set")
|
64
|
+
return
|
65
|
+
if not self._has_bound_to_simulator:
|
66
|
+
FROM_MEMORY_KEYS = {
|
67
|
+
"attribute",
|
68
|
+
"home",
|
69
|
+
"work",
|
70
|
+
"vehicle_attribute",
|
71
|
+
"bus_attribute",
|
72
|
+
"pedestrian_attribute",
|
73
|
+
"bike_attribute",
|
74
|
+
}
|
75
|
+
simulator = self.simulator
|
76
|
+
memory = self.memory
|
77
|
+
person_id = await memory.get("id")
|
78
|
+
# ATTENTION:模拟器分配的id从0开始
|
79
|
+
if person_id >= 0:
|
80
|
+
await simulator.get_person(person_id)
|
81
|
+
logger.debug(f"Binding to Person `{person_id}` already in Simulator")
|
82
|
+
else:
|
83
|
+
dict_person = deepcopy(self._person_template)
|
84
|
+
for _key in FROM_MEMORY_KEYS:
|
85
|
+
try:
|
86
|
+
_value = await memory.get(_key)
|
87
|
+
if _value:
|
88
|
+
dict_person[_key] = _value
|
89
|
+
except KeyError as e:
|
90
|
+
continue
|
91
|
+
resp = await simulator.add_person(
|
92
|
+
dict2pb(dict_person, person_pb2.Person())
|
93
|
+
)
|
94
|
+
person_id = resp["person_id"]
|
95
|
+
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
96
|
+
logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
|
97
|
+
# 防止模拟器还没有到prepare阶段导致get_person出错
|
98
|
+
self._has_bound_to_simulator = True
|
99
|
+
self._agent_id = person_id
|
100
|
+
self.memory.set_agent_id(person_id)
|
101
|
+
|
102
|
+
async def _bind_to_economy(self):
|
103
|
+
if self._economy_client is None:
|
104
|
+
logger.warning("Economy client is not set")
|
105
|
+
return
|
106
|
+
if not self._has_bound_to_economy:
|
107
|
+
if self._has_bound_to_simulator:
|
108
|
+
try:
|
109
|
+
await self._economy_client.remove_agents([self._agent_id])
|
110
|
+
except:
|
111
|
+
pass
|
112
|
+
person_id = await self.memory.get("id")
|
113
|
+
currency = await self.memory.get("currency")
|
114
|
+
await self._economy_client.add_agents(
|
115
|
+
{
|
116
|
+
"id": person_id,
|
117
|
+
"currency": currency,
|
118
|
+
}
|
119
|
+
)
|
120
|
+
self._has_bound_to_economy = True
|
121
|
+
else:
|
122
|
+
logger.debug(
|
123
|
+
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
124
|
+
)
|
125
|
+
|
126
|
+
async def handle_gather_message(self, payload: dict):
|
127
|
+
"""处理收到的消息,识别发送者"""
|
128
|
+
# 从消息中解析发送者 ID 和消息内容
|
129
|
+
target = payload["target"]
|
130
|
+
sender_id = payload["from"]
|
131
|
+
content = await self.memory.get(f"{target}")
|
132
|
+
payload = {
|
133
|
+
"from": self._uuid,
|
134
|
+
"content": content,
|
135
|
+
}
|
136
|
+
await self._send_message(sender_id, payload, "gather")
|
137
|
+
|
138
|
+
|
139
|
+
class InstitutionAgent(Agent):
|
140
|
+
"""
|
141
|
+
InstitutionAgent: 机构智能体类及其定义
|
142
|
+
"""
|
143
|
+
|
144
|
+
def __init__(
|
145
|
+
self,
|
146
|
+
name: str,
|
147
|
+
llm_client: Optional[LLM] = None,
|
148
|
+
simulator: Optional[Simulator] = None,
|
149
|
+
mlflow_client: Optional[MlflowClient] = None,
|
150
|
+
memory: Optional[Memory] = None,
|
151
|
+
economy_client: Optional[EconomyClient] = None,
|
152
|
+
messager: Optional[Messager] = None, # type:ignore
|
153
|
+
avro_file: Optional[dict] = None,
|
154
|
+
) -> None:
|
155
|
+
super().__init__(
|
156
|
+
name=name,
|
157
|
+
type=AgentType.Institution,
|
158
|
+
llm_client=llm_client,
|
159
|
+
economy_client=economy_client,
|
160
|
+
mlflow_client=mlflow_client,
|
161
|
+
messager=messager,
|
162
|
+
simulator=simulator,
|
163
|
+
memory=memory,
|
164
|
+
avro_file=avro_file,
|
165
|
+
)
|
166
|
+
# 添加响应收集器
|
167
|
+
self._gather_responses: dict[str, asyncio.Future] = {}
|
168
|
+
|
169
|
+
async def bind_to_simulator(self):
|
170
|
+
await self._bind_to_economy()
|
171
|
+
|
172
|
+
async def _bind_to_economy(self):
|
173
|
+
print("Debug:", self._economy_client, self._has_bound_to_economy)
|
174
|
+
if self._economy_client is None:
|
175
|
+
logger.debug("Economy client is not set")
|
176
|
+
return
|
177
|
+
if not self._has_bound_to_economy:
|
178
|
+
# TODO: More general id generation
|
179
|
+
_id = random.randint(100000, 999999)
|
180
|
+
self._agent_id = _id
|
181
|
+
self.memory.set_agent_id(_id)
|
182
|
+
map_header = self.simulator.map.header
|
183
|
+
# TODO: remove random position assignment
|
184
|
+
await self.memory.update(
|
185
|
+
"position",
|
186
|
+
{
|
187
|
+
"xy_position": {
|
188
|
+
"x": float(
|
189
|
+
random.randrange(
|
190
|
+
start=int(map_header["west"]),
|
191
|
+
stop=int(map_header["east"]),
|
192
|
+
)
|
193
|
+
),
|
194
|
+
"y": float(
|
195
|
+
random.randrange(
|
196
|
+
start=int(map_header["south"]),
|
197
|
+
stop=int(map_header["north"]),
|
198
|
+
)
|
199
|
+
),
|
200
|
+
}
|
201
|
+
},
|
202
|
+
protect_llm_read_only_fields=False,
|
203
|
+
)
|
204
|
+
await self.memory.update("id", _id, protect_llm_read_only_fields=False)
|
205
|
+
try:
|
206
|
+
await self._economy_client.remove_orgs([self._agent_id])
|
207
|
+
except:
|
208
|
+
pass
|
209
|
+
try:
|
210
|
+
_memory = self.memory
|
211
|
+
_id = await _memory.get("id")
|
212
|
+
_type = await _memory.get("type")
|
213
|
+
try:
|
214
|
+
nominal_gdp = await _memory.get("nominal_gdp")
|
215
|
+
except:
|
216
|
+
nominal_gdp = []
|
217
|
+
try:
|
218
|
+
real_gdp = await _memory.get("real_gdp")
|
219
|
+
except:
|
220
|
+
real_gdp = []
|
221
|
+
try:
|
222
|
+
unemployment = await _memory.get("unemployment")
|
223
|
+
except:
|
224
|
+
unemployment = []
|
225
|
+
try:
|
226
|
+
wages = await _memory.get("wages")
|
227
|
+
except:
|
228
|
+
wages = []
|
229
|
+
try:
|
230
|
+
prices = await _memory.get("prices")
|
231
|
+
except:
|
232
|
+
prices = []
|
233
|
+
try:
|
234
|
+
inventory = await _memory.get("inventory")
|
235
|
+
except:
|
236
|
+
inventory = 0
|
237
|
+
try:
|
238
|
+
price = await _memory.get("price")
|
239
|
+
except:
|
240
|
+
price = 0
|
241
|
+
try:
|
242
|
+
currency = await _memory.get("currency")
|
243
|
+
except:
|
244
|
+
currency = 0.0
|
245
|
+
try:
|
246
|
+
interest_rate = await _memory.get("interest_rate")
|
247
|
+
except:
|
248
|
+
interest_rate = 0.0
|
249
|
+
try:
|
250
|
+
bracket_cutoffs = await _memory.get("bracket_cutoffs")
|
251
|
+
except:
|
252
|
+
bracket_cutoffs = []
|
253
|
+
try:
|
254
|
+
bracket_rates = await _memory.get("bracket_rates")
|
255
|
+
except:
|
256
|
+
bracket_rates = []
|
257
|
+
await self._economy_client.add_orgs(
|
258
|
+
{
|
259
|
+
"id": _id,
|
260
|
+
"type": _type,
|
261
|
+
"nominal_gdp": nominal_gdp,
|
262
|
+
"real_gdp": real_gdp,
|
263
|
+
"unemployment": unemployment,
|
264
|
+
"wages": wages,
|
265
|
+
"prices": prices,
|
266
|
+
"inventory": inventory,
|
267
|
+
"price": price,
|
268
|
+
"currency": currency,
|
269
|
+
"interest_rate": interest_rate,
|
270
|
+
"bracket_cutoffs": bracket_cutoffs,
|
271
|
+
"bracket_rates": bracket_rates,
|
272
|
+
}
|
273
|
+
)
|
274
|
+
except Exception as e:
|
275
|
+
logger.error(f"Failed to bind to Economy: {e}")
|
276
|
+
self._has_bound_to_economy = True
|
277
|
+
|
278
|
+
async def handle_gather_message(self, payload: dict):
|
279
|
+
"""处理收到的消息,识别发送者"""
|
280
|
+
content = payload["content"]
|
281
|
+
sender_id = payload["from"]
|
282
|
+
|
283
|
+
# 将响应存储到对应的Future中
|
284
|
+
response_key = str(sender_id)
|
285
|
+
if response_key in self._gather_responses:
|
286
|
+
self._gather_responses[response_key].set_result(
|
287
|
+
{
|
288
|
+
"from": sender_id,
|
289
|
+
"content": content,
|
290
|
+
}
|
291
|
+
)
|
292
|
+
|
293
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
|
294
|
+
"""从多个智能体收集消息
|
295
|
+
|
296
|
+
Args:
|
297
|
+
agent_uuids: 目标智能体UUID列表
|
298
|
+
target: 要收集的信息类型
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
list[dict]: 收集到的所有响应
|
302
|
+
"""
|
303
|
+
# 为每个agent创建Future
|
304
|
+
futures = {}
|
305
|
+
for agent_uuid in agent_uuids:
|
306
|
+
futures[agent_uuid] = asyncio.Future()
|
307
|
+
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
308
|
+
|
309
|
+
# 发送gather请求
|
310
|
+
payload = {
|
311
|
+
"from": self._uuid,
|
312
|
+
"target": target,
|
313
|
+
}
|
314
|
+
for agent_uuid in agent_uuids:
|
315
|
+
await self._send_message(agent_uuid, payload, "gather")
|
316
|
+
|
317
|
+
try:
|
318
|
+
# 等待所有响应
|
319
|
+
responses = await asyncio.gather(*futures.values())
|
320
|
+
return responses
|
321
|
+
finally:
|
322
|
+
# 清理Future
|
323
|
+
for key in futures:
|
324
|
+
self._gather_responses.pop(key, None)
|