pycityagent 2.0.0a52__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a54__cp39-cp39-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
pycityagent/message/__init__.py
CHANGED
@@ -1,3 +1,10 @@
|
|
1
|
+
from .message_interceptor import (MessageBlockBase, MessageBlockListenerBase,
|
2
|
+
MessageInterceptor)
|
1
3
|
from .messager import Messager
|
2
4
|
|
3
|
-
__all__ = [
|
5
|
+
__all__ = [
|
6
|
+
"Messager",
|
7
|
+
"MessageBlockBase",
|
8
|
+
"MessageBlockListenerBase",
|
9
|
+
"MessageInterceptor",
|
10
|
+
]
|
@@ -0,0 +1,322 @@
|
|
1
|
+
import asyncio
|
2
|
+
import inspect
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from collections import defaultdict
|
6
|
+
from collections.abc import Callable, Sequence
|
7
|
+
from copy import deepcopy
|
8
|
+
from typing import Any, Optional, Union
|
9
|
+
|
10
|
+
import ray
|
11
|
+
from ray.util.queue import Queue
|
12
|
+
|
13
|
+
from ..llm import LLM, LLMConfig
|
14
|
+
from ..utils.decorators import lock_decorator
|
15
|
+
|
16
|
+
DEFAULT_ERROR_STRING = """
|
17
|
+
From `{from_uuid}` To `{to_uuid}` abort due to block `{block_name}`
|
18
|
+
"""
|
19
|
+
|
20
|
+
logger = logging.getLogger("message_interceptor")
|
21
|
+
|
22
|
+
|
23
|
+
class MessageBlockBase(ABC):
|
24
|
+
"""
|
25
|
+
用于过滤的block
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
name: str = "",
|
31
|
+
) -> None:
|
32
|
+
self._name = name
|
33
|
+
self._llm = None
|
34
|
+
self._lock = asyncio.Lock()
|
35
|
+
|
36
|
+
@property
|
37
|
+
def llm(
|
38
|
+
self,
|
39
|
+
) -> LLM:
|
40
|
+
if self._llm is None:
|
41
|
+
raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
|
42
|
+
return self._llm
|
43
|
+
|
44
|
+
@property
|
45
|
+
def name(
|
46
|
+
self,
|
47
|
+
):
|
48
|
+
return self._name
|
49
|
+
|
50
|
+
@property
|
51
|
+
def has_llm(
|
52
|
+
self,
|
53
|
+
) -> bool:
|
54
|
+
return self._llm is not None
|
55
|
+
|
56
|
+
@lock_decorator
|
57
|
+
async def set_llm(self, llm: LLM):
|
58
|
+
"""
|
59
|
+
Set the llm_client of the block.
|
60
|
+
"""
|
61
|
+
self._llm = llm
|
62
|
+
|
63
|
+
@lock_decorator
|
64
|
+
async def set_name(self, name: str):
|
65
|
+
"""
|
66
|
+
Set the name of the block.
|
67
|
+
"""
|
68
|
+
self._name = name
|
69
|
+
|
70
|
+
@lock_decorator
|
71
|
+
async def forward(
|
72
|
+
self,
|
73
|
+
from_uuid: str,
|
74
|
+
to_uuid: str,
|
75
|
+
msg: str,
|
76
|
+
violation_counts: dict[str, int],
|
77
|
+
black_list: list[tuple[str, str]],
|
78
|
+
) -> tuple[bool, str]:
|
79
|
+
return True, ""
|
80
|
+
|
81
|
+
|
82
|
+
@ray.remote
|
83
|
+
class MessageInterceptor:
|
84
|
+
"""
|
85
|
+
信息拦截器
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
blocks: Optional[list[MessageBlockBase]] = None,
|
91
|
+
black_list: Optional[list[tuple[str, str]]] = None,
|
92
|
+
llm_config: Optional[dict] = None,
|
93
|
+
queue: Optional[Queue] = None,
|
94
|
+
) -> None:
|
95
|
+
if blocks is not None:
|
96
|
+
self._blocks: list[MessageBlockBase] = blocks
|
97
|
+
else:
|
98
|
+
self._blocks: list[MessageBlockBase] = []
|
99
|
+
self._violation_counts: dict[str, int] = defaultdict(int)
|
100
|
+
if black_list is not None:
|
101
|
+
self._black_list: list[tuple[str, str]] = black_list
|
102
|
+
else:
|
103
|
+
self._black_list: list[tuple[str, str]] = (
|
104
|
+
[]
|
105
|
+
) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
|
106
|
+
if llm_config:
|
107
|
+
self._llm = LLM(LLMConfig(llm_config))
|
108
|
+
else:
|
109
|
+
self._llm = None
|
110
|
+
self._queue = queue
|
111
|
+
self._lock = asyncio.Lock()
|
112
|
+
|
113
|
+
@property
|
114
|
+
def llm(
|
115
|
+
self,
|
116
|
+
) -> LLM:
|
117
|
+
if self._llm is None:
|
118
|
+
raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
|
119
|
+
return self._llm
|
120
|
+
|
121
|
+
@lock_decorator
|
122
|
+
async def blocks(
|
123
|
+
self,
|
124
|
+
) -> list[MessageBlockBase]:
|
125
|
+
return self._blocks
|
126
|
+
|
127
|
+
@lock_decorator
|
128
|
+
async def set_llm(self, llm: LLM):
|
129
|
+
"""
|
130
|
+
Set the llm_client of the block.
|
131
|
+
"""
|
132
|
+
if self._llm is None:
|
133
|
+
self._llm = llm
|
134
|
+
|
135
|
+
@lock_decorator
|
136
|
+
async def violation_counts(
|
137
|
+
self,
|
138
|
+
) -> dict[str, int]:
|
139
|
+
return deepcopy(self._violation_counts)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def has_llm(
|
143
|
+
self,
|
144
|
+
) -> bool:
|
145
|
+
return self._llm is not None
|
146
|
+
|
147
|
+
@lock_decorator
|
148
|
+
async def black_list(
|
149
|
+
self,
|
150
|
+
) -> list[tuple[str, str]]:
|
151
|
+
return deepcopy(self._black_list)
|
152
|
+
|
153
|
+
@lock_decorator
|
154
|
+
async def add_to_black_list(
|
155
|
+
self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
156
|
+
):
|
157
|
+
if all(isinstance(s, str) for s in black_list):
|
158
|
+
# tuple[str,str]
|
159
|
+
_black_list = [black_list]
|
160
|
+
else:
|
161
|
+
_black_list = black_list
|
162
|
+
_black_list = [tuple(p) for p in _black_list]
|
163
|
+
self._black_list.extend(_black_list) # type: ignore
|
164
|
+
self._black_list = list(set(self._black_list))
|
165
|
+
|
166
|
+
@property
|
167
|
+
def has_queue(
|
168
|
+
self,
|
169
|
+
) -> bool:
|
170
|
+
return self._queue is not None
|
171
|
+
|
172
|
+
@lock_decorator
|
173
|
+
async def set_queue(self, queue: Queue):
|
174
|
+
"""
|
175
|
+
Set the queue of the MessageInterceptor.
|
176
|
+
"""
|
177
|
+
self._queue = queue
|
178
|
+
|
179
|
+
@lock_decorator
|
180
|
+
async def remove_from_black_list(
|
181
|
+
self, to_remove_black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
182
|
+
):
|
183
|
+
if all(isinstance(s, str) for s in to_remove_black_list):
|
184
|
+
# tuple[str,str]
|
185
|
+
_black_list = [to_remove_black_list]
|
186
|
+
else:
|
187
|
+
_black_list = to_remove_black_list
|
188
|
+
_black_list_set = {tuple(p) for p in _black_list}
|
189
|
+
self._black_list = [p for p in self._black_list if p not in _black_list_set]
|
190
|
+
|
191
|
+
@property
|
192
|
+
def queue(
|
193
|
+
self,
|
194
|
+
) -> Queue:
|
195
|
+
if self._queue is None:
|
196
|
+
raise RuntimeError(
|
197
|
+
f"Queue access before assignment, please `set_queue` first!"
|
198
|
+
)
|
199
|
+
return self._queue
|
200
|
+
|
201
|
+
@lock_decorator
|
202
|
+
async def insert_block(self, block: MessageBlockBase, index: Optional[int] = None):
|
203
|
+
if index is None:
|
204
|
+
index = len(self._blocks)
|
205
|
+
self._blocks.insert(index, block)
|
206
|
+
|
207
|
+
@lock_decorator
|
208
|
+
async def pop_block(self, index: Optional[int] = None) -> MessageBlockBase:
|
209
|
+
if index is None:
|
210
|
+
index = -1
|
211
|
+
return self._blocks.pop(index)
|
212
|
+
|
213
|
+
@lock_decorator
|
214
|
+
async def set_black_list(
|
215
|
+
self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
216
|
+
):
|
217
|
+
if all(isinstance(s, str) for s in black_list):
|
218
|
+
# tuple[str,str]
|
219
|
+
_black_list = [black_list]
|
220
|
+
else:
|
221
|
+
_black_list = black_list
|
222
|
+
_black_list = [tuple(p) for p in _black_list]
|
223
|
+
self._black_list = list(set(_black_list)) # type: ignore
|
224
|
+
|
225
|
+
@lock_decorator
|
226
|
+
async def set_blocks(self, blocks: list[MessageBlockBase]):
|
227
|
+
self._blocks = blocks
|
228
|
+
|
229
|
+
@lock_decorator
|
230
|
+
async def forward(
|
231
|
+
self,
|
232
|
+
from_uuid: str,
|
233
|
+
to_uuid: str,
|
234
|
+
msg: str,
|
235
|
+
):
|
236
|
+
for _block in self._blocks:
|
237
|
+
if not _block.has_llm and self.has_llm:
|
238
|
+
await _block.set_llm(self.llm)
|
239
|
+
func_params = inspect.signature(_block.forward).parameters
|
240
|
+
_args = {
|
241
|
+
"from_uuid": from_uuid,
|
242
|
+
"to_uuid": to_uuid,
|
243
|
+
"msg": msg,
|
244
|
+
"violation_counts": self._violation_counts,
|
245
|
+
"black_list": self._black_list,
|
246
|
+
}
|
247
|
+
_required_args = {k: v for k, v in _args.items() if k in func_params}
|
248
|
+
res = await _block.forward(**_required_args)
|
249
|
+
try:
|
250
|
+
is_valid, err = res
|
251
|
+
except TypeError as e:
|
252
|
+
is_valid: bool = res # type:ignore
|
253
|
+
err = (
|
254
|
+
DEFAULT_ERROR_STRING.format(
|
255
|
+
from_uuid=from_uuid,
|
256
|
+
to_uuid=to_uuid,
|
257
|
+
block_name=f"{_block.__class__.__name__} `{_block.name}`",
|
258
|
+
)
|
259
|
+
if not is_valid
|
260
|
+
else ""
|
261
|
+
)
|
262
|
+
if not is_valid:
|
263
|
+
if self.has_queue:
|
264
|
+
logger.debug(f"put `{err}` into queue")
|
265
|
+
await self.queue.put_async(err) # type:ignore
|
266
|
+
self._violation_counts[from_uuid] += 1
|
267
|
+
print(self._black_list)
|
268
|
+
return False
|
269
|
+
else:
|
270
|
+
# valid
|
271
|
+
pass
|
272
|
+
print(self._black_list)
|
273
|
+
return True
|
274
|
+
|
275
|
+
|
276
|
+
class MessageBlockListenerBase(ABC):
|
277
|
+
def __init__(
|
278
|
+
self,
|
279
|
+
save_queue_values: bool = False,
|
280
|
+
get_queue_period: float = 0.1,
|
281
|
+
) -> None:
|
282
|
+
self._queue = None
|
283
|
+
self._lock = asyncio.Lock()
|
284
|
+
self._values_from_queue: list[Any] = []
|
285
|
+
self._save_queue_values = save_queue_values
|
286
|
+
self._get_queue_period = get_queue_period
|
287
|
+
|
288
|
+
@property
|
289
|
+
def queue(
|
290
|
+
self,
|
291
|
+
) -> Queue:
|
292
|
+
if self._queue is None:
|
293
|
+
raise RuntimeError(
|
294
|
+
f"Queue access before assignment, please `set_queue` first!"
|
295
|
+
)
|
296
|
+
return self._queue
|
297
|
+
|
298
|
+
@property
|
299
|
+
def has_queue(
|
300
|
+
self,
|
301
|
+
) -> bool:
|
302
|
+
return self._queue is not None
|
303
|
+
|
304
|
+
@lock_decorator
|
305
|
+
async def set_queue(self, queue: Queue):
|
306
|
+
"""
|
307
|
+
Set the queue of the MessageBlockListenerBase.
|
308
|
+
"""
|
309
|
+
self._queue = queue
|
310
|
+
|
311
|
+
@lock_decorator
|
312
|
+
async def forward(
|
313
|
+
self,
|
314
|
+
):
|
315
|
+
while True:
|
316
|
+
if self.has_queue:
|
317
|
+
value = await self.queue.get_async() # type: ignore
|
318
|
+
if self._save_queue_values:
|
319
|
+
self._values_from_queue.append(value)
|
320
|
+
logger.debug(f"get `{value}` from queue")
|
321
|
+
# do something with the value
|
322
|
+
await asyncio.sleep(self._get_queue_period)
|
pycityagent/message/messager.py
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
-
import
|
5
|
-
from typing import Any, List, Union
|
4
|
+
from typing import Any, Optional, Union
|
6
5
|
|
7
6
|
import ray
|
8
7
|
from aiomqtt import Client
|
9
8
|
|
9
|
+
from .message_interceptor import MessageInterceptor
|
10
|
+
|
10
11
|
logger = logging.getLogger("pycityagent")
|
11
12
|
|
12
13
|
|
13
14
|
@ray.remote
|
14
15
|
class Messager:
|
15
16
|
def __init__(
|
16
|
-
self,
|
17
|
+
self,
|
18
|
+
hostname: str,
|
19
|
+
port: int = 1883,
|
20
|
+
username=None,
|
21
|
+
password=None,
|
22
|
+
timeout=60,
|
23
|
+
message_interceptor: Optional[ray.ObjectRef] = None,
|
17
24
|
):
|
18
25
|
self.client = Client(
|
19
26
|
hostname, port=port, username=username, password=password, timeout=timeout
|
@@ -21,6 +28,16 @@ class Messager:
|
|
21
28
|
self.connected = False # 是否已连接标志
|
22
29
|
self.message_queue = asyncio.Queue() # 用于存储接收到的消息
|
23
30
|
self.receive_messages_task = None
|
31
|
+
self._message_interceptor = message_interceptor
|
32
|
+
|
33
|
+
@property
|
34
|
+
def message_interceptor(
|
35
|
+
self,
|
36
|
+
) -> Union[None, ray.ObjectRef]:
|
37
|
+
return self._message_interceptor
|
38
|
+
|
39
|
+
def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
|
40
|
+
self._message_interceptor = message_interceptor
|
24
41
|
|
25
42
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
26
43
|
await self.stop()
|
@@ -50,8 +67,9 @@ class Messager:
|
|
50
67
|
"""检查是否成功连接到 Broker"""
|
51
68
|
return self.connected
|
52
69
|
|
70
|
+
# TODO:add message interceptor
|
53
71
|
async def subscribe(
|
54
|
-
self, topics: Union[str,
|
72
|
+
self, topics: Union[str, list[str]], agents: Union[Any, list[Any]]
|
55
73
|
):
|
56
74
|
if not await self.is_connected():
|
57
75
|
logger.error(
|
@@ -62,7 +80,7 @@ class Messager:
|
|
62
80
|
topics = [topics]
|
63
81
|
if not isinstance(agents, list):
|
64
82
|
agents = [agents]
|
65
|
-
await self.client.subscribe(topics, qos=1)
|
83
|
+
await self.client.subscribe(topics, qos=1) # type: ignore
|
66
84
|
|
67
85
|
async def receive_messages(self):
|
68
86
|
"""监听并将消息存入队列"""
|
@@ -76,11 +94,26 @@ class Messager:
|
|
76
94
|
messages.append(await self.message_queue.get())
|
77
95
|
return messages
|
78
96
|
|
79
|
-
async def send_message(
|
97
|
+
async def send_message(
|
98
|
+
self,
|
99
|
+
topic: str,
|
100
|
+
payload: dict,
|
101
|
+
from_uuid: Optional[str] = None,
|
102
|
+
to_uuid: Optional[str] = None,
|
103
|
+
):
|
80
104
|
"""通过 Messager 发送消息"""
|
81
105
|
message = json.dumps(payload, default=str)
|
82
|
-
|
83
|
-
|
106
|
+
interceptor = self.message_interceptor
|
107
|
+
is_valid: bool = True
|
108
|
+
if interceptor is not None and (from_uuid is not None and to_uuid is not None):
|
109
|
+
is_valid = await interceptor.forward.remote( # type:ignore
|
110
|
+
from_uuid, to_uuid, message
|
111
|
+
)
|
112
|
+
if is_valid:
|
113
|
+
await self.client.publish(topic=topic, payload=message, qos=1)
|
114
|
+
logger.info(f"Message sent to {topic}: {message}")
|
115
|
+
else:
|
116
|
+
logger.info(f"Message not sent to {topic}: {message} due to interceptor")
|
84
117
|
|
85
118
|
async def start_listening(self):
|
86
119
|
"""启动消息监听任务"""
|
@@ -92,7 +125,5 @@ class Messager:
|
|
92
125
|
async def stop(self):
|
93
126
|
assert self.receive_messages_task is not None
|
94
127
|
self.receive_messages_task.cancel()
|
95
|
-
await asyncio.gather(
|
96
|
-
self.receive_messages_task, return_exceptions=True
|
97
|
-
)
|
128
|
+
await asyncio.gather(self.receive_messages_task, return_exceptions=True)
|
98
129
|
await self.disconnect()
|
pycityagent/pycityagent-sim
CHANGED
Binary file
|