pycityagent 2.0.0a52__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. pycityagent/agent/agent.py +83 -62
  2. pycityagent/agent/agent_base.py +81 -54
  3. pycityagent/cityagent/bankagent.py +5 -7
  4. pycityagent/cityagent/blocks/__init__.py +0 -2
  5. pycityagent/cityagent/blocks/cognition_block.py +149 -172
  6. pycityagent/cityagent/blocks/economy_block.py +90 -129
  7. pycityagent/cityagent/blocks/mobility_block.py +56 -29
  8. pycityagent/cityagent/blocks/needs_block.py +163 -145
  9. pycityagent/cityagent/blocks/other_block.py +17 -9
  10. pycityagent/cityagent/blocks/plan_block.py +45 -57
  11. pycityagent/cityagent/blocks/social_block.py +70 -51
  12. pycityagent/cityagent/blocks/utils.py +2 -0
  13. pycityagent/cityagent/firmagent.py +6 -7
  14. pycityagent/cityagent/governmentagent.py +7 -9
  15. pycityagent/cityagent/memory_config.py +48 -48
  16. pycityagent/cityagent/message_intercept.py +99 -0
  17. pycityagent/cityagent/nbsagent.py +6 -29
  18. pycityagent/cityagent/societyagent.py +325 -127
  19. pycityagent/cli/wrapper.py +4 -0
  20. pycityagent/economy/econ_client.py +0 -2
  21. pycityagent/environment/__init__.py +7 -1
  22. pycityagent/environment/sim/client.py +10 -1
  23. pycityagent/environment/sim/clock_service.py +2 -2
  24. pycityagent/environment/sim/pause_service.py +61 -0
  25. pycityagent/environment/sim/sim_env.py +34 -46
  26. pycityagent/environment/simulator.py +18 -14
  27. pycityagent/llm/embeddings.py +0 -24
  28. pycityagent/llm/llm.py +18 -10
  29. pycityagent/memory/faiss_query.py +29 -26
  30. pycityagent/memory/memory.py +733 -247
  31. pycityagent/message/__init__.py +8 -1
  32. pycityagent/message/message_interceptor.py +322 -0
  33. pycityagent/message/messager.py +42 -11
  34. pycityagent/pycityagent-sim +0 -0
  35. pycityagent/simulation/agentgroup.py +137 -96
  36. pycityagent/simulation/simulation.py +184 -38
  37. pycityagent/simulation/storage/pg.py +2 -2
  38. pycityagent/tools/tool.py +7 -9
  39. pycityagent/utils/__init__.py +7 -2
  40. pycityagent/utils/pg_query.py +1 -0
  41. pycityagent/utils/survey_util.py +26 -23
  42. pycityagent/workflow/block.py +14 -7
  43. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
  44. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +48 -46
  45. pycityagent/cityagent/blocks/time_block.py +0 -116
  46. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
  47. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
  48. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
  49. {pycityagent-2.0.0a52.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,10 @@
1
+ from .message_interceptor import (MessageBlockBase, MessageBlockListenerBase,
2
+ MessageInterceptor)
1
3
  from .messager import Messager
2
4
 
3
- __all__ = ["Messager"]
5
+ __all__ = [
6
+ "Messager",
7
+ "MessageBlockBase",
8
+ "MessageBlockListenerBase",
9
+ "MessageInterceptor",
10
+ ]
@@ -0,0 +1,322 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict
6
+ from collections.abc import Callable, Sequence
7
+ from copy import deepcopy
8
+ from typing import Any, Optional, Union
9
+
10
+ import ray
11
+ from ray.util.queue import Queue
12
+
13
+ from ..llm import LLM, LLMConfig
14
+ from ..utils.decorators import lock_decorator
15
+
16
+ DEFAULT_ERROR_STRING = """
17
+ From `{from_uuid}` To `{to_uuid}` abort due to block `{block_name}`
18
+ """
19
+
20
+ logger = logging.getLogger("message_interceptor")
21
+
22
+
23
+ class MessageBlockBase(ABC):
24
+ """
25
+ 用于过滤的block
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ name: str = "",
31
+ ) -> None:
32
+ self._name = name
33
+ self._llm = None
34
+ self._lock = asyncio.Lock()
35
+
36
+ @property
37
+ def llm(
38
+ self,
39
+ ) -> LLM:
40
+ if self._llm is None:
41
+ raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
42
+ return self._llm
43
+
44
+ @property
45
+ def name(
46
+ self,
47
+ ):
48
+ return self._name
49
+
50
+ @property
51
+ def has_llm(
52
+ self,
53
+ ) -> bool:
54
+ return self._llm is not None
55
+
56
+ @lock_decorator
57
+ async def set_llm(self, llm: LLM):
58
+ """
59
+ Set the llm_client of the block.
60
+ """
61
+ self._llm = llm
62
+
63
+ @lock_decorator
64
+ async def set_name(self, name: str):
65
+ """
66
+ Set the name of the block.
67
+ """
68
+ self._name = name
69
+
70
+ @lock_decorator
71
+ async def forward(
72
+ self,
73
+ from_uuid: str,
74
+ to_uuid: str,
75
+ msg: str,
76
+ violation_counts: dict[str, int],
77
+ black_list: list[tuple[str, str]],
78
+ ) -> tuple[bool, str]:
79
+ return True, ""
80
+
81
+
82
+ @ray.remote
83
+ class MessageInterceptor:
84
+ """
85
+ 信息拦截器
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ blocks: Optional[list[MessageBlockBase]] = None,
91
+ black_list: Optional[list[tuple[str, str]]] = None,
92
+ llm_config: Optional[dict] = None,
93
+ queue: Optional[Queue] = None,
94
+ ) -> None:
95
+ if blocks is not None:
96
+ self._blocks: list[MessageBlockBase] = blocks
97
+ else:
98
+ self._blocks: list[MessageBlockBase] = []
99
+ self._violation_counts: dict[str, int] = defaultdict(int)
100
+ if black_list is not None:
101
+ self._black_list: list[tuple[str, str]] = black_list
102
+ else:
103
+ self._black_list: list[tuple[str, str]] = (
104
+ []
105
+ ) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
106
+ if llm_config:
107
+ self._llm = LLM(LLMConfig(llm_config))
108
+ else:
109
+ self._llm = None
110
+ self._queue = queue
111
+ self._lock = asyncio.Lock()
112
+
113
+ @property
114
+ def llm(
115
+ self,
116
+ ) -> LLM:
117
+ if self._llm is None:
118
+ raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
119
+ return self._llm
120
+
121
+ @lock_decorator
122
+ async def blocks(
123
+ self,
124
+ ) -> list[MessageBlockBase]:
125
+ return self._blocks
126
+
127
+ @lock_decorator
128
+ async def set_llm(self, llm: LLM):
129
+ """
130
+ Set the llm_client of the block.
131
+ """
132
+ if self._llm is None:
133
+ self._llm = llm
134
+
135
+ @lock_decorator
136
+ async def violation_counts(
137
+ self,
138
+ ) -> dict[str, int]:
139
+ return deepcopy(self._violation_counts)
140
+
141
+ @property
142
+ def has_llm(
143
+ self,
144
+ ) -> bool:
145
+ return self._llm is not None
146
+
147
+ @lock_decorator
148
+ async def black_list(
149
+ self,
150
+ ) -> list[tuple[str, str]]:
151
+ return deepcopy(self._black_list)
152
+
153
+ @lock_decorator
154
+ async def add_to_black_list(
155
+ self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
156
+ ):
157
+ if all(isinstance(s, str) for s in black_list):
158
+ # tuple[str,str]
159
+ _black_list = [black_list]
160
+ else:
161
+ _black_list = black_list
162
+ _black_list = [tuple(p) for p in _black_list]
163
+ self._black_list.extend(_black_list) # type: ignore
164
+ self._black_list = list(set(self._black_list))
165
+
166
+ @property
167
+ def has_queue(
168
+ self,
169
+ ) -> bool:
170
+ return self._queue is not None
171
+
172
+ @lock_decorator
173
+ async def set_queue(self, queue: Queue):
174
+ """
175
+ Set the queue of the MessageInterceptor.
176
+ """
177
+ self._queue = queue
178
+
179
+ @lock_decorator
180
+ async def remove_from_black_list(
181
+ self, to_remove_black_list: Union[list[tuple[str, str]], tuple[str, str]]
182
+ ):
183
+ if all(isinstance(s, str) for s in to_remove_black_list):
184
+ # tuple[str,str]
185
+ _black_list = [to_remove_black_list]
186
+ else:
187
+ _black_list = to_remove_black_list
188
+ _black_list_set = {tuple(p) for p in _black_list}
189
+ self._black_list = [p for p in self._black_list if p not in _black_list_set]
190
+
191
+ @property
192
+ def queue(
193
+ self,
194
+ ) -> Queue:
195
+ if self._queue is None:
196
+ raise RuntimeError(
197
+ f"Queue access before assignment, please `set_queue` first!"
198
+ )
199
+ return self._queue
200
+
201
+ @lock_decorator
202
+ async def insert_block(self, block: MessageBlockBase, index: Optional[int] = None):
203
+ if index is None:
204
+ index = len(self._blocks)
205
+ self._blocks.insert(index, block)
206
+
207
+ @lock_decorator
208
+ async def pop_block(self, index: Optional[int] = None) -> MessageBlockBase:
209
+ if index is None:
210
+ index = -1
211
+ return self._blocks.pop(index)
212
+
213
+ @lock_decorator
214
+ async def set_black_list(
215
+ self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
216
+ ):
217
+ if all(isinstance(s, str) for s in black_list):
218
+ # tuple[str,str]
219
+ _black_list = [black_list]
220
+ else:
221
+ _black_list = black_list
222
+ _black_list = [tuple(p) for p in _black_list]
223
+ self._black_list = list(set(_black_list)) # type: ignore
224
+
225
+ @lock_decorator
226
+ async def set_blocks(self, blocks: list[MessageBlockBase]):
227
+ self._blocks = blocks
228
+
229
+ @lock_decorator
230
+ async def forward(
231
+ self,
232
+ from_uuid: str,
233
+ to_uuid: str,
234
+ msg: str,
235
+ ):
236
+ for _block in self._blocks:
237
+ if not _block.has_llm and self.has_llm:
238
+ await _block.set_llm(self.llm)
239
+ func_params = inspect.signature(_block.forward).parameters
240
+ _args = {
241
+ "from_uuid": from_uuid,
242
+ "to_uuid": to_uuid,
243
+ "msg": msg,
244
+ "violation_counts": self._violation_counts,
245
+ "black_list": self._black_list,
246
+ }
247
+ _required_args = {k: v for k, v in _args.items() if k in func_params}
248
+ res = await _block.forward(**_required_args)
249
+ try:
250
+ is_valid, err = res
251
+ except TypeError as e:
252
+ is_valid: bool = res # type:ignore
253
+ err = (
254
+ DEFAULT_ERROR_STRING.format(
255
+ from_uuid=from_uuid,
256
+ to_uuid=to_uuid,
257
+ block_name=f"{_block.__class__.__name__} `{_block.name}`",
258
+ )
259
+ if not is_valid
260
+ else ""
261
+ )
262
+ if not is_valid:
263
+ if self.has_queue:
264
+ logger.debug(f"put `{err}` into queue")
265
+ await self.queue.put_async(err) # type:ignore
266
+ self._violation_counts[from_uuid] += 1
267
+ print(self._black_list)
268
+ return False
269
+ else:
270
+ # valid
271
+ pass
272
+ print(self._black_list)
273
+ return True
274
+
275
+
276
+ class MessageBlockListenerBase(ABC):
277
+ def __init__(
278
+ self,
279
+ save_queue_values: bool = False,
280
+ get_queue_period: float = 0.1,
281
+ ) -> None:
282
+ self._queue = None
283
+ self._lock = asyncio.Lock()
284
+ self._values_from_queue: list[Any] = []
285
+ self._save_queue_values = save_queue_values
286
+ self._get_queue_period = get_queue_period
287
+
288
+ @property
289
+ def queue(
290
+ self,
291
+ ) -> Queue:
292
+ if self._queue is None:
293
+ raise RuntimeError(
294
+ f"Queue access before assignment, please `set_queue` first!"
295
+ )
296
+ return self._queue
297
+
298
+ @property
299
+ def has_queue(
300
+ self,
301
+ ) -> bool:
302
+ return self._queue is not None
303
+
304
+ @lock_decorator
305
+ async def set_queue(self, queue: Queue):
306
+ """
307
+ Set the queue of the MessageBlockListenerBase.
308
+ """
309
+ self._queue = queue
310
+
311
+ @lock_decorator
312
+ async def forward(
313
+ self,
314
+ ):
315
+ while True:
316
+ if self.has_queue:
317
+ value = await self.queue.get_async() # type: ignore
318
+ if self._save_queue_values:
319
+ self._values_from_queue.append(value)
320
+ logger.debug(f"get `{value}` from queue")
321
+ # do something with the value
322
+ await asyncio.sleep(self._get_queue_period)
@@ -1,19 +1,26 @@
1
1
  import asyncio
2
2
  import json
3
3
  import logging
4
- import math
5
- from typing import Any, List, Union
4
+ from typing import Any, Optional, Union
6
5
 
7
6
  import ray
8
7
  from aiomqtt import Client
9
8
 
9
+ from .message_interceptor import MessageInterceptor
10
+
10
11
  logger = logging.getLogger("pycityagent")
11
12
 
12
13
 
13
14
  @ray.remote
14
15
  class Messager:
15
16
  def __init__(
16
- self, hostname: str, port: int = 1883, username=None, password=None, timeout=60
17
+ self,
18
+ hostname: str,
19
+ port: int = 1883,
20
+ username=None,
21
+ password=None,
22
+ timeout=60,
23
+ message_interceptor: Optional[ray.ObjectRef] = None,
17
24
  ):
18
25
  self.client = Client(
19
26
  hostname, port=port, username=username, password=password, timeout=timeout
@@ -21,6 +28,16 @@ class Messager:
21
28
  self.connected = False # 是否已连接标志
22
29
  self.message_queue = asyncio.Queue() # 用于存储接收到的消息
23
30
  self.receive_messages_task = None
31
+ self._message_interceptor = message_interceptor
32
+
33
+ @property
34
+ def message_interceptor(
35
+ self,
36
+ ) -> Union[None, ray.ObjectRef]:
37
+ return self._message_interceptor
38
+
39
+ def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
40
+ self._message_interceptor = message_interceptor
24
41
 
25
42
  async def __aexit__(self, exc_type, exc_value, traceback):
26
43
  await self.stop()
@@ -50,8 +67,9 @@ class Messager:
50
67
  """检查是否成功连接到 Broker"""
51
68
  return self.connected
52
69
 
70
+ # TODO:add message interceptor
53
71
  async def subscribe(
54
- self, topics: Union[str, List[str]], agents: Union[Any, List[Any]]
72
+ self, topics: Union[str, list[str]], agents: Union[Any, list[Any]]
55
73
  ):
56
74
  if not await self.is_connected():
57
75
  logger.error(
@@ -62,7 +80,7 @@ class Messager:
62
80
  topics = [topics]
63
81
  if not isinstance(agents, list):
64
82
  agents = [agents]
65
- await self.client.subscribe(topics, qos=1)
83
+ await self.client.subscribe(topics, qos=1) # type: ignore
66
84
 
67
85
  async def receive_messages(self):
68
86
  """监听并将消息存入队列"""
@@ -76,11 +94,26 @@ class Messager:
76
94
  messages.append(await self.message_queue.get())
77
95
  return messages
78
96
 
79
- async def send_message(self, topic: str, payload: dict):
97
+ async def send_message(
98
+ self,
99
+ topic: str,
100
+ payload: dict,
101
+ from_uuid: Optional[str] = None,
102
+ to_uuid: Optional[str] = None,
103
+ ):
80
104
  """通过 Messager 发送消息"""
81
105
  message = json.dumps(payload, default=str)
82
- await self.client.publish(topic=topic, payload=message, qos=1)
83
- logger.info(f"Message sent to {topic}: {message}")
106
+ interceptor = self.message_interceptor
107
+ is_valid: bool = True
108
+ if interceptor is not None and (from_uuid is not None and to_uuid is not None):
109
+ is_valid = await interceptor.forward.remote( # type:ignore
110
+ from_uuid, to_uuid, message
111
+ )
112
+ if is_valid:
113
+ await self.client.publish(topic=topic, payload=message, qos=1)
114
+ logger.info(f"Message sent to {topic}: {message}")
115
+ else:
116
+ logger.info(f"Message not sent to {topic}: {message} due to interceptor")
84
117
 
85
118
  async def start_listening(self):
86
119
  """启动消息监听任务"""
@@ -92,7 +125,5 @@ class Messager:
92
125
  async def stop(self):
93
126
  assert self.receive_messages_task is not None
94
127
  self.receive_messages_task.cancel()
95
- await asyncio.gather(
96
- self.receive_messages_task, return_exceptions=True
97
- )
128
+ await asyncio.gather(self.receive_messages_task, return_exceptions=True)
98
129
  await self.disconnect()
Binary file