pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +44 -31
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +133 -49
- pycityagent/simulation/simulation.py +157 -78
- pycityagent/utils/__init__.py +2 -2
- pycityagent/utils/avro_schema.py +26 -1
- pycityagent/workflow/tool.py +0 -3
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a15.dist-info}/METADATA +1 -1
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a15.dist-info}/RECORD +14 -14
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a15.dist-info}/WHEEL +0 -0
pycityagent/__init__.py
CHANGED
@@ -4,5 +4,19 @@ Pycityagent: 城市智能体构建框架
|
|
4
4
|
|
5
5
|
from .agent import Agent, CitizenAgent, InstitutionAgent
|
6
6
|
from .environment import Simulator
|
7
|
+
import logging
|
8
|
+
|
9
|
+
# 创建一个 pycityagent 记录器
|
10
|
+
logger = logging.getLogger("pycityagent")
|
11
|
+
logger.setLevel(logging.WARNING) # 默认级别
|
12
|
+
|
13
|
+
# 如果没有处理器,则添加一个
|
14
|
+
if not logger.hasHandlers():
|
15
|
+
handler = logging.StreamHandler()
|
16
|
+
formatter = logging.Formatter(
|
17
|
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
18
|
+
)
|
19
|
+
handler.setFormatter(formatter)
|
20
|
+
logger.addHandler(handler)
|
7
21
|
|
8
22
|
__all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent"]
|
pycityagent/agent.py
CHANGED
@@ -2,11 +2,9 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import asyncio
|
5
|
-
import json
|
6
5
|
from uuid import UUID
|
7
6
|
from copy import deepcopy
|
8
7
|
from datetime import datetime
|
9
|
-
import time
|
10
8
|
from enum import Enum
|
11
9
|
import logging
|
12
10
|
import random
|
@@ -28,6 +26,8 @@ from .environment import Simulator
|
|
28
26
|
from .llm import LLM
|
29
27
|
from .memory import Memory
|
30
28
|
|
29
|
+
logger = logging.getLogger("pycityagent")
|
30
|
+
|
31
31
|
|
32
32
|
class AgentType(Enum):
|
33
33
|
"""
|
@@ -73,7 +73,7 @@ class Agent(ABC):
|
|
73
73
|
"""
|
74
74
|
self._name = name
|
75
75
|
self._type = type
|
76
|
-
self._uuid = uuid.uuid4()
|
76
|
+
self._uuid = str(uuid.uuid4())
|
77
77
|
self._llm_client = llm_client
|
78
78
|
self._economy_client = economy_client
|
79
79
|
self._messager = messager
|
@@ -123,12 +123,18 @@ class Agent(ABC):
|
|
123
123
|
"""
|
124
124
|
self._memory = memory
|
125
125
|
|
126
|
-
def set_exp_id(self, exp_id: str
|
126
|
+
def set_exp_id(self, exp_id: str):
|
127
127
|
"""
|
128
128
|
Set the exp_id of the agent.
|
129
129
|
"""
|
130
130
|
self._exp_id = exp_id
|
131
131
|
|
132
|
+
def set_avro_file(self, avro_file: Dict[str, str]):
|
133
|
+
"""
|
134
|
+
Set the avro file of the agent.
|
135
|
+
"""
|
136
|
+
self._avro_file = avro_file
|
137
|
+
|
132
138
|
@property
|
133
139
|
def uuid(self):
|
134
140
|
"""The Agent's UUID"""
|
@@ -214,8 +220,10 @@ class Agent(ABC):
|
|
214
220
|
|
215
221
|
async def _process_survey(self, survey: dict):
|
216
222
|
survey_response = await self.generate_user_survey_response(survey)
|
223
|
+
if self._avro_file is None:
|
224
|
+
return
|
217
225
|
response_to_avro = [{
|
218
|
-
"id":
|
226
|
+
"id": self._uuid,
|
219
227
|
"day": await self._simulator.get_simulator_day(),
|
220
228
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
221
229
|
"survey_id": survey["id"],
|
@@ -264,7 +272,7 @@ class Agent(ABC):
|
|
264
272
|
|
265
273
|
async def _process_interview(self, payload: dict):
|
266
274
|
auros = [{
|
267
|
-
"id":
|
275
|
+
"id": self._uuid,
|
268
276
|
"day": await self._simulator.get_simulator_day(),
|
269
277
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
270
278
|
"type": 2,
|
@@ -275,7 +283,7 @@ class Agent(ABC):
|
|
275
283
|
question = payload["content"]
|
276
284
|
response = await self.generate_user_chat_response(question)
|
277
285
|
auros.append({
|
278
|
-
"id":
|
286
|
+
"id": self._uuid,
|
279
287
|
"day": await self._simulator.get_simulator_day(),
|
280
288
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
281
289
|
"type": 2,
|
@@ -283,15 +291,17 @@ class Agent(ABC):
|
|
283
291
|
"content": response,
|
284
292
|
"created_at": int(datetime.now().timestamp() * 1000),
|
285
293
|
})
|
294
|
+
if self._avro_file is None:
|
295
|
+
return
|
286
296
|
with open(self._avro_file["dialog"], "a+b") as f:
|
287
297
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
288
298
|
|
289
299
|
async def process_agent_chat_response(self, payload: dict) -> str:
|
290
|
-
|
300
|
+
logger.info(f"Agent {self._uuid} received agent chat response: {payload}")
|
291
301
|
|
292
302
|
async def _process_agent_chat(self, payload: dict):
|
293
303
|
auros = [{
|
294
|
-
"id":
|
304
|
+
"id": self._uuid,
|
295
305
|
"day": payload["day"],
|
296
306
|
"t": payload["t"],
|
297
307
|
"type": 1,
|
@@ -300,6 +310,8 @@ class Agent(ABC):
|
|
300
310
|
"created_at": int(datetime.now().timestamp() * 1000),
|
301
311
|
}]
|
302
312
|
asyncio.create_task(self.process_agent_chat_response(payload))
|
313
|
+
if self._avro_file is None:
|
314
|
+
return
|
303
315
|
with open(self._avro_file["dialog"], "a+b") as f:
|
304
316
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
305
317
|
|
@@ -307,19 +319,19 @@ class Agent(ABC):
|
|
307
319
|
async def handle_agent_chat_message(self, payload: dict):
|
308
320
|
"""处理收到的消息,识别发送者"""
|
309
321
|
# 从消息中解析发送者 ID 和消息内容
|
310
|
-
|
322
|
+
logger.info(f"Agent {self._uuid} received agent chat message: {payload}")
|
311
323
|
asyncio.create_task(self._process_agent_chat(payload))
|
312
324
|
|
313
325
|
async def handle_user_chat_message(self, payload: dict):
|
314
326
|
"""处理收到的消息,识别发送者"""
|
315
327
|
# 从消息中解析发送者 ID 和消息内容
|
316
|
-
|
328
|
+
logger.info(f"Agent {self._uuid} received user chat message: {payload}")
|
317
329
|
asyncio.create_task(self._process_interview(payload))
|
318
330
|
|
319
331
|
async def handle_user_survey_message(self, payload: dict):
|
320
332
|
"""处理收到的消息,识别发送者"""
|
321
333
|
# 从消息中解析发送者 ID 和消息内容
|
322
|
-
|
334
|
+
logger.info(f"Agent {self._uuid} received user survey message: {payload}")
|
323
335
|
asyncio.create_task(self._process_survey(payload["data"]))
|
324
336
|
|
325
337
|
async def handle_gather_message(self, payload: str):
|
@@ -327,7 +339,7 @@ class Agent(ABC):
|
|
327
339
|
|
328
340
|
# MQTT send message
|
329
341
|
async def _send_message(
|
330
|
-
self, to_agent_uuid:
|
342
|
+
self, to_agent_uuid: str, payload: dict, sub_topic: str
|
331
343
|
):
|
332
344
|
"""通过 Messager 发送消息"""
|
333
345
|
if self._messager is None:
|
@@ -336,7 +348,7 @@ class Agent(ABC):
|
|
336
348
|
await self._messager.send_message(topic, payload)
|
337
349
|
|
338
350
|
async def send_message_to_agent(
|
339
|
-
self, to_agent_uuid:
|
351
|
+
self, to_agent_uuid: str, content: str
|
340
352
|
):
|
341
353
|
"""通过 Messager 发送消息"""
|
342
354
|
if self._messager is None:
|
@@ -350,14 +362,16 @@ class Agent(ABC):
|
|
350
362
|
}
|
351
363
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
352
364
|
auros = [{
|
353
|
-
"id":
|
365
|
+
"id": self._uuid,
|
354
366
|
"day": await self._simulator.get_simulator_day(),
|
355
367
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
356
368
|
"type": 1,
|
357
|
-
"speaker":
|
369
|
+
"speaker": self._uuid,
|
358
370
|
"content": content,
|
359
371
|
"created_at": int(datetime.now().timestamp() * 1000),
|
360
372
|
}]
|
373
|
+
if self._avro_file is None:
|
374
|
+
return
|
361
375
|
with open(self._avro_file["dialog"], "a+b") as f:
|
362
376
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
363
377
|
|
@@ -414,7 +428,7 @@ class CitizenAgent(Agent):
|
|
414
428
|
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
415
429
|
"""
|
416
430
|
if self._simulator is None:
|
417
|
-
|
431
|
+
logger.warning("Simulator is not set")
|
418
432
|
return
|
419
433
|
if not self._has_bound_to_simulator:
|
420
434
|
FROM_MEMORY_KEYS = {
|
@@ -432,7 +446,7 @@ class CitizenAgent(Agent):
|
|
432
446
|
# ATTENTION:模拟器分配的id从0开始
|
433
447
|
if person_id >= 0:
|
434
448
|
await simulator.get_person(person_id)
|
435
|
-
|
449
|
+
logger.debug(f"Binding to Person `{person_id}` already in Simulator")
|
436
450
|
else:
|
437
451
|
dict_person = deepcopy(self._person_template)
|
438
452
|
for _key in FROM_MEMORY_KEYS:
|
@@ -447,7 +461,7 @@ class CitizenAgent(Agent):
|
|
447
461
|
)
|
448
462
|
person_id = resp["person_id"]
|
449
463
|
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
450
|
-
|
464
|
+
logger.debug(
|
451
465
|
f"Binding to Person `{person_id}` just added to Simulator"
|
452
466
|
)
|
453
467
|
# 防止模拟器还没有到prepare阶段导致get_person出错
|
@@ -456,7 +470,7 @@ class CitizenAgent(Agent):
|
|
456
470
|
|
457
471
|
async def _bind_to_economy(self):
|
458
472
|
if self._economy_client is None:
|
459
|
-
|
473
|
+
logger.warning("Economy client is not set")
|
460
474
|
return
|
461
475
|
if not self._has_bound_to_economy:
|
462
476
|
if self._has_bound_to_simulator:
|
@@ -473,7 +487,7 @@ class CitizenAgent(Agent):
|
|
473
487
|
)
|
474
488
|
self._has_bound_to_economy = True
|
475
489
|
else:
|
476
|
-
|
490
|
+
logger.debug(
|
477
491
|
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
478
492
|
)
|
479
493
|
|
@@ -523,7 +537,7 @@ class InstitutionAgent(Agent):
|
|
523
537
|
|
524
538
|
async def _bind_to_economy(self):
|
525
539
|
if self._economy_client is None:
|
526
|
-
|
540
|
+
logger.debug("Economy client is not set")
|
527
541
|
return
|
528
542
|
if not self._has_bound_to_economy:
|
529
543
|
# TODO: More general id generation
|
@@ -599,7 +613,7 @@ class InstitutionAgent(Agent):
|
|
599
613
|
}
|
600
614
|
)
|
601
615
|
except Exception as e:
|
602
|
-
|
616
|
+
logger.error(f"Failed to bind to Economy: {e}")
|
603
617
|
self._has_bound_to_economy = True
|
604
618
|
|
605
619
|
async def handle_gather_message(self, payload: dict):
|
@@ -615,11 +629,11 @@ class InstitutionAgent(Agent):
|
|
615
629
|
"content": content,
|
616
630
|
})
|
617
631
|
|
618
|
-
async def gather_messages(self,
|
632
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
|
619
633
|
"""从多个智能体收集消息
|
620
634
|
|
621
635
|
Args:
|
622
|
-
|
636
|
+
agent_uuids: 目标智能体UUID列表
|
623
637
|
target: 要收集的信息类型
|
624
638
|
|
625
639
|
Returns:
|
@@ -627,18 +641,17 @@ class InstitutionAgent(Agent):
|
|
627
641
|
"""
|
628
642
|
# 为每个agent创建Future
|
629
643
|
futures = {}
|
630
|
-
for
|
631
|
-
|
632
|
-
|
633
|
-
self._gather_responses[response_key] = futures[response_key]
|
644
|
+
for agent_uuid in agent_uuids:
|
645
|
+
futures[agent_uuid] = asyncio.Future()
|
646
|
+
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
634
647
|
|
635
648
|
# 发送gather请求
|
636
649
|
payload = {
|
637
650
|
"from": self._uuid,
|
638
651
|
"target": target,
|
639
652
|
}
|
640
|
-
for
|
641
|
-
await self._send_message(
|
653
|
+
for agent_uuid in agent_uuids:
|
654
|
+
await self._send_message(agent_uuid, payload, "gather")
|
642
655
|
|
643
656
|
try:
|
644
657
|
# 等待所有响应
|
@@ -20,6 +20,7 @@ from shapely.strtree import STRtree
|
|
20
20
|
from .sim import CityClient, ControlSimEnv
|
21
21
|
from .utils.const import *
|
22
22
|
|
23
|
+
logger = logging.getLogger("pycityagent")
|
23
24
|
|
24
25
|
class Simulator:
|
25
26
|
"""
|
@@ -72,7 +73,7 @@ class Simulator:
|
|
72
73
|
else:
|
73
74
|
self._client = CityClient(config["simulator"]["server"], secure=False)
|
74
75
|
else:
|
75
|
-
|
76
|
+
logger.warning(
|
76
77
|
"No simulator config found, no simulator client will be used"
|
77
78
|
)
|
78
79
|
self.map = SimMap(
|
@@ -285,7 +286,7 @@ class Simulator:
|
|
285
286
|
reset_position["aoi_position"] = {"aoi_id": aoi_id}
|
286
287
|
if poi_id is not None:
|
287
288
|
reset_position["aoi_position"]["poi_id"] = poi_id
|
288
|
-
|
289
|
+
logger.debug(
|
289
290
|
f"Setting person {person_id} pos to AoiPosition {reset_position}"
|
290
291
|
)
|
291
292
|
await self._client.person_service.ResetPersonPosition(
|
@@ -298,14 +299,14 @@ class Simulator:
|
|
298
299
|
}
|
299
300
|
if s is not None:
|
300
301
|
reset_position["lane_position"]["s"] = s
|
301
|
-
|
302
|
+
logger.debug(
|
302
303
|
f"Setting person {person_id} pos to LanePosition {reset_position}"
|
303
304
|
)
|
304
305
|
await self._client.person_service.ResetPersonPosition(
|
305
306
|
{"person_id": person_id, "position": reset_position}
|
306
307
|
)
|
307
308
|
else:
|
308
|
-
|
309
|
+
logger.debug(
|
309
310
|
f"Neither aoi or lane pos provided for person {person_id} position reset!!"
|
310
311
|
)
|
311
312
|
|
pycityagent/memory/memory.py
CHANGED
@@ -13,6 +13,7 @@ from .profile import ProfileMemory
|
|
13
13
|
from .self_define import DynamicMemory
|
14
14
|
from .state import StateMemory
|
15
15
|
|
16
|
+
logger = logging.getLogger("pycityagent")
|
16
17
|
|
17
18
|
class Memory:
|
18
19
|
"""
|
@@ -83,7 +84,7 @@ class Memory:
|
|
83
84
|
_type.extend(_value)
|
84
85
|
_value = deepcopy(_type)
|
85
86
|
else:
|
86
|
-
|
87
|
+
logger.warning(f"type `{_type}` is not supported!")
|
87
88
|
pass
|
88
89
|
except TypeError as e:
|
89
90
|
pass
|
@@ -99,7 +100,7 @@ class Memory:
|
|
99
100
|
or k in STATE_ATTRIBUTES
|
100
101
|
or k == TIME_STAMP_KEY
|
101
102
|
):
|
102
|
-
|
103
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
103
104
|
continue
|
104
105
|
|
105
106
|
_dynamic_config[k] = deepcopy(_value)
|
@@ -112,19 +113,19 @@ class Memory:
|
|
112
113
|
if profile is not None:
|
113
114
|
for k, v in profile.items():
|
114
115
|
if k not in PROFILE_ATTRIBUTES:
|
115
|
-
|
116
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
116
117
|
continue
|
117
118
|
_profile_config[k] = v
|
118
119
|
if motion is not None:
|
119
120
|
for k, v in motion.items():
|
120
121
|
if k not in STATE_ATTRIBUTES:
|
121
|
-
|
122
|
+
logger.warning(f"key `{k}` is not a correct `motion` field!")
|
122
123
|
continue
|
123
124
|
_state_config[k] = v
|
124
125
|
if base is not None:
|
125
126
|
for k, v in base.items():
|
126
127
|
if k not in STATE_ATTRIBUTES:
|
127
|
-
|
128
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
128
129
|
continue
|
129
130
|
_state_config[k] = v
|
130
131
|
self._state = StateMemory(
|
@@ -182,7 +183,7 @@ class Memory:
|
|
182
183
|
"""更新记忆值并在必要时更新embedding"""
|
183
184
|
if protect_llm_read_only_fields:
|
184
185
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
185
|
-
|
186
|
+
logger.warning(f"Trying to write protected key `{key}`!")
|
186
187
|
return
|
187
188
|
for _mem in [self._state, self._profile, self._dynamic]:
|
188
189
|
try:
|
@@ -208,7 +209,7 @@ class Memory:
|
|
208
209
|
elif isinstance(original_value, deque):
|
209
210
|
original_value.extend(deque(value))
|
210
211
|
else:
|
211
|
-
|
212
|
+
logger.debug(
|
212
213
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
213
214
|
)
|
214
215
|
await _mem.update(key, value, store_snapshot)
|
@@ -10,6 +10,8 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union
|
|
10
10
|
|
11
11
|
from .const import *
|
12
12
|
|
13
|
+
logger = logging.getLogger("pycityagent")
|
14
|
+
|
13
15
|
|
14
16
|
class MemoryUnit:
|
15
17
|
def __init__(
|
@@ -57,7 +59,7 @@ class MemoryUnit:
|
|
57
59
|
orig_v = self._content[k]
|
58
60
|
orig_type, new_type = type(orig_v), type(v)
|
59
61
|
if not orig_type == new_type:
|
60
|
-
|
62
|
+
logger.debug(
|
61
63
|
f"Type warning: The type of the value for key '{k}' is changing from `{orig_type.__name__}` to `{new_type.__name__}`!"
|
62
64
|
)
|
63
65
|
self._content.update(content)
|
@@ -82,7 +84,7 @@ class MemoryUnit:
|
|
82
84
|
await self._lock.acquire()
|
83
85
|
values = self._content[key]
|
84
86
|
if not isinstance(values, Sequence):
|
85
|
-
|
87
|
+
logger.warning(
|
86
88
|
f"the value stored in key `{key}` is not `sequence`, return value `{values}` instead!"
|
87
89
|
)
|
88
90
|
return values
|
@@ -93,7 +95,7 @@ class MemoryUnit:
|
|
93
95
|
)
|
94
96
|
top_k = len(values) if top_k is None else top_k
|
95
97
|
if len(_sorted_values_with_idx) < top_k:
|
96
|
-
|
98
|
+
logger.debug(
|
97
99
|
f"Length of values {len(_sorted_values_with_idx)} is less than top_k {top_k}, returning all values."
|
98
100
|
)
|
99
101
|
self._lock.release()
|
@@ -149,7 +151,7 @@ class MemoryBase(ABC):
|
|
149
151
|
if recent_n is None:
|
150
152
|
return _list_units
|
151
153
|
if len(_memories) < recent_n:
|
152
|
-
|
154
|
+
logger.debug(
|
153
155
|
f"Length of memory {len(_memories)} is less than recent_n {recent_n}, returning all available memories."
|
154
156
|
)
|
155
157
|
return _list_units[-recent_n:]
|
pycityagent/message/messager.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5
5
|
import math
|
6
6
|
from aiomqtt import Client
|
7
7
|
|
8
|
+
logger = logging.getLogger("pycityagent")
|
8
9
|
|
9
10
|
class Messager:
|
10
11
|
def __init__(
|
@@ -21,15 +22,15 @@ class Messager:
|
|
21
22
|
try:
|
22
23
|
await self.client.__aenter__()
|
23
24
|
self.connected = True
|
24
|
-
|
25
|
+
logger.info("Connected to MQTT Broker")
|
25
26
|
except Exception as e:
|
26
27
|
self.connected = False
|
27
|
-
|
28
|
+
logger.error(f"Failed to connect to MQTT Broker: {e}")
|
28
29
|
|
29
30
|
async def disconnect(self):
|
30
31
|
await self.client.__aexit__(None, None, None)
|
31
32
|
self.connected = False
|
32
|
-
|
33
|
+
logger.info("Disconnected from MQTT Broker")
|
33
34
|
|
34
35
|
def is_connected(self):
|
35
36
|
"""检查是否成功连接到 Broker"""
|
@@ -37,13 +38,13 @@ class Messager:
|
|
37
38
|
|
38
39
|
async def subscribe(self, topic, agent):
|
39
40
|
if not self.is_connected():
|
40
|
-
|
41
|
+
logger.error(
|
41
42
|
f"Cannot subscribe to {topic} because not connected to the Broker."
|
42
43
|
)
|
43
44
|
return
|
44
45
|
await self.client.subscribe(topic)
|
45
46
|
self.subscribers[topic] = agent
|
46
|
-
|
47
|
+
logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
|
47
48
|
|
48
49
|
async def receive_messages(self):
|
49
50
|
"""监听并将消息存入队列"""
|
@@ -61,11 +62,11 @@ class Messager:
|
|
61
62
|
"""通过 Messager 发送消息"""
|
62
63
|
message = json.dumps(payload, default=str)
|
63
64
|
await self.client.publish(topic, message)
|
64
|
-
|
65
|
+
logger.info(f"Message sent to {topic}: {message}")
|
65
66
|
|
66
67
|
async def start_listening(self):
|
67
68
|
"""启动消息监听任务"""
|
68
69
|
if self.is_connected():
|
69
70
|
asyncio.create_task(self.receive_messages())
|
70
71
|
else:
|
71
|
-
|
72
|
+
logger.error("Cannot start listening because not connected to the Broker.")
|
@@ -2,27 +2,41 @@ import asyncio
|
|
2
2
|
from datetime import datetime
|
3
3
|
import json
|
4
4
|
import logging
|
5
|
+
from pathlib import Path
|
5
6
|
import uuid
|
6
7
|
import fastavro
|
7
8
|
import ray
|
8
9
|
from uuid import UUID
|
9
|
-
from pycityagent.agent import Agent, CitizenAgent
|
10
|
+
from pycityagent.agent import Agent, CitizenAgent, InstitutionAgent
|
10
11
|
from pycityagent.economy.econ_client import EconomyClient
|
11
12
|
from pycityagent.environment.simulator import Simulator
|
12
13
|
from pycityagent.llm.llm import LLM
|
13
14
|
from pycityagent.llm.llmconfig import LLMConfig
|
14
15
|
from pycityagent.message import Messager
|
15
|
-
from pycityagent.utils import STATUS_SCHEMA
|
16
|
+
from pycityagent.utils import STATUS_SCHEMA, PROFILE_SCHEMA, DIALOG_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
|
16
17
|
from typing import Any
|
17
18
|
|
19
|
+
logger = logging.getLogger("pycityagent")
|
20
|
+
|
18
21
|
@ray.remote
|
19
22
|
class AgentGroup:
|
20
|
-
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID,
|
23
|
+
def __init__(self, agents: list[Agent], config: dict, exp_id: str|UUID, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
|
24
|
+
logger.setLevel(logging_level)
|
25
|
+
self._uuid = str(uuid.uuid4())
|
21
26
|
self.agents = agents
|
22
27
|
self.config = config
|
23
28
|
self.exp_id = exp_id
|
24
|
-
self.
|
25
|
-
self.
|
29
|
+
self.enable_avro = enable_avro
|
30
|
+
self.avro_path = avro_path / f"{self._uuid}"
|
31
|
+
if enable_avro:
|
32
|
+
self.avro_path.mkdir(parents=True, exist_ok=True)
|
33
|
+
self.avro_file = {
|
34
|
+
"profile": self.avro_path / f"profile.avro",
|
35
|
+
"dialog": self.avro_path / f"dialog.avro",
|
36
|
+
"status": self.avro_path / f"status.avro",
|
37
|
+
"survey": self.avro_path / f"survey.avro",
|
38
|
+
}
|
39
|
+
|
26
40
|
self.messager = Messager(
|
27
41
|
hostname=config["simulator_request"]["mqtt"]["server"],
|
28
42
|
port=config["simulator_request"]["mqtt"]["port"],
|
@@ -33,16 +47,16 @@ class AgentGroup:
|
|
33
47
|
self.id2agent = {}
|
34
48
|
# Step:1 prepare LLM client
|
35
49
|
llmConfig = LLMConfig(config["llm_request"])
|
36
|
-
|
50
|
+
logger.info(f"-----Creating LLM client in AgentGroup {self._uuid} ...")
|
37
51
|
self.llm = LLM(llmConfig)
|
38
52
|
|
39
53
|
# Step:2 prepare Simulator
|
40
|
-
|
54
|
+
logger.info(f"-----Creating Simulator in AgentGroup {self._uuid} ...")
|
41
55
|
self.simulator = Simulator(config["simulator_request"])
|
42
56
|
|
43
57
|
# Step:3 prepare Economy client
|
44
58
|
if "economy" in config["simulator_request"]:
|
45
|
-
|
59
|
+
logger.info(f"-----Creating Economy client in AgentGroup {self._uuid} ...")
|
46
60
|
self.economy_client = EconomyClient(
|
47
61
|
config["simulator_request"]["economy"]["server"]
|
48
62
|
)
|
@@ -56,11 +70,16 @@ class AgentGroup:
|
|
56
70
|
if self.economy_client is not None:
|
57
71
|
agent.set_economy_client(self.economy_client)
|
58
72
|
agent.set_messager(self.messager)
|
73
|
+
if self.enable_avro:
|
74
|
+
agent.set_avro_file(self.avro_file)
|
59
75
|
|
60
76
|
async def init_agents(self):
|
77
|
+
logger.debug(f"-----Initializing Agents in AgentGroup {self._uuid} ...")
|
78
|
+
logger.debug(f"-----Binding Agents to Simulator in AgentGroup {self._uuid} ...")
|
61
79
|
for agent in self.agents:
|
62
80
|
await agent.bind_to_simulator()
|
63
81
|
self.id2agent = {agent._uuid: agent for agent in self.agents}
|
82
|
+
logger.debug(f"-----Binding Agents to Messager in AgentGroup {self._uuid} ...")
|
64
83
|
await self.messager.connect()
|
65
84
|
if self.messager.is_connected():
|
66
85
|
await self.messager.start_listening()
|
@@ -74,27 +93,65 @@ class AgentGroup:
|
|
74
93
|
await self.messager.subscribe(topic, agent)
|
75
94
|
topic = f"exps/{self.exp_id}/agents/{agent._uuid}/gather"
|
76
95
|
await self.messager.subscribe(topic, agent)
|
77
|
-
self.initialized = True
|
78
96
|
self.message_dispatch_task = asyncio.create_task(self.message_dispatch())
|
79
|
-
|
97
|
+
if self.enable_avro:
|
98
|
+
logger.debug(f"-----Creating Avro files in AgentGroup {self._uuid} ...")
|
99
|
+
# profile
|
100
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
101
|
+
filename = self.avro_file["profile"]
|
102
|
+
with open(filename, "wb") as f:
|
103
|
+
profiles = []
|
104
|
+
for agent in self.agents:
|
105
|
+
profile = await agent.memory._profile.export()
|
106
|
+
profile = profile[0]
|
107
|
+
profile['id'] = agent._uuid
|
108
|
+
profiles.append(profile)
|
109
|
+
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
110
|
+
|
111
|
+
# dialog
|
112
|
+
filename = self.avro_file["dialog"]
|
113
|
+
with open(filename, "wb") as f:
|
114
|
+
dialogs = []
|
115
|
+
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
116
|
+
|
117
|
+
# status
|
118
|
+
filename = self.avro_file["status"]
|
119
|
+
with open(filename, "wb") as f:
|
120
|
+
statuses = []
|
121
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
122
|
+
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
123
|
+
else:
|
124
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, statuses)
|
125
|
+
|
126
|
+
# survey
|
127
|
+
filename = self.avro_file["survey"]
|
128
|
+
with open(filename, "wb") as f:
|
129
|
+
surveys = []
|
130
|
+
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
131
|
+
self.initialized = True
|
132
|
+
logger.debug(f"-----AgentGroup {self._uuid} initialized")
|
133
|
+
|
80
134
|
async def gather(self, content: str):
|
135
|
+
logger.debug(f"-----Gathering {content} from all agents in group {self._uuid}")
|
81
136
|
results = {}
|
82
137
|
for agent in self.agents:
|
83
138
|
results[agent._uuid] = await agent.memory.get(content)
|
84
139
|
return results
|
85
140
|
|
86
141
|
async def update(self, target_agent_uuid: str, target_key: str, content: Any):
|
142
|
+
logger.debug(f"-----Updating {target_key} for agent {target_agent_uuid} in group {self._uuid}")
|
87
143
|
agent = self.id2agent[target_agent_uuid]
|
88
144
|
await agent.memory.update(target_key, content)
|
89
145
|
|
90
146
|
async def message_dispatch(self):
|
147
|
+
logger.debug(f"-----Starting message dispatch for group {self._uuid}")
|
91
148
|
while True:
|
92
149
|
if not self.messager.is_connected():
|
93
|
-
|
150
|
+
logger.warning("Messager is not connected. Skipping message processing.")
|
94
151
|
|
95
152
|
# Step 1: 获取消息
|
96
153
|
messages = await self.messager.fetch_messages()
|
97
|
-
|
154
|
+
logger.info(f"Group {self._uuid} received {len(messages)} messages")
|
98
155
|
|
99
156
|
# Step 2: 分发消息到对应的 Agent
|
100
157
|
for message in messages:
|
@@ -109,8 +166,8 @@ class AgentGroup:
|
|
109
166
|
# 提取 agent_id(主题格式为 "exps/{exp_id}/agents/{agent_uuid}/{topic_type}")
|
110
167
|
_, _, _, agent_uuid, topic_type = topic.strip("/").split("/")
|
111
168
|
|
112
|
-
if
|
113
|
-
agent = self.id2agent[
|
169
|
+
if agent_uuid in self.id2agent:
|
170
|
+
agent = self.id2agent[agent_uuid]
|
114
171
|
# topic_type: agent-chat, user-chat, user-survey, gather
|
115
172
|
if topic_type == "agent-chat":
|
116
173
|
await agent.handle_agent_chat_message(payload)
|
@@ -123,46 +180,73 @@ class AgentGroup:
|
|
123
180
|
|
124
181
|
await asyncio.sleep(0.5)
|
125
182
|
|
183
|
+
async def save_status(self):
|
184
|
+
if self.enable_avro:
|
185
|
+
logger.debug(f"-----Saving status for group {self._uuid}")
|
186
|
+
avros = []
|
187
|
+
if not issubclass(type(self.agents[0]), InstitutionAgent):
|
188
|
+
for agent in self.agents:
|
189
|
+
position = await agent.memory.get("position")
|
190
|
+
lng = position["longlat_position"]["longitude"]
|
191
|
+
lat = position["longlat_position"]["latitude"]
|
192
|
+
if "aoi_position" in position:
|
193
|
+
parent_id = position["aoi_position"]["aoi_id"]
|
194
|
+
elif "lane_position" in position:
|
195
|
+
parent_id = position["lane_position"]["lane_id"]
|
196
|
+
else:
|
197
|
+
# BUG: 需要处理
|
198
|
+
parent_id = -1
|
199
|
+
needs = await agent.memory.get("needs")
|
200
|
+
action = await agent.memory.get("current_step")
|
201
|
+
action = action["intention"]
|
202
|
+
avro = {
|
203
|
+
"id": agent._uuid,
|
204
|
+
"day": await self.simulator.get_simulator_day(),
|
205
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
206
|
+
"lng": lng,
|
207
|
+
"lat": lat,
|
208
|
+
"parent_id": parent_id,
|
209
|
+
"action": action,
|
210
|
+
"hungry": needs["hungry"],
|
211
|
+
"tired": needs["tired"],
|
212
|
+
"safe": needs["safe"],
|
213
|
+
"social": needs["social"],
|
214
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
215
|
+
}
|
216
|
+
avros.append(avro)
|
217
|
+
with open(self.avro_file["status"], "a+b") as f:
|
218
|
+
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
219
|
+
else:
|
220
|
+
for agent in self.agents:
|
221
|
+
avro = {
|
222
|
+
"id": agent._uuid,
|
223
|
+
"day": await self.simulator.get_simulator_day(),
|
224
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
225
|
+
"type": await agent.memory.get("type"),
|
226
|
+
"nominal_gdp": await agent.memory.get("nominal_gdp"),
|
227
|
+
"real_gdp": await agent.memory.get("real_gdp"),
|
228
|
+
"unemployment": await agent.memory.get("unemployment"),
|
229
|
+
"wages": await agent.memory.get("wages"),
|
230
|
+
"prices": await agent.memory.get("prices"),
|
231
|
+
"inventory": await agent.memory.get("inventory"),
|
232
|
+
"price": await agent.memory.get("price"),
|
233
|
+
"interest_rate": await agent.memory.get("interest_rate"),
|
234
|
+
"bracket_cutoffs": await agent.memory.get("bracket_cutoffs"),
|
235
|
+
"bracket_rates": await agent.memory.get("bracket_rates"),
|
236
|
+
"employees": await agent.memory.get("employees"),
|
237
|
+
"customers": await agent.memory.get("customers"),
|
238
|
+
}
|
239
|
+
avros.append(avro)
|
240
|
+
with open(self.avro_file["status"], "a+b") as f:
|
241
|
+
fastavro.writer(f, INSTITUTION_STATUS_SCHEMA, avros, codec="snappy")
|
242
|
+
|
126
243
|
async def step(self):
|
127
244
|
if not self.initialized:
|
128
245
|
await self.init_agents()
|
129
246
|
|
130
247
|
tasks = [agent.run() for agent in self.agents]
|
131
248
|
await asyncio.gather(*tasks)
|
132
|
-
|
133
|
-
for agent in self.agents:
|
134
|
-
if not issubclass(type(agent), CitizenAgent):
|
135
|
-
continue
|
136
|
-
position = await agent.memory.get("position")
|
137
|
-
lng = position["longlat_position"]["longitude"]
|
138
|
-
lat = position["longlat_position"]["latitude"]
|
139
|
-
if "aoi_position" in position:
|
140
|
-
parent_id = position["aoi_position"]["aoi_id"]
|
141
|
-
elif "lane_position" in position:
|
142
|
-
parent_id = position["lane_position"]["lane_id"]
|
143
|
-
else:
|
144
|
-
# BUG: 需要处理
|
145
|
-
parent_id = -1
|
146
|
-
needs = await agent.memory.get("needs")
|
147
|
-
action = await agent.memory.get("current_step")
|
148
|
-
action = action["intention"]
|
149
|
-
avro = {
|
150
|
-
"id": str(agent._uuid), # uuid as string
|
151
|
-
"day": await self.simulator.get_simulator_day(),
|
152
|
-
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
153
|
-
"lng": lng,
|
154
|
-
"lat": lat,
|
155
|
-
"parent_id": parent_id,
|
156
|
-
"action": action,
|
157
|
-
"hungry": needs["hungry"],
|
158
|
-
"tired": needs["tired"],
|
159
|
-
"safe": needs["safe"],
|
160
|
-
"social": needs["social"],
|
161
|
-
"created_at": int(datetime.now().timestamp() * 1000),
|
162
|
-
}
|
163
|
-
avros.append(avro)
|
164
|
-
with open(self.avro_file["status"], "a+b") as f:
|
165
|
-
fastavro.writer(f, STATUS_SCHEMA, avros, codec="snappy")
|
249
|
+
await self.save_status()
|
166
250
|
|
167
251
|
async def run(self, day: int = 1):
|
168
252
|
"""运行模拟器
|
@@ -185,5 +269,5 @@ class AgentGroup:
|
|
185
269
|
await self.step()
|
186
270
|
|
187
271
|
except Exception as e:
|
188
|
-
|
272
|
+
logger.error(f"模拟器运行错误: {str(e)}")
|
189
273
|
raise
|
@@ -4,21 +4,22 @@ import logging
|
|
4
4
|
import os
|
5
5
|
from pathlib import Path
|
6
6
|
import uuid
|
7
|
-
from datetime import datetime
|
7
|
+
from datetime import datetime, timezone
|
8
8
|
import random
|
9
9
|
from typing import Dict, List, Optional, Callable, Union,Any
|
10
|
-
import fastavro
|
11
10
|
from mosstool.map._map_util.const import AOI_START_ID
|
12
11
|
import pycityproto.city.economy.v2.economy_pb2 as economyv2
|
12
|
+
from pycityagent.environment.simulator import Simulator
|
13
13
|
from pycityagent.memory.memory import Memory
|
14
14
|
from pycityagent.message.messager import Messager
|
15
15
|
from pycityagent.survey import Survey
|
16
|
-
|
16
|
+
import yaml
|
17
|
+
from concurrent.futures import ThreadPoolExecutor
|
17
18
|
|
18
19
|
from ..agent import Agent, InstitutionAgent
|
19
20
|
from .agentgroup import AgentGroup
|
20
21
|
|
21
|
-
logger = logging.getLogger(
|
22
|
+
logger = logging.getLogger("pycityagent")
|
22
23
|
|
23
24
|
|
24
25
|
class AgentSimulation:
|
@@ -29,19 +30,24 @@ class AgentSimulation:
|
|
29
30
|
agent_class: Union[type[Agent], list[type[Agent]]],
|
30
31
|
config: dict,
|
31
32
|
agent_prefix: str = "agent_",
|
33
|
+
exp_name: str = "default_experiment",
|
34
|
+
logging_level: int = logging.WARNING
|
32
35
|
):
|
33
36
|
"""
|
34
37
|
Args:
|
35
38
|
agent_class: 智能体类
|
36
39
|
config: 配置
|
37
40
|
agent_prefix: 智能体名称前缀
|
41
|
+
exp_name: 实验名称
|
38
42
|
"""
|
39
|
-
self.exp_id = uuid.uuid4()
|
43
|
+
self.exp_id = str(uuid.uuid4())
|
40
44
|
if isinstance(agent_class, list):
|
41
45
|
self.agent_class = agent_class
|
42
46
|
else:
|
43
47
|
self.agent_class = [agent_class]
|
48
|
+
self.logging_level = logging_level
|
44
49
|
self.config = config
|
50
|
+
self._simulator = Simulator(config["simulator_request"])
|
45
51
|
self.agent_prefix = agent_prefix
|
46
52
|
self._agents: Dict[uuid.UUID, Agent] = {}
|
47
53
|
self._groups: Dict[str, AgentGroup] = {}
|
@@ -61,13 +67,10 @@ class AgentSimulation:
|
|
61
67
|
asyncio.create_task(self._messager.connect())
|
62
68
|
|
63
69
|
self._enable_avro = config["storage"]["avro"]["enabled"]
|
64
|
-
self.
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
"status": self._avro_path / f"{self.exp_id}_status.avro",
|
69
|
-
"survey": self._avro_path / f"{self.exp_id}_survey.avro",
|
70
|
-
}
|
70
|
+
if not self._enable_avro:
|
71
|
+
logger.warning("AVRO is not enabled, NO AVRO LOCAL STORAGE")
|
72
|
+
self._avro_path = Path(config["storage"]["avro"]["path"]) / f"{self.exp_id}"
|
73
|
+
self._avro_path.mkdir(parents=True, exist_ok=True)
|
71
74
|
|
72
75
|
self._enable_pgsql = config["storage"]["pgsql"]["enabled"]
|
73
76
|
self._pgsql_host = config["storage"]["pgsql"]["host"]
|
@@ -76,6 +79,24 @@ class AgentSimulation:
|
|
76
79
|
self._pgsql_user = config["storage"]["pgsql"]["user"]
|
77
80
|
self._pgsql_password = config["storage"]["pgsql"]["password"]
|
78
81
|
|
82
|
+
# 添加实验信息相关的属性
|
83
|
+
self._exp_info = {
|
84
|
+
"id": self.exp_id,
|
85
|
+
"name": exp_name,
|
86
|
+
"num_day": 0, # 将在 run 方法中更新
|
87
|
+
"status": 0,
|
88
|
+
"cur_day": 0,
|
89
|
+
"cur_t": 0.0,
|
90
|
+
"config": json.dumps(config),
|
91
|
+
"error": "",
|
92
|
+
"created_at": datetime.now(timezone.utc).isoformat()
|
93
|
+
}
|
94
|
+
|
95
|
+
# 创建异步任务保存实验信息
|
96
|
+
self._exp_info_file = self._avro_path / "experiment_info.yaml"
|
97
|
+
with open(self._exp_info_file, 'w') as f:
|
98
|
+
yaml.dump(self._exp_info, f)
|
99
|
+
|
79
100
|
@property
|
80
101
|
def agents(self):
|
81
102
|
return self._agents
|
@@ -91,6 +112,11 @@ class AgentSimulation:
|
|
91
112
|
@property
|
92
113
|
def agent_uuid2group(self):
|
93
114
|
return self._agent_uuid2group
|
115
|
+
|
116
|
+
def create_remote_group(self, group_name: str, agents: list[Agent], config: dict, exp_id: str, enable_avro: bool, avro_path: Path, logging_level: int = logging.WARNING):
|
117
|
+
"""创建远程组"""
|
118
|
+
group = AgentGroup.remote(agents, config, exp_id, enable_avro, avro_path, logging_level)
|
119
|
+
return group_name, group, agents
|
94
120
|
|
95
121
|
async def init_agents(
|
96
122
|
self,
|
@@ -112,7 +138,7 @@ class AgentSimulation:
|
|
112
138
|
raise ValueError("agent_class和agent_count的长度不一致")
|
113
139
|
|
114
140
|
if memory_config_func is None:
|
115
|
-
|
141
|
+
logger.warning(
|
116
142
|
"memory_config_func is None, using default memory config function"
|
117
143
|
)
|
118
144
|
memory_config_func = []
|
@@ -125,17 +151,21 @@ class AgentSimulation:
|
|
125
151
|
memory_config_func = [memory_config_func]
|
126
152
|
|
127
153
|
if len(memory_config_func) != len(agent_count):
|
128
|
-
|
154
|
+
logger.warning(
|
129
155
|
"memory_config_func和agent_count的长度不一致,使用默认的memory_config"
|
130
156
|
)
|
131
157
|
memory_config_func = []
|
132
158
|
for agent_class in self.agent_class:
|
133
|
-
if agent_class
|
159
|
+
if issubclass(agent_class, InstitutionAgent):
|
134
160
|
memory_config_func.append(self.default_memory_config_institution)
|
135
161
|
else:
|
136
162
|
memory_config_func.append(self.default_memory_config_citizen)
|
137
163
|
|
164
|
+
# 使用线程池并行创建 AgentGroup
|
165
|
+
group_creation_params = []
|
138
166
|
class_init_index = 0
|
167
|
+
|
168
|
+
# 首先收集所有需要创建的组的参数
|
139
169
|
for i in range(len(self.agent_class)):
|
140
170
|
agent_class = self.agent_class[i]
|
141
171
|
agent_count_i = agent_count[i]
|
@@ -153,7 +183,6 @@ class AgentSimulation:
|
|
153
183
|
agent = agent_class(
|
154
184
|
name=agent_name,
|
155
185
|
memory=memory,
|
156
|
-
avro_file=self._avro_file,
|
157
186
|
)
|
158
187
|
|
159
188
|
self._agents[agent._uuid] = agent
|
@@ -161,66 +190,51 @@ class AgentSimulation:
|
|
161
190
|
|
162
191
|
# 计算需要的组数,向上取整以处理不足一组的情况
|
163
192
|
num_group = (agent_count_i + group_size - 1) // group_size
|
164
|
-
|
193
|
+
|
165
194
|
for k in range(num_group):
|
166
|
-
# 计算当前组的起始和结束索引
|
167
195
|
start_idx = class_init_index + k * group_size
|
168
196
|
end_idx = min(
|
169
|
-
class_init_index +
|
170
|
-
class_init_index + agent_count_i
|
197
|
+
class_init_index + (k + 1) * group_size, # 修正了索引计算
|
198
|
+
class_init_index + agent_count_i
|
171
199
|
)
|
172
|
-
|
173
|
-
# 获取当前组的agents
|
200
|
+
|
174
201
|
agents = list(self._agents.values())[start_idx:end_idx]
|
175
202
|
group_name = f"AgentType_{i}_Group_{k}"
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
203
|
+
|
204
|
+
# 收集创建参数
|
205
|
+
group_creation_params.append((
|
206
|
+
group_name,
|
207
|
+
agents
|
208
|
+
))
|
209
|
+
|
210
|
+
class_init_index += agent_count_i
|
211
|
+
|
212
|
+
# 收集所有创建组的参数
|
213
|
+
creation_tasks = []
|
214
|
+
for group_name, agents in group_creation_params:
|
215
|
+
# 直接创建异步任务
|
216
|
+
group = AgentGroup.remote(agents, self.config, self.exp_id,
|
217
|
+
self._enable_avro, self._avro_path,
|
218
|
+
self.logging_level)
|
219
|
+
creation_tasks.append((group_name, group, agents))
|
220
|
+
|
221
|
+
# 更新数据结构
|
222
|
+
for group_name, group, agents in creation_tasks:
|
223
|
+
self._groups[group_name] = group
|
224
|
+
for agent in agents:
|
225
|
+
self._agent_uuid2group[agent._uuid] = group
|
226
|
+
|
227
|
+
# 并行初始化所有组的agents
|
183
228
|
init_tasks = []
|
184
229
|
for group in self._groups.values():
|
185
230
|
init_tasks.append(group.init_agents.remote())
|
186
231
|
await asyncio.gather(*init_tasks)
|
232
|
+
|
233
|
+
# 设置用户主题
|
187
234
|
for uuid, agent in self._agents.items():
|
188
235
|
self._user_chat_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-chat"
|
189
236
|
self._user_survey_topics[uuid] = f"exps/{self.exp_id}/agents/{uuid}/user-survey"
|
190
237
|
|
191
|
-
# save profile
|
192
|
-
if self._enable_avro:
|
193
|
-
self._avro_path.mkdir(parents=True, exist_ok=True)
|
194
|
-
# profile
|
195
|
-
filename = self._avro_file["profile"]
|
196
|
-
with open(filename, "wb") as f:
|
197
|
-
profiles = []
|
198
|
-
for agent in self._agents.values():
|
199
|
-
profile = await agent.memory._profile.export()
|
200
|
-
profile = profile[0]
|
201
|
-
profile['id'] = str(agent._uuid)
|
202
|
-
profiles.append(profile)
|
203
|
-
fastavro.writer(f, PROFILE_SCHEMA, profiles)
|
204
|
-
|
205
|
-
# dialog
|
206
|
-
filename = self._avro_file["dialog"]
|
207
|
-
with open(filename, "wb") as f:
|
208
|
-
dialogs = []
|
209
|
-
fastavro.writer(f, DIALOG_SCHEMA, dialogs)
|
210
|
-
|
211
|
-
# status
|
212
|
-
filename = self._avro_file["status"]
|
213
|
-
with open(filename, "wb") as f:
|
214
|
-
statuses = []
|
215
|
-
fastavro.writer(f, STATUS_SCHEMA, statuses)
|
216
|
-
|
217
|
-
# survey
|
218
|
-
filename = self._avro_file["survey"]
|
219
|
-
with open(filename, "wb") as f:
|
220
|
-
surveys = []
|
221
|
-
fastavro.writer(f, SURVEY_SCHEMA, surveys)
|
222
|
-
|
223
|
-
|
224
238
|
async def gather(self, content: str):
|
225
239
|
"""收集智能体的特定信息"""
|
226
240
|
gather_tasks = []
|
@@ -228,10 +242,10 @@ class AgentSimulation:
|
|
228
242
|
gather_tasks.append(group.gather.remote(content))
|
229
243
|
return await asyncio.gather(*gather_tasks)
|
230
244
|
|
231
|
-
async def update(self,
|
245
|
+
async def update(self, target_agent_uuid: uuid.UUID, target_key: str, content: Any):
|
232
246
|
"""更新指定智能体的记忆"""
|
233
|
-
group = self._agent_uuid2group[
|
234
|
-
await group.update.remote(
|
247
|
+
group = self._agent_uuid2group[target_agent_uuid]
|
248
|
+
await group.update.remote(target_agent_uuid, target_key, content)
|
235
249
|
|
236
250
|
def default_memory_config_institution(self):
|
237
251
|
"""默认的Memory配置函数"""
|
@@ -388,23 +402,88 @@ class AgentSimulation:
|
|
388
402
|
logger.error(f"运行错误: {str(e)}")
|
389
403
|
raise
|
390
404
|
|
391
|
-
async def
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
405
|
+
async def _save_exp_info(self) -> None:
|
406
|
+
"""异步保存实验信息到YAML文件"""
|
407
|
+
try:
|
408
|
+
with open(self._exp_info_file, 'w') as f:
|
409
|
+
yaml.dump(self._exp_info, f)
|
410
|
+
except Exception as e:
|
411
|
+
logger.error(f"保存实验信息失败: {str(e)}")
|
412
|
+
|
413
|
+
async def _update_exp_status(self, status: int, error: str = "") -> None:
|
414
|
+
"""更新实验状态并保存"""
|
415
|
+
self._exp_info["status"] = status
|
416
|
+
self._exp_info["error"] = error
|
417
|
+
await self._save_exp_info()
|
396
418
|
|
419
|
+
async def _monitor_exp_status(self, stop_event: asyncio.Event):
|
420
|
+
"""监控实验状态并更新
|
421
|
+
|
397
422
|
Args:
|
398
|
-
|
423
|
+
stop_event: 用于通知监控任务停止的事件
|
399
424
|
"""
|
400
425
|
try:
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
426
|
+
while not stop_event.is_set():
|
427
|
+
# 更新实验状态
|
428
|
+
# 假设所有group的cur_day和cur_t是同步的,取第一个即可
|
429
|
+
self._exp_info["cur_day"] = await self._simulator.get_simulator_day()
|
430
|
+
self._exp_info["cur_t"] = await self._simulator.get_simulator_second_from_start_of_day()
|
431
|
+
await self._save_exp_info()
|
432
|
+
|
433
|
+
await asyncio.sleep(1) # 避免过于频繁的更新
|
434
|
+
except asyncio.CancelError:
|
435
|
+
# 正常取消,不需要特殊处理
|
436
|
+
pass
|
437
|
+
except Exception as e:
|
438
|
+
logger.error(f"监控实验状态时发生错误: {str(e)}")
|
439
|
+
raise
|
405
440
|
|
406
|
-
|
441
|
+
async def run(
|
442
|
+
self,
|
443
|
+
day: int = 1,
|
444
|
+
):
|
445
|
+
"""运行模拟器"""
|
446
|
+
try:
|
447
|
+
self._exp_info["num_day"] += day
|
448
|
+
await self._update_exp_status(1) # 更新状态为运行中
|
449
|
+
|
450
|
+
# 创建停止事件
|
451
|
+
stop_event = asyncio.Event()
|
452
|
+
# 创建监控任务
|
453
|
+
monitor_task = asyncio.create_task(self._monitor_exp_status(stop_event))
|
454
|
+
|
455
|
+
try:
|
456
|
+
tasks = []
|
457
|
+
for group in self._groups.values():
|
458
|
+
tasks.append(group.run.remote())
|
459
|
+
|
460
|
+
# 等待所有group运行完成
|
461
|
+
await asyncio.gather(*tasks)
|
462
|
+
|
463
|
+
finally:
|
464
|
+
# 设置停止事件
|
465
|
+
stop_event.set()
|
466
|
+
# 等待监控任务结束
|
467
|
+
await monitor_task
|
468
|
+
|
469
|
+
# 运行成功后更新状态
|
470
|
+
await self._update_exp_status(2)
|
407
471
|
|
408
472
|
except Exception as e:
|
409
|
-
|
410
|
-
|
473
|
+
error_msg = f"模拟器运行错误: {str(e)}"
|
474
|
+
logger.error(error_msg)
|
475
|
+
await self._update_exp_status(3, error_msg)
|
476
|
+
raise e
|
477
|
+
|
478
|
+
async def __aenter__(self):
|
479
|
+
"""异步上下文管理器入口"""
|
480
|
+
return self
|
481
|
+
|
482
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
483
|
+
"""异步上下文管理器出口"""
|
484
|
+
if exc_type is not None:
|
485
|
+
# 如果发生异常,更新状态为错误
|
486
|
+
await self._update_exp_status(3, str(exc_val))
|
487
|
+
elif self._exp_info["status"] != 3:
|
488
|
+
# 如果没有发生异常且状态不是错误,则更新为完成
|
489
|
+
await self._update_exp_status(2)
|
pycityagent/utils/__init__.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
|
-
from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA
|
1
|
+
from .avro_schema import PROFILE_SCHEMA, DIALOG_SCHEMA, STATUS_SCHEMA, SURVEY_SCHEMA, INSTITUTION_STATUS_SCHEMA
|
2
2
|
from .survey_util import process_survey_for_llm
|
3
3
|
|
4
4
|
__all__ = [
|
5
|
-
"PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA",
|
5
|
+
"PROFILE_SCHEMA", "DIALOG_SCHEMA", "STATUS_SCHEMA", "SURVEY_SCHEMA", "INSTITUTION_STATUS_SCHEMA",
|
6
6
|
"process_survey_for_llm"
|
7
7
|
]
|
pycityagent/utils/avro_schema.py
CHANGED
@@ -66,6 +66,31 @@ STATUS_SCHEMA = {
|
|
66
66
|
],
|
67
67
|
}
|
68
68
|
|
69
|
+
INSTITUTION_STATUS_SCHEMA = {
|
70
|
+
"doc": "Institution状态",
|
71
|
+
"name": "InstitutionStatus",
|
72
|
+
"namespace": "com.socialcity",
|
73
|
+
"type": "record",
|
74
|
+
"fields": [
|
75
|
+
{"name": "id", "type": "string"}, # uuid as string
|
76
|
+
{"name": "day", "type": "int"},
|
77
|
+
{"name": "t", "type": "float"},
|
78
|
+
{"name": "type", "type": "int"},
|
79
|
+
{"name": "nominal_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
80
|
+
{"name": "real_gdp", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
81
|
+
{"name": "unemployment", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
82
|
+
{"name": "wages", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
83
|
+
{"name": "prices", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
84
|
+
{"name": "inventory", "type": ["int", "null"]},
|
85
|
+
{"name": "price", "type": ["float", "null"]},
|
86
|
+
{"name": "interest_rate", "type": ["float", "null"]},
|
87
|
+
{"name": "bracket_cutoffs", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
88
|
+
{"name": "bracket_rates", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
89
|
+
{"name": "employees", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
90
|
+
{"name": "customers", "type": {"type": "array", "items": ["float", "int", "string", "null"]}},
|
91
|
+
],
|
92
|
+
}
|
93
|
+
|
69
94
|
SURVEY_SCHEMA = {
|
70
95
|
"doc": "Agent问卷",
|
71
96
|
"name": "AgentSurvey",
|
@@ -82,4 +107,4 @@ SURVEY_SCHEMA = {
|
|
82
107
|
"type": {"type": "long", "logicalType": "timestamp-millis"},
|
83
108
|
},
|
84
109
|
],
|
85
|
-
}
|
110
|
+
}
|
pycityagent/workflow/tool.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
|
-
pycityagent/__init__.py,sha256=
|
2
|
-
pycityagent/agent.py,sha256=
|
1
|
+
pycityagent/__init__.py,sha256=EDxt3Su3lH1IMh9suNw7GeGL7UrXeWiZTw5KWNznDzc,637
|
2
|
+
pycityagent/agent.py,sha256=fcuKX6FtMzjNP8lVep9pG-9KHzHQwJ8IymJbmLKMfu0,23109
|
3
3
|
pycityagent/economy/__init__.py,sha256=aonY4WHnx-6EGJ4WKrx4S-2jAkYNLtqUA04jp6q8B7w,75
|
4
4
|
pycityagent/economy/econ_client.py,sha256=wcuNtcpkSijJwNkt2mXw3SshYy4SBy6qbvJ0VQ7Aovo,10854
|
5
5
|
pycityagent/environment/__init__.py,sha256=awHxlOud-btWbk0FCS4RmGJ13W84oVCkbGfcrhKqihA,240
|
@@ -21,7 +21,7 @@ pycityagent/environment/sim/person_service.py,sha256=nIvOsoBoqOTDYtsiThg07-4ZBgk
|
|
21
21
|
pycityagent/environment/sim/road_service.py,sha256=phKTwTyhc_6Ht2mddEXpdENfl-lRXIVY0CHAlw1yHjI,1264
|
22
22
|
pycityagent/environment/sim/sim_env.py,sha256=HI1LcS_FotDKQ6vBnx0e49prXSABOfA20aU9KM-ZkCY,4625
|
23
23
|
pycityagent/environment/sim/social_service.py,sha256=6Iqvq6dz8H2jhLLdtaITc6Js9QnQw-Ylsd5AZgUj3-E,1993
|
24
|
-
pycityagent/environment/simulator.py,sha256=
|
24
|
+
pycityagent/environment/simulator.py,sha256=K7IyhiGC9BxanW28bpML4M0YREdMp1h7yMoWBlbf3RY,12504
|
25
25
|
pycityagent/environment/utils/__init__.py,sha256=1m4Q1EfGvNpUsa1bgQzzCyWhfkpElnskNImjjFD3Znc,237
|
26
26
|
pycityagent/environment/utils/base64.py,sha256=hoREzQo3FXMN79pqQLO2jgsDEvudciomyKii7MWljAM,374
|
27
27
|
pycityagent/environment/utils/const.py,sha256=3RMNy7_bE7-23K90j9DFW_tWEzu8s7hSTgKbV-3BFl4,5327
|
@@ -37,22 +37,22 @@ pycityagent/llm/llmconfig.py,sha256=4Ylf4OFSBEFy8jrOneeX0HvPhWEaF5jGvy1HkXK08Ro,
|
|
37
37
|
pycityagent/llm/utils.py,sha256=hoNPhvomb1u6lhFX0GctFipw74hVKb7bvUBDqwBzBYw,160
|
38
38
|
pycityagent/memory/__init__.py,sha256=Hs2NhYpIG-lvpwPWwj4DydB1sxtjz7cuA4iDAzCXnjI,243
|
39
39
|
pycityagent/memory/const.py,sha256=6zpJPJXWoH9-yf4RARYYff586agCoud9BRn7sPERB1g,932
|
40
|
-
pycityagent/memory/memory.py,sha256=
|
41
|
-
pycityagent/memory/memory_base.py,sha256=
|
40
|
+
pycityagent/memory/memory.py,sha256=FjKVL_MgNBnSc0sox2tuxLqXg9_MQQr9vYdRDHMdDL4,18183
|
41
|
+
pycityagent/memory/memory_base.py,sha256=euKZRCs4dbcKxjlZzpLCTnH066DAtRjj5g1JFKD40qQ,5633
|
42
42
|
pycityagent/memory/profile.py,sha256=s4LnxSPGSjIGZXHXkkd8mMa6uYYZrytgyQdWjcaqGf4,5182
|
43
43
|
pycityagent/memory/self_define.py,sha256=poPiexNhOLq_iTgK8s4mK_xoL_DAAcB8kMvInj7iE5E,5179
|
44
44
|
pycityagent/memory/state.py,sha256=5W0c1yJ-aaPpE74B2LEcw3Ygpm77tyooHv8NylyrozE,5113
|
45
45
|
pycityagent/memory/utils.py,sha256=wLNlNlZ-AY9VB8kbUIy0UQSYh26FOQABbhmKQkit5o8,850
|
46
46
|
pycityagent/message/__init__.py,sha256=TCjazxqb5DVwbTu1fF0sNvaH_EPXVuj2XQ0p6W-QCLU,55
|
47
|
-
pycityagent/message/messager.py,sha256=
|
47
|
+
pycityagent/message/messager.py,sha256=W_OVlNGcreHSBf6v-DrEnfNCXExB78ySr0w26MSncfU,2541
|
48
48
|
pycityagent/simulation/__init__.py,sha256=jYaqaNpzM5M_e_ykISS_M-mIyYdzJXJWhgpfBpA6l5k,111
|
49
|
-
pycityagent/simulation/agentgroup.py,sha256=
|
50
|
-
pycityagent/simulation/simulation.py,sha256=
|
49
|
+
pycityagent/simulation/agentgroup.py,sha256=JwfssUtVrOgSnJCan4jcIcSHLjWBCwYxqOPT-AXA2sE,12514
|
50
|
+
pycityagent/simulation/simulation.py,sha256=G68P1EJ3JceA3zID2O6AGd_KdhhYy5XVZVUgkfJHypc,18897
|
51
51
|
pycityagent/survey/__init__.py,sha256=rxwou8U9KeFSP7rMzXtmtp2fVFZxK4Trzi-psx9LPIs,153
|
52
52
|
pycityagent/survey/manager.py,sha256=N4-Q8vve4L-PLaFikAgVB4Z8BNyFDEd6WjEk3AyMTgs,1764
|
53
53
|
pycityagent/survey/models.py,sha256=Z2gHEazQRj0TkTz5qbh4Uy_JrU_FZGWpOLwjN0RoUrY,3547
|
54
|
-
pycityagent/utils/__init__.py,sha256=
|
55
|
-
pycityagent/utils/avro_schema.py,sha256=
|
54
|
+
pycityagent/utils/__init__.py,sha256=xXEMhVfFeOJUXjczaHv9DJqYNp57rc6FibtS7CfrVbA,305
|
55
|
+
pycityagent/utils/avro_schema.py,sha256=DHM3bOo8m0dJf8oSwyOWzVeXrH6OERmzA_a5vS4So4M,4255
|
56
56
|
pycityagent/utils/decorators.py,sha256=Gk3r41hfk6awui40tbwpq3C7wC7jHaRmLRlcJFlLQCE,3160
|
57
57
|
pycityagent/utils/parsers/__init__.py,sha256=AN2xgiPxszWK4rpX7zrqRsqNwfGF3WnCA5-PFTvbaKk,281
|
58
58
|
pycityagent/utils/parsers/code_block_parser.py,sha256=Cs2Z_hm9VfNCpPPll1TwteaJF-HAQPs-3RApsOekFm4,1173
|
@@ -62,8 +62,8 @@ pycityagent/utils/survey_util.py,sha256=Be9nptmu2JtesFNemPgORh_2GsN7rcDYGQS9Zfvc
|
|
62
62
|
pycityagent/workflow/__init__.py,sha256=EyCcjB6LyBim-5iAOPe4m2qfvghEPqu1ZdGfy4KPeZ8,551
|
63
63
|
pycityagent/workflow/block.py,sha256=6EmiRMLdOZC1wMlmLMIjfrp9TuiI7Gw4s3nnXVMbrnw,6031
|
64
64
|
pycityagent/workflow/prompt.py,sha256=tY69nDO8fgYfF_dOA-iceR8pAhkYmCqoox8uRPqEuGY,2956
|
65
|
-
pycityagent/workflow/tool.py,sha256=
|
65
|
+
pycityagent/workflow/tool.py,sha256=zMvz3BV4QBs5TqyQ3ziJxj4pCfL2uqUI3A1FbT1gd3Q,6626
|
66
66
|
pycityagent/workflow/trigger.py,sha256=t5X_i0WtL32bipZSsq_E3UUyYYudYLxQUpvxbgClp2s,5683
|
67
|
-
pycityagent-2.0.
|
68
|
-
pycityagent-2.0.
|
69
|
-
pycityagent-2.0.
|
67
|
+
pycityagent-2.0.0a15.dist-info/METADATA,sha256=8ONKHaTIPOGja6FKJDE9tZ-pLZL2aeXQufB-LWkElR8,7705
|
68
|
+
pycityagent-2.0.0a15.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
69
|
+
pycityagent-2.0.0a15.dist-info/RECORD,,
|
File without changes
|