pycityagent 2.0.0a13__py3-none-any.whl → 2.0.0a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycityagent/__init__.py +14 -0
- pycityagent/agent.py +164 -63
- pycityagent/economy/econ_client.py +2 -0
- pycityagent/environment/simulator.py +5 -4
- pycityagent/memory/const.py +1 -0
- pycityagent/memory/memory.py +8 -7
- pycityagent/memory/memory_base.py +6 -4
- pycityagent/message/messager.py +8 -7
- pycityagent/simulation/agentgroup.py +136 -14
- pycityagent/simulation/simulation.py +212 -42
- pycityagent/survey/manager.py +58 -0
- pycityagent/survey/models.py +120 -0
- pycityagent/utils/__init__.py +7 -0
- pycityagent/utils/avro_schema.py +110 -0
- pycityagent/utils/survey_util.py +53 -0
- pycityagent/workflow/tool.py +0 -3
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/METADATA +3 -1
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/RECORD +20 -21
- pycityagent/simulation/interview.py +0 -40
- pycityagent/simulation/survey/manager.py +0 -68
- pycityagent/simulation/survey/models.py +0 -52
- pycityagent/simulation/ui/__init__.py +0 -3
- pycityagent/simulation/ui/interface.py +0 -602
- /pycityagent/{simulation/survey → survey}/__init__.py +0 -0
- {pycityagent-2.0.0a13.dist-info → pycityagent-2.0.0a15.dist-info}/WHEEL +0 -0
pycityagent/__init__.py
CHANGED
@@ -4,5 +4,19 @@ Pycityagent: 城市智能体构建框架
|
|
4
4
|
|
5
5
|
from .agent import Agent, CitizenAgent, InstitutionAgent
|
6
6
|
from .environment import Simulator
|
7
|
+
import logging
|
8
|
+
|
9
|
+
# 创建一个 pycityagent 记录器
|
10
|
+
logger = logging.getLogger("pycityagent")
|
11
|
+
logger.setLevel(logging.WARNING) # 默认级别
|
12
|
+
|
13
|
+
# 如果没有处理器,则添加一个
|
14
|
+
if not logger.hasHandlers():
|
15
|
+
handler = logging.StreamHandler()
|
16
|
+
formatter = logging.Formatter(
|
17
|
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
18
|
+
)
|
19
|
+
handler.setFormatter(formatter)
|
20
|
+
logger.addHandler(handler)
|
7
21
|
|
8
22
|
__all__ = ["Agent", "Simulator", "CitizenAgent", "InstitutionAgent"]
|
pycityagent/agent.py
CHANGED
@@ -2,28 +2,32 @@
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
import asyncio
|
5
|
-
import json
|
6
5
|
from uuid import UUID
|
7
6
|
from copy import deepcopy
|
8
7
|
from datetime import datetime
|
9
|
-
import time
|
10
8
|
from enum import Enum
|
11
9
|
import logging
|
12
10
|
import random
|
13
11
|
import uuid
|
14
12
|
from typing import Dict, List, Optional
|
15
13
|
|
14
|
+
import fastavro
|
15
|
+
|
16
16
|
from pycityagent.environment.sim.person_service import PersonService
|
17
17
|
from mosstool.util.format_converter import dict2pb
|
18
18
|
from pycityproto.city.person.v2 import person_pb2 as person_pb2
|
19
|
+
from pycityagent.utils import process_survey_for_llm
|
19
20
|
|
20
21
|
from pycityagent.message.messager import Messager
|
22
|
+
from pycityagent.utils import SURVEY_SCHEMA, DIALOG_SCHEMA
|
21
23
|
|
22
24
|
from .economy import EconomyClient
|
23
25
|
from .environment import Simulator
|
24
26
|
from .llm import LLM
|
25
27
|
from .memory import Memory
|
26
28
|
|
29
|
+
logger = logging.getLogger("pycityagent")
|
30
|
+
|
27
31
|
|
28
32
|
class AgentType(Enum):
|
29
33
|
"""
|
@@ -52,6 +56,7 @@ class Agent(ABC):
|
|
52
56
|
messager: Optional[Messager] = None,
|
53
57
|
simulator: Optional[Simulator] = None,
|
54
58
|
memory: Optional[Memory] = None,
|
59
|
+
avro_file: Optional[Dict[str, str]] = None,
|
55
60
|
) -> None:
|
56
61
|
"""
|
57
62
|
Initialize the Agent.
|
@@ -64,10 +69,11 @@ class Agent(ABC):
|
|
64
69
|
messager (Messager, optional): The messager object. Defaults to None.
|
65
70
|
simulator (Simulator, optional): The simulator object. Defaults to None.
|
66
71
|
memory (Memory, optional): The memory of the agent. Defaults to None.
|
72
|
+
avro_file (Dict[str, str], optional): The avro file of the agent. Defaults to None.
|
67
73
|
"""
|
68
74
|
self._name = name
|
69
75
|
self._type = type
|
70
|
-
self._uuid = uuid.uuid4()
|
76
|
+
self._uuid = str(uuid.uuid4())
|
71
77
|
self._llm_client = llm_client
|
72
78
|
self._economy_client = economy_client
|
73
79
|
self._messager = messager
|
@@ -79,6 +85,7 @@ class Agent(ABC):
|
|
79
85
|
self._blocked = False
|
80
86
|
self._interview_history: List[Dict] = [] # 存储采访历史
|
81
87
|
self._person_template = PersonService.default_dict_person()
|
88
|
+
self._avro_file = avro_file
|
82
89
|
|
83
90
|
def __getstate__(self):
|
84
91
|
state = self.__dict__.copy()
|
@@ -116,12 +123,18 @@ class Agent(ABC):
|
|
116
123
|
"""
|
117
124
|
self._memory = memory
|
118
125
|
|
119
|
-
def set_exp_id(self, exp_id: str
|
126
|
+
def set_exp_id(self, exp_id: str):
|
120
127
|
"""
|
121
128
|
Set the exp_id of the agent.
|
122
129
|
"""
|
123
130
|
self._exp_id = exp_id
|
124
131
|
|
132
|
+
def set_avro_file(self, avro_file: Dict[str, str]):
|
133
|
+
"""
|
134
|
+
Set the avro file of the agent.
|
135
|
+
"""
|
136
|
+
self._avro_file = avro_file
|
137
|
+
|
125
138
|
@property
|
126
139
|
def uuid(self):
|
127
140
|
"""The Agent's UUID"""
|
@@ -168,11 +181,61 @@ class Agent(ABC):
|
|
168
181
|
)
|
169
182
|
return self._simulator
|
170
183
|
|
171
|
-
async def
|
172
|
-
"""生成回答
|
184
|
+
async def generate_user_survey_response(self, survey: dict) -> str:
|
185
|
+
"""生成回答 —— 可重写
|
186
|
+
基于智能体的记忆和当前状态,生成对问卷调查的回答。
|
187
|
+
Args:
|
188
|
+
survey: 需要回答的问卷 dict
|
189
|
+
Returns:
|
190
|
+
str: 智能体的回答
|
191
|
+
"""
|
192
|
+
survey_prompt = process_survey_for_llm(survey)
|
193
|
+
dialog = []
|
173
194
|
|
174
|
-
|
195
|
+
# 添加系统提示
|
196
|
+
system_prompt = "Please answer the survey question in first person. Follow the format requirements strictly and provide clear and specific answers."
|
197
|
+
dialog.append({"role": "system", "content": system_prompt})
|
175
198
|
|
199
|
+
# 添加记忆上下文
|
200
|
+
if self._memory:
|
201
|
+
relevant_memories = await self._memory.search(survey_prompt)
|
202
|
+
if relevant_memories:
|
203
|
+
dialog.append(
|
204
|
+
{
|
205
|
+
"role": "system",
|
206
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
207
|
+
}
|
208
|
+
)
|
209
|
+
|
210
|
+
# 添加问卷问题
|
211
|
+
dialog.append({"role": "user", "content": survey_prompt})
|
212
|
+
|
213
|
+
# 使用LLM生成回答
|
214
|
+
if not self._llm_client:
|
215
|
+
return "Sorry, I cannot answer survey questions right now."
|
216
|
+
|
217
|
+
response = await self._llm_client.atext_request(dialog) # type:ignore
|
218
|
+
|
219
|
+
return response # type:ignore
|
220
|
+
|
221
|
+
async def _process_survey(self, survey: dict):
|
222
|
+
survey_response = await self.generate_user_survey_response(survey)
|
223
|
+
if self._avro_file is None:
|
224
|
+
return
|
225
|
+
response_to_avro = [{
|
226
|
+
"id": self._uuid,
|
227
|
+
"day": await self._simulator.get_simulator_day(),
|
228
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
229
|
+
"survey_id": survey["id"],
|
230
|
+
"result": survey_response,
|
231
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
232
|
+
}]
|
233
|
+
with open(self._avro_file["survey"], "a+b") as f:
|
234
|
+
fastavro.writer(f, SURVEY_SCHEMA, response_to_avro, codec="snappy")
|
235
|
+
|
236
|
+
async def generate_user_chat_response(self, question: str) -> str:
|
237
|
+
"""生成回答 —— 可重写
|
238
|
+
基于智能体的记忆和当前状态,生成对问题的回答。
|
176
239
|
Args:
|
177
240
|
question: 需要回答的问题
|
178
241
|
|
@@ -182,7 +245,7 @@ class Agent(ABC):
|
|
182
245
|
dialog = []
|
183
246
|
|
184
247
|
# 添加系统提示
|
185
|
-
system_prompt =
|
248
|
+
system_prompt = "Please answer the question in first person and keep the response concise and clear."
|
186
249
|
dialog.append({"role": "system", "content": system_prompt})
|
187
250
|
|
188
251
|
# 添加记忆上下文
|
@@ -192,7 +255,7 @@ class Agent(ABC):
|
|
192
255
|
dialog.append(
|
193
256
|
{
|
194
257
|
"role": "system",
|
195
|
-
"content": f"
|
258
|
+
"content": f"Answer based on these memories:\n{relevant_memories}",
|
196
259
|
}
|
197
260
|
)
|
198
261
|
|
@@ -201,48 +264,82 @@ class Agent(ABC):
|
|
201
264
|
|
202
265
|
# 使用LLM生成回答
|
203
266
|
if not self._llm_client:
|
204
|
-
return "
|
267
|
+
return "Sorry, I cannot answer questions right now."
|
205
268
|
|
206
269
|
response = await self._llm_client.atext_request(dialog) # type:ignore
|
207
270
|
|
208
|
-
# 记录采访历史
|
209
|
-
self._interview_history.append(
|
210
|
-
{
|
211
|
-
"timestamp": datetime.now().isoformat(),
|
212
|
-
"question": question,
|
213
|
-
"response": response,
|
214
|
-
}
|
215
|
-
)
|
216
|
-
|
217
271
|
return response # type:ignore
|
272
|
+
|
273
|
+
async def _process_interview(self, payload: dict):
|
274
|
+
auros = [{
|
275
|
+
"id": self._uuid,
|
276
|
+
"day": await self._simulator.get_simulator_day(),
|
277
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
278
|
+
"type": 2,
|
279
|
+
"speaker": "user",
|
280
|
+
"content": payload["content"],
|
281
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
282
|
+
}]
|
283
|
+
question = payload["content"]
|
284
|
+
response = await self.generate_user_chat_response(question)
|
285
|
+
auros.append({
|
286
|
+
"id": self._uuid,
|
287
|
+
"day": await self._simulator.get_simulator_day(),
|
288
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
289
|
+
"type": 2,
|
290
|
+
"speaker": "",
|
291
|
+
"content": response,
|
292
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
293
|
+
})
|
294
|
+
if self._avro_file is None:
|
295
|
+
return
|
296
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
297
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
298
|
+
|
299
|
+
async def process_agent_chat_response(self, payload: dict) -> str:
|
300
|
+
logger.info(f"Agent {self._uuid} received agent chat response: {payload}")
|
301
|
+
|
302
|
+
async def _process_agent_chat(self, payload: dict):
|
303
|
+
auros = [{
|
304
|
+
"id": self._uuid,
|
305
|
+
"day": payload["day"],
|
306
|
+
"t": payload["t"],
|
307
|
+
"type": 1,
|
308
|
+
"speaker": payload["from"],
|
309
|
+
"content": payload["content"],
|
310
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
311
|
+
}]
|
312
|
+
asyncio.create_task(self.process_agent_chat_response(payload))
|
313
|
+
if self._avro_file is None:
|
314
|
+
return
|
315
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
316
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
218
317
|
|
219
|
-
|
220
|
-
"""获取采访历史记录"""
|
221
|
-
return self._interview_history
|
222
|
-
|
318
|
+
# Callback functions for MQTT message
|
223
319
|
async def handle_agent_chat_message(self, payload: dict):
|
224
320
|
"""处理收到的消息,识别发送者"""
|
225
321
|
# 从消息中解析发送者 ID 和消息内容
|
226
|
-
|
227
|
-
|
228
|
-
)
|
322
|
+
logger.info(f"Agent {self._uuid} received agent chat message: {payload}")
|
323
|
+
asyncio.create_task(self._process_agent_chat(payload))
|
229
324
|
|
230
325
|
async def handle_user_chat_message(self, payload: dict):
|
231
326
|
"""处理收到的消息,识别发送者"""
|
232
327
|
# 从消息中解析发送者 ID 和消息内容
|
233
|
-
|
234
|
-
|
235
|
-
)
|
328
|
+
logger.info(f"Agent {self._uuid} received user chat message: {payload}")
|
329
|
+
asyncio.create_task(self._process_interview(payload))
|
236
330
|
|
237
331
|
async def handle_user_survey_message(self, payload: dict):
|
238
332
|
"""处理收到的消息,识别发送者"""
|
239
333
|
# 从消息中解析发送者 ID 和消息内容
|
240
|
-
|
241
|
-
|
242
|
-
|
334
|
+
logger.info(f"Agent {self._uuid} received user survey message: {payload}")
|
335
|
+
asyncio.create_task(self._process_survey(payload["data"]))
|
336
|
+
|
337
|
+
async def handle_gather_message(self, payload: str):
|
338
|
+
raise NotImplementedError
|
243
339
|
|
340
|
+
# MQTT send message
|
244
341
|
async def _send_message(
|
245
|
-
self, to_agent_uuid:
|
342
|
+
self, to_agent_uuid: str, payload: dict, sub_topic: str
|
246
343
|
):
|
247
344
|
"""通过 Messager 发送消息"""
|
248
345
|
if self._messager is None:
|
@@ -251,7 +348,7 @@ class Agent(ABC):
|
|
251
348
|
await self._messager.send_message(topic, payload)
|
252
349
|
|
253
350
|
async def send_message_to_agent(
|
254
|
-
self, to_agent_uuid:
|
351
|
+
self, to_agent_uuid: str, content: str
|
255
352
|
):
|
256
353
|
"""通过 Messager 发送消息"""
|
257
354
|
if self._messager is None:
|
@@ -259,29 +356,30 @@ class Agent(ABC):
|
|
259
356
|
payload = {
|
260
357
|
"from": self._uuid,
|
261
358
|
"content": content,
|
262
|
-
"timestamp": int(
|
359
|
+
"timestamp": int(datetime.now().timestamp() * 1000),
|
263
360
|
"day": await self._simulator.get_simulator_day(),
|
264
361
|
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
265
362
|
}
|
266
363
|
await self._send_message(to_agent_uuid, payload, "agent-chat")
|
364
|
+
auros = [{
|
365
|
+
"id": self._uuid,
|
366
|
+
"day": await self._simulator.get_simulator_day(),
|
367
|
+
"t": await self._simulator.get_simulator_second_from_start_of_day(),
|
368
|
+
"type": 1,
|
369
|
+
"speaker": self._uuid,
|
370
|
+
"content": content,
|
371
|
+
"created_at": int(datetime.now().timestamp() * 1000),
|
372
|
+
}]
|
373
|
+
if self._avro_file is None:
|
374
|
+
return
|
375
|
+
with open(self._avro_file["dialog"], "a+b") as f:
|
376
|
+
fastavro.writer(f, DIALOG_SCHEMA, auros, codec="snappy")
|
267
377
|
|
268
|
-
|
269
|
-
self, content: dict
|
270
|
-
):
|
271
|
-
pass
|
272
|
-
|
273
|
-
async def send_message_to_survey(
|
274
|
-
self, content: dict
|
275
|
-
):
|
276
|
-
pass
|
277
|
-
|
378
|
+
# Agent logic
|
278
379
|
@abstractmethod
|
279
380
|
async def forward(self) -> None:
|
280
381
|
"""智能体行为逻辑"""
|
281
382
|
raise NotImplementedError
|
282
|
-
|
283
|
-
async def handle_gather_message(self, payload: str):
|
284
|
-
raise NotImplementedError
|
285
383
|
|
286
384
|
async def run(self) -> None:
|
287
385
|
"""
|
@@ -305,6 +403,7 @@ class CitizenAgent(Agent):
|
|
305
403
|
memory: Optional[Memory] = None,
|
306
404
|
economy_client: Optional[EconomyClient] = None,
|
307
405
|
messager: Optional[Messager] = None,
|
406
|
+
avro_file: Optional[dict] = None,
|
308
407
|
) -> None:
|
309
408
|
super().__init__(
|
310
409
|
name,
|
@@ -314,6 +413,7 @@ class CitizenAgent(Agent):
|
|
314
413
|
messager,
|
315
414
|
simulator,
|
316
415
|
memory,
|
416
|
+
avro_file,
|
317
417
|
)
|
318
418
|
|
319
419
|
async def bind_to_simulator(self):
|
@@ -328,7 +428,7 @@ class CitizenAgent(Agent):
|
|
328
428
|
person_template (dict, optional): The person template in dict format. Defaults to PersonService.default_dict_person().
|
329
429
|
"""
|
330
430
|
if self._simulator is None:
|
331
|
-
|
431
|
+
logger.warning("Simulator is not set")
|
332
432
|
return
|
333
433
|
if not self._has_bound_to_simulator:
|
334
434
|
FROM_MEMORY_KEYS = {
|
@@ -346,7 +446,7 @@ class CitizenAgent(Agent):
|
|
346
446
|
# ATTENTION:模拟器分配的id从0开始
|
347
447
|
if person_id >= 0:
|
348
448
|
await simulator.get_person(person_id)
|
349
|
-
|
449
|
+
logger.debug(f"Binding to Person `{person_id}` already in Simulator")
|
350
450
|
else:
|
351
451
|
dict_person = deepcopy(self._person_template)
|
352
452
|
for _key in FROM_MEMORY_KEYS:
|
@@ -361,7 +461,7 @@ class CitizenAgent(Agent):
|
|
361
461
|
)
|
362
462
|
person_id = resp["person_id"]
|
363
463
|
await memory.update("id", person_id, protect_llm_read_only_fields=False)
|
364
|
-
|
464
|
+
logger.debug(
|
365
465
|
f"Binding to Person `{person_id}` just added to Simulator"
|
366
466
|
)
|
367
467
|
# 防止模拟器还没有到prepare阶段导致get_person出错
|
@@ -370,7 +470,7 @@ class CitizenAgent(Agent):
|
|
370
470
|
|
371
471
|
async def _bind_to_economy(self):
|
372
472
|
if self._economy_client is None:
|
373
|
-
|
473
|
+
logger.warning("Economy client is not set")
|
374
474
|
return
|
375
475
|
if not self._has_bound_to_economy:
|
376
476
|
if self._has_bound_to_simulator:
|
@@ -387,7 +487,7 @@ class CitizenAgent(Agent):
|
|
387
487
|
)
|
388
488
|
self._has_bound_to_economy = True
|
389
489
|
else:
|
390
|
-
|
490
|
+
logger.debug(
|
391
491
|
f"Binding to Economy before binding to Simulator, skip binding to Economy Simulator"
|
392
492
|
)
|
393
493
|
|
@@ -417,6 +517,7 @@ class InstitutionAgent(Agent):
|
|
417
517
|
memory: Optional[Memory] = None,
|
418
518
|
economy_client: Optional[EconomyClient] = None,
|
419
519
|
messager: Optional[Messager] = None,
|
520
|
+
avro_file: Optional[dict] = None,
|
420
521
|
) -> None:
|
421
522
|
super().__init__(
|
422
523
|
name,
|
@@ -426,6 +527,7 @@ class InstitutionAgent(Agent):
|
|
426
527
|
messager,
|
427
528
|
simulator,
|
428
529
|
memory,
|
530
|
+
avro_file,
|
429
531
|
)
|
430
532
|
# 添加响应收集器
|
431
533
|
self._gather_responses: Dict[str, asyncio.Future] = {}
|
@@ -435,7 +537,7 @@ class InstitutionAgent(Agent):
|
|
435
537
|
|
436
538
|
async def _bind_to_economy(self):
|
437
539
|
if self._economy_client is None:
|
438
|
-
|
540
|
+
logger.debug("Economy client is not set")
|
439
541
|
return
|
440
542
|
if not self._has_bound_to_economy:
|
441
543
|
# TODO: More general id generation
|
@@ -511,7 +613,7 @@ class InstitutionAgent(Agent):
|
|
511
613
|
}
|
512
614
|
)
|
513
615
|
except Exception as e:
|
514
|
-
|
616
|
+
logger.error(f"Failed to bind to Economy: {e}")
|
515
617
|
self._has_bound_to_economy = True
|
516
618
|
|
517
619
|
async def handle_gather_message(self, payload: dict):
|
@@ -527,11 +629,11 @@ class InstitutionAgent(Agent):
|
|
527
629
|
"content": content,
|
528
630
|
})
|
529
631
|
|
530
|
-
async def gather_messages(self,
|
632
|
+
async def gather_messages(self, agent_uuids: list[str], target: str) -> List[dict]:
|
531
633
|
"""从多个智能体收集消息
|
532
634
|
|
533
635
|
Args:
|
534
|
-
|
636
|
+
agent_uuids: 目标智能体UUID列表
|
535
637
|
target: 要收集的信息类型
|
536
638
|
|
537
639
|
Returns:
|
@@ -539,18 +641,17 @@ class InstitutionAgent(Agent):
|
|
539
641
|
"""
|
540
642
|
# 为每个agent创建Future
|
541
643
|
futures = {}
|
542
|
-
for
|
543
|
-
|
544
|
-
|
545
|
-
self._gather_responses[response_key] = futures[response_key]
|
644
|
+
for agent_uuid in agent_uuids:
|
645
|
+
futures[agent_uuid] = asyncio.Future()
|
646
|
+
self._gather_responses[agent_uuid] = futures[agent_uuid]
|
546
647
|
|
547
648
|
# 发送gather请求
|
548
649
|
payload = {
|
549
650
|
"from": self._uuid,
|
550
651
|
"target": target,
|
551
652
|
}
|
552
|
-
for
|
553
|
-
await self._send_message(
|
653
|
+
for agent_uuid in agent_uuids:
|
654
|
+
await self._send_message(agent_uuid, payload, "gather")
|
554
655
|
|
555
656
|
try:
|
556
657
|
# 等待所有响应
|
@@ -20,6 +20,7 @@ from shapely.strtree import STRtree
|
|
20
20
|
from .sim import CityClient, ControlSimEnv
|
21
21
|
from .utils.const import *
|
22
22
|
|
23
|
+
logger = logging.getLogger("pycityagent")
|
23
24
|
|
24
25
|
class Simulator:
|
25
26
|
"""
|
@@ -72,7 +73,7 @@ class Simulator:
|
|
72
73
|
else:
|
73
74
|
self._client = CityClient(config["simulator"]["server"], secure=False)
|
74
75
|
else:
|
75
|
-
|
76
|
+
logger.warning(
|
76
77
|
"No simulator config found, no simulator client will be used"
|
77
78
|
)
|
78
79
|
self.map = SimMap(
|
@@ -285,7 +286,7 @@ class Simulator:
|
|
285
286
|
reset_position["aoi_position"] = {"aoi_id": aoi_id}
|
286
287
|
if poi_id is not None:
|
287
288
|
reset_position["aoi_position"]["poi_id"] = poi_id
|
288
|
-
|
289
|
+
logger.debug(
|
289
290
|
f"Setting person {person_id} pos to AoiPosition {reset_position}"
|
290
291
|
)
|
291
292
|
await self._client.person_service.ResetPersonPosition(
|
@@ -298,14 +299,14 @@ class Simulator:
|
|
298
299
|
}
|
299
300
|
if s is not None:
|
300
301
|
reset_position["lane_position"]["s"] = s
|
301
|
-
|
302
|
+
logger.debug(
|
302
303
|
f"Setting person {person_id} pos to LanePosition {reset_position}"
|
303
304
|
)
|
304
305
|
await self._client.person_service.ResetPersonPosition(
|
305
306
|
{"person_id": person_id, "position": reset_position}
|
306
307
|
)
|
307
308
|
else:
|
308
|
-
|
309
|
+
logger.debug(
|
309
310
|
f"Neither aoi or lane pos provided for person {person_id} position reset!!"
|
310
311
|
)
|
311
312
|
|
pycityagent/memory/const.py
CHANGED
pycityagent/memory/memory.py
CHANGED
@@ -13,6 +13,7 @@ from .profile import ProfileMemory
|
|
13
13
|
from .self_define import DynamicMemory
|
14
14
|
from .state import StateMemory
|
15
15
|
|
16
|
+
logger = logging.getLogger("pycityagent")
|
16
17
|
|
17
18
|
class Memory:
|
18
19
|
"""
|
@@ -83,7 +84,7 @@ class Memory:
|
|
83
84
|
_type.extend(_value)
|
84
85
|
_value = deepcopy(_type)
|
85
86
|
else:
|
86
|
-
|
87
|
+
logger.warning(f"type `{_type}` is not supported!")
|
87
88
|
pass
|
88
89
|
except TypeError as e:
|
89
90
|
pass
|
@@ -99,7 +100,7 @@ class Memory:
|
|
99
100
|
or k in STATE_ATTRIBUTES
|
100
101
|
or k == TIME_STAMP_KEY
|
101
102
|
):
|
102
|
-
|
103
|
+
logger.warning(f"key `{k}` already declared in memory!")
|
103
104
|
continue
|
104
105
|
|
105
106
|
_dynamic_config[k] = deepcopy(_value)
|
@@ -112,19 +113,19 @@ class Memory:
|
|
112
113
|
if profile is not None:
|
113
114
|
for k, v in profile.items():
|
114
115
|
if k not in PROFILE_ATTRIBUTES:
|
115
|
-
|
116
|
+
logger.warning(f"key `{k}` is not a correct `profile` field!")
|
116
117
|
continue
|
117
118
|
_profile_config[k] = v
|
118
119
|
if motion is not None:
|
119
120
|
for k, v in motion.items():
|
120
121
|
if k not in STATE_ATTRIBUTES:
|
121
|
-
|
122
|
+
logger.warning(f"key `{k}` is not a correct `motion` field!")
|
122
123
|
continue
|
123
124
|
_state_config[k] = v
|
124
125
|
if base is not None:
|
125
126
|
for k, v in base.items():
|
126
127
|
if k not in STATE_ATTRIBUTES:
|
127
|
-
|
128
|
+
logger.warning(f"key `{k}` is not a correct `base` field!")
|
128
129
|
continue
|
129
130
|
_state_config[k] = v
|
130
131
|
self._state = StateMemory(
|
@@ -182,7 +183,7 @@ class Memory:
|
|
182
183
|
"""更新记忆值并在必要时更新embedding"""
|
183
184
|
if protect_llm_read_only_fields:
|
184
185
|
if any(key in _attrs for _attrs in [STATE_ATTRIBUTES]):
|
185
|
-
|
186
|
+
logger.warning(f"Trying to write protected key `{key}`!")
|
186
187
|
return
|
187
188
|
for _mem in [self._state, self._profile, self._dynamic]:
|
188
189
|
try:
|
@@ -208,7 +209,7 @@ class Memory:
|
|
208
209
|
elif isinstance(original_value, deque):
|
209
210
|
original_value.extend(deque(value))
|
210
211
|
else:
|
211
|
-
|
212
|
+
logger.debug(
|
212
213
|
f"Type of {type(original_value)} does not support mode `merge`, using `replace` instead!"
|
213
214
|
)
|
214
215
|
await _mem.update(key, value, store_snapshot)
|
@@ -10,6 +10,8 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union
|
|
10
10
|
|
11
11
|
from .const import *
|
12
12
|
|
13
|
+
logger = logging.getLogger("pycityagent")
|
14
|
+
|
13
15
|
|
14
16
|
class MemoryUnit:
|
15
17
|
def __init__(
|
@@ -57,7 +59,7 @@ class MemoryUnit:
|
|
57
59
|
orig_v = self._content[k]
|
58
60
|
orig_type, new_type = type(orig_v), type(v)
|
59
61
|
if not orig_type == new_type:
|
60
|
-
|
62
|
+
logger.debug(
|
61
63
|
f"Type warning: The type of the value for key '{k}' is changing from `{orig_type.__name__}` to `{new_type.__name__}`!"
|
62
64
|
)
|
63
65
|
self._content.update(content)
|
@@ -82,7 +84,7 @@ class MemoryUnit:
|
|
82
84
|
await self._lock.acquire()
|
83
85
|
values = self._content[key]
|
84
86
|
if not isinstance(values, Sequence):
|
85
|
-
|
87
|
+
logger.warning(
|
86
88
|
f"the value stored in key `{key}` is not `sequence`, return value `{values}` instead!"
|
87
89
|
)
|
88
90
|
return values
|
@@ -93,7 +95,7 @@ class MemoryUnit:
|
|
93
95
|
)
|
94
96
|
top_k = len(values) if top_k is None else top_k
|
95
97
|
if len(_sorted_values_with_idx) < top_k:
|
96
|
-
|
98
|
+
logger.debug(
|
97
99
|
f"Length of values {len(_sorted_values_with_idx)} is less than top_k {top_k}, returning all values."
|
98
100
|
)
|
99
101
|
self._lock.release()
|
@@ -149,7 +151,7 @@ class MemoryBase(ABC):
|
|
149
151
|
if recent_n is None:
|
150
152
|
return _list_units
|
151
153
|
if len(_memories) < recent_n:
|
152
|
-
|
154
|
+
logger.debug(
|
153
155
|
f"Length of memory {len(_memories)} is less than recent_n {recent_n}, returning all available memories."
|
154
156
|
)
|
155
157
|
return _list_units[-recent_n:]
|