pycityagent 2.0.0a3__py3-none-any.whl → 2.0.0a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent.py +140 -5
- pycityagent/environment/simulator.py +9 -10
- pycityagent/memory/const.py +1 -0
- pycityagent/message/__init__.py +3 -0
- pycityagent/message/messager.py +64 -0
- pycityagent/simulation/agentgroup.py +109 -0
- pycityagent/simulation/simulation.py +42 -37
- pycityagent/workflow/tool.py +4 -63
- pycityagent/workflow/trigger.py +1 -1
- {pycityagent-2.0.0a3.dist-info → pycityagent-2.0.0a5.dist-info}/METADATA +3 -1
- {pycityagent-2.0.0a3.dist-info → pycityagent-2.0.0a5.dist-info}/RECORD +12 -11
- pycityagent/config.py +0 -0
- {pycityagent-2.0.0a3.dist-info → pycityagent-2.0.0a5.dist-info}/WHEEL +0 -0
pycityagent/agent.py
CHANGED
@@ -1,10 +1,19 @@
|
|
1
1
|
"""智能体模板类及其定义"""
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
|
+
import asyncio
|
5
|
+
from copy import deepcopy
|
4
6
|
from datetime import datetime
|
5
7
|
from enum import Enum
|
8
|
+
import logging
|
6
9
|
from typing import Dict, List, Optional
|
7
10
|
|
11
|
+
from pycityagent.environment.sim.person_service import PersonService
|
12
|
+
from mosstool.util.format_converter import dict2pb
|
13
|
+
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
14
|
+
|
15
|
+
from pycityagent.message.messager import Messager
|
16
|
+
|
8
17
|
from .economy import EconomyClient
|
9
18
|
from .environment import Simulator
|
10
19
|
from .llm import LLM
|
@@ -35,6 +44,7 @@ class Agent(ABC):
|
|
35
44
|
type: AgentType = AgentType.Unspecified,
|
36
45
|
llm_client: Optional[LLM] = None,
|
37
46
|
economy_client: Optional[EconomyClient] = None,
|
47
|
+
messager: Optional[Messager] = None,
|
38
48
|
simulator: Optional[Simulator] = None,
|
39
49
|
memory: Optional[Memory] = None,
|
40
50
|
) -> None:
|
@@ -46,6 +56,7 @@ class Agent(ABC):
|
|
46
56
|
type (AgentType): The type of the agent. Defaults to `AgentType.Unspecified`
|
47
57
|
llm_client (LLM): The language model client. Defaults to None.
|
48
58
|
economy_client (EconomyClient): The `EconomySim` client. Defaults to None.
|
59
|
+
messager (Messager, optional): The messager object. Defaults to None.
|
49
60
|
simulator (Simulator, optional): The simulator object. Defaults to None.
|
50
61
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
51
62
|
"""
|
@@ -53,16 +64,36 @@ class Agent(ABC):
|
|
53
64
|
self._type = type
|
54
65
|
self._llm_client = llm_client
|
55
66
|
self._economy_client = economy_client
|
67
|
+
self._messager = messager
|
56
68
|
self._simulator = simulator
|
57
69
|
self._memory = memory
|
58
70
|
self._has_bound_to_simulator = False
|
71
|
+
self._has_bound_to_economy = False
|
72
|
+
self._blocked = False
|
59
73
|
self._interview_history: List[Dict] = [] # 存储采访历史
|
74
|
+
self._person_template = PersonService.default_dict_person()
|
60
75
|
|
61
|
-
def
|
76
|
+
def __getstate__(self):
|
77
|
+
state = self.__dict__.copy()
|
78
|
+
# 排除锁对象
|
79
|
+
del state['_llm_client']
|
80
|
+
return state
|
81
|
+
|
82
|
+
async def bind_to_simulator(self):
|
83
|
+
await self._bind_to_simulator()
|
84
|
+
await self._bind_to_economy()
|
85
|
+
|
86
|
+
def set_messager(self, messager: Messager):
|
62
87
|
"""
|
63
|
-
Set the
|
88
|
+
Set the messager of the agent.
|
64
89
|
"""
|
65
|
-
self.
|
90
|
+
self._messager = messager
|
91
|
+
|
92
|
+
def set_llm_client(self, llm_client: LLM):
|
93
|
+
"""
|
94
|
+
Set the llm_client of the agent.
|
95
|
+
"""
|
96
|
+
self._llm_client = llm_client
|
66
97
|
|
67
98
|
def set_simulator(self, simulator: Simulator):
|
68
99
|
"""
|
@@ -76,6 +107,83 @@ class Agent(ABC):
|
|
76
107
|
"""
|
77
108
|
self._economy_client = economy_client
|
78
109
|
|
110
|
+
def set_memory(self, memory: Memory):
|
111
|
+
"""
|
112
|
+
Set the memory of the agent.
|
113
|
+
"""
|
114
|
+
self._memory = memory
|
115
|
+
|
116
|
+
async def _bind_to_simulator(self):
|
117
|
+
"""
|
118
|
+
Bind Agent to Simulator
|
119
|
+
|
120
|
+
Args:
|
121
|
+
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
122
|
+
"""
|
123
|
+
if self._simulator is None:
|
124
|
+
logging.warning("Simulator is not set")
|
125
|
+
return
|
126
|
+
if not self._has_bound_to_simulator:
|
127
|
+
FROM_MEMORY_KEYS = {
|
128
|
+
"attribute",
|
129
|
+
"home",
|
130
|
+
"work",
|
131
|
+
"vehicle_attribute",
|
132
|
+
"bus_attribute",
|
133
|
+
"pedestrian_attribute",
|
134
|
+
"bike_attribute",
|
135
|
+
}
|
136
|
+
simulator = self._simulator
|
137
|
+
memory = self._memory
|
138
|
+
person_id = await memory.get("id")
|
139
|
+
# ATTENTION:模拟器分配的id从0开始
|
140
|
+
if person_id >= 0:
|
141
|
+
await simulator.get_person(person_id)
|
142
|
+
logging.debug(f"Binding to Person `{person_id}` already in Simulator")
|
143
|
+
else:
|
144
|
+
dict_person = deepcopy(self._person_template)
|
145
|
+
for _key in FROM_MEMORY_KEYS:
|
146
|
+
try:
|
147
|
+
_value = await memory.get(_key)
|
148
|
+
if _value:
|
149
|
+
dict_person[_key] = _value
|
150
|
+
except KeyError as e:
|
151
|
+
continue
|
152
|
+
resp = await simulator.add_person(
|
153
|
+
dict2pb(dict_person, person_pb2.Person())
|
154
|
+
)
|
155
|
+
person_id = resp["person_id"]
|
156
|
+
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
157
|
+
logging.debug(
|
158
|
+
f"Binding to Person `{person_id}` just added to Simulator"
|
159
|
+
)
|
160
|
+
# 防止模拟器还没有到prepare阶段导致get_person出错
|
161
|
+
self._has_bound_to_simulator = True
|
162
|
+
self._agent_id = person_id
|
163
|
+
|
164
|
+
async def _bind_to_economy(self):
|
165
|
+
if self._economy_client is None:
|
166
|
+
logging.warning("Economy client is not set")
|
167
|
+
return
|
168
|
+
if not self._has_bound_to_economy:
|
169
|
+
if self._has_bound_to_simulator:
|
170
|
+
try:
|
171
|
+
await self._economy_client.remove_agents([self._agent_id])
|
172
|
+
except:
|
173
|
+
pass
|
174
|
+
person_id = await self._memory.get("id")
|
175
|
+
await self._economy_client.add_agents(
|
176
|
+
{
|
177
|
+
"id": person_id,
|
178
|
+
"currency": await self._memory.get("currency"),
|
179
|
+
}
|
180
|
+
)
|
181
|
+
self._has_bound_to_economy = True
|
182
|
+
else:
|
183
|
+
logging.debug(
|
184
|
+
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
185
|
+
)
|
186
|
+
|
79
187
|
@property
|
80
188
|
def LLM(self):
|
81
189
|
"""The Agent's LLM"""
|
@@ -163,12 +271,33 @@ class Agent(ABC):
|
|
163
271
|
def get_interview_history(self) -> List[Dict]:
|
164
272
|
"""获取采访历史记录"""
|
165
273
|
return self._interview_history
|
274
|
+
|
275
|
+
async def handle_message(self, payload: str):
|
276
|
+
"""处理收到的消息,识别发送者"""
|
277
|
+
# 从消息中解析发送者 ID 和消息内容
|
278
|
+
message, sender_id = payload.split("|from:")
|
279
|
+
print(f"Agent {self._agent_id} received message: '{message}' from Agent {sender_id}")
|
280
|
+
|
281
|
+
async def send_message(self, to_agent_id: int, message: str):
|
282
|
+
"""通过 Messager 发送消息,附带发送者的 ID"""
|
283
|
+
if self._messager is None:
|
284
|
+
raise RuntimeError("Messager is not set")
|
285
|
+
topic = f"/agents/{to_agent_id}/chat"
|
286
|
+
await self._messager.send_message(topic, message, self._agent_id)
|
166
287
|
|
167
288
|
@abstractmethod
|
168
289
|
async def forward(self) -> None:
|
169
290
|
"""智能体行为逻辑"""
|
170
291
|
raise NotImplementedError
|
171
292
|
|
293
|
+
async def run(self) -> None:
|
294
|
+
"""
|
295
|
+
统一的Agent执行入口
|
296
|
+
当_blocked为True时,不执行forward方法
|
297
|
+
"""
|
298
|
+
if not self._blocked:
|
299
|
+
await self.forward()
|
300
|
+
|
172
301
|
|
173
302
|
class CitizenAgent(Agent):
|
174
303
|
"""
|
@@ -181,12 +310,15 @@ class CitizenAgent(Agent):
|
|
181
310
|
llm_client: Optional[LLM] = None,
|
182
311
|
simulator: Optional[Simulator] = None,
|
183
312
|
memory: Optional[Memory] = None,
|
313
|
+
economy_client: Optional[EconomyClient] = None,
|
314
|
+
messager: Optional[Messager] = None,
|
184
315
|
) -> None:
|
185
316
|
super().__init__(
|
186
317
|
name,
|
187
318
|
AgentType.Citizen,
|
188
319
|
llm_client,
|
189
|
-
|
320
|
+
economy_client,
|
321
|
+
messager,
|
190
322
|
simulator,
|
191
323
|
memory,
|
192
324
|
)
|
@@ -203,12 +335,15 @@ class InstitutionAgent(Agent):
|
|
203
335
|
llm_client: Optional[LLM] = None,
|
204
336
|
simulator: Optional[Simulator] = None,
|
205
337
|
memory: Optional[Memory] = None,
|
338
|
+
economy_client: Optional[EconomyClient] = None,
|
339
|
+
messager: Optional[Messager] = None,
|
206
340
|
) -> None:
|
207
341
|
super().__init__(
|
208
342
|
name,
|
209
343
|
AgentType.Institution,
|
210
344
|
llm_client,
|
211
|
-
|
345
|
+
economy_client,
|
346
|
+
messager,
|
212
347
|
simulator,
|
213
348
|
memory,
|
214
349
|
)
|
@@ -74,7 +74,6 @@ class Simulator:
|
|
74
74
|
self._client = CityClient(config['simulator']['server'], secure=False)
|
75
75
|
else:
|
76
76
|
logging.warning("No simulator config found, no simulator client will be used")
|
77
|
-
|
78
77
|
self.map = SimMap(
|
79
78
|
mongo_uri=_mongo_uri,
|
80
79
|
mongo_db=_mongo_db,
|
@@ -107,7 +106,7 @@ class Simulator:
|
|
107
106
|
self.set_poi_tree()
|
108
107
|
|
109
108
|
# * Agent相关
|
110
|
-
def
|
109
|
+
def find_agents_by_area(self, req: dict, status=None):
|
111
110
|
"""
|
112
111
|
通过区域范围查找agent/person
|
113
112
|
Get agents/persons in the provided area
|
@@ -148,7 +147,7 @@ class Simulator:
|
|
148
147
|
self.tree_id_2_poi_and_catg = tree_id_2_poi_and_catg
|
149
148
|
self.pois_tree = STRtree(poi_geos)
|
150
149
|
|
151
|
-
def
|
150
|
+
def get_poi_categories(
|
152
151
|
self,
|
153
152
|
center: Optional[Union[tuple[float, float], Point]] = None,
|
154
153
|
radius: Optional[float] = None,
|
@@ -165,7 +164,7 @@ class Simulator:
|
|
165
164
|
categories.append(catg.split("|")[-1])
|
166
165
|
return list(set(categories))
|
167
166
|
|
168
|
-
async def
|
167
|
+
async def get_time(
|
169
168
|
self, format_time: bool = False, format: str = "%H:%M:%S"
|
170
169
|
) -> Union[int, str]:
|
171
170
|
"""
|
@@ -192,19 +191,19 @@ class Simulator:
|
|
192
191
|
else:
|
193
192
|
return t_sec["t"]
|
194
193
|
|
195
|
-
async def
|
194
|
+
async def get_person(self, person_id: int) -> dict:
|
196
195
|
return await self._client.person_service.GetPerson(
|
197
196
|
req={"person_id": person_id}
|
198
197
|
) # type:ignore
|
199
198
|
|
200
|
-
async def
|
199
|
+
async def add_person(self, person: Any) -> dict:
|
201
200
|
if isinstance(person, person_pb2.Person):
|
202
201
|
req = person_service.AddPersonRequest(person=person)
|
203
202
|
else:
|
204
203
|
req = person
|
205
204
|
return await self._client.person_service.AddPerson(req) # type:ignore
|
206
205
|
|
207
|
-
async def
|
206
|
+
async def set_aoi_schedules(
|
208
207
|
self,
|
209
208
|
person_id: int,
|
210
209
|
target_positions: Union[
|
@@ -213,7 +212,7 @@ class Simulator:
|
|
213
212
|
departure_times: Optional[list[float]] = None,
|
214
213
|
modes: Optional[list[TripMode]] = None,
|
215
214
|
):
|
216
|
-
cur_time = float(await self.
|
215
|
+
cur_time = float(await self.get_time())
|
217
216
|
if not isinstance(target_positions, list):
|
218
217
|
target_positions = [target_positions]
|
219
218
|
if departure_times is None:
|
@@ -254,7 +253,7 @@ class Simulator:
|
|
254
253
|
req = {"person_id": person_id, "schedules": _schedules}
|
255
254
|
await self._client.person_service.SetSchedule(req)
|
256
255
|
|
257
|
-
async def
|
256
|
+
async def reset_person_position(
|
258
257
|
self,
|
259
258
|
person_id: int,
|
260
259
|
aoi_id: Optional[int] = None,
|
@@ -291,7 +290,7 @@ class Simulator:
|
|
291
290
|
f"Neither aoi or lane pos provided for person {person_id} position reset!!"
|
292
291
|
)
|
293
292
|
|
294
|
-
def
|
293
|
+
def get_around_poi(
|
295
294
|
self,
|
296
295
|
center: Union[tuple[float, float], Point],
|
297
296
|
radius: float,
|
pycityagent/memory/const.py
CHANGED
pycityagent/message/__init__.py
CHANGED
@@ -0,0 +1,64 @@
|
|
1
|
+
import asyncio
|
2
|
+
from collections import defaultdict
|
3
|
+
import logging
|
4
|
+
import math
|
5
|
+
from aiomqtt import Client
|
6
|
+
|
7
|
+
class Messager:
|
8
|
+
def __init__(self, broker, port=1883, timeout=math.inf):
|
9
|
+
self.client = Client(broker, port=port, timeout=timeout)
|
10
|
+
self.connected = False # 是否已连接标志
|
11
|
+
self.message_queue = asyncio.Queue() # 用于存储接收到的消息
|
12
|
+
self.subscribers = {} # 订阅者信息,topic -> Agent 映射
|
13
|
+
|
14
|
+
async def connect(self):
|
15
|
+
try:
|
16
|
+
await self.client.__aenter__()
|
17
|
+
self.connected = True
|
18
|
+
logging.info("Connected to MQTT Broker")
|
19
|
+
except Exception as e:
|
20
|
+
self.connected = False
|
21
|
+
logging.error(f"Failed to connect to MQTT Broker: {e}")
|
22
|
+
|
23
|
+
async def disconnect(self):
|
24
|
+
await self.client.__aexit__(None, None, None)
|
25
|
+
self.connected = False
|
26
|
+
logging.info("Disconnected from MQTT Broker")
|
27
|
+
|
28
|
+
def is_connected(self):
|
29
|
+
"""检查是否成功连接到 Broker"""
|
30
|
+
return self.connected
|
31
|
+
|
32
|
+
async def subscribe(self, topic, agent):
|
33
|
+
if not self.is_connected():
|
34
|
+
logging.error(f"Cannot subscribe to {topic} because not connected to the Broker.")
|
35
|
+
return
|
36
|
+
await self.client.subscribe(topic)
|
37
|
+
self.subscribers[topic] = agent
|
38
|
+
logging.info(f"Subscribed to {topic} for Agent {agent._agent_id}")
|
39
|
+
|
40
|
+
async def receive_messages(self):
|
41
|
+
"""监听并将消息存入队列"""
|
42
|
+
async for message in self.client.messages:
|
43
|
+
await self.message_queue.put(message)
|
44
|
+
|
45
|
+
async def fetch_messages(self):
|
46
|
+
"""从队列中批量获取消息"""
|
47
|
+
messages = []
|
48
|
+
while not self.message_queue.empty():
|
49
|
+
messages.append(await self.message_queue.get())
|
50
|
+
return messages
|
51
|
+
|
52
|
+
async def send_message(self, topic: str, payload: str, sender_id: int):
|
53
|
+
"""通过 Messager 发送消息,包含发送者 ID"""
|
54
|
+
# 构造消息,payload 中加入 sender_id 以便接收者识别
|
55
|
+
message = f"{payload}|from:{sender_id}"
|
56
|
+
await self.client.publish(topic, message)
|
57
|
+
logging.info(f"Message sent to {topic}: {message}")
|
58
|
+
|
59
|
+
async def start_listening(self):
|
60
|
+
"""启动消息监听任务"""
|
61
|
+
if self.is_connected():
|
62
|
+
asyncio.create_task(self.receive_messages())
|
63
|
+
else:
|
64
|
+
logging.error("Cannot start listening because not connected to the Broker.")
|
@@ -0,0 +1,109 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import ray
|
4
|
+
from pycityagent.agent import Agent
|
5
|
+
from pycityagent.economy.econ_client import EconomyClient
|
6
|
+
from pycityagent.environment.simulator import Simulator
|
7
|
+
from pycityagent.llm.llm import LLM
|
8
|
+
from pycityagent.llm.llmconfig import LLMConfig
|
9
|
+
from pycityagent.message import Messager
|
10
|
+
|
11
|
+
@ray.remote
|
12
|
+
class AgentGroup:
|
13
|
+
def __init__(self, agents: list[Agent], config: dict):
|
14
|
+
self.agents = agents
|
15
|
+
self.config = config
|
16
|
+
self.messager = Messager(config["simulator_request"]["mqtt"]["server"], config["simulator_request"]["mqtt"]["port"])
|
17
|
+
self.initialized = False
|
18
|
+
|
19
|
+
# Step:1 prepare LLM client
|
20
|
+
llmConfig = LLMConfig(config["llm_request"])
|
21
|
+
logging.info("-----Creating LLM client in remote...")
|
22
|
+
self.llm = LLM(llmConfig)
|
23
|
+
|
24
|
+
# Step:2 prepare Simulator
|
25
|
+
logging.info("-----Creating Simulator in remote...")
|
26
|
+
self.simulator = Simulator(config["simulator_request"])
|
27
|
+
|
28
|
+
# Step:3 prepare Economy client
|
29
|
+
logging.info("-----Creating Economy client in remote...")
|
30
|
+
self.economy_client = EconomyClient(config["simulator_request"]["economy"]['server'])
|
31
|
+
|
32
|
+
for agent in self.agents:
|
33
|
+
agent.set_llm_client(self.llm)
|
34
|
+
agent.set_simulator(self.simulator)
|
35
|
+
agent.set_economy_client(self.economy_client)
|
36
|
+
agent.set_messager(self.messager)
|
37
|
+
|
38
|
+
async def init_agents(self):
|
39
|
+
for agent in self.agents:
|
40
|
+
await agent.bind_to_simulator()
|
41
|
+
self.id2agent = {agent._agent_id: agent for agent in self.agents}
|
42
|
+
await self.messager.connect()
|
43
|
+
if self.messager.is_connected():
|
44
|
+
await self.messager.start_listening()
|
45
|
+
for agent in self.agents:
|
46
|
+
agent.set_messager(self.messager)
|
47
|
+
topic = f"/agents/{agent._agent_id}/chat"
|
48
|
+
await self.messager.subscribe(topic, agent)
|
49
|
+
self.initialized = True
|
50
|
+
|
51
|
+
async def step(self):
|
52
|
+
if not self.initialized:
|
53
|
+
await self.init_agents()
|
54
|
+
|
55
|
+
# Step 1: 如果 Messager 无法连接,则跳过消息接收
|
56
|
+
if not self.messager.is_connected():
|
57
|
+
logging.warning("Messager is not connected. Skipping message processing.")
|
58
|
+
# 跳过接收和分发消息
|
59
|
+
tasks = [agent.run() for agent in self.agents]
|
60
|
+
await asyncio.gather(*tasks)
|
61
|
+
return
|
62
|
+
|
63
|
+
# Step 2: 从 Messager 获取消息
|
64
|
+
messages = await self.messager.fetch_messages()
|
65
|
+
|
66
|
+
# Step 3: 分发消息到对应的 Agent
|
67
|
+
for message in messages:
|
68
|
+
topic = message.topic.value
|
69
|
+
payload = message.payload
|
70
|
+
|
71
|
+
# 添加解码步骤,将bytes转换为str
|
72
|
+
if isinstance(payload, bytes):
|
73
|
+
payload = payload.decode('utf-8')
|
74
|
+
|
75
|
+
# 提取 agent_id(主题格式为 "/agents/{agent_id}/chat")
|
76
|
+
_, agent_id, _ = topic.strip("/").split("/")
|
77
|
+
agent_id = int(agent_id)
|
78
|
+
|
79
|
+
if agent_id in self.id2agent:
|
80
|
+
agent = self.id2agent[agent_id]
|
81
|
+
await agent.handle_message(payload)
|
82
|
+
|
83
|
+
# Step 4: 调用每个 Agent 的运行逻辑
|
84
|
+
tasks = [agent.run() for agent in self.agents]
|
85
|
+
await asyncio.gather(*tasks)
|
86
|
+
|
87
|
+
async def run(self, day: int = 1):
|
88
|
+
"""运行模拟器
|
89
|
+
|
90
|
+
Args:
|
91
|
+
day: 运行天数,默认为1天
|
92
|
+
"""
|
93
|
+
try:
|
94
|
+
# 获取开始时间
|
95
|
+
start_time = await self.simulator.get_time()
|
96
|
+
# 计算结束时间(秒)
|
97
|
+
end_time = start_time + day * 24 * 3600 # 将天数转换为秒
|
98
|
+
|
99
|
+
while True:
|
100
|
+
current_time = await self.simulator.get_time()
|
101
|
+
if current_time >= end_time:
|
102
|
+
break
|
103
|
+
|
104
|
+
await self.step()
|
105
|
+
|
106
|
+
except Exception as e:
|
107
|
+
logging.error(f"模拟器运行错误: {str(e)}")
|
108
|
+
raise
|
109
|
+
|
@@ -6,33 +6,33 @@ import random
|
|
6
6
|
from typing import Dict, List, Optional, Callable
|
7
7
|
from mosstool.map._map_util.const import AOI_START_ID
|
8
8
|
|
9
|
-
from pycityagent.
|
9
|
+
from pycityagent.economy.econ_client import EconomyClient
|
10
|
+
from pycityagent.environment.simulator import Simulator
|
10
11
|
from pycityagent.memory.memory import Memory
|
12
|
+
from pycityagent.message.messager import Messager
|
11
13
|
|
12
14
|
from ..agent import Agent
|
13
|
-
from ..environment import Simulator
|
14
15
|
from .interview import InterviewManager
|
15
16
|
from .survey import QuestionType, SurveyManager
|
16
17
|
from .ui import InterviewUI
|
17
|
-
|
18
|
+
from .agentgroup import AgentGroup
|
18
19
|
logger = logging.getLogger(__name__)
|
19
20
|
|
20
21
|
|
21
22
|
class AgentSimulation:
|
22
23
|
"""城市智能体模拟器"""
|
23
|
-
def __init__(self, agent_class: type[Agent],
|
24
|
+
def __init__(self, agent_class: type[Agent], config: dict, agent_prefix: str = "agent_"):
|
24
25
|
"""
|
25
26
|
Args:
|
26
27
|
agent_class: 智能体类
|
27
|
-
|
28
|
-
llm: 语言模型
|
28
|
+
config: 配置
|
29
29
|
agent_prefix: 智能体名称前缀
|
30
30
|
"""
|
31
31
|
self.agent_class = agent_class
|
32
|
-
self.
|
33
|
-
self.llm = llm
|
32
|
+
self.config = config
|
34
33
|
self.agent_prefix = agent_prefix
|
35
34
|
self._agents: Dict[str, Agent] = {}
|
35
|
+
self._groups: Dict[str, AgentGroup] = {}
|
36
36
|
self._interview_manager = InterviewManager()
|
37
37
|
self._interview_lock = asyncio.Lock()
|
38
38
|
self._start_time = datetime.now()
|
@@ -42,16 +42,17 @@ class AgentSimulation:
|
|
42
42
|
self._blocked_agents: List[str] = [] # 新增:持续阻塞的智能体列表
|
43
43
|
self._survey_manager = SurveyManager()
|
44
44
|
|
45
|
-
def init_agents(self, agent_count: int, memory_config_func: Callable = None) -> None:
|
45
|
+
async def init_agents(self, agent_count: int, group_size: int = 1000, memory_config_func: Callable = None) -> None:
|
46
46
|
"""初始化智能体
|
47
47
|
|
48
48
|
Args:
|
49
|
-
agent_count:
|
49
|
+
agent_count: 要创建的总智能体数量
|
50
|
+
group_size: 每个组的智能体数量,每一个组为一个独立的ray actor
|
50
51
|
memory_config_func: 返回Memory配置的函数,需要返回(EXTRA_ATTRIBUTES, PROFILE, BASE)元组
|
51
52
|
"""
|
52
53
|
if memory_config_func is None:
|
53
|
-
memory_config_func = self.default_memory_config_func
|
54
|
-
|
54
|
+
memory_config_func = self.default_memory_config_func
|
55
|
+
|
55
56
|
for i in range(agent_count):
|
56
57
|
agent_name = f"{self.agent_prefix}{i}"
|
57
58
|
|
@@ -66,13 +67,25 @@ class AgentSimulation:
|
|
66
67
|
# 创建智能体时传入Memory配置
|
67
68
|
agent = self.agent_class(
|
68
69
|
name=agent_name,
|
69
|
-
|
70
|
-
llm=self.llm,
|
71
|
-
memory=memory
|
70
|
+
memory=memory,
|
72
71
|
)
|
73
|
-
|
72
|
+
|
74
73
|
self._agents[agent_name] = agent
|
75
74
|
|
75
|
+
# 计算需要的组数,向上取整以处理不足一组的情况
|
76
|
+
num_group = (agent_count + group_size - 1) // group_size
|
77
|
+
|
78
|
+
for i in range(num_group):
|
79
|
+
# 计算当前组的起始和结束索引
|
80
|
+
start_idx = i * group_size
|
81
|
+
end_idx = min((i + 1) * group_size, agent_count)
|
82
|
+
|
83
|
+
# 获取当前组的agents
|
84
|
+
agents = list(self._agents.values())[start_idx:end_idx]
|
85
|
+
group_name = f"{self.agent_prefix}_group_{i}"
|
86
|
+
group = AgentGroup.remote(agents, self.config)
|
87
|
+
self._groups[group_name] = group
|
88
|
+
|
76
89
|
def default_memory_config_func(self):
|
77
90
|
"""默认的Memory配置函数"""
|
78
91
|
EXTRA_ATTRIBUTES = {
|
@@ -100,16 +113,16 @@ class AgentSimulation:
|
|
100
113
|
"family_consumption": random.choice(["low", "medium", "high"]),
|
101
114
|
"personality": random.choice(["outgoint", "introvert", "ambivert", "extrovert"]),
|
102
115
|
"income": random.randint(1000, 10000),
|
116
|
+
"currency": random.randint(10000, 100000),
|
103
117
|
"residence": random.choice(["city", "suburb", "rural"]),
|
104
118
|
"race": random.choice(["Chinese", "American", "British", "French", "German", "Japanese", "Korean", "Russian", "Other"]),
|
105
119
|
"religion": random.choice(["none", "Christian", "Muslim", "Buddhist", "Hindu", "Other"]),
|
106
120
|
"marital_status": random.choice(["not married", "married", "divorced", "widowed"]),
|
107
121
|
}
|
108
122
|
|
109
|
-
aois = self.simulator.aois.keys()
|
110
123
|
BASE = {
|
111
|
-
"home": {"aoi_position": {"aoi_id": random.
|
112
|
-
"work": {"aoi_position": {"aoi_id": random.
|
124
|
+
"home": {"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}},
|
125
|
+
"work": {"aoi_position": {"aoi_id": AOI_START_ID + random.randint(1, 50000)}},
|
113
126
|
}
|
114
127
|
|
115
128
|
return EXTRA_ATTRIBUTES, PROFILE, BASE
|
@@ -177,9 +190,11 @@ class AgentSimulation:
|
|
177
190
|
|
178
191
|
if blocking and agent_name not in self._blocked_agents:
|
179
192
|
self._blocked_agents.append(agent_name)
|
193
|
+
self._agents[agent_name]._blocked = True
|
180
194
|
return f"已阻塞智能体 {agent_name}"
|
181
195
|
elif not blocking and agent_name in self._blocked_agents:
|
182
196
|
self._blocked_agents.remove(agent_name)
|
197
|
+
self._agents[agent_name]._blocked = False
|
183
198
|
return f"已取消阻塞智能体 {agent_name}"
|
184
199
|
|
185
200
|
return f"智能体 {agent_name} 状态未变"
|
@@ -305,7 +320,7 @@ class AgentSimulation:
|
|
305
320
|
prevent_thread_lock=True,
|
306
321
|
quiet=True,
|
307
322
|
)
|
308
|
-
|
323
|
+
logger.info(
|
309
324
|
f"Gradio Frontend is running on http://{server_name}:{server_port}"
|
310
325
|
)
|
311
326
|
|
@@ -313,8 +328,8 @@ class AgentSimulation:
|
|
313
328
|
"""运行一步, 即每个智能体执行一次forward"""
|
314
329
|
try:
|
315
330
|
tasks = []
|
316
|
-
for
|
317
|
-
tasks.append(
|
331
|
+
for group in self._groups.values():
|
332
|
+
tasks.append(group.step.remote())
|
318
333
|
await asyncio.gather(*tasks)
|
319
334
|
except Exception as e:
|
320
335
|
logger.error(f"运行错误: {str(e)}")
|
@@ -331,22 +346,12 @@ class AgentSimulation:
|
|
331
346
|
"""
|
332
347
|
try:
|
333
348
|
# 获取开始时间
|
334
|
-
|
335
|
-
|
336
|
-
|
349
|
+
tasks = []
|
350
|
+
for group in self._groups.values():
|
351
|
+
tasks.append(group.run.remote(day))
|
337
352
|
|
338
|
-
|
339
|
-
current_time = self.simulator.GetTime()
|
340
|
-
if current_time >= end_time:
|
341
|
-
break
|
342
|
-
|
343
|
-
tasks = []
|
344
|
-
for agent in self._agents.values():
|
345
|
-
if agent.name not in self._blocked_agents:
|
346
|
-
tasks.append(agent.forward())
|
347
|
-
|
348
|
-
await asyncio.gather(*tasks)
|
353
|
+
await asyncio.gather(*tasks)
|
349
354
|
|
350
355
|
except Exception as e:
|
351
356
|
logger.error(f"模拟器运行错误: {str(e)}")
|
352
|
-
raise
|
357
|
+
raise
|
pycityagent/workflow/tool.py
CHANGED
@@ -3,9 +3,6 @@ import logging
|
|
3
3
|
from copy import deepcopy
|
4
4
|
from typing import Any, Callable, Dict, List, Optional, Union
|
5
5
|
|
6
|
-
from mosstool.util.format_converter import dict2pb
|
7
|
-
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
8
|
-
|
9
6
|
from ..agent import Agent
|
10
7
|
from ..environment import (LEVEL_ONE_PRE, POI_TYPE_DICT, AoiService,
|
11
8
|
PersonService)
|
@@ -136,63 +133,8 @@ class SencePOI(Tool):
|
|
136
133
|
|
137
134
|
|
138
135
|
class UpdateWithSimulator(Tool):
|
139
|
-
def __init__(
|
140
|
-
|
141
|
-
person_template_func: Callable[[], dict] = PersonService.default_dict_person,
|
142
|
-
) -> None:
|
143
|
-
self.person_template_func = person_template_func
|
144
|
-
|
145
|
-
async def _bind_to_simulator(
|
146
|
-
self,
|
147
|
-
):
|
148
|
-
"""
|
149
|
-
Bind Agent to Simulator
|
150
|
-
|
151
|
-
Args:
|
152
|
-
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
153
|
-
"""
|
154
|
-
agent = self.agent
|
155
|
-
if agent._simulator is None:
|
156
|
-
return
|
157
|
-
if not agent._has_bound_to_simulator:
|
158
|
-
FROM_MEMORY_KEYS = {
|
159
|
-
"attribute",
|
160
|
-
"home",
|
161
|
-
"work",
|
162
|
-
"vehicle_attribute",
|
163
|
-
"bus_attribute",
|
164
|
-
"pedestrian_attribute",
|
165
|
-
"bike_attribute",
|
166
|
-
}
|
167
|
-
simulator = agent.simulator
|
168
|
-
memory = agent.memory
|
169
|
-
person_id = await memory.get("id")
|
170
|
-
# ATTENTION:模拟器分配的id从0开始
|
171
|
-
if person_id >= 0:
|
172
|
-
await simulator.GetPerson(person_id)
|
173
|
-
logging.debug(f"Binding to Person `{person_id}` already in Simulator")
|
174
|
-
else:
|
175
|
-
dict_person = deepcopy(self.person_template_func())
|
176
|
-
for _key in FROM_MEMORY_KEYS:
|
177
|
-
try:
|
178
|
-
_value = await memory.get(_key)
|
179
|
-
if _value:
|
180
|
-
dict_person[_key] = _value
|
181
|
-
except KeyError as e:
|
182
|
-
continue
|
183
|
-
resp = await simulator.AddPerson(
|
184
|
-
dict2pb(dict_person, person_pb2.Person())
|
185
|
-
)
|
186
|
-
person_id = resp["person_id"]
|
187
|
-
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
188
|
-
logging.debug(
|
189
|
-
f"Binding to Person `{person_id}` just added to Simulator"
|
190
|
-
)
|
191
|
-
# 防止模拟器还没有到prepare阶段导致GetPerson出错
|
192
|
-
await asyncio.sleep(5)
|
193
|
-
agent._has_bound_to_simulator = True
|
194
|
-
else:
|
195
|
-
pass
|
136
|
+
def __init__(self) -> None:
|
137
|
+
pass
|
196
138
|
|
197
139
|
async def _update_motion_with_sim(
|
198
140
|
self,
|
@@ -205,7 +147,7 @@ class UpdateWithSimulator(Tool):
|
|
205
147
|
simulator = agent.simulator
|
206
148
|
memory = agent.memory
|
207
149
|
person_id = await memory.get("id")
|
208
|
-
resp = await simulator.
|
150
|
+
resp = await simulator.get_person(person_id)
|
209
151
|
resp_dict = resp["person"]
|
210
152
|
for k, v in resp_dict.get("motion", {}).items():
|
211
153
|
try:
|
@@ -220,7 +162,6 @@ class UpdateWithSimulator(Tool):
|
|
220
162
|
self,
|
221
163
|
):
|
222
164
|
agent = self.agent
|
223
|
-
await self._bind_to_simulator()
|
224
165
|
await self._update_motion_with_sim()
|
225
166
|
|
226
167
|
|
@@ -237,7 +178,7 @@ class ResetAgentPosition(Tool):
|
|
237
178
|
):
|
238
179
|
agent = self.agent
|
239
180
|
memory = agent.memory
|
240
|
-
await agent.simulator.
|
181
|
+
await agent.simulator.reset_person_position(
|
241
182
|
person_id=await memory.get("id"),
|
242
183
|
aoi_id=aoi_id,
|
243
184
|
poi_id=poi_id,
|
pycityagent/workflow/trigger.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: pycityagent
|
3
|
-
Version: 2.0.
|
3
|
+
Version: 2.0.0a5
|
4
4
|
Summary: LLM-based城市环境agent构建库
|
5
5
|
License: MIT
|
6
6
|
Author: Yuwei Yan
|
@@ -17,6 +17,7 @@ Requires-Dist: Pillow (==11.0.0)
|
|
17
17
|
Requires-Dist: Requests (==2.32.3)
|
18
18
|
Requires-Dist: Shapely (==2.0.6)
|
19
19
|
Requires-Dist: aiohttp (==3.10.10)
|
20
|
+
Requires-Dist: aiomqtt (>=2.3.0,<3.0.0)
|
20
21
|
Requires-Dist: citystreetview (==1.2.4)
|
21
22
|
Requires-Dist: dashscope (==1.14.1)
|
22
23
|
Requires-Dist: geojson (==3.1.0)
|
@@ -32,6 +33,7 @@ Requires-Dist: protobuf (<=4.24.0)
|
|
32
33
|
Requires-Dist: pycitydata (==1.0.0)
|
33
34
|
Requires-Dist: pycityproto (==2.0.7)
|
34
35
|
Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
|
36
|
+
Requires-Dist: ray (>=2.40.0,<3.0.0)
|
35
37
|
Requires-Dist: sidecar (==0.7.0)
|
36
38
|
Requires-Dist: zhipuai (>=2.1.5.20230904,<3.0.0.0)
|
37
39
|
Description-Content-Type: text/markdown
|
@@ -1,6 +1,5 @@
|
|
1
1
|
pycityagent/__init__.py,sha256=n56bWkAUEcvjDsb7LcJpaGjlrriSKPnR0yBhwRfEYBA,212
|
2
|
-
pycityagent/agent.py,sha256=
|
3
|
-
pycityagent/config.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
pycityagent/agent.py,sha256=VvNTTw5MLY8XRMdOIgstBmru7v29r7q3QY7f-UjN9Bg,11086
|
4
3
|
pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
|
5
4
|
pycityagent/economy/econ_client.py,sha256=qQb_kZneEXGBRaS_y5Jdoi95I8GyjKEsDSC4s6V6R7w,10829
|
6
5
|
pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
|
@@ -22,7 +21,7 @@ pycityagent/environment/sim/person_service.py,sha256=nIvOsoBoqOTDYtsiThg07-4ZBgk
|
|
22
21
|
pycityagent/environment/sim/road_service.py,sha256=Pab182YRcrjLw3UcfoD1Hdd6O8XEdi6Q2hJKzFcpSWE,1272
|
23
22
|
pycityagent/environment/sim/sim_env.py,sha256=HI1LcS_FotDKQ6vBnx0e49prXSABOfA20aU9KM-ZkCY,4625
|
24
23
|
pycityagent/environment/sim/social_service.py,sha256=a9mGZm95EFUIKQJUwQi9f8anmtf2SK4XqGfE2W9IXSQ,2001
|
25
|
-
pycityagent/environment/simulator.py,sha256=
|
24
|
+
pycityagent/environment/simulator.py,sha256=bjzn5mTmsQgePZc3hDzoVS4dXAHLaI_yJWjUCTMAdi8,11872
|
26
25
|
pycityagent/environment/utils/__init__.py,sha256=PUx8etr2p_AA7F50ZR7g27odkgv-nOqFZa61ER8-DLg,221
|
27
26
|
pycityagent/environment/utils/base64.py,sha256=hoREzQo3FXMN79pqQLO2jgsDEvudciomyKii7MWljAM,374
|
28
27
|
pycityagent/environment/utils/const.py,sha256=3RMNy7_bE7-23K90j9DFW_tWEzu8s7hSTgKbV-3BFl4,5327
|
@@ -37,17 +36,19 @@ pycityagent/llm/llm.py,sha256=PFbGCVQOLNJEfd5KaZdXABYIoPTFOGrNXIAE2eDRCP4,19191
|
|
37
36
|
pycityagent/llm/llmconfig.py,sha256=WcyyfjP6XVLCDj2aemEM6ogQCWNxbTXQBiR0WjN4E8E,457
|
38
37
|
pycityagent/llm/utils.py,sha256=emio-WhYh6vYb5Sp7KsnW9hKKe_jStjJsGBBJfcAPgM,153
|
39
38
|
pycityagent/memory/__init__.py,sha256=Hs2NhYpIG-lvpwPWwj4DydB1sxtjz7cuA4iDAzCXnjI,243
|
40
|
-
pycityagent/memory/const.py,sha256=
|
39
|
+
pycityagent/memory/const.py,sha256=m9AidLs7Zu28StpvYetclqx-1qQcy3bYvwagcXB3a04,913
|
41
40
|
pycityagent/memory/memory.py,sha256=I5GJ77eQZBcKcWevZQcHUBTd2XNT4CotE6mmvB5mA8I,18170
|
42
41
|
pycityagent/memory/memory_base.py,sha256=U6odnR_5teaT3N2XF0W2McCxPyxtCpw61DjRVIO6p6I,5639
|
43
42
|
pycityagent/memory/profile.py,sha256=dJPeZ91DZcYdR7zmsfSSQx7epvvE1mrvj4-4WAT0tkM,5217
|
44
43
|
pycityagent/memory/self_define.py,sha256=vH2PrRSIykAOc5FebKN2JFQdx17nfR-IWv4s_ZS0YJk,5214
|
45
44
|
pycityagent/memory/state.py,sha256=xqOxo69sZX1M9eCUMbg5BUlSFe9LdMDH3oTWDXC0y64,5148
|
46
45
|
pycityagent/memory/utils.py,sha256=97lkenn-36wgt7uWb3Z39BXdJ5zlEQTQnQBFpoND1gg,879
|
47
|
-
pycityagent/message/__init__.py,sha256=
|
46
|
+
pycityagent/message/__init__.py,sha256=dEigAhYQ8e-suktavjo-Ip-nkEzf659hrFJwoFQ2ljE,54
|
47
|
+
pycityagent/message/messager.py,sha256=Xd02IxSkVtfrC386frB5kiqlDIDhNKuh5pVM4xKIyyw,2463
|
48
48
|
pycityagent/simulation/__init__.py,sha256=SZQzjcGR-zkTDrE81bQuEdKn7yjk4BcZ9m7o1wTc7EE,111
|
49
|
+
pycityagent/simulation/agentgroup.py,sha256=mHW8ggwG1iQ_-dCL5t_8Nlkp4OnCshdGGKygq9jFZX8,3983
|
49
50
|
pycityagent/simulation/interview.py,sha256=mY4Vpz0vgJo4rrMy3TZnwwM-iVDL6J0LgjOxbEuV27E,1173
|
50
|
-
pycityagent/simulation/simulation.py,sha256
|
51
|
+
pycityagent/simulation/simulation.py,sha256=PcWYx8GqCypqDga9JdHg3sRux5lrGcT8b_hvapjKw34,13177
|
51
52
|
pycityagent/simulation/survey/__init__.py,sha256=hFJ0Q1yo4jwKAIXP17sznBSWwm2Lyh3F3W3Lly40wr8,172
|
52
53
|
pycityagent/simulation/survey/manager.py,sha256=DkNrb12Ay7TiGURoyJTFFeUdV1zh6TgRpTmpZOblADw,2158
|
53
54
|
pycityagent/simulation/survey/models.py,sha256=-3EKe-qvkUJ2TH24ow0A_Lc4teGet7pueN2T5mOR_Qc,1308
|
@@ -62,8 +63,8 @@ pycityagent/utils/parsers/parser_base.py,sha256=k6DVqwAMK3jJdOP4IeLE-aFPm3V2F-St
|
|
62
63
|
pycityagent/workflow/__init__.py,sha256=EyCcjB6LyBim-5iAOPe4m2qfvghEPqu1ZdGfy4KPeZ8,551
|
63
64
|
pycityagent/workflow/block.py,sha256=IXfarqIax6yVP_DniU6ZsPTT8QA4aIDnvZbwP_MtRaw,6054
|
64
65
|
pycityagent/workflow/prompt.py,sha256=cmzKEmlzjdwg50uwfnTnN_6xNJA8OVjo5fdmcsaTbdU,2886
|
65
|
-
pycityagent/workflow/tool.py,sha256=
|
66
|
-
pycityagent/workflow/trigger.py,sha256=
|
67
|
-
pycityagent-2.0.
|
68
|
-
pycityagent-2.0.
|
69
|
-
pycityagent-2.0.
|
66
|
+
pycityagent/workflow/tool.py,sha256=2VDzBISnx7LASdAou8Zu230j7o2SdK1oBeqHtwDHZy0,6711
|
67
|
+
pycityagent/workflow/trigger.py,sha256=8530YfBbLsdtT1QZSqvuha64NNLVJYXVznlvknz9-xI,5737
|
68
|
+
pycityagent-2.0.0a5.dist-info/METADATA,sha256=eGx85-Jh_-wtRP3XhlJne7IpW5gpB-4iJrLsuoe9brQ,7614
|
69
|
+
pycityagent-2.0.0a5.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
70
|
+
pycityagent-2.0.0a5.dist-info/RECORD,,
|
pycityagent/config.py
DELETED
File without changes
|
File without changes
|