pycityagent 2.0.0a14__py3-none-any.whl → 2.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +80 -63
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +135 -51
- pycityagent/simulation/simulation.py +206 -92
- pycityagent/survey/manager.py +10 -14
- pycityagent/survey/models.py +24 -24
- pycityagent/utils/__init__.py +2 -2
- pycityagent/utils/avro_schema.py +26 -1
- pycityagent/workflow/tool.py +1 -4
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/METADATA +3 -2
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/RECORD +16 -16
- {pycityagent-2.0.0a14.dist-info → pycityagent-2.0.0a16.dist-info}/WHEEL +0 -0
pycityagent/__init__.py
CHANGED
@@ -4,5 +4,19 @@ Pycityagent: 城市智能体构建框架
|
|
4
4
|
|
5
5
|
from .agent import Agent, CitizenAgent, InstitutionAgent
|
6
6
|
from .environment import Simulator
|
7
|
+
import logging
|
8
|
+
|
9
|
+
# 创建一个 pycityagent 记录器
|
10
|
+
logger = logging.getLogger("pycityagent")
|
11
|
+
logger.setLevel(logging.WARNING) # 默认级别
|
12
|
+
|
13
|
+
# 如果没有处理器,则添加一个
|
14
|
+
if not logger.hasHandlers():
|
15
|
+
handler = logging.StreamHandler()
|
16
|
+
formatter = logging.Formatter(
|
17
|
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
18
|
+
)
|
19
|
+
handler.setFormatter(formatter)
|
20
|
+
logger.addHandler(handler)
|
7
21
|
|
8
22
|
__all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent"]
|
pycityagent/agent.py
CHANGED
@@ -2,16 +2,14 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import asyncio
|
5
|
-
import json
|
6
5
|
from uuid import UUID
|
7
6
|
from copy import deepcopy
|
8
7
|
from datetime import datetime
|
9
|
-
import time
|
10
8
|
from enum import Enum
|
11
9
|
import logging
|
12
10
|
import random
|
13
11
|
import uuid
|
14
|
-
from typing import Dict, List, Optional
|
12
|
+
from typing import Dict, List, Optional,Any
|
15
13
|
|
16
14
|
import fastavro
|
17
15
|
|
@@ -28,6 +26,8 @@ from .environment import Simulator
|
|
28
26
|
from .llm import LLM
|
29
27
|
from .memory import Memory
|
30
28
|
|
29
|
+
logger = logging.getLogger("pycityagent")
|
30
|
+
|
31
31
|
|
32
32
|
class AgentType(Enum):
|
33
33
|
"""
|
@@ -73,13 +73,14 @@ class Agent(ABC):
|
|
73
73
|
"""
|
74
74
|
self._name = name
|
75
75
|
self._type = type
|
76
|
-
self._uuid = uuid.uuid4()
|
76
|
+
self._uuid = str(uuid.uuid4())
|
77
77
|
self._llm_client = llm_client
|
78
78
|
self._economy_client = economy_client
|
79
79
|
self._messager = messager
|
80
80
|
self._simulator = simulator
|
81
81
|
self._memory = memory
|
82
82
|
self._exp_id = -1
|
83
|
+
self._agent_id = -1
|
83
84
|
self._has_bound_to_simulator = False
|
84
85
|
self._has_bound_to_economy = False
|
85
86
|
self._blocked = False
|
@@ -123,12 +124,18 @@ class Agent(ABC):
|
|
123
124
|
"""
|
124
125
|
self._memory = memory
|
125
126
|
|
126
|
-
def set_exp_id(self, exp_id: str
|
127
|
+
def set_exp_id(self, exp_id: str):
|
127
128
|
"""
|
128
129
|
Set the exp_id of the agent.
|
129
130
|
"""
|
130
131
|
self._exp_id = exp_id
|
131
132
|
|
133
|
+
def set_avro_file(self, avro_file: Dict[str, str]):
|
134
|
+
"""
|
135
|
+
Set the avro file of the agent.
|
136
|
+
"""
|
137
|
+
self._avro_file = avro_file
|
138
|
+
|
132
139
|
@property
|
133
140
|
def uuid(self):
|
134
141
|
"""The Agent's UUID"""
|
@@ -214,10 +221,12 @@ class Agent(ABC):
|
|
214
221
|
|
215
222
|
async def _process_survey(self, survey: dict):
|
216
223
|
survey_response = await self.generate_user_survey_response(survey)
|
224
|
+
if self._avro_file is None:
|
225
|
+
return
|
217
226
|
response_to_avro = [{
|
218
|
-
"id":
|
219
|
-
"day": await self.
|
220
|
-
"t": await self.
|
227
|
+
"id": self._uuid,
|
228
|
+
"day": await self.simulator.get_simulator_day(),
|
229
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
221
230
|
"survey_id": survey["id"],
|
222
231
|
"result": survey_response,
|
223
232
|
"created_at": int(datetime.now().timestamp() * 1000),
|
@@ -264,9 +273,9 @@ class Agent(ABC):
|
|
264
273
|
|
265
274
|
async def _process_interview(self, payload: dict):
|
266
275
|
auros = [{
|
267
|
-
"id":
|
268
|
-
"day": await self.
|
269
|
-
"t": await self.
|
276
|
+
"id": self._uuid,
|
277
|
+
"day": await self.simulator.get_simulator_day(),
|
278
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
270
279
|
"type": 2,
|
271
280
|
"speaker": "user",
|
272
281
|
"content": payload["content"],
|
@@ -275,23 +284,27 @@ class Agent(ABC):
|
|
275
284
|
question = payload["content"]
|
276
285
|
response = await self.generate_user_chat_response(question)
|
277
286
|
auros.append({
|
278
|
-
"id":
|
279
|
-
"day": await self.
|
280
|
-
"t": await self.
|
287
|
+
"id": self._uuid,
|
288
|
+
"day": await self.simulator.get_simulator_day(),
|
289
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
281
290
|
"type": 2,
|
282
291
|
"speaker": "",
|
283
292
|
"content": response,
|
284
293
|
"created_at": int(datetime.now().timestamp() * 1000),
|
285
294
|
})
|
295
|
+
if self._avro_file is None:
|
296
|
+
return
|
286
297
|
with open(self._avro_file["dialog"], "a+b") as f:
|
287
298
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
288
299
|
|
289
300
|
async def process_agent_chat_response(self, payload: dict) -> str:
|
290
|
-
|
301
|
+
resp = f"Agent {self._uuid} received agent chat response: {payload}"
|
302
|
+
logger.info(resp)
|
303
|
+
return resp
|
291
304
|
|
292
305
|
async def _process_agent_chat(self, payload: dict):
|
293
306
|
auros = [{
|
294
|
-
"id":
|
307
|
+
"id": self._uuid,
|
295
308
|
"day": payload["day"],
|
296
309
|
"t": payload["t"],
|
297
310
|
"type": 1,
|
@@ -300,6 +313,8 @@ class Agent(ABC):
|
|
300
313
|
"created_at": int(datetime.now().timestamp() * 1000),
|
301
314
|
}]
|
302
315
|
asyncio.create_task(self.process_agent_chat_response(payload))
|
316
|
+
if self._avro_file is None:
|
317
|
+
return
|
303
318
|
with open(self._avro_file["dialog"], "a+b") as f:
|
304
319
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
305
320
|
|
@@ -307,27 +322,27 @@ class Agent(ABC):
|
|
307
322
|
async def handle_agent_chat_message(self, payload: dict):
|
308
323
|
"""处理收到的消息,识别发送者"""
|
309
324
|
# 从消息中解析发送者 ID 和消息内容
|
310
|
-
|
325
|
+
logger.info(f"Agent {self._uuid} received agent chat message: {payload}")
|
311
326
|
asyncio.create_task(self._process_agent_chat(payload))
|
312
327
|
|
313
328
|
async def handle_user_chat_message(self, payload: dict):
|
314
329
|
"""处理收到的消息,识别发送者"""
|
315
330
|
# 从消息中解析发送者 ID 和消息内容
|
316
|
-
|
331
|
+
logger.info(f"Agent {self._uuid} received user chat message: {payload}")
|
317
332
|
asyncio.create_task(self._process_interview(payload))
|
318
333
|
|
319
334
|
async def handle_user_survey_message(self, payload: dict):
|
320
335
|
"""处理收到的消息,识别发送者"""
|
321
336
|
# 从消息中解析发送者 ID 和消息内容
|
322
|
-
|
337
|
+
logger.info(f"Agent {self._uuid} received user survey message: {payload}")
|
323
338
|
asyncio.create_task(self._process_survey(payload["data"]))
|
324
339
|
|
325
|
-
async def handle_gather_message(self, payload:
|
340
|
+
async def handle_gather_message(self, payload: Any):
|
326
341
|
raise NotImplementedError
|
327
342
|
|
328
343
|
# MQTT send message
|
329
344
|
async def _send_message(
|
330
|
-
self, to_agent_uuid:
|
345
|
+
self, to_agent_uuid: str, payload: dict, sub_topic: str
|
331
346
|
):
|
332
347
|
"""通过 Messager 发送消息"""
|
333
348
|
if self._messager is None:
|
@@ -336,7 +351,7 @@ class Agent(ABC):
|
|
336
351
|
await self._messager.send_message(topic, payload)
|
337
352
|
|
338
353
|
async def send_message_to_agent(
|
339
|
-
self, to_agent_uuid:
|
354
|
+
self, to_agent_uuid: str, content: str
|
340
355
|
):
|
341
356
|
"""通过 Messager 发送消息"""
|
342
357
|
if self._messager is None:
|
@@ -345,19 +360,21 @@ class Agent(ABC):
|
|
345
360
|
"from": self._uuid,
|
346
361
|
"content": content,
|
347
362
|
"timestamp": int(datetime.now().timestamp() * 1000),
|
348
|
-
"day": await self.
|
349
|
-
"t": await self.
|
363
|
+
"day": await self.simulator.get_simulator_day(),
|
364
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
350
365
|
}
|
351
366
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
352
367
|
auros = [{
|
353
|
-
"id":
|
354
|
-
"day": await self.
|
355
|
-
"t": await self.
|
368
|
+
"id": self._uuid,
|
369
|
+
"day": await self.simulator.get_simulator_day(),
|
370
|
+
"t": await self.simulator.get_simulator_second_from_start_of_day(),
|
356
371
|
"type": 1,
|
357
|
-
"speaker":
|
372
|
+
"speaker": self._uuid,
|
358
373
|
"content": content,
|
359
374
|
"created_at": int(datetime.now().timestamp() * 1000),
|
360
375
|
}]
|
376
|
+
if self._avro_file is None:
|
377
|
+
return
|
361
378
|
with open(self._avro_file["dialog"], "a+b") as f:
|
362
379
|
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
363
380
|
|
@@ -414,7 +431,7 @@ class CitizenAgent(Agent):
|
|
414
431
|
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
415
432
|
"""
|
416
433
|
if self._simulator is None:
|
417
|
-
|
434
|
+
logger.warning("Simulator is not set")
|
418
435
|
return
|
419
436
|
if not self._has_bound_to_simulator:
|
420
437
|
FROM_MEMORY_KEYS = {
|
@@ -426,13 +443,13 @@ class CitizenAgent(Agent):
|
|
426
443
|
"pedestrian_attribute",
|
427
444
|
"bike_attribute",
|
428
445
|
}
|
429
|
-
simulator = self.
|
430
|
-
memory = self.
|
446
|
+
simulator = self.simulator
|
447
|
+
memory = self.memory
|
431
448
|
person_id = await memory.get("id")
|
432
449
|
# ATTENTION:模拟器分配的id从0开始
|
433
450
|
if person_id >= 0:
|
434
451
|
await simulator.get_person(person_id)
|
435
|
-
|
452
|
+
logger.debug(f"Binding to Person `{person_id}` already in Simulator")
|
436
453
|
else:
|
437
454
|
dict_person = deepcopy(self._person_template)
|
438
455
|
for _key in FROM_MEMORY_KEYS:
|
@@ -447,7 +464,7 @@ class CitizenAgent(Agent):
|
|
447
464
|
)
|
448
465
|
person_id = resp["person_id"]
|
449
466
|
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
450
|
-
|
467
|
+
logger.debug(
|
451
468
|
f"Binding to Person `{person_id}` just added to Simulator"
|
452
469
|
)
|
453
470
|
# 防止模拟器还没有到prepare阶段导致get_person出错
|
@@ -456,7 +473,7 @@ class CitizenAgent(Agent):
|
|
456
473
|
|
457
474
|
async def _bind_to_economy(self):
|
458
475
|
if self._economy_client is None:
|
459
|
-
|
476
|
+
logger.warning("Economy client is not set")
|
460
477
|
return
|
461
478
|
if not self._has_bound_to_economy:
|
462
479
|
if self._has_bound_to_simulator:
|
@@ -464,16 +481,16 @@ class CitizenAgent(Agent):
|
|
464
481
|
await self._economy_client.remove_agents([self._agent_id])
|
465
482
|
except:
|
466
483
|
pass
|
467
|
-
person_id = await self.
|
484
|
+
person_id = await self.memory.get("id")
|
468
485
|
await self._economy_client.add_agents(
|
469
486
|
{
|
470
487
|
"id": person_id,
|
471
|
-
"currency": await self.
|
488
|
+
"currency": await self.memory.get("currency"),
|
472
489
|
}
|
473
490
|
)
|
474
491
|
self._has_bound_to_economy = True
|
475
492
|
else:
|
476
|
-
|
493
|
+
logger.debug(
|
477
494
|
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
478
495
|
)
|
479
496
|
|
@@ -523,68 +540,69 @@ class InstitutionAgent(Agent):
|
|
523
540
|
|
524
541
|
async def _bind_to_economy(self):
|
525
542
|
if self._economy_client is None:
|
526
|
-
|
543
|
+
logger.debug("Economy client is not set")
|
527
544
|
return
|
528
545
|
if not self._has_bound_to_economy:
|
529
546
|
# TODO: More general id generation
|
530
547
|
_id = random.randint(100000, 999999)
|
531
548
|
self._agent_id = _id
|
532
|
-
await self.
|
549
|
+
await self.memory.update("id", _id, protect_llm_read_only_fields=False)
|
533
550
|
try:
|
534
551
|
await self._economy_client.remove_orgs([self._agent_id])
|
535
552
|
except:
|
536
553
|
pass
|
537
554
|
try:
|
538
|
-
|
539
|
-
|
555
|
+
_memory = self.memory
|
556
|
+
_id = await _memory.get("id")
|
557
|
+
_type = await _memory.get("type")
|
540
558
|
try:
|
541
|
-
nominal_gdp = await
|
559
|
+
nominal_gdp = await _memory.get("nominal_gdp")
|
542
560
|
except:
|
543
561
|
nominal_gdp = []
|
544
562
|
try:
|
545
|
-
real_gdp = await
|
563
|
+
real_gdp = await _memory.get("real_gdp")
|
546
564
|
except:
|
547
565
|
real_gdp = []
|
548
566
|
try:
|
549
|
-
unemployment = await
|
567
|
+
unemployment = await _memory.get("unemployment")
|
550
568
|
except:
|
551
569
|
unemployment = []
|
552
570
|
try:
|
553
|
-
wages = await
|
571
|
+
wages = await _memory.get("wages")
|
554
572
|
except:
|
555
573
|
wages = []
|
556
574
|
try:
|
557
|
-
prices = await
|
575
|
+
prices = await _memory.get("prices")
|
558
576
|
except:
|
559
577
|
prices = []
|
560
578
|
try:
|
561
|
-
inventory = await
|
579
|
+
inventory = await _memory.get("inventory")
|
562
580
|
except:
|
563
581
|
inventory = 0
|
564
582
|
try:
|
565
|
-
price = await
|
583
|
+
price = await _memory.get("price")
|
566
584
|
except:
|
567
585
|
price = 0
|
568
586
|
try:
|
569
|
-
currency = await
|
587
|
+
currency = await _memory.get("currency")
|
570
588
|
except:
|
571
589
|
currency = 0.0
|
572
590
|
try:
|
573
|
-
interest_rate = await
|
591
|
+
interest_rate = await _memory.get("interest_rate")
|
574
592
|
except:
|
575
593
|
interest_rate = 0.0
|
576
594
|
try:
|
577
|
-
bracket_cutoffs = await
|
595
|
+
bracket_cutoffs = await _memory.get("bracket_cutoffs")
|
578
596
|
except:
|
579
597
|
bracket_cutoffs = []
|
580
598
|
try:
|
581
|
-
bracket_rates = await
|
599
|
+
bracket_rates = await _memory.get("bracket_rates")
|
582
600
|
except:
|
583
601
|
bracket_rates = []
|
584
602
|
await self._economy_client.add_orgs(
|
585
603
|
{
|
586
|
-
"id":
|
587
|
-
"type":
|
604
|
+
"id": _id,
|
605
|
+
"type": _type,
|
588
606
|
"nominal_gdp": nominal_gdp,
|
589
607
|
"real_gdp": real_gdp,
|
590
608
|
"unemployment": unemployment,
|
@@ -599,7 +617,7 @@ class InstitutionAgent(Agent):
|
|
599
617
|
}
|
600
618
|
)
|
601
619
|
except Exception as e:
|
602
|
-
|
620
|
+
logger.error(f"Failed to bind to Economy: {e}")
|
603
621
|
self._has_bound_to_economy = True
|
604
622
|
|
605
623
|
async def handle_gather_message(self, payload: dict):
|
@@ -615,11 +633,11 @@ class InstitutionAgent(Agent):
|
|
615
633
|
"content": content,
|
616
634
|
})
|
617
635
|
|
618
|
-
async def gather_messages(self,
|
636
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
|
619
637
|
"""从多个智能体收集消息
|
620
638
|
|
621
639
|
Args:
|
622
|
-
|
640
|
+
agent_uuids: 目标智能体UUID列表
|
623
641
|
target: 要收集的信息类型
|
624
642
|
|
625
643
|
Returns:
|
@@ -627,18 +645,17 @@ class InstitutionAgent(Agent):
|
|
627
645
|
"""
|
628
646
|
# 为每个agent创建Future
|
629
647
|
futures = {}
|
630
|
-
for
|
631
|
-
|
632
|
-
|
633
|
-
self._gather_responses[response_key] = futures[response_key]
|
648
|
+
for agent_uuid in agent_uuids:
|
649
|
+
futures[agent_uuid] = asyncio.Future()
|
650
|
+
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
634
651
|
|
635
652
|
# 发送gather请求
|
636
653
|
payload = {
|
637
654
|
"from": self._uuid,
|
638
655
|
"target": target,
|
639
656
|
}
|
640
|
-
for
|
641
|
-
await self._send_message(
|
657
|
+
for agent_uuid in agent_uuids:
|
658
|
+
await self._send_message(agent_uuid, payload, "gather")
|
642
659
|
|
643
660
|
try:
|
644
661
|
# 等待所有响应
|
@@ -20,6 +20,7 @@ from shapely.strtree import STRtree
|
|
20
20
|
from .sim import CityClient, ControlSimEnv
|
21
21
|
from .utils.const import *
|
22
22
|
|
23
|
+
logger = logging.getLogger("pycityagent")
|
23
24
|
|
24
25
|
class Simulator:
|
25
26
|
"""
|
@@ -72,7 +73,7 @@ class Simulator:
|
|
72
73
|
else:
|
73
74
|
self._client = CityClient(config["simulator"]["server"], secure=False)
|
74
75
|
else:
|
75
|
-
|
76
|
+
logger.warning(
|
76
77
|
"No simulator config found, no simulator client will be used"
|
77
78
|
)
|
78
79
|
self.map = SimMap(
|
@@ -285,7 +286,7 @@ class Simulator:
|
|
285
286
|
reset_position["aoi_position"] = {"aoi_id": aoi_id}
|
286
287
|
if poi_id is not None:
|
287
288
|
reset_position["aoi_position"]["poi_id"] = poi_id
|
288
|
-
|
289
|
+
logger.debug(
|
289
290
|
f"Setting person {person_id} pos to AoiPosition {reset_position}"
|
290
291
|
)
|
291
292
|
await self._client.person_service.ResetPersonPosition(
|
@@ -298,14 +299,14 @@ class Simulator:
|
|
298
299
|
}
|
299
300
|
if s is not None:
|
300
301
|
reset_position["lane_position"]["s"] = s
|
301
|
-
|
302
|
+
logger.debug(
|
302
303
|
f"Setting person {person_id} pos to LanePosition {reset_position}"
|
303
304
|
)
|
304
305
|
await self._client.person_service.ResetPersonPosition(
|
305
306
|
{"person_id": person_id, "position": reset_position}
|
306
307
|
)
|
307
308
|
else:
|
308
|
-
|
309
|
+
logger.debug(
|
309
310
|
f"Neither aoi or lane pos provided for person {person_id} position reset!!"
|
310
311
|
)
|
311
312
|
|
pycityagent/memory/memory.py
CHANGED
@@ -13,6 +13,7 @@ from .profile import ProfileMemory
|
|
13
13
|
from .self_define import DynamicMemory
|
14
14
|
from .state import StateMemory
|
15
15
|
|
16
|
+
logger = logging.getLogger("pycityagent")
|
16
17
|
|
17
18
|
class Memory:
|
18
19
|
"""
|
@@ -83,7 +84,7 @@ class Memory:
|
|
83
84
|
_type.extend(_value)
|
84
85
|
_value = deepcopy(_type)
|
85
86
|
else:
|
86
|
-
|
87
|
+
logger.warning(f"type `{_type}` is not supported!")
|
87
88
|
pass
|
88
89
|
except TypeError as e:
|
89
90
|
pass
|
@@ -99,7 +100,7 @@ class Memory:
|
|
99
100
|
or k in STATE_ATTRIBUTES
|
100
101
|
or k == TIME_STAMP_KEY
|
101
102
|
):
|
102
|
-
|
103
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
103
104
|
continue
|
104
105
|
|
105
106
|
_dynamic_config[k] = deepcopy(_value)
|
@@ -112,19 +113,19 @@ class Memory:
|
|
112
113
|
if profile is not None:
|
113
114
|
for k, v in profile.items():
|
114
115
|
if k not in PROFILE_ATTRIBUTES:
|
115
|
-
|
116
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
116
117
|
continue
|
117
118
|
_profile_config[k] = v
|
118
119
|
if motion is not None:
|
119
120
|
for k, v in motion.items():
|
120
121
|
if k not in STATE_ATTRIBUTES:
|
121
|
-
|
122
|
+
logger.warning(f"key `{k}` is not a correct `motion` field!")
|
122
123
|
continue
|
123
124
|
_state_config[k] = v
|
124
125
|
if base is not None:
|
125
126
|
for k, v in base.items():
|
126
127
|
if k not in STATE_ATTRIBUTES:
|
127
|
-
|
128
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
128
129
|
continue
|
129
130
|
_state_config[k] = v
|
130
131
|
self._state = StateMemory(
|
@@ -182,7 +183,7 @@ class Memory:
|
|
182
183
|
"""更新记忆值并在必要时更新embedding"""
|
183
184
|
if protect_llm_read_only_fields:
|
184
185
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
185
|
-
|
186
|
+
logger.warning(f"Trying to write protected key `{key}`!")
|
186
187
|
return
|
187
188
|
for _mem in [self._state, self._profile, self._dynamic]:
|
188
189
|
try:
|
@@ -208,7 +209,7 @@ class Memory:
|
|
208
209
|
elif isinstance(original_value, deque):
|
209
210
|
original_value.extend(deque(value))
|
210
211
|
else:
|
211
|
-
|
212
|
+
logger.debug(
|
212
213
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
213
214
|
)
|
214
215
|
await _mem.update(key, value, store_snapshot)
|
@@ -10,6 +10,8 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union
|
|
10
10
|
|
11
11
|
from .const import *
|
12
12
|
|
13
|
+
logger = logging.getLogger("pycityagent")
|
14
|
+
|
13
15
|
|
14
16
|
class MemoryUnit:
|
15
17
|
def __init__(
|
@@ -57,7 +59,7 @@ class MemoryUnit:
|
|
57
59
|
orig_v = self._content[k]
|
58
60
|
orig_type, new_type = type(orig_v), type(v)
|
59
61
|
if not orig_type == new_type:
|
60
|
-
|
62
|
+
logger.debug(
|
61
63
|
f"Type warning: The type of the value for key '{k}' is changing from `{orig_type.__name__}` to `{new_type.__name__}`!"
|
62
64
|
)
|
63
65
|
self._content.update(content)
|
@@ -82,7 +84,7 @@ class MemoryUnit:
|
|
82
84
|
await self._lock.acquire()
|
83
85
|
values = self._content[key]
|
84
86
|
if not isinstance(values, Sequence):
|
85
|
-
|
87
|
+
logger.warning(
|
86
88
|
f"the value stored in key `{key}` is not `sequence`, return value `{values}` instead!"
|
87
89
|
)
|
88
90
|
return values
|
@@ -93,7 +95,7 @@ class MemoryUnit:
|
|
93
95
|
)
|
94
96
|
top_k = len(values) if top_k is None else top_k
|
95
97
|
if len(_sorted_values_with_idx) < top_k:
|
96
|
-
|
98
|
+
logger.debug(
|
97
99
|
f"Length of values {len(_sorted_values_with_idx)} is less than top_k {top_k}, returning all values."
|
98
100
|
)
|
99
101
|
self._lock.release()
|
@@ -149,7 +151,7 @@ class MemoryBase(ABC):
|
|
149
151
|
if recent_n is None:
|
150
152
|
return _list_units
|
151
153
|
if len(_memories) < recent_n:
|
152
|
-
|
154
|
+
logger.debug(
|
153
155
|
f"Length of memory {len(_memories)} is less than recent_n {recent_n}, returning all available memories."
|
154
156
|
)
|
155
157
|
return _list_units[-recent_n:]
|
pycityagent/message/messager.py
CHANGED
@@ -5,6 +5,7 @@ import logging
|
|
5
5
|
import math
|
6
6
|
from aiomqtt import Client
|
7
7
|
|
8
|
+
logger = logging.getLogger("pycityagent")
|
8
9
|
|
9
10
|
class Messager:
|
10
11
|
def __init__(
|
@@ -21,15 +22,15 @@ class Messager:
|
|
21
22
|
try:
|
22
23
|
await self.client.__aenter__()
|
23
24
|
self.connected = True
|
24
|
-
|
25
|
+
logger.info("Connected to MQTT Broker")
|
25
26
|
except Exception as e:
|
26
27
|
self.connected = False
|
27
|
-
|
28
|
+
logger.error(f"Failed to connect to MQTT Broker: {e}")
|
28
29
|
|
29
30
|
async def disconnect(self):
|
30
31
|
await self.client.__aexit__(None, None, None)
|
31
32
|
self.connected = False
|
32
|
-
|
33
|
+
logger.info("Disconnected from MQTT Broker")
|
33
34
|
|
34
35
|
def is_connected(self):
|
35
36
|
"""检查是否成功连接到 Broker"""
|
@@ -37,13 +38,13 @@ class Messager:
|
|
37
38
|
|
38
39
|
async def subscribe(self, topic, agent):
|
39
40
|
if not self.is_connected():
|
40
|
-
|
41
|
+
logger.error(
|
41
42
|
f"Cannot subscribe to {topic} because not connected to the Broker."
|
42
43
|
)
|
43
44
|
return
|
44
45
|
await self.client.subscribe(topic)
|
45
46
|
self.subscribers[topic] = agent
|
46
|
-
|
47
|
+
logger.info(f"Subscribed to {topic} for Agent {agent._uuid}")
|
47
48
|
|
48
49
|
async def receive_messages(self):
|
49
50
|
"""监听并将消息存入队列"""
|
@@ -61,11 +62,11 @@ class Messager:
|
|
61
62
|
"""通过 Messager 发送消息"""
|
62
63
|
message = json.dumps(payload, default=str)
|
63
64
|
await self.client.publish(topic, message)
|
64
|
-
|
65
|
+
logger.info(f"Message sent to {topic}: {message}")
|
65
66
|
|
66
67
|
async def start_listening(self):
|
67
68
|
"""启动消息监听任务"""
|
68
69
|
if self.is_connected():
|
69
70
|
asyncio.create_task(self.receive_messages())
|
70
71
|
else:
|
71
|
-
|
72
|
+
logger.error("Cannot start listening because not connected to the Broker.")
|