pycityagent 2.0.0a53__cp39-cp39-macosx_11_0_arm64.whl → 2.0.0a54__cp39-cp39-macosx_11_0_arm64.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,322 @@
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from collections import defaultdict
6
+ from collections.abc import Callable, Sequence
7
+ from copy import deepcopy
8
+ from typing import Any, Optional, Union
9
+
10
+ import ray
11
+ from ray.util.queue import Queue
12
+
13
+ from ..llm import LLM, LLMConfig
14
+ from ..utils.decorators import lock_decorator
15
+
16
+ DEFAULT_ERROR_STRING = """
17
+ From `{from_uuid}` To `{to_uuid}` abort due to block `{block_name}`
18
+ """
19
+
20
+ logger = logging.getLogger("message_interceptor")
21
+
22
+
23
+ class MessageBlockBase(ABC):
24
+ """
25
+ 用于过滤的block
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ name: str = "",
31
+ ) -> None:
32
+ self._name = name
33
+ self._llm = None
34
+ self._lock = asyncio.Lock()
35
+
36
+ @property
37
+ def llm(
38
+ self,
39
+ ) -> LLM:
40
+ if self._llm is None:
41
+ raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
42
+ return self._llm
43
+
44
+ @property
45
+ def name(
46
+ self,
47
+ ):
48
+ return self._name
49
+
50
+ @property
51
+ def has_llm(
52
+ self,
53
+ ) -> bool:
54
+ return self._llm is not None
55
+
56
+ @lock_decorator
57
+ async def set_llm(self, llm: LLM):
58
+ """
59
+ Set the llm_client of the block.
60
+ """
61
+ self._llm = llm
62
+
63
+ @lock_decorator
64
+ async def set_name(self, name: str):
65
+ """
66
+ Set the name of the block.
67
+ """
68
+ self._name = name
69
+
70
+ @lock_decorator
71
+ async def forward(
72
+ self,
73
+ from_uuid: str,
74
+ to_uuid: str,
75
+ msg: str,
76
+ violation_counts: dict[str, int],
77
+ black_list: list[tuple[str, str]],
78
+ ) -> tuple[bool, str]:
79
+ return True, ""
80
+
81
+
82
+ @ray.remote
83
+ class MessageInterceptor:
84
+ """
85
+ 信息拦截器
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ blocks: Optional[list[MessageBlockBase]] = None,
91
+ black_list: Optional[list[tuple[str, str]]] = None,
92
+ llm_config: Optional[dict] = None,
93
+ queue: Optional[Queue] = None,
94
+ ) -> None:
95
+ if blocks is not None:
96
+ self._blocks: list[MessageBlockBase] = blocks
97
+ else:
98
+ self._blocks: list[MessageBlockBase] = []
99
+ self._violation_counts: dict[str, int] = defaultdict(int)
100
+ if black_list is not None:
101
+ self._black_list: list[tuple[str, str]] = black_list
102
+ else:
103
+ self._black_list: list[tuple[str, str]] = (
104
+ []
105
+ ) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
106
+ if llm_config:
107
+ self._llm = LLM(LLMConfig(llm_config))
108
+ else:
109
+ self._llm = None
110
+ self._queue = queue
111
+ self._lock = asyncio.Lock()
112
+
113
+ @property
114
+ def llm(
115
+ self,
116
+ ) -> LLM:
117
+ if self._llm is None:
118
+ raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
119
+ return self._llm
120
+
121
+ @lock_decorator
122
+ async def blocks(
123
+ self,
124
+ ) -> list[MessageBlockBase]:
125
+ return self._blocks
126
+
127
+ @lock_decorator
128
+ async def set_llm(self, llm: LLM):
129
+ """
130
+ Set the llm_client of the block.
131
+ """
132
+ if self._llm is None:
133
+ self._llm = llm
134
+
135
+ @lock_decorator
136
+ async def violation_counts(
137
+ self,
138
+ ) -> dict[str, int]:
139
+ return deepcopy(self._violation_counts)
140
+
141
+ @property
142
+ def has_llm(
143
+ self,
144
+ ) -> bool:
145
+ return self._llm is not None
146
+
147
+ @lock_decorator
148
+ async def black_list(
149
+ self,
150
+ ) -> list[tuple[str, str]]:
151
+ return deepcopy(self._black_list)
152
+
153
+ @lock_decorator
154
+ async def add_to_black_list(
155
+ self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
156
+ ):
157
+ if all(isinstance(s, str) for s in black_list):
158
+ # tuple[str,str]
159
+ _black_list = [black_list]
160
+ else:
161
+ _black_list = black_list
162
+ _black_list = [tuple(p) for p in _black_list]
163
+ self._black_list.extend(_black_list) # type: ignore
164
+ self._black_list = list(set(self._black_list))
165
+
166
+ @property
167
+ def has_queue(
168
+ self,
169
+ ) -> bool:
170
+ return self._queue is not None
171
+
172
+ @lock_decorator
173
+ async def set_queue(self, queue: Queue):
174
+ """
175
+ Set the queue of the MessageInterceptor.
176
+ """
177
+ self._queue = queue
178
+
179
+ @lock_decorator
180
+ async def remove_from_black_list(
181
+ self, to_remove_black_list: Union[list[tuple[str, str]], tuple[str, str]]
182
+ ):
183
+ if all(isinstance(s, str) for s in to_remove_black_list):
184
+ # tuple[str,str]
185
+ _black_list = [to_remove_black_list]
186
+ else:
187
+ _black_list = to_remove_black_list
188
+ _black_list_set = {tuple(p) for p in _black_list}
189
+ self._black_list = [p for p in self._black_list if p not in _black_list_set]
190
+
191
+ @property
192
+ def queue(
193
+ self,
194
+ ) -> Queue:
195
+ if self._queue is None:
196
+ raise RuntimeError(
197
+ f"Queue access before assignment, please `set_queue` first!"
198
+ )
199
+ return self._queue
200
+
201
+ @lock_decorator
202
+ async def insert_block(self, block: MessageBlockBase, index: Optional[int] = None):
203
+ if index is None:
204
+ index = len(self._blocks)
205
+ self._blocks.insert(index, block)
206
+
207
+ @lock_decorator
208
+ async def pop_block(self, index: Optional[int] = None) -> MessageBlockBase:
209
+ if index is None:
210
+ index = -1
211
+ return self._blocks.pop(index)
212
+
213
+ @lock_decorator
214
+ async def set_black_list(
215
+ self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
216
+ ):
217
+ if all(isinstance(s, str) for s in black_list):
218
+ # tuple[str,str]
219
+ _black_list = [black_list]
220
+ else:
221
+ _black_list = black_list
222
+ _black_list = [tuple(p) for p in _black_list]
223
+ self._black_list = list(set(_black_list)) # type: ignore
224
+
225
+ @lock_decorator
226
+ async def set_blocks(self, blocks: list[MessageBlockBase]):
227
+ self._blocks = blocks
228
+
229
+ @lock_decorator
230
+ async def forward(
231
+ self,
232
+ from_uuid: str,
233
+ to_uuid: str,
234
+ msg: str,
235
+ ):
236
+ for _block in self._blocks:
237
+ if not _block.has_llm and self.has_llm:
238
+ await _block.set_llm(self.llm)
239
+ func_params = inspect.signature(_block.forward).parameters
240
+ _args = {
241
+ "from_uuid": from_uuid,
242
+ "to_uuid": to_uuid,
243
+ "msg": msg,
244
+ "violation_counts": self._violation_counts,
245
+ "black_list": self._black_list,
246
+ }
247
+ _required_args = {k: v for k, v in _args.items() if k in func_params}
248
+ res = await _block.forward(**_required_args)
249
+ try:
250
+ is_valid, err = res
251
+ except TypeError as e:
252
+ is_valid: bool = res # type:ignore
253
+ err = (
254
+ DEFAULT_ERROR_STRING.format(
255
+ from_uuid=from_uuid,
256
+ to_uuid=to_uuid,
257
+ block_name=f"{_block.__class__.__name__} `{_block.name}`",
258
+ )
259
+ if not is_valid
260
+ else ""
261
+ )
262
+ if not is_valid:
263
+ if self.has_queue:
264
+ logger.debug(f"put `{err}` into queue")
265
+ await self.queue.put_async(err) # type:ignore
266
+ self._violation_counts[from_uuid] += 1
267
+ print(self._black_list)
268
+ return False
269
+ else:
270
+ # valid
271
+ pass
272
+ print(self._black_list)
273
+ return True
274
+
275
+
276
+ class MessageBlockListenerBase(ABC):
277
+ def __init__(
278
+ self,
279
+ save_queue_values: bool = False,
280
+ get_queue_period: float = 0.1,
281
+ ) -> None:
282
+ self._queue = None
283
+ self._lock = asyncio.Lock()
284
+ self._values_from_queue: list[Any] = []
285
+ self._save_queue_values = save_queue_values
286
+ self._get_queue_period = get_queue_period
287
+
288
+ @property
289
+ def queue(
290
+ self,
291
+ ) -> Queue:
292
+ if self._queue is None:
293
+ raise RuntimeError(
294
+ f"Queue access before assignment, please `set_queue` first!"
295
+ )
296
+ return self._queue
297
+
298
+ @property
299
+ def has_queue(
300
+ self,
301
+ ) -> bool:
302
+ return self._queue is not None
303
+
304
+ @lock_decorator
305
+ async def set_queue(self, queue: Queue):
306
+ """
307
+ Set the queue of the MessageBlockListenerBase.
308
+ """
309
+ self._queue = queue
310
+
311
+ @lock_decorator
312
+ async def forward(
313
+ self,
314
+ ):
315
+ while True:
316
+ if self.has_queue:
317
+ value = await self.queue.get_async() # type: ignore
318
+ if self._save_queue_values:
319
+ self._values_from_queue.append(value)
320
+ logger.debug(f"get `{value}` from queue")
321
+ # do something with the value
322
+ await asyncio.sleep(self._get_queue_period)
@@ -1,19 +1,26 @@
1
1
  import asyncio
2
2
  import json
3
3
  import logging
4
- import math
5
- from typing import Any, List, Union
4
+ from typing import Any, Optional, Union
6
5
 
7
6
  import ray
8
7
  from aiomqtt import Client
9
8
 
9
+ from .message_interceptor import MessageInterceptor
10
+
10
11
  logger = logging.getLogger("pycityagent")
11
12
 
12
13
 
13
14
  @ray.remote
14
15
  class Messager:
15
16
  def __init__(
16
- self, hostname: str, port: int = 1883, username=None, password=None, timeout=60
17
+ self,
18
+ hostname: str,
19
+ port: int = 1883,
20
+ username=None,
21
+ password=None,
22
+ timeout=60,
23
+ message_interceptor: Optional[ray.ObjectRef] = None,
17
24
  ):
18
25
  self.client = Client(
19
26
  hostname, port=port, username=username, password=password, timeout=timeout
@@ -21,6 +28,16 @@ class Messager:
21
28
  self.connected = False # 是否已连接标志
22
29
  self.message_queue = asyncio.Queue() # 用于存储接收到的消息
23
30
  self.receive_messages_task = None
31
+ self._message_interceptor = message_interceptor
32
+
33
+ @property
34
+ def message_interceptor(
35
+ self,
36
+ ) -> Union[None, ray.ObjectRef]:
37
+ return self._message_interceptor
38
+
39
+ def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
40
+ self._message_interceptor = message_interceptor
24
41
 
25
42
  async def __aexit__(self, exc_type, exc_value, traceback):
26
43
  await self.stop()
@@ -50,8 +67,9 @@ class Messager:
50
67
  """检查是否成功连接到 Broker"""
51
68
  return self.connected
52
69
 
70
+ # TODO:add message interceptor
53
71
  async def subscribe(
54
- self, topics: Union[str, List[str]], agents: Union[Any, List[Any]]
72
+ self, topics: Union[str, list[str]], agents: Union[Any, list[Any]]
55
73
  ):
56
74
  if not await self.is_connected():
57
75
  logger.error(
@@ -62,7 +80,7 @@ class Messager:
62
80
  topics = [topics]
63
81
  if not isinstance(agents, list):
64
82
  agents = [agents]
65
- await self.client.subscribe(topics, qos=1)
83
+ await self.client.subscribe(topics, qos=1) # type: ignore
66
84
 
67
85
  async def receive_messages(self):
68
86
  """监听并将消息存入队列"""
@@ -76,11 +94,26 @@ class Messager:
76
94
  messages.append(await self.message_queue.get())
77
95
  return messages
78
96
 
79
- async def send_message(self, topic: str, payload: dict):
97
+ async def send_message(
98
+ self,
99
+ topic: str,
100
+ payload: dict,
101
+ from_uuid: Optional[str] = None,
102
+ to_uuid: Optional[str] = None,
103
+ ):
80
104
  """通过 Messager 发送消息"""
81
105
  message = json.dumps(payload, default=str)
82
- await self.client.publish(topic=topic, payload=message, qos=1)
83
- logger.info(f"Message sent to {topic}: {message}")
106
+ interceptor = self.message_interceptor
107
+ is_valid: bool = True
108
+ if interceptor is not None and (from_uuid is not None and to_uuid is not None):
109
+ is_valid = await interceptor.forward.remote( # type:ignore
110
+ from_uuid, to_uuid, message
111
+ )
112
+ if is_valid:
113
+ await self.client.publish(topic=topic, payload=message, qos=1)
114
+ logger.info(f"Message sent to {topic}: {message}")
115
+ else:
116
+ logger.info(f"Message not sent to {topic}: {message} due to interceptor")
84
117
 
85
118
  async def start_listening(self):
86
119
  """启动消息监听任务"""
@@ -92,7 +125,5 @@ class Messager:
92
125
  async def stop(self):
93
126
  assert self.receive_messages_task is not None
94
127
  self.receive_messages_task.cancel()
95
- await asyncio.gather(
96
- self.receive_messages_task, return_exceptions=True
97
- )
128
+ await asyncio.gather(self.receive_messages_task, return_exceptions=True)
98
129
  await self.disconnect()
@@ -16,11 +16,12 @@ from langchain_core.embeddings import Embeddings
16
16
 
17
17
  from ..agent import Agent, InstitutionAgent
18
18
  from ..economy.econ_client import EconomyClient
19
- from ..environment.simulator import Simulator
19
+ from ..environment import Simulator
20
20
  from ..llm.llm import LLM
21
21
  from ..llm.llmconfig import LLMConfig
22
22
  from ..memory import FaissQuery, Memory
23
23
  from ..message import Messager
24
+ from ..metrics import MlflowClient
24
25
  from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
25
26
  STATUS_SCHEMA, SURVEY_SCHEMA)
26
27
 
@@ -38,11 +39,14 @@ class AgentGroup:
38
39
  list[Callable[[], tuple[dict, dict, dict]]],
39
40
  ],
40
41
  config: dict,
42
+ exp_name: str,
41
43
  exp_id: str | UUID,
42
44
  enable_avro: bool,
43
45
  avro_path: Path,
44
46
  enable_pgsql: bool,
45
47
  pgsql_writer: ray.ObjectRef,
48
+ message_interceptor: ray.ObjectRef,
49
+ mlflow_run_id: str,
46
50
  embedding_model: Embeddings,
47
51
  logging_level: int,
48
52
  agent_config_file: Optional[Union[str, list[str]]] = None,
@@ -77,6 +81,18 @@ class AgentGroup:
77
81
  }
78
82
  if self.enable_pgsql:
79
83
  pass
84
+ # Mlflow
85
+ _mlflow_config = config.get("metric_request", {}).get("mlflow")
86
+ if _mlflow_config:
87
+ logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
88
+ self.mlflow_client = MlflowClient(
89
+ config=_mlflow_config,
90
+ mlflow_run_name=f"{exp_name}_{1000*int(time.time())}",
91
+ experiment_name=exp_name,
92
+ run_id=mlflow_run_id,
93
+ )
94
+ else:
95
+ self.mlflow_client = None
80
96
 
81
97
  # prepare Messager
82
98
  if "mqtt" in config["simulator_request"]:
@@ -91,6 +107,7 @@ class AgentGroup:
91
107
 
92
108
  self.message_dispatch_task = None
93
109
  self._pgsql_writer = pgsql_writer
110
+ self._message_interceptor = message_interceptor
94
111
  self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
95
112
  self.initialized = False
96
113
  self.id2agent = {}
@@ -126,13 +143,9 @@ class AgentGroup:
126
143
  for j in range(number_of_agents_i):
127
144
  memory_config_function_group_i = memory_config_function_group[i]
128
145
  extra_attributes, profile, base = memory_config_function_group_i()
129
- memory = Memory(
130
- config=extra_attributes,
131
- profile=profile,
132
- base=base
133
- )
146
+ memory = Memory(config=extra_attributes, profile=profile, base=base)
134
147
  agent = agent_class_i(
135
- name=f"{agent_class_i.__name__}_{i}",
148
+ name=f"{agent_class_i.__name__}_{i}", # type: ignore
136
149
  memory=memory,
137
150
  llm_client=self.llm,
138
151
  economy_client=self.economy_client,
@@ -141,12 +154,16 @@ class AgentGroup:
141
154
  agent.set_exp_id(self.exp_id) # type: ignore
142
155
  if self.messager is not None:
143
156
  agent.set_messager(self.messager)
157
+ if self.mlflow_client is not None:
158
+ agent.set_mlflow_client(self.mlflow_client) # type: ignore
144
159
  if self.enable_avro:
145
160
  agent.set_avro_file(self.avro_file) # type: ignore
146
161
  if self.enable_pgsql:
147
162
  agent.set_pgsql_writer(self._pgsql_writer)
148
163
  if self.agent_config_file is not None and self.agent_config_file[i]:
149
164
  agent.load_from_file(self.agent_config_file[i])
165
+ if self._message_interceptor is not None:
166
+ agent.set_message_interceptor(self._message_interceptor)
150
167
  self.agents.append(agent)
151
168
  self.id2agent[agent._uuid] = agent
152
169
 
@@ -287,11 +304,13 @@ class AgentGroup:
287
304
  embedding_tasks = []
288
305
  for agent in self.agents:
289
306
  embedding_tasks.append(agent.memory.initialize_embeddings())
290
- agent.memory.set_search_components(self.faiss_query, self.embedding_model)
307
+ agent.memory.set_search_components(
308
+ self.faiss_query, self.embedding_model
309
+ )
291
310
  agent.memory.set_simulator(self.simulator)
292
311
  await asyncio.gather(*embedding_tasks)
293
312
  logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
294
-
313
+
295
314
  self.initialized = True
296
315
  logger.debug(f"-----AgentGroup {self._uuid} initialized")
297
316
 
@@ -324,7 +343,9 @@ class AgentGroup:
324
343
  filtered_uuids.append(agent._uuid)
325
344
  return filtered_uuids
326
345
 
327
- async def gather(self, content: str, target_agent_uuids: Optional[list[str]] = None):
346
+ async def gather(
347
+ self, content: str, target_agent_uuids: Optional[list[str]] = None
348
+ ):
328
349
  logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
329
350
  results = {}
330
351
  if target_agent_uuids is None:
@@ -522,6 +543,11 @@ class AgentGroup:
522
543
  ]:
523
544
  if key not in _status_dict:
524
545
  _status_dict[key] = ""
546
+ for key in [
547
+ "friend_ids",
548
+ ]:
549
+ if key not in _status_dict:
550
+ _status_dict[key] = []
525
551
  _status_dict["created_at"] = _date_time
526
552
  else:
527
553
  if not issubclass(type(self.agents[0]), InstitutionAgent):
@@ -537,10 +563,19 @@ class AgentGroup:
537
563
  parent_id = position["lane_position"]["lane_id"]
538
564
  else:
539
565
  parent_id = -1
540
- hunger_satisfaction = await agent.status.get("hunger_satisfaction")
541
- energy_satisfaction = await agent.status.get("energy_satisfaction")
542
- safety_satisfaction = await agent.status.get("safety_satisfaction")
543
- social_satisfaction = await agent.status.get("social_satisfaction")
566
+ hunger_satisfaction = await agent.status.get(
567
+ "hunger_satisfaction"
568
+ )
569
+ energy_satisfaction = await agent.status.get(
570
+ "energy_satisfaction"
571
+ )
572
+ safety_satisfaction = await agent.status.get(
573
+ "safety_satisfaction"
574
+ )
575
+ social_satisfaction = await agent.status.get(
576
+ "social_satisfaction"
577
+ )
578
+ friend_ids = await agent.status.get("friends")
544
579
  action = await agent.status.get("current_step")
545
580
  action = action["intention"]
546
581
  _status_dict = {
@@ -550,6 +585,9 @@ class AgentGroup:
550
585
  "lng": lng,
551
586
  "lat": lat,
552
587
  "parent_id": parent_id,
588
+ "friend_ids": [
589
+ str(_friend_id) for _friend_id in friend_ids
590
+ ],
553
591
  "action": action,
554
592
  "hungry": hunger_satisfaction,
555
593
  "tired": energy_satisfaction,
@@ -612,6 +650,10 @@ class AgentGroup:
612
650
  employees = await agent.status.get("employees")
613
651
  except:
614
652
  employees = []
653
+ try:
654
+ friend_ids = await agent.status.get("friends")
655
+ except:
656
+ friend_ids = []
615
657
  _status_dict = {
616
658
  "id": agent._uuid,
617
659
  "day": _day,
@@ -619,6 +661,9 @@ class AgentGroup:
619
661
  "lng": lng,
620
662
  "lat": lat,
621
663
  "parent_id": parent_id,
664
+ "friend_ids": [
665
+ str(_friend_id) for _friend_id in friend_ids
666
+ ],
622
667
  "action": "",
623
668
  "type": await agent.status.get("type"),
624
669
  "nominal_gdp": nominal_gdp,
@@ -644,6 +689,7 @@ class AgentGroup:
644
689
  "lng",
645
690
  "lat",
646
691
  "parent_id",
692
+ "friend_ids",
647
693
  "action",
648
694
  "created_at",
649
695
  ]
@@ -668,6 +714,7 @@ class AgentGroup:
668
714
  await asyncio.gather(*tasks)
669
715
  except Exception as e:
670
716
  import traceback
717
+
671
718
  logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
672
719
  raise RuntimeError(str(e)) from e
673
720
 
@@ -676,5 +723,6 @@ class AgentGroup:
676
723
  await self.save_status(day, t)
677
724
  except Exception as e:
678
725
  import traceback
726
+
679
727
  logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
680
728
  raise RuntimeError(str(e)) from e