pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/agent/agent.py +83 -62
- pycityagent/agent/agent_base.py +81 -54
- pycityagent/cityagent/bankagent.py +5 -7
- pycityagent/cityagent/blocks/__init__.py +0 -2
- pycityagent/cityagent/blocks/cognition_block.py +149 -172
- pycityagent/cityagent/blocks/economy_block.py +90 -129
- pycityagent/cityagent/blocks/mobility_block.py +56 -29
- pycityagent/cityagent/blocks/needs_block.py +163 -145
- pycityagent/cityagent/blocks/other_block.py +17 -9
- pycityagent/cityagent/blocks/plan_block.py +45 -57
- pycityagent/cityagent/blocks/social_block.py +70 -51
- pycityagent/cityagent/blocks/utils.py +2 -0
- pycityagent/cityagent/firmagent.py +6 -7
- pycityagent/cityagent/governmentagent.py +7 -9
- pycityagent/cityagent/memory_config.py +48 -48
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/nbsagent.py +6 -29
- pycityagent/cityagent/societyagent.py +325 -127
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/client.py +10 -1
- pycityagent/environment/sim/clock_service.py +2 -2
- pycityagent/environment/sim/pause_service.py +61 -0
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +18 -14
- pycityagent/llm/embeddings.py +0 -24
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/faiss_query.py +29 -26
- pycityagent/memory/memory.py +733 -247
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/pycityagent-sim +0 -0
- pycityagent/simulation/agentgroup.py +137 -96
- pycityagent/simulation/simulation.py +184 -38
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/tools/tool.py +7 -9
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +14 -7
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
- pycityagent/cityagent/blocks/time_block.py +0 -116
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
pycityagent/message/__init__.py
CHANGED
@@ -1,3 +1,10 @@
|
|
1
|
+
from .message_interceptor import (MessageBlockBase, MessageBlockListenerBase,
|
2
|
+
MessageInterceptor)
|
1
3
|
from .messager import Messager
|
2
4
|
|
3
|
-
__all__ = [
|
5
|
+
__all__ = [
|
6
|
+
"Messager",
|
7
|
+
"MessageBlockBase",
|
8
|
+
"MessageBlockListenerBase",
|
9
|
+
"MessageInterceptor",
|
10
|
+
]
|
@@ -0,0 +1,322 @@
|
|
1
|
+
import asyncio
|
2
|
+
import inspect
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from collections import defaultdict
|
6
|
+
from collections.abc import Callable, Sequence
|
7
|
+
from copy import deepcopy
|
8
|
+
from typing import Any, Optional, Union
|
9
|
+
|
10
|
+
import ray
|
11
|
+
from ray.util.queue import Queue
|
12
|
+
|
13
|
+
from ..llm import LLM, LLMConfig
|
14
|
+
from ..utils.decorators import lock_decorator
|
15
|
+
|
16
|
+
DEFAULT_ERROR_STRING = """
|
17
|
+
From `{from_uuid}` To `{to_uuid}` abort due to block `{block_name}`
|
18
|
+
"""
|
19
|
+
|
20
|
+
logger = logging.getLogger("message_interceptor")
|
21
|
+
|
22
|
+
|
23
|
+
class MessageBlockBase(ABC):
|
24
|
+
"""
|
25
|
+
用于过滤的block
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
name: str = "",
|
31
|
+
) -> None:
|
32
|
+
self._name = name
|
33
|
+
self._llm = None
|
34
|
+
self._lock = asyncio.Lock()
|
35
|
+
|
36
|
+
@property
|
37
|
+
def llm(
|
38
|
+
self,
|
39
|
+
) -> LLM:
|
40
|
+
if self._llm is None:
|
41
|
+
raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
|
42
|
+
return self._llm
|
43
|
+
|
44
|
+
@property
|
45
|
+
def name(
|
46
|
+
self,
|
47
|
+
):
|
48
|
+
return self._name
|
49
|
+
|
50
|
+
@property
|
51
|
+
def has_llm(
|
52
|
+
self,
|
53
|
+
) -> bool:
|
54
|
+
return self._llm is not None
|
55
|
+
|
56
|
+
@lock_decorator
|
57
|
+
async def set_llm(self, llm: LLM):
|
58
|
+
"""
|
59
|
+
Set the llm_client of the block.
|
60
|
+
"""
|
61
|
+
self._llm = llm
|
62
|
+
|
63
|
+
@lock_decorator
|
64
|
+
async def set_name(self, name: str):
|
65
|
+
"""
|
66
|
+
Set the name of the block.
|
67
|
+
"""
|
68
|
+
self._name = name
|
69
|
+
|
70
|
+
@lock_decorator
|
71
|
+
async def forward(
|
72
|
+
self,
|
73
|
+
from_uuid: str,
|
74
|
+
to_uuid: str,
|
75
|
+
msg: str,
|
76
|
+
violation_counts: dict[str, int],
|
77
|
+
black_list: list[tuple[str, str]],
|
78
|
+
) -> tuple[bool, str]:
|
79
|
+
return True, ""
|
80
|
+
|
81
|
+
|
82
|
+
@ray.remote
|
83
|
+
class MessageInterceptor:
|
84
|
+
"""
|
85
|
+
信息拦截器
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
blocks: Optional[list[MessageBlockBase]] = None,
|
91
|
+
black_list: Optional[list[tuple[str, str]]] = None,
|
92
|
+
llm_config: Optional[dict] = None,
|
93
|
+
queue: Optional[Queue] = None,
|
94
|
+
) -> None:
|
95
|
+
if blocks is not None:
|
96
|
+
self._blocks: list[MessageBlockBase] = blocks
|
97
|
+
else:
|
98
|
+
self._blocks: list[MessageBlockBase] = []
|
99
|
+
self._violation_counts: dict[str, int] = defaultdict(int)
|
100
|
+
if black_list is not None:
|
101
|
+
self._black_list: list[tuple[str, str]] = black_list
|
102
|
+
else:
|
103
|
+
self._black_list: list[tuple[str, str]] = (
|
104
|
+
[]
|
105
|
+
) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
|
106
|
+
if llm_config:
|
107
|
+
self._llm = LLM(LLMConfig(llm_config))
|
108
|
+
else:
|
109
|
+
self._llm = None
|
110
|
+
self._queue = queue
|
111
|
+
self._lock = asyncio.Lock()
|
112
|
+
|
113
|
+
@property
|
114
|
+
def llm(
|
115
|
+
self,
|
116
|
+
) -> LLM:
|
117
|
+
if self._llm is None:
|
118
|
+
raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
|
119
|
+
return self._llm
|
120
|
+
|
121
|
+
@lock_decorator
|
122
|
+
async def blocks(
|
123
|
+
self,
|
124
|
+
) -> list[MessageBlockBase]:
|
125
|
+
return self._blocks
|
126
|
+
|
127
|
+
@lock_decorator
|
128
|
+
async def set_llm(self, llm: LLM):
|
129
|
+
"""
|
130
|
+
Set the llm_client of the block.
|
131
|
+
"""
|
132
|
+
if self._llm is None:
|
133
|
+
self._llm = llm
|
134
|
+
|
135
|
+
@lock_decorator
|
136
|
+
async def violation_counts(
|
137
|
+
self,
|
138
|
+
) -> dict[str, int]:
|
139
|
+
return deepcopy(self._violation_counts)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def has_llm(
|
143
|
+
self,
|
144
|
+
) -> bool:
|
145
|
+
return self._llm is not None
|
146
|
+
|
147
|
+
@lock_decorator
|
148
|
+
async def black_list(
|
149
|
+
self,
|
150
|
+
) -> list[tuple[str, str]]:
|
151
|
+
return deepcopy(self._black_list)
|
152
|
+
|
153
|
+
@lock_decorator
|
154
|
+
async def add_to_black_list(
|
155
|
+
self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
156
|
+
):
|
157
|
+
if all(isinstance(s, str) for s in black_list):
|
158
|
+
# tuple[str,str]
|
159
|
+
_black_list = [black_list]
|
160
|
+
else:
|
161
|
+
_black_list = black_list
|
162
|
+
_black_list = [tuple(p) for p in _black_list]
|
163
|
+
self._black_list.extend(_black_list) # type: ignore
|
164
|
+
self._black_list = list(set(self._black_list))
|
165
|
+
|
166
|
+
@property
|
167
|
+
def has_queue(
|
168
|
+
self,
|
169
|
+
) -> bool:
|
170
|
+
return self._queue is not None
|
171
|
+
|
172
|
+
@lock_decorator
|
173
|
+
async def set_queue(self, queue: Queue):
|
174
|
+
"""
|
175
|
+
Set the queue of the MessageInterceptor.
|
176
|
+
"""
|
177
|
+
self._queue = queue
|
178
|
+
|
179
|
+
@lock_decorator
|
180
|
+
async def remove_from_black_list(
|
181
|
+
self, to_remove_black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
182
|
+
):
|
183
|
+
if all(isinstance(s, str) for s in to_remove_black_list):
|
184
|
+
# tuple[str,str]
|
185
|
+
_black_list = [to_remove_black_list]
|
186
|
+
else:
|
187
|
+
_black_list = to_remove_black_list
|
188
|
+
_black_list_set = {tuple(p) for p in _black_list}
|
189
|
+
self._black_list = [p for p in self._black_list if p not in _black_list_set]
|
190
|
+
|
191
|
+
@property
|
192
|
+
def queue(
|
193
|
+
self,
|
194
|
+
) -> Queue:
|
195
|
+
if self._queue is None:
|
196
|
+
raise RuntimeError(
|
197
|
+
f"Queue access before assignment, please `set_queue` first!"
|
198
|
+
)
|
199
|
+
return self._queue
|
200
|
+
|
201
|
+
@lock_decorator
|
202
|
+
async def insert_block(self, block: MessageBlockBase, index: Optional[int] = None):
|
203
|
+
if index is None:
|
204
|
+
index = len(self._blocks)
|
205
|
+
self._blocks.insert(index, block)
|
206
|
+
|
207
|
+
@lock_decorator
|
208
|
+
async def pop_block(self, index: Optional[int] = None) -> MessageBlockBase:
|
209
|
+
if index is None:
|
210
|
+
index = -1
|
211
|
+
return self._blocks.pop(index)
|
212
|
+
|
213
|
+
@lock_decorator
|
214
|
+
async def set_black_list(
|
215
|
+
self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
216
|
+
):
|
217
|
+
if all(isinstance(s, str) for s in black_list):
|
218
|
+
# tuple[str,str]
|
219
|
+
_black_list = [black_list]
|
220
|
+
else:
|
221
|
+
_black_list = black_list
|
222
|
+
_black_list = [tuple(p) for p in _black_list]
|
223
|
+
self._black_list = list(set(_black_list)) # type: ignore
|
224
|
+
|
225
|
+
@lock_decorator
|
226
|
+
async def set_blocks(self, blocks: list[MessageBlockBase]):
|
227
|
+
self._blocks = blocks
|
228
|
+
|
229
|
+
@lock_decorator
|
230
|
+
async def forward(
|
231
|
+
self,
|
232
|
+
from_uuid: str,
|
233
|
+
to_uuid: str,
|
234
|
+
msg: str,
|
235
|
+
):
|
236
|
+
for _block in self._blocks:
|
237
|
+
if not _block.has_llm and self.has_llm:
|
238
|
+
await _block.set_llm(self.llm)
|
239
|
+
func_params = inspect.signature(_block.forward).parameters
|
240
|
+
_args = {
|
241
|
+
"from_uuid": from_uuid,
|
242
|
+
"to_uuid": to_uuid,
|
243
|
+
"msg": msg,
|
244
|
+
"violation_counts": self._violation_counts,
|
245
|
+
"black_list": self._black_list,
|
246
|
+
}
|
247
|
+
_required_args = {k: v for k, v in _args.items() if k in func_params}
|
248
|
+
res = await _block.forward(**_required_args)
|
249
|
+
try:
|
250
|
+
is_valid, err = res
|
251
|
+
except TypeError as e:
|
252
|
+
is_valid: bool = res # type:ignore
|
253
|
+
err = (
|
254
|
+
DEFAULT_ERROR_STRING.format(
|
255
|
+
from_uuid=from_uuid,
|
256
|
+
to_uuid=to_uuid,
|
257
|
+
block_name=f"{_block.__class__.__name__} `{_block.name}`",
|
258
|
+
)
|
259
|
+
if not is_valid
|
260
|
+
else ""
|
261
|
+
)
|
262
|
+
if not is_valid:
|
263
|
+
if self.has_queue:
|
264
|
+
logger.debug(f"put `{err}` into queue")
|
265
|
+
await self.queue.put_async(err) # type:ignore
|
266
|
+
self._violation_counts[from_uuid] += 1
|
267
|
+
print(self._black_list)
|
268
|
+
return False
|
269
|
+
else:
|
270
|
+
# valid
|
271
|
+
pass
|
272
|
+
print(self._black_list)
|
273
|
+
return True
|
274
|
+
|
275
|
+
|
276
|
+
class MessageBlockListenerBase(ABC):
|
277
|
+
def __init__(
|
278
|
+
self,
|
279
|
+
save_queue_values: bool = False,
|
280
|
+
get_queue_period: float = 0.1,
|
281
|
+
) -> None:
|
282
|
+
self._queue = None
|
283
|
+
self._lock = asyncio.Lock()
|
284
|
+
self._values_from_queue: list[Any] = []
|
285
|
+
self._save_queue_values = save_queue_values
|
286
|
+
self._get_queue_period = get_queue_period
|
287
|
+
|
288
|
+
@property
|
289
|
+
def queue(
|
290
|
+
self,
|
291
|
+
) -> Queue:
|
292
|
+
if self._queue is None:
|
293
|
+
raise RuntimeError(
|
294
|
+
f"Queue access before assignment, please `set_queue` first!"
|
295
|
+
)
|
296
|
+
return self._queue
|
297
|
+
|
298
|
+
@property
|
299
|
+
def has_queue(
|
300
|
+
self,
|
301
|
+
) -> bool:
|
302
|
+
return self._queue is not None
|
303
|
+
|
304
|
+
@lock_decorator
|
305
|
+
async def set_queue(self, queue: Queue):
|
306
|
+
"""
|
307
|
+
Set the queue of the MessageBlockListenerBase.
|
308
|
+
"""
|
309
|
+
self._queue = queue
|
310
|
+
|
311
|
+
@lock_decorator
|
312
|
+
async def forward(
|
313
|
+
self,
|
314
|
+
):
|
315
|
+
while True:
|
316
|
+
if self.has_queue:
|
317
|
+
value = await self.queue.get_async() # type: ignore
|
318
|
+
if self._save_queue_values:
|
319
|
+
self._values_from_queue.append(value)
|
320
|
+
logger.debug(f"get `{value}` from queue")
|
321
|
+
# do something with the value
|
322
|
+
await asyncio.sleep(self._get_queue_period)
|
pycityagent/message/messager.py
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
-
import
|
5
|
-
from typing import Any, List, Union
|
4
|
+
from typing import Any, Optional, Union
|
6
5
|
|
7
6
|
import ray
|
8
7
|
from aiomqtt import Client
|
9
8
|
|
9
|
+
from .message_interceptor import MessageInterceptor
|
10
|
+
|
10
11
|
logger = logging.getLogger("pycityagent")
|
11
12
|
|
12
13
|
|
13
14
|
@ray.remote
|
14
15
|
class Messager:
|
15
16
|
def __init__(
|
16
|
-
self,
|
17
|
+
self,
|
18
|
+
hostname: str,
|
19
|
+
port: int = 1883,
|
20
|
+
username=None,
|
21
|
+
password=None,
|
22
|
+
timeout=60,
|
23
|
+
message_interceptor: Optional[ray.ObjectRef] = None,
|
17
24
|
):
|
18
25
|
self.client = Client(
|
19
26
|
hostname, port=port, username=username, password=password, timeout=timeout
|
@@ -21,6 +28,16 @@ class Messager:
|
|
21
28
|
self.connected = False # 是否已连接标志
|
22
29
|
self.message_queue = asyncio.Queue() # 用于存储接收到的消息
|
23
30
|
self.receive_messages_task = None
|
31
|
+
self._message_interceptor = message_interceptor
|
32
|
+
|
33
|
+
@property
|
34
|
+
def message_interceptor(
|
35
|
+
self,
|
36
|
+
) -> Union[None, ray.ObjectRef]:
|
37
|
+
return self._message_interceptor
|
38
|
+
|
39
|
+
def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
|
40
|
+
self._message_interceptor = message_interceptor
|
24
41
|
|
25
42
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
26
43
|
await self.stop()
|
@@ -50,8 +67,9 @@ class Messager:
|
|
50
67
|
"""检查是否成功连接到 Broker"""
|
51
68
|
return self.connected
|
52
69
|
|
70
|
+
# TODO:add message interceptor
|
53
71
|
async def subscribe(
|
54
|
-
self, topics: Union[str,
|
72
|
+
self, topics: Union[str, list[str]], agents: Union[Any, list[Any]]
|
55
73
|
):
|
56
74
|
if not await self.is_connected():
|
57
75
|
logger.error(
|
@@ -62,7 +80,7 @@ class Messager:
|
|
62
80
|
topics = [topics]
|
63
81
|
if not isinstance(agents, list):
|
64
82
|
agents = [agents]
|
65
|
-
await self.client.subscribe(topics, qos=1)
|
83
|
+
await self.client.subscribe(topics, qos=1) # type: ignore
|
66
84
|
|
67
85
|
async def receive_messages(self):
|
68
86
|
"""监听并将消息存入队列"""
|
@@ -76,11 +94,26 @@ class Messager:
|
|
76
94
|
messages.append(await self.message_queue.get())
|
77
95
|
return messages
|
78
96
|
|
79
|
-
async def send_message(
|
97
|
+
async def send_message(
|
98
|
+
self,
|
99
|
+
topic: str,
|
100
|
+
payload: dict,
|
101
|
+
from_uuid: Optional[str] = None,
|
102
|
+
to_uuid: Optional[str] = None,
|
103
|
+
):
|
80
104
|
"""通过 Messager 发送消息"""
|
81
105
|
message = json.dumps(payload, default=str)
|
82
|
-
|
83
|
-
|
106
|
+
interceptor = self.message_interceptor
|
107
|
+
is_valid: bool = True
|
108
|
+
if interceptor is not None and (from_uuid is not None and to_uuid is not None):
|
109
|
+
is_valid = await interceptor.forward.remote( # type:ignore
|
110
|
+
from_uuid, to_uuid, message
|
111
|
+
)
|
112
|
+
if is_valid:
|
113
|
+
await self.client.publish(topic=topic, payload=message, qos=1)
|
114
|
+
logger.info(f"Message sent to {topic}: {message}")
|
115
|
+
else:
|
116
|
+
logger.info(f"Message not sent to {topic}: {message} due to interceptor")
|
84
117
|
|
85
118
|
async def start_listening(self):
|
86
119
|
"""启动消息监听任务"""
|
@@ -92,7 +125,5 @@ class Messager:
|
|
92
125
|
async def stop(self):
|
93
126
|
assert self.receive_messages_task is not None
|
94
127
|
self.receive_messages_task.cancel()
|
95
|
-
await asyncio.gather(
|
96
|
-
self.receive_messages_task, return_exceptions=True
|
97
|
-
)
|
128
|
+
await asyncio.gather(self.receive_messages_task, return_exceptions=True)
|
98
129
|
await self.disconnect()
|
pycityagent/pycityagent-sim
CHANGED
Binary file
|