pycityagent 2.0.0a49__cp312-cp312-macosx_11_0_arm64.whl → 2.0.0a51__cp312-cp312-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
pycityagent/__init__.py CHANGED
@@ -2,11 +2,12 @@
2
2
  Pycityagent: 城市智能体构建框架
3
3
  """
4
4
 
5
- from .agent import Agent, CitizenAgent, InstitutionAgent
5
+ import logging
6
+
7
+ from .agent import Agent, AgentType, CitizenAgent, InstitutionAgent
6
8
  from .environment import Simulator
7
9
  from .llm import SentenceEmbedding
8
10
  from .simulation import AgentSimulation
9
- import logging
10
11
 
11
12
  # 创建一个 pycityagent 记录器
12
13
  logger = logging.getLogger("pycityagent")
@@ -21,4 +22,12 @@ if not logger.hasHandlers():
21
22
  handler.setFormatter(formatter)
22
23
  logger.addHandler(handler)
23
24
 
24
- __all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent","SentenceEmbedding","AgentSimulation"]
25
+ __all__ = [
26
+ "Agent",
27
+ "Simulator",
28
+ "CitizenAgent",
29
+ "InstitutionAgent",
30
+ "SentenceEmbedding",
31
+ "AgentSimulation",
32
+ "AgentType",
33
+ ]
@@ -0,0 +1,9 @@
1
+ from .agent import CitizenAgent, InstitutionAgent
2
+ from .agent_base import Agent, AgentType
3
+
4
+ __all__ = [
5
+ "Agent",
6
+ "CitizenAgent",
7
+ "InstitutionAgent",
8
+ "AgentType",
9
+ ]
@@ -0,0 +1,324 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import random
6
+ from copy import deepcopy
7
+ from typing import Any, Optional
8
+
9
+ from mosstool.util.format_converter import dict2pb
10
+ from pycityproto.city.person.v2 import person_pb2 as person_pb2
11
+
12
+ from ..economy import EconomyClient
13
+ from ..environment import Simulator
14
+ from ..llm import LLM
15
+ from ..memory import Memory
16
+ from ..message.messager import Messager
17
+ from ..metrics import MlflowClient
18
+ from .agent_base import Agent, AgentType
19
+
20
+ logger = logging.getLogger("pycityagent")
21
+
22
+
23
+ class CitizenAgent(Agent):
24
+ """
25
+ CitizenAgent: 城市居民智能体类及其定义
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ name: str,
31
+ llm_client: Optional[LLM] = None,
32
+ simulator: Optional[Simulator] = None,
33
+ mlflow_client: Optional[MlflowClient] = None,
34
+ memory: Optional[Memory] = None,
35
+ economy_client: Optional[EconomyClient] = None,
36
+ messager: Optional[Messager] = None, # type:ignore
37
+ avro_file: Optional[dict] = None,
38
+ ) -> None:
39
+ super().__init__(
40
+ name=name,
41
+ type=AgentType.Citizen,
42
+ llm_client=llm_client,
43
+ economy_client=economy_client,
44
+ messager=messager,
45
+ simulator=simulator,
46
+ mlflow_client=mlflow_client,
47
+ memory=memory,
48
+ avro_file=avro_file,
49
+ )
50
+
51
+ async def bind_to_simulator(self):
52
+ await self._bind_to_simulator()
53
+ await self._bind_to_economy()
54
+
55
+ async def _bind_to_simulator(self):
56
+ """
57
+ Bind Agent to Simulator
58
+
59
+ Args:
60
+ person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
61
+ """
62
+ if self._simulator is None:
63
+ logger.warning("Simulator is not set")
64
+ return
65
+ if not self._has_bound_to_simulator:
66
+ FROM_MEMORY_KEYS = {
67
+ "attribute",
68
+ "home",
69
+ "work",
70
+ "vehicle_attribute",
71
+ "bus_attribute",
72
+ "pedestrian_attribute",
73
+ "bike_attribute",
74
+ }
75
+ simulator = self.simulator
76
+ memory = self.memory
77
+ person_id = await memory.get("id")
78
+ # ATTENTION:模拟器分配的id从0开始
79
+ if person_id >= 0:
80
+ await simulator.get_person(person_id)
81
+ logger.debug(f"Binding to Person `{person_id}` already in Simulator")
82
+ else:
83
+ dict_person = deepcopy(self._person_template)
84
+ for _key in FROM_MEMORY_KEYS:
85
+ try:
86
+ _value = await memory.get(_key)
87
+ if _value:
88
+ dict_person[_key] = _value
89
+ except KeyError as e:
90
+ continue
91
+ resp = await simulator.add_person(
92
+ dict2pb(dict_person, person_pb2.Person())
93
+ )
94
+ person_id = resp["person_id"]
95
+ await memory.update("id", person_id, protect_llm_read_only_fields=False)
96
+ logger.debug(f"Binding to Person `{person_id}` just added to Simulator")
97
+ # 防止模拟器还没有到prepare阶段导致get_person出错
98
+ self._has_bound_to_simulator = True
99
+ self._agent_id = person_id
100
+ self.memory.set_agent_id(person_id)
101
+
102
+ async def _bind_to_economy(self):
103
+ if self._economy_client is None:
104
+ logger.warning("Economy client is not set")
105
+ return
106
+ if not self._has_bound_to_economy:
107
+ if self._has_bound_to_simulator:
108
+ try:
109
+ await self._economy_client.remove_agents([self._agent_id])
110
+ except:
111
+ pass
112
+ person_id = await self.memory.get("id")
113
+ currency = await self.memory.get("currency")
114
+ await self._economy_client.add_agents(
115
+ {
116
+ "id": person_id,
117
+ "currency": currency,
118
+ }
119
+ )
120
+ self._has_bound_to_economy = True
121
+ else:
122
+ logger.debug(
123
+ f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
124
+ )
125
+
126
+ async def handle_gather_message(self, payload: dict):
127
+ """处理收到的消息,识别发送者"""
128
+ # 从消息中解析发送者 ID 和消息内容
129
+ target = payload["target"]
130
+ sender_id = payload["from"]
131
+ content = await self.memory.get(f"{target}")
132
+ payload = {
133
+ "from": self._uuid,
134
+ "content": content,
135
+ }
136
+ await self._send_message(sender_id, payload, "gather")
137
+
138
+
139
+ class InstitutionAgent(Agent):
140
+ """
141
+ InstitutionAgent: 机构智能体类及其定义
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ name: str,
147
+ llm_client: Optional[LLM] = None,
148
+ simulator: Optional[Simulator] = None,
149
+ mlflow_client: Optional[MlflowClient] = None,
150
+ memory: Optional[Memory] = None,
151
+ economy_client: Optional[EconomyClient] = None,
152
+ messager: Optional[Messager] = None, # type:ignore
153
+ avro_file: Optional[dict] = None,
154
+ ) -> None:
155
+ super().__init__(
156
+ name=name,
157
+ type=AgentType.Institution,
158
+ llm_client=llm_client,
159
+ economy_client=economy_client,
160
+ mlflow_client=mlflow_client,
161
+ messager=messager,
162
+ simulator=simulator,
163
+ memory=memory,
164
+ avro_file=avro_file,
165
+ )
166
+ # 添加响应收集器
167
+ self._gather_responses: dict[str, asyncio.Future] = {}
168
+
169
+ async def bind_to_simulator(self):
170
+ await self._bind_to_economy()
171
+
172
+ async def _bind_to_economy(self):
173
+ print("Debug:", self._economy_client, self._has_bound_to_economy)
174
+ if self._economy_client is None:
175
+ logger.debug("Economy client is not set")
176
+ return
177
+ if not self._has_bound_to_economy:
178
+ # TODO: More general id generation
179
+ _id = random.randint(100000, 999999)
180
+ self._agent_id = _id
181
+ self.memory.set_agent_id(_id)
182
+ map_header = self.simulator.map.header
183
+ # TODO: remove random position assignment
184
+ await self.memory.update(
185
+ "position",
186
+ {
187
+ "xy_position": {
188
+ "x": float(
189
+ random.randrange(
190
+ start=int(map_header["west"]),
191
+ stop=int(map_header["east"]),
192
+ )
193
+ ),
194
+ "y": float(
195
+ random.randrange(
196
+ start=int(map_header["south"]),
197
+ stop=int(map_header["north"]),
198
+ )
199
+ ),
200
+ }
201
+ },
202
+ protect_llm_read_only_fields=False,
203
+ )
204
+ await self.memory.update("id", _id, protect_llm_read_only_fields=False)
205
+ try:
206
+ await self._economy_client.remove_orgs([self._agent_id])
207
+ except:
208
+ pass
209
+ try:
210
+ _memory = self.memory
211
+ _id = await _memory.get("id")
212
+ _type = await _memory.get("type")
213
+ try:
214
+ nominal_gdp = await _memory.get("nominal_gdp")
215
+ except:
216
+ nominal_gdp = []
217
+ try:
218
+ real_gdp = await _memory.get("real_gdp")
219
+ except:
220
+ real_gdp = []
221
+ try:
222
+ unemployment = await _memory.get("unemployment")
223
+ except:
224
+ unemployment = []
225
+ try:
226
+ wages = await _memory.get("wages")
227
+ except:
228
+ wages = []
229
+ try:
230
+ prices = await _memory.get("prices")
231
+ except:
232
+ prices = []
233
+ try:
234
+ inventory = await _memory.get("inventory")
235
+ except:
236
+ inventory = 0
237
+ try:
238
+ price = await _memory.get("price")
239
+ except:
240
+ price = 0
241
+ try:
242
+ currency = await _memory.get("currency")
243
+ except:
244
+ currency = 0.0
245
+ try:
246
+ interest_rate = await _memory.get("interest_rate")
247
+ except:
248
+ interest_rate = 0.0
249
+ try:
250
+ bracket_cutoffs = await _memory.get("bracket_cutoffs")
251
+ except:
252
+ bracket_cutoffs = []
253
+ try:
254
+ bracket_rates = await _memory.get("bracket_rates")
255
+ except:
256
+ bracket_rates = []
257
+ await self._economy_client.add_orgs(
258
+ {
259
+ "id": _id,
260
+ "type": _type,
261
+ "nominal_gdp": nominal_gdp,
262
+ "real_gdp": real_gdp,
263
+ "unemployment": unemployment,
264
+ "wages": wages,
265
+ "prices": prices,
266
+ "inventory": inventory,
267
+ "price": price,
268
+ "currency": currency,
269
+ "interest_rate": interest_rate,
270
+ "bracket_cutoffs": bracket_cutoffs,
271
+ "bracket_rates": bracket_rates,
272
+ }
273
+ )
274
+ except Exception as e:
275
+ logger.error(f"Failed to bind to Economy: {e}")
276
+ self._has_bound_to_economy = True
277
+
278
+ async def handle_gather_message(self, payload: dict):
279
+ """处理收到的消息,识别发送者"""
280
+ content = payload["content"]
281
+ sender_id = payload["from"]
282
+
283
+ # 将响应存储到对应的Future中
284
+ response_key = str(sender_id)
285
+ if response_key in self._gather_responses:
286
+ self._gather_responses[response_key].set_result(
287
+ {
288
+ "from": sender_id,
289
+ "content": content,
290
+ }
291
+ )
292
+
293
+ async def gather_messages(self, agent_uuids: list[str], target: str) -> list[dict]:
294
+ """从多个智能体收集消息
295
+
296
+ Args:
297
+ agent_uuids: 目标智能体UUID列表
298
+ target: 要收集的信息类型
299
+
300
+ Returns:
301
+ list[dict]: 收集到的所有响应
302
+ """
303
+ # 为每个agent创建Future
304
+ futures = {}
305
+ for agent_uuid in agent_uuids:
306
+ futures[agent_uuid] = asyncio.Future()
307
+ self._gather_responses[agent_uuid] = futures[agent_uuid]
308
+
309
+ # 发送gather请求
310
+ payload = {
311
+ "from": self._uuid,
312
+ "target": target,
313
+ }
314
+ for agent_uuid in agent_uuids:
315
+ await self._send_message(agent_uuid, payload, "gather")
316
+
317
+ try:
318
+ # 等待所有响应
319
+ responses = await asyncio.gather(*futures.values())
320
+ return responses
321
+ finally:
322
+ # 清理Future
323
+ for key in futures:
324
+ self._gather_responses.pop(key, None)