pycityagent 2.0.0a53__cp311-cp311-macosx_11_0_arm64.whl → 2.0.0a54__cp311-cp311-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- pycityagent/agent/agent.py +39 -4
- pycityagent/agent/agent_base.py +39 -25
- pycityagent/cityagent/blocks/plan_block.py +1 -1
- pycityagent/cityagent/message_intercept.py +99 -0
- pycityagent/cityagent/societyagent.py +145 -32
- pycityagent/cli/wrapper.py +4 -0
- pycityagent/economy/econ_client.py +0 -2
- pycityagent/environment/__init__.py +7 -1
- pycityagent/environment/sim/sim_env.py +34 -46
- pycityagent/environment/simulator.py +2 -3
- pycityagent/llm/llm.py +18 -10
- pycityagent/memory/memory.py +151 -113
- pycityagent/message/__init__.py +8 -1
- pycityagent/message/message_interceptor.py +322 -0
- pycityagent/message/messager.py +42 -11
- pycityagent/simulation/agentgroup.py +62 -14
- pycityagent/simulation/simulation.py +95 -24
- pycityagent/simulation/storage/pg.py +2 -2
- pycityagent/utils/__init__.py +7 -2
- pycityagent/utils/pg_query.py +1 -0
- pycityagent/utils/survey_util.py +26 -23
- pycityagent/workflow/block.py +3 -3
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a54.dist-info}/METADATA +2 -2
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a54.dist-info}/RECORD +28 -26
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a54.dist-info}/LICENSE +0 -0
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a54.dist-info}/WHEEL +0 -0
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a54.dist-info}/entry_points.txt +0 -0
- {pycityagent-2.0.0a53.dist-info → pycityagent-2.0.0a54.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,322 @@
|
|
1
|
+
import asyncio
|
2
|
+
import inspect
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from collections import defaultdict
|
6
|
+
from collections.abc import Callable, Sequence
|
7
|
+
from copy import deepcopy
|
8
|
+
from typing import Any, Optional, Union
|
9
|
+
|
10
|
+
import ray
|
11
|
+
from ray.util.queue import Queue
|
12
|
+
|
13
|
+
from ..llm import LLM, LLMConfig
|
14
|
+
from ..utils.decorators import lock_decorator
|
15
|
+
|
16
|
+
DEFAULT_ERROR_STRING = """
|
17
|
+
From `{from_uuid}` To `{to_uuid}` abort due to block `{block_name}`
|
18
|
+
"""
|
19
|
+
|
20
|
+
logger = logging.getLogger("message_interceptor")
|
21
|
+
|
22
|
+
|
23
|
+
class MessageBlockBase(ABC):
|
24
|
+
"""
|
25
|
+
用于过滤的block
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
name: str = "",
|
31
|
+
) -> None:
|
32
|
+
self._name = name
|
33
|
+
self._llm = None
|
34
|
+
self._lock = asyncio.Lock()
|
35
|
+
|
36
|
+
@property
|
37
|
+
def llm(
|
38
|
+
self,
|
39
|
+
) -> LLM:
|
40
|
+
if self._llm is None:
|
41
|
+
raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
|
42
|
+
return self._llm
|
43
|
+
|
44
|
+
@property
|
45
|
+
def name(
|
46
|
+
self,
|
47
|
+
):
|
48
|
+
return self._name
|
49
|
+
|
50
|
+
@property
|
51
|
+
def has_llm(
|
52
|
+
self,
|
53
|
+
) -> bool:
|
54
|
+
return self._llm is not None
|
55
|
+
|
56
|
+
@lock_decorator
|
57
|
+
async def set_llm(self, llm: LLM):
|
58
|
+
"""
|
59
|
+
Set the llm_client of the block.
|
60
|
+
"""
|
61
|
+
self._llm = llm
|
62
|
+
|
63
|
+
@lock_decorator
|
64
|
+
async def set_name(self, name: str):
|
65
|
+
"""
|
66
|
+
Set the name of the block.
|
67
|
+
"""
|
68
|
+
self._name = name
|
69
|
+
|
70
|
+
@lock_decorator
|
71
|
+
async def forward(
|
72
|
+
self,
|
73
|
+
from_uuid: str,
|
74
|
+
to_uuid: str,
|
75
|
+
msg: str,
|
76
|
+
violation_counts: dict[str, int],
|
77
|
+
black_list: list[tuple[str, str]],
|
78
|
+
) -> tuple[bool, str]:
|
79
|
+
return True, ""
|
80
|
+
|
81
|
+
|
82
|
+
@ray.remote
|
83
|
+
class MessageInterceptor:
|
84
|
+
"""
|
85
|
+
信息拦截器
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
blocks: Optional[list[MessageBlockBase]] = None,
|
91
|
+
black_list: Optional[list[tuple[str, str]]] = None,
|
92
|
+
llm_config: Optional[dict] = None,
|
93
|
+
queue: Optional[Queue] = None,
|
94
|
+
) -> None:
|
95
|
+
if blocks is not None:
|
96
|
+
self._blocks: list[MessageBlockBase] = blocks
|
97
|
+
else:
|
98
|
+
self._blocks: list[MessageBlockBase] = []
|
99
|
+
self._violation_counts: dict[str, int] = defaultdict(int)
|
100
|
+
if black_list is not None:
|
101
|
+
self._black_list: list[tuple[str, str]] = black_list
|
102
|
+
else:
|
103
|
+
self._black_list: list[tuple[str, str]] = (
|
104
|
+
[]
|
105
|
+
) # list[tuple(from_uuid, to_uuid)] `None` means forbidden for everyone.
|
106
|
+
if llm_config:
|
107
|
+
self._llm = LLM(LLMConfig(llm_config))
|
108
|
+
else:
|
109
|
+
self._llm = None
|
110
|
+
self._queue = queue
|
111
|
+
self._lock = asyncio.Lock()
|
112
|
+
|
113
|
+
@property
|
114
|
+
def llm(
|
115
|
+
self,
|
116
|
+
) -> LLM:
|
117
|
+
if self._llm is None:
|
118
|
+
raise RuntimeError(f"LLM access before assignment, please `set_llm` first!")
|
119
|
+
return self._llm
|
120
|
+
|
121
|
+
@lock_decorator
|
122
|
+
async def blocks(
|
123
|
+
self,
|
124
|
+
) -> list[MessageBlockBase]:
|
125
|
+
return self._blocks
|
126
|
+
|
127
|
+
@lock_decorator
|
128
|
+
async def set_llm(self, llm: LLM):
|
129
|
+
"""
|
130
|
+
Set the llm_client of the block.
|
131
|
+
"""
|
132
|
+
if self._llm is None:
|
133
|
+
self._llm = llm
|
134
|
+
|
135
|
+
@lock_decorator
|
136
|
+
async def violation_counts(
|
137
|
+
self,
|
138
|
+
) -> dict[str, int]:
|
139
|
+
return deepcopy(self._violation_counts)
|
140
|
+
|
141
|
+
@property
|
142
|
+
def has_llm(
|
143
|
+
self,
|
144
|
+
) -> bool:
|
145
|
+
return self._llm is not None
|
146
|
+
|
147
|
+
@lock_decorator
|
148
|
+
async def black_list(
|
149
|
+
self,
|
150
|
+
) -> list[tuple[str, str]]:
|
151
|
+
return deepcopy(self._black_list)
|
152
|
+
|
153
|
+
@lock_decorator
|
154
|
+
async def add_to_black_list(
|
155
|
+
self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
156
|
+
):
|
157
|
+
if all(isinstance(s, str) for s in black_list):
|
158
|
+
# tuple[str,str]
|
159
|
+
_black_list = [black_list]
|
160
|
+
else:
|
161
|
+
_black_list = black_list
|
162
|
+
_black_list = [tuple(p) for p in _black_list]
|
163
|
+
self._black_list.extend(_black_list) # type: ignore
|
164
|
+
self._black_list = list(set(self._black_list))
|
165
|
+
|
166
|
+
@property
|
167
|
+
def has_queue(
|
168
|
+
self,
|
169
|
+
) -> bool:
|
170
|
+
return self._queue is not None
|
171
|
+
|
172
|
+
@lock_decorator
|
173
|
+
async def set_queue(self, queue: Queue):
|
174
|
+
"""
|
175
|
+
Set the queue of the MessageInterceptor.
|
176
|
+
"""
|
177
|
+
self._queue = queue
|
178
|
+
|
179
|
+
@lock_decorator
|
180
|
+
async def remove_from_black_list(
|
181
|
+
self, to_remove_black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
182
|
+
):
|
183
|
+
if all(isinstance(s, str) for s in to_remove_black_list):
|
184
|
+
# tuple[str,str]
|
185
|
+
_black_list = [to_remove_black_list]
|
186
|
+
else:
|
187
|
+
_black_list = to_remove_black_list
|
188
|
+
_black_list_set = {tuple(p) for p in _black_list}
|
189
|
+
self._black_list = [p for p in self._black_list if p not in _black_list_set]
|
190
|
+
|
191
|
+
@property
|
192
|
+
def queue(
|
193
|
+
self,
|
194
|
+
) -> Queue:
|
195
|
+
if self._queue is None:
|
196
|
+
raise RuntimeError(
|
197
|
+
f"Queue access before assignment, please `set_queue` first!"
|
198
|
+
)
|
199
|
+
return self._queue
|
200
|
+
|
201
|
+
@lock_decorator
|
202
|
+
async def insert_block(self, block: MessageBlockBase, index: Optional[int] = None):
|
203
|
+
if index is None:
|
204
|
+
index = len(self._blocks)
|
205
|
+
self._blocks.insert(index, block)
|
206
|
+
|
207
|
+
@lock_decorator
|
208
|
+
async def pop_block(self, index: Optional[int] = None) -> MessageBlockBase:
|
209
|
+
if index is None:
|
210
|
+
index = -1
|
211
|
+
return self._blocks.pop(index)
|
212
|
+
|
213
|
+
@lock_decorator
|
214
|
+
async def set_black_list(
|
215
|
+
self, black_list: Union[list[tuple[str, str]], tuple[str, str]]
|
216
|
+
):
|
217
|
+
if all(isinstance(s, str) for s in black_list):
|
218
|
+
# tuple[str,str]
|
219
|
+
_black_list = [black_list]
|
220
|
+
else:
|
221
|
+
_black_list = black_list
|
222
|
+
_black_list = [tuple(p) for p in _black_list]
|
223
|
+
self._black_list = list(set(_black_list)) # type: ignore
|
224
|
+
|
225
|
+
@lock_decorator
|
226
|
+
async def set_blocks(self, blocks: list[MessageBlockBase]):
|
227
|
+
self._blocks = blocks
|
228
|
+
|
229
|
+
@lock_decorator
|
230
|
+
async def forward(
|
231
|
+
self,
|
232
|
+
from_uuid: str,
|
233
|
+
to_uuid: str,
|
234
|
+
msg: str,
|
235
|
+
):
|
236
|
+
for _block in self._blocks:
|
237
|
+
if not _block.has_llm and self.has_llm:
|
238
|
+
await _block.set_llm(self.llm)
|
239
|
+
func_params = inspect.signature(_block.forward).parameters
|
240
|
+
_args = {
|
241
|
+
"from_uuid": from_uuid,
|
242
|
+
"to_uuid": to_uuid,
|
243
|
+
"msg": msg,
|
244
|
+
"violation_counts": self._violation_counts,
|
245
|
+
"black_list": self._black_list,
|
246
|
+
}
|
247
|
+
_required_args = {k: v for k, v in _args.items() if k in func_params}
|
248
|
+
res = await _block.forward(**_required_args)
|
249
|
+
try:
|
250
|
+
is_valid, err = res
|
251
|
+
except TypeError as e:
|
252
|
+
is_valid: bool = res # type:ignore
|
253
|
+
err = (
|
254
|
+
DEFAULT_ERROR_STRING.format(
|
255
|
+
from_uuid=from_uuid,
|
256
|
+
to_uuid=to_uuid,
|
257
|
+
block_name=f"{_block.__class__.__name__} `{_block.name}`",
|
258
|
+
)
|
259
|
+
if not is_valid
|
260
|
+
else ""
|
261
|
+
)
|
262
|
+
if not is_valid:
|
263
|
+
if self.has_queue:
|
264
|
+
logger.debug(f"put `{err}` into queue")
|
265
|
+
await self.queue.put_async(err) # type:ignore
|
266
|
+
self._violation_counts[from_uuid] += 1
|
267
|
+
print(self._black_list)
|
268
|
+
return False
|
269
|
+
else:
|
270
|
+
# valid
|
271
|
+
pass
|
272
|
+
print(self._black_list)
|
273
|
+
return True
|
274
|
+
|
275
|
+
|
276
|
+
class MessageBlockListenerBase(ABC):
|
277
|
+
def __init__(
|
278
|
+
self,
|
279
|
+
save_queue_values: bool = False,
|
280
|
+
get_queue_period: float = 0.1,
|
281
|
+
) -> None:
|
282
|
+
self._queue = None
|
283
|
+
self._lock = asyncio.Lock()
|
284
|
+
self._values_from_queue: list[Any] = []
|
285
|
+
self._save_queue_values = save_queue_values
|
286
|
+
self._get_queue_period = get_queue_period
|
287
|
+
|
288
|
+
@property
|
289
|
+
def queue(
|
290
|
+
self,
|
291
|
+
) -> Queue:
|
292
|
+
if self._queue is None:
|
293
|
+
raise RuntimeError(
|
294
|
+
f"Queue access before assignment, please `set_queue` first!"
|
295
|
+
)
|
296
|
+
return self._queue
|
297
|
+
|
298
|
+
@property
|
299
|
+
def has_queue(
|
300
|
+
self,
|
301
|
+
) -> bool:
|
302
|
+
return self._queue is not None
|
303
|
+
|
304
|
+
@lock_decorator
|
305
|
+
async def set_queue(self, queue: Queue):
|
306
|
+
"""
|
307
|
+
Set the queue of the MessageBlockListenerBase.
|
308
|
+
"""
|
309
|
+
self._queue = queue
|
310
|
+
|
311
|
+
@lock_decorator
|
312
|
+
async def forward(
|
313
|
+
self,
|
314
|
+
):
|
315
|
+
while True:
|
316
|
+
if self.has_queue:
|
317
|
+
value = await self.queue.get_async() # type: ignore
|
318
|
+
if self._save_queue_values:
|
319
|
+
self._values_from_queue.append(value)
|
320
|
+
logger.debug(f"get `{value}` from queue")
|
321
|
+
# do something with the value
|
322
|
+
await asyncio.sleep(self._get_queue_period)
|
pycityagent/message/messager.py
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
import asyncio
|
2
2
|
import json
|
3
3
|
import logging
|
4
|
-
import
|
5
|
-
from typing import Any, List, Union
|
4
|
+
from typing import Any, Optional, Union
|
6
5
|
|
7
6
|
import ray
|
8
7
|
from aiomqtt import Client
|
9
8
|
|
9
|
+
from .message_interceptor import MessageInterceptor
|
10
|
+
|
10
11
|
logger = logging.getLogger("pycityagent")
|
11
12
|
|
12
13
|
|
13
14
|
@ray.remote
|
14
15
|
class Messager:
|
15
16
|
def __init__(
|
16
|
-
self,
|
17
|
+
self,
|
18
|
+
hostname: str,
|
19
|
+
port: int = 1883,
|
20
|
+
username=None,
|
21
|
+
password=None,
|
22
|
+
timeout=60,
|
23
|
+
message_interceptor: Optional[ray.ObjectRef] = None,
|
17
24
|
):
|
18
25
|
self.client = Client(
|
19
26
|
hostname, port=port, username=username, password=password, timeout=timeout
|
@@ -21,6 +28,16 @@ class Messager:
|
|
21
28
|
self.connected = False # 是否已连接标志
|
22
29
|
self.message_queue = asyncio.Queue() # 用于存储接收到的消息
|
23
30
|
self.receive_messages_task = None
|
31
|
+
self._message_interceptor = message_interceptor
|
32
|
+
|
33
|
+
@property
|
34
|
+
def message_interceptor(
|
35
|
+
self,
|
36
|
+
) -> Union[None, ray.ObjectRef]:
|
37
|
+
return self._message_interceptor
|
38
|
+
|
39
|
+
def set_message_interceptor(self, message_interceptor: ray.ObjectRef):
|
40
|
+
self._message_interceptor = message_interceptor
|
24
41
|
|
25
42
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
26
43
|
await self.stop()
|
@@ -50,8 +67,9 @@ class Messager:
|
|
50
67
|
"""检查是否成功连接到 Broker"""
|
51
68
|
return self.connected
|
52
69
|
|
70
|
+
# TODO:add message interceptor
|
53
71
|
async def subscribe(
|
54
|
-
self, topics: Union[str,
|
72
|
+
self, topics: Union[str, list[str]], agents: Union[Any, list[Any]]
|
55
73
|
):
|
56
74
|
if not await self.is_connected():
|
57
75
|
logger.error(
|
@@ -62,7 +80,7 @@ class Messager:
|
|
62
80
|
topics = [topics]
|
63
81
|
if not isinstance(agents, list):
|
64
82
|
agents = [agents]
|
65
|
-
await self.client.subscribe(topics, qos=1)
|
83
|
+
await self.client.subscribe(topics, qos=1) # type: ignore
|
66
84
|
|
67
85
|
async def receive_messages(self):
|
68
86
|
"""监听并将消息存入队列"""
|
@@ -76,11 +94,26 @@ class Messager:
|
|
76
94
|
messages.append(await self.message_queue.get())
|
77
95
|
return messages
|
78
96
|
|
79
|
-
async def send_message(
|
97
|
+
async def send_message(
|
98
|
+
self,
|
99
|
+
topic: str,
|
100
|
+
payload: dict,
|
101
|
+
from_uuid: Optional[str] = None,
|
102
|
+
to_uuid: Optional[str] = None,
|
103
|
+
):
|
80
104
|
"""通过 Messager 发送消息"""
|
81
105
|
message = json.dumps(payload, default=str)
|
82
|
-
|
83
|
-
|
106
|
+
interceptor = self.message_interceptor
|
107
|
+
is_valid: bool = True
|
108
|
+
if interceptor is not None and (from_uuid is not None and to_uuid is not None):
|
109
|
+
is_valid = await interceptor.forward.remote( # type:ignore
|
110
|
+
from_uuid, to_uuid, message
|
111
|
+
)
|
112
|
+
if is_valid:
|
113
|
+
await self.client.publish(topic=topic, payload=message, qos=1)
|
114
|
+
logger.info(f"Message sent to {topic}: {message}")
|
115
|
+
else:
|
116
|
+
logger.info(f"Message not sent to {topic}: {message} due to interceptor")
|
84
117
|
|
85
118
|
async def start_listening(self):
|
86
119
|
"""启动消息监听任务"""
|
@@ -92,7 +125,5 @@ class Messager:
|
|
92
125
|
async def stop(self):
|
93
126
|
assert self.receive_messages_task is not None
|
94
127
|
self.receive_messages_task.cancel()
|
95
|
-
await asyncio.gather(
|
96
|
-
self.receive_messages_task, return_exceptions=True
|
97
|
-
)
|
128
|
+
await asyncio.gather(self.receive_messages_task, return_exceptions=True)
|
98
129
|
await self.disconnect()
|
@@ -16,11 +16,12 @@ from langchain_core.embeddings import Embeddings
|
|
16
16
|
|
17
17
|
from ..agent import Agent, InstitutionAgent
|
18
18
|
from ..economy.econ_client import EconomyClient
|
19
|
-
from ..environment
|
19
|
+
from ..environment import Simulator
|
20
20
|
from ..llm.llm import LLM
|
21
21
|
from ..llm.llmconfig import LLMConfig
|
22
22
|
from ..memory import FaissQuery, Memory
|
23
23
|
from ..message import Messager
|
24
|
+
from ..metrics import MlflowClient
|
24
25
|
from ..utils import (DIALOG_SCHEMA, INSTITUTION_STATUS_SCHEMA, PROFILE_SCHEMA,
|
25
26
|
STATUS_SCHEMA, SURVEY_SCHEMA)
|
26
27
|
|
@@ -38,11 +39,14 @@ class AgentGroup:
|
|
38
39
|
list[Callable[[], tuple[dict, dict, dict]]],
|
39
40
|
],
|
40
41
|
config: dict,
|
42
|
+
exp_name: str,
|
41
43
|
exp_id: str | UUID,
|
42
44
|
enable_avro: bool,
|
43
45
|
avro_path: Path,
|
44
46
|
enable_pgsql: bool,
|
45
47
|
pgsql_writer: ray.ObjectRef,
|
48
|
+
message_interceptor: ray.ObjectRef,
|
49
|
+
mlflow_run_id: str,
|
46
50
|
embedding_model: Embeddings,
|
47
51
|
logging_level: int,
|
48
52
|
agent_config_file: Optional[Union[str, list[str]]] = None,
|
@@ -77,6 +81,18 @@ class AgentGroup:
|
|
77
81
|
}
|
78
82
|
if self.enable_pgsql:
|
79
83
|
pass
|
84
|
+
# Mlflow
|
85
|
+
_mlflow_config = config.get("metric_request", {}).get("mlflow")
|
86
|
+
if _mlflow_config:
|
87
|
+
logger.info(f"-----Creating Mlflow client in AgentGroup {self._uuid} ...")
|
88
|
+
self.mlflow_client = MlflowClient(
|
89
|
+
config=_mlflow_config,
|
90
|
+
mlflow_run_name=f"{exp_name}_{1000*int(time.time())}",
|
91
|
+
experiment_name=exp_name,
|
92
|
+
run_id=mlflow_run_id,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
self.mlflow_client = None
|
80
96
|
|
81
97
|
# prepare Messager
|
82
98
|
if "mqtt" in config["simulator_request"]:
|
@@ -91,6 +107,7 @@ class AgentGroup:
|
|
91
107
|
|
92
108
|
self.message_dispatch_task = None
|
93
109
|
self._pgsql_writer = pgsql_writer
|
110
|
+
self._message_interceptor = message_interceptor
|
94
111
|
self._last_asyncio_pg_task = None # 将SQL写入的IO隐藏到计算任务后
|
95
112
|
self.initialized = False
|
96
113
|
self.id2agent = {}
|
@@ -126,13 +143,9 @@ class AgentGroup:
|
|
126
143
|
for j in range(number_of_agents_i):
|
127
144
|
memory_config_function_group_i = memory_config_function_group[i]
|
128
145
|
extra_attributes, profile, base = memory_config_function_group_i()
|
129
|
-
memory = Memory(
|
130
|
-
config=extra_attributes,
|
131
|
-
profile=profile,
|
132
|
-
base=base
|
133
|
-
)
|
146
|
+
memory = Memory(config=extra_attributes, profile=profile, base=base)
|
134
147
|
agent = agent_class_i(
|
135
|
-
name=f"{agent_class_i.__name__}_{i}",
|
148
|
+
name=f"{agent_class_i.__name__}_{i}", # type: ignore
|
136
149
|
memory=memory,
|
137
150
|
llm_client=self.llm,
|
138
151
|
economy_client=self.economy_client,
|
@@ -141,12 +154,16 @@ class AgentGroup:
|
|
141
154
|
agent.set_exp_id(self.exp_id) # type: ignore
|
142
155
|
if self.messager is not None:
|
143
156
|
agent.set_messager(self.messager)
|
157
|
+
if self.mlflow_client is not None:
|
158
|
+
agent.set_mlflow_client(self.mlflow_client) # type: ignore
|
144
159
|
if self.enable_avro:
|
145
160
|
agent.set_avro_file(self.avro_file) # type: ignore
|
146
161
|
if self.enable_pgsql:
|
147
162
|
agent.set_pgsql_writer(self._pgsql_writer)
|
148
163
|
if self.agent_config_file is not None and self.agent_config_file[i]:
|
149
164
|
agent.load_from_file(self.agent_config_file[i])
|
165
|
+
if self._message_interceptor is not None:
|
166
|
+
agent.set_message_interceptor(self._message_interceptor)
|
150
167
|
self.agents.append(agent)
|
151
168
|
self.id2agent[agent._uuid] = agent
|
152
169
|
|
@@ -287,11 +304,13 @@ class AgentGroup:
|
|
287
304
|
embedding_tasks = []
|
288
305
|
for agent in self.agents:
|
289
306
|
embedding_tasks.append(agent.memory.initialize_embeddings())
|
290
|
-
agent.memory.set_search_components(
|
307
|
+
agent.memory.set_search_components(
|
308
|
+
self.faiss_query, self.embedding_model
|
309
|
+
)
|
291
310
|
agent.memory.set_simulator(self.simulator)
|
292
311
|
await asyncio.gather(*embedding_tasks)
|
293
312
|
logger.debug(f"-----Embedding initialized in AgentGroup {self._uuid} ...")
|
294
|
-
|
313
|
+
|
295
314
|
self.initialized = True
|
296
315
|
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
297
316
|
|
@@ -324,7 +343,9 @@ class AgentGroup:
|
|
324
343
|
filtered_uuids.append(agent._uuid)
|
325
344
|
return filtered_uuids
|
326
345
|
|
327
|
-
async def gather(
|
346
|
+
async def gather(
|
347
|
+
self, content: str, target_agent_uuids: Optional[list[str]] = None
|
348
|
+
):
|
328
349
|
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
329
350
|
results = {}
|
330
351
|
if target_agent_uuids is None:
|
@@ -522,6 +543,11 @@ class AgentGroup:
|
|
522
543
|
]:
|
523
544
|
if key not in _status_dict:
|
524
545
|
_status_dict[key] = ""
|
546
|
+
for key in [
|
547
|
+
"friend_ids",
|
548
|
+
]:
|
549
|
+
if key not in _status_dict:
|
550
|
+
_status_dict[key] = []
|
525
551
|
_status_dict["created_at"] = _date_time
|
526
552
|
else:
|
527
553
|
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
@@ -537,10 +563,19 @@ class AgentGroup:
|
|
537
563
|
parent_id = position["lane_position"]["lane_id"]
|
538
564
|
else:
|
539
565
|
parent_id = -1
|
540
|
-
hunger_satisfaction = await agent.status.get(
|
541
|
-
|
542
|
-
|
543
|
-
|
566
|
+
hunger_satisfaction = await agent.status.get(
|
567
|
+
"hunger_satisfaction"
|
568
|
+
)
|
569
|
+
energy_satisfaction = await agent.status.get(
|
570
|
+
"energy_satisfaction"
|
571
|
+
)
|
572
|
+
safety_satisfaction = await agent.status.get(
|
573
|
+
"safety_satisfaction"
|
574
|
+
)
|
575
|
+
social_satisfaction = await agent.status.get(
|
576
|
+
"social_satisfaction"
|
577
|
+
)
|
578
|
+
friend_ids = await agent.status.get("friends")
|
544
579
|
action = await agent.status.get("current_step")
|
545
580
|
action = action["intention"]
|
546
581
|
_status_dict = {
|
@@ -550,6 +585,9 @@ class AgentGroup:
|
|
550
585
|
"lng": lng,
|
551
586
|
"lat": lat,
|
552
587
|
"parent_id": parent_id,
|
588
|
+
"friend_ids": [
|
589
|
+
str(_friend_id) for _friend_id in friend_ids
|
590
|
+
],
|
553
591
|
"action": action,
|
554
592
|
"hungry": hunger_satisfaction,
|
555
593
|
"tired": energy_satisfaction,
|
@@ -612,6 +650,10 @@ class AgentGroup:
|
|
612
650
|
employees = await agent.status.get("employees")
|
613
651
|
except:
|
614
652
|
employees = []
|
653
|
+
try:
|
654
|
+
friend_ids = await agent.status.get("friends")
|
655
|
+
except:
|
656
|
+
friend_ids = []
|
615
657
|
_status_dict = {
|
616
658
|
"id": agent._uuid,
|
617
659
|
"day": _day,
|
@@ -619,6 +661,9 @@ class AgentGroup:
|
|
619
661
|
"lng": lng,
|
620
662
|
"lat": lat,
|
621
663
|
"parent_id": parent_id,
|
664
|
+
"friend_ids": [
|
665
|
+
str(_friend_id) for _friend_id in friend_ids
|
666
|
+
],
|
622
667
|
"action": "",
|
623
668
|
"type": await agent.status.get("type"),
|
624
669
|
"nominal_gdp": nominal_gdp,
|
@@ -644,6 +689,7 @@ class AgentGroup:
|
|
644
689
|
"lng",
|
645
690
|
"lat",
|
646
691
|
"parent_id",
|
692
|
+
"friend_ids",
|
647
693
|
"action",
|
648
694
|
"created_at",
|
649
695
|
]
|
@@ -668,6 +714,7 @@ class AgentGroup:
|
|
668
714
|
await asyncio.gather(*tasks)
|
669
715
|
except Exception as e:
|
670
716
|
import traceback
|
717
|
+
|
671
718
|
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
672
719
|
raise RuntimeError(str(e)) from e
|
673
720
|
|
@@ -676,5 +723,6 @@ class AgentGroup:
|
|
676
723
|
await self.save_status(day, t)
|
677
724
|
except Exception as e:
|
678
725
|
import traceback
|
726
|
+
|
679
727
|
logger.error(f"模拟器运行错误: {str(e)}\n{traceback.format_exc()}")
|
680
728
|
raise RuntimeError(str(e)) from e
|