pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (49) hide show
  1. pycityagent/agent/agent.py +83 -62
  2. pycityagent/agent/agent_base.py +81 -54
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +45 -57
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/message_intercept.py +99 -0
  17. pycityagent/cityagent/nbsagent.py +6 -29
  18. pycityagent/cityagent/societyagent.py +325 -127
  19. pycityagent/cli/wrapper.py +4 -0
  20. pycityagent/economy/econ_client.py +0 -2
  21. pycityagent/environment/__init__.py +7 -1
  22. pycityagent/environment/sim/client.py +10 -1
  23. pycityagent/environment/sim/clock_service.py +2 -2
  24. pycityagent/environment/sim/pause_service.py +61 -0
  25. pycityagent/environment/sim/sim_env.py +34 -46
  26. pycityagent/environment/simulator.py +18 -14
  27. pycityagent/llm/embeddings.py +0 -24
  28. pycityagent/llm/llm.py +18 -10
  29. pycityagent/memory/faiss_query.py +29 -26
  30. pycityagent/memory/memory.py +733 -247
  31. pycityagent/message/__init__.py +8 -1
  32. pycityagent/message/message_interceptor.py +322 -0
  33. pycityagent/message/messager.py +42 -11
  34. pycityagent/pycityagent-sim +0 -0
  35. pycityagent/simulation/agentgroup.py +137 -96
  36. pycityagent/simulation/simulation.py +184 -38
  37. pycityagent/simulation/storage/pg.py +2 -2
  38. pycityagent/tools/tool.py +7 -9
  39. pycityagent/utils/__init__.py +7 -2
  40. pycityagent/utils/pg_query.py +1 -0
  41. pycityagent/utils/survey_util.py +26 -23
  42. pycityagent/workflow/block.py +14 -7
  43. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
  44. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
  45. pycityagent/cityagent/blocks/time_block.py +0 -116
  46. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
  47. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
  48. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
  49. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,10 @@
1
+ from .message_interceptor import (MessageBlockBase, MessageBlockListenerBase,
2
+ MessageInterceptor)
1
3
  from .messager import Messager
2
4
 
3
- __all__ = ["Messager"]
5
+ __all__ = [
6
+ "Messager",
7
+ "MessageBlockBase",
8
+ "MessageBlockListenerBase",
9
+ "MessageInterceptor",
10
+ ]
@@ -0,0 +1,322 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict
6
+ from collections.abc import Callable, Sequence
7
+ from copy import deepcopy
8
+ from typing import Any, Optional, Union
9
+
10
+ import ray
11
+ from ray.util.queue import Queue
12
+
13
+ from ..llm import LLM, LLMConfig
14
+ from ..utils.decorators import lock_decorator
15
+
16
+ DEFAULT_ERROR_STRING = """
17
+ From `{from_uuid}` To `{to_uuid}` abort due to block `{block_name}`
18
+ """
19
+
20
+ logger = logging.getLogger("message_interceptor")
21
+
22
+
23
+ class MessageBlockBase(ABC):
24
+ """
25
+ 用于过滤的block
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ name: str = "",
31
+ ) -> None:
32
+ self._name = name
33
+ self._llm = None
34
+ self._lock = asyncio.Lock()
35
+
36
+ @property
37
+ def llm(
38
+ self,
39
+ ) -> LLM:
40
+ if self._llm is None:
41
+ raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
42
+ return self._llm
43
+
44
+ @property
45
+ def name(
46
+ self,
47
+ ):
48
+ return self._name
49
+
50
+ @property
51
+ def has_llm(
52
+ self,
53
+ ) -> bool:
54
+ return self._llm is not None
55
+
56
+ @lock_decorator
57
+ async def set_llm(self, llm: LLM):
58
+ """
59
+ Set the llm_client of the block.
60
+ """
61
+ self._llm = llm
62
+
63
+ @lock_decorator
64
+ async def set_name(self, name: str):
65
+ """
66
+ Set the name of the block.
67
+ """
68
+ self._name = name
69
+
70
+ @lock_decorator
71
+ async def forward(
72
+ self,
73
+ from_uuid: str,
74
+ to_uuid: str,
75
+ msg: str,
76
+ violation_counts: dict[str, int],
77
+ black_list: list[tuple[str, str]],
78
+ ) -> tuple[bool, str]:
79
+ return True, ""
80
+
81
+
82
+ @ray.remote
83
+ class MessageInterceptor:
84
+ """
85
+ 信息拦截器
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ blocks: Optional[list[MessageBlockBase]] = None,
91
+ black_list: Optional[list[tuple[str, str]]] = None,
92
+ llm_config: Optional[dict] = None,
93
+ queue: Optional[Queue] = None,
94
+ ) -> None:
95
+ if blocks is not None:
96
+ self._blocks: list[MessageBlockBase] = blocks
97
+ else:
98
+ self._blocks: list[MessageBlockBase] = []
99
+ self._violation_counts: dict[str, int] = defaultdict(int)
100
+ if black_list is not None:
101
+ self._black_list: list[tuple[str, str]] = black_list
102
+ else:
103
+ self._black_list: list[tuple[str, str]] = (
104
+ []
105
+ ) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
106
+ if llm_config:
107
+ self._llm = LLM(LLMConfig(llm_config))
108
+ else:
109
+ self._llm = None
110
+ self._queue = queue
111
+ self._lock = asyncio.Lock()
112
+
113
+ @property
114
+ def llm(
115
+ self,
116
+ ) -> LLM:
117
+ if self._llm is None:
118
+ raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
119
+ return self._llm
120
+
121
+ @lock_decorator
122
+ async def blocks(
123
+ self,
124
+ ) -> list[MessageBlockBase]:
125
+ return self._blocks
126
+
127
+ @lock_decorator
128
+ async def set_llm(self, llm: LLM):
129
+ """
130
+ Set the llm_client of the block.
131
+ """
132
+ if self._llm is None:
133
+ self._llm = llm
134
+
135
+ @lock_decorator
136
+ async def violation_counts(
137
+ self,
138
+ ) -> dict[str, int]:
139
+ return deepcopy(self._violation_counts)
140
+
141
+ @property
142
+ def has_llm(
143
+ self,
144
+ ) -> bool:
145
+ return self._llm is not None
146
+
147
+ @lock_decorator
148
+ async def black_list(
149
+ self,
150
+ ) -> list[tuple[str, str]]:
151
+ return deepcopy(self._black_list)
152
+
153
+ @lock_decorator
154
+ async def add_to_black_list(
155
+ self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
156
+ ):
157
+ if all(isinstance(s, str) for s in black_list):
158
+ # tuple[str,str]
159
+ _black_list = [black_list]
160
+ else:
161
+ _black_list = black_list
162
+ _black_list = [tuple(p) for p in _black_list]
163
+ self._black_list.extend(_black_list) # type: ignore
164
+ self._black_list = list(set(self._black_list))
165
+
166
+ @property
167
+ def has_queue(
168
+ self,
169
+ ) -> bool:
170
+ return self._queue is not None
171
+
172
+ @lock_decorator
173
+ async def set_queue(self, queue: Queue):
174
+ """
175
+ Set the queue of the MessageInterceptor.
176
+ """
177
+ self._queue = queue
178
+
179
+ @lock_decorator
180
+ async def remove_from_black_list(
181
+ self, to_remove_black_list: Union[list[tuple[str, str]], tuple[str, str]]
182
+ ):
183
+ if all(isinstance(s, str) for s in to_remove_black_list):
184
+ # tuple[str,str]
185
+ _black_list = [to_remove_black_list]
186
+ else:
187
+ _black_list = to_remove_black_list
188
+ _black_list_set = {tuple(p) for p in _black_list}
189
+ self._black_list = [p for p in self._black_list if p not in _black_list_set]
190
+
191
+ @property
192
+ def queue(
193
+ self,
194
+ ) -> Queue:
195
+ if self._queue is None:
196
+ raise RuntimeError(
197
+ f"Queue access before assignment, please `set_queue` first!"
198
+ )
199
+ return self._queue
200
+
201
+ @lock_decorator
202
+ async def insert_block(self, block: MessageBlockBase, index: Optional[int] = None):
203
+ if index is None:
204
+ index = len(self._blocks)
205
+ self._blocks.insert(index, block)
206
+
207
+ @lock_decorator
208
+ async def pop_block(self, index: Optional[int] = None) -> MessageBlockBase:
209
+ if index is None:
210
+ index = -1
211
+ return self._blocks.pop(index)
212
+
213
+ @lock_decorator
214
+ async def set_black_list(
215
+ self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
216
+ ):
217
+ if all(isinstance(s, str) for s in black_list):
218
+ # tuple[str,str]
219
+ _black_list = [black_list]
220
+ else:
221
+ _black_list = black_list
222
+ _black_list = [tuple(p) for p in _black_list]
223
+ self._black_list = list(set(_black_list)) # type: ignore
224
+
225
+ @lock_decorator
226
+ async def set_blocks(self, blocks: list[MessageBlockBase]):
227
+ self._blocks = blocks
228
+
229
+ @lock_decorator
230
+ async def forward(
231
+ self,
232
+ from_uuid: str,
233
+ to_uuid: str,
234
+ msg: str,
235
+ ):
236
+ for _block in self._blocks:
237
+ if not _block.has_llm and self.has_llm:
238
+ await _block.set_llm(self.llm)
239
+ func_params = inspect.signature(_block.forward).parameters
240
+ _args = {
241
+ "from_uuid": from_uuid,
242
+ "to_uuid": to_uuid,
243
+ "msg": msg,
244
+ "violation_counts": self._violation_counts,
245
+ "black_list": self._black_list,
246
+ }
247
+ _required_args = {k: v for k, v in _args.items() if k in func_params}
248
+ res = await _block.forward(**_required_args)
249
+ try:
250
+ is_valid, err = res
251
+ except TypeError as e:
252
+ is_valid: bool = res # type:ignore
253
+ err = (
254
+ DEFAULT_ERROR_STRING.format(
255
+ from_uuid=from_uuid,
256
+ to_uuid=to_uuid,
257
+ block_name=f"{_block.__class__.__name__} `{_block.name}`",
258
+ )
259
+ if not is_valid
260
+ else ""
261
+ )
262
+ if not is_valid:
263
+ if self.has_queue:
264
+ logger.debug(f"put `{err}` into queue")
265
+ await self.queue.put_async(err) # type:ignore
266
+ self._violation_counts[from_uuid] += 1
267
+ print(self._black_list)
268
+ return False
269
+ else:
270
+ # valid
271
+ pass
272
+ print(self._black_list)
273
+ return True
274
+
275
+
276
+ class MessageBlockListenerBase(ABC):
277
+ def __init__(
278
+ self,
279
+ save_queue_values: bool = False,
280
+ get_queue_period: float = 0.1,
281
+ ) -> None:
282
+ self._queue = None
283
+ self._lock = asyncio.Lock()
284
+ self._values_from_queue: list[Any] = []
285
+ self._save_queue_values = save_queue_values
286
+ self._get_queue_period = get_queue_period
287
+
288
+ @property
289
+ def queue(
290
+ self,
291
+ ) -> Queue:
292
+ if self._queue is None:
293
+ raise RuntimeError(
294
+ f"Queue access before assignment, please `set_queue` first!"
295
+ )
296
+ return self._queue
297
+
298
+ @property
299
+ def has_queue(
300
+ self,
301
+ ) -> bool:
302
+ return self._queue is not None
303
+
304
+ @lock_decorator
305
+ async def set_queue(self, queue: Queue):
306
+ """
307
+ Set the queue of the MessageBlockListenerBase.
308
+ """
309
+ self._queue = queue
310
+
311
+ @lock_decorator
312
+ async def forward(
313
+ self,
314
+ ):
315
+ while True:
316
+ if self.has_queue:
317
+ value = await self.queue.get_async() # type: ignore
318
+ if self._save_queue_values:
319
+ self._values_from_queue.append(value)
320
+ logger.debug(f"get `{value}` from queue")
321
+ # do something with the value
322
+ await asyncio.sleep(self._get_queue_period)
@@ -1,19 +1,26 @@
1
1
  import asyncio
2
2
  import json
3
3
  import logging
4
- import math
5
- from typing import Any, List, Union
4
+ from typing import Any, Optional, Union
6
5
 
7
6
  import ray
8
7
  from aiomqtt import Client
9
8
 
9
+ from .message_interceptor import MessageInterceptor
10
+
10
11
  logger = logging.getLogger("pycityagent")
11
12
 
12
13
 
13
14
  @ray.remote
14
15
  class Messager:
15
16
  def __init__(
16
- self, hostname: str, port: int = 1883, username=None, password=None, timeout=60
17
+ self,
18
+ hostname: str,
19
+ port: int = 1883,
20
+ username=None,
21
+ password=None,
22
+ timeout=60,
23
+ message_interceptor: Optional[ray.ObjectRef] = None,
17
24
  ):
18
25
  self.client = Client(
19
26
  hostname, port=port, username=username, password=password, timeout=timeout
@@ -21,6 +28,16 @@ class Messager:
21
28
  self.connected = False # 是否已连接标志
22
29
  self.message_queue = asyncio.Queue() # 用于存储接收到的消息
23
30
  self.receive_messages_task = None
31
+ self._message_interceptor = message_interceptor
32
+
33
+ @property
34
+ def message_interceptor(
35
+ self,
36
+ ) -> Union[None, ray.ObjectRef]:
37
+ return self._message_interceptor
38
+
39
+ def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
40
+ self._message_interceptor = message_interceptor
24
41
 
25
42
  async def __aexit__(self, exc_type, exc_value, traceback):
26
43
  await self.stop()
@@ -50,8 +67,9 @@ class Messager:
50
67
  """检查是否成功连接到 Broker"""
51
68
  return self.connected
52
69
 
70
+ # TODO:add message interceptor
53
71
  async def subscribe(
54
- self, topics: Union[str, List[str]], agents: Union[Any, List[Any]]
72
+ self, topics: Union[str, list[str]], agents: Union[Any, list[Any]]
55
73
  ):
56
74
  if not await self.is_connected():
57
75
  logger.error(
@@ -62,7 +80,7 @@ class Messager:
62
80
  topics = [topics]
63
81
  if not isinstance(agents, list):
64
82
  agents = [agents]
65
- await self.client.subscribe(topics, qos=1)
83
+ await self.client.subscribe(topics, qos=1) # type: ignore
66
84
 
67
85
  async def receive_messages(self):
68
86
  """监听并将消息存入队列"""
@@ -76,11 +94,26 @@ class Messager:
76
94
  messages.append(await self.message_queue.get())
77
95
  return messages
78
96
 
79
- async def send_message(self, topic: str, payload: dict):
97
+ async def send_message(
98
+ self,
99
+ topic: str,
100
+ payload: dict,
101
+ from_uuid: Optional[str] = None,
102
+ to_uuid: Optional[str] = None,
103
+ ):
80
104
  """通过 Messager 发送消息"""
81
105
  message = json.dumps(payload, default=str)
82
- await self.client.publish(topic=topic, payload=message, qos=1)
83
- logger.info(f"Message sent to {topic}: {message}")
106
+ interceptor = self.message_interceptor
107
+ is_valid: bool = True
108
+ if interceptor is not None and (from_uuid is not None and to_uuid is not None):
109
+ is_valid = await interceptor.forward.remote( # type:ignore
110
+ from_uuid, to_uuid, message
111
+ )
112
+ if is_valid:
113
+ await self.client.publish(topic=topic, payload=message, qos=1)
114
+ logger.info(f"Message sent to {topic}: {message}")
115
+ else:
116
+ logger.info(f"Message not sent to {topic}: {message} due to interceptor")
84
117
 
85
118
  async def start_listening(self):
86
119
  """启动消息监听任务"""
@@ -92,7 +125,5 @@ class Messager:
92
125
  async def stop(self):
93
126
  assert self.receive_messages_task is not None
94
127
  self.receive_messages_task.cancel()
95
- await asyncio.gather(
96
- self.receive_messages_task, return_exceptions=True
97
- )
128
+ await asyncio.gather(self.receive_messages_task, return_exceptions=True)
98
129
  await self.disconnect()
Binary file